write.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. # mautrix-instagram - A Matrix-Instagram puppeting bridge.
  2. # Copyright (C) 2020 Tulir Asokan
  3. #
  4. # This program is free software: you can redistribute it and/or modify
  5. # it under the terms of the GNU Affero General Public License as published by
  6. # the Free Software Foundation, either version 3 of the License, or
  7. # (at your option) any later version.
  8. #
  9. # This program is distributed in the hope that it will be useful,
  10. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. # GNU Affero General Public License for more details.
  13. #
  14. # You should have received a copy of the GNU Affero General Public License
  15. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. from __future__ import annotations
  17. from typing import Any
  18. import io
  19. from .type import TType
  20. class ThriftWriter(io.BytesIO):
  21. prev_field_id: int
  22. stack: list[int]
  23. def __init__(self, *args, **kwargs) -> None:
  24. super().__init__(*args, **kwargs)
  25. self.prev_field_id = 0
  26. self.stack = []
  27. def _push_stack(self) -> None:
  28. self.stack.append(self.prev_field_id)
  29. self.prev_field_id = 0
  30. def _pop_stack(self) -> None:
  31. if self.stack:
  32. self.prev_field_id = self.stack.pop()
  33. def _write_byte(self, byte: int | TType) -> None:
  34. self.write(bytes([byte]))
  35. @staticmethod
  36. def _to_zigzag(val: int, bits: int) -> int:
  37. return (val << 1) ^ (val >> (bits - 1))
  38. def _write_varint(self, val: int) -> None:
  39. while True:
  40. byte = val & ~0x7F
  41. if byte == 0:
  42. self._write_byte(val)
  43. break
  44. elif byte == -128:
  45. self._write_byte(0)
  46. break
  47. else:
  48. self._write_byte((val & 0xFF) | 0x80)
  49. val = val >> 7
  50. def _write_word(self, val: int) -> None:
  51. self._write_varint(self._to_zigzag(val, 16))
  52. def _write_int(self, val: int) -> None:
  53. self._write_varint(self._to_zigzag(val, 32))
  54. def _write_long(self, val: int) -> None:
  55. self._write_varint(self._to_zigzag(val, 64))
  56. def write_field_begin(self, field_id: int, ttype: TType) -> None:
  57. ttype_val = ttype.value
  58. delta = field_id - self.prev_field_id
  59. if 0 < delta < 16:
  60. self._write_byte((delta << 4) | ttype_val)
  61. else:
  62. self._write_byte(ttype_val)
  63. self._write_word(field_id)
  64. self.prev_field_id = field_id
  65. def write_map(
  66. self, field_id: int, key_type: TType, value_type: TType, val: dict[Any, Any]
  67. ) -> None:
  68. self.write_field_begin(field_id, TType.MAP)
  69. if not map:
  70. self._write_byte(0)
  71. return
  72. self._write_varint(len(val))
  73. self._write_byte(((key_type.value & 0xF) << 4) | (value_type.value & 0xF))
  74. for key, value in val.items():
  75. self.write_val(None, key_type, key)
  76. self.write_val(None, value_type, value)
  77. def write_string_direct(self, val: str | bytes) -> None:
  78. if isinstance(val, str):
  79. val = val.encode("utf-8")
  80. self._write_varint(len(val))
  81. self.write(val)
  82. def write_stop(self) -> None:
  83. self._write_byte(TType.STOP.value)
  84. self._pop_stack()
  85. def write_int8(self, field_id: int, val: int) -> None:
  86. self.write_field_begin(field_id, TType.BYTE)
  87. self._write_byte(val)
  88. def write_int16(self, field_id: int, val: int) -> None:
  89. self.write_field_begin(field_id, TType.I16)
  90. self._write_word(val)
  91. def write_int32(self, field_id: int, val: int) -> None:
  92. self.write_field_begin(field_id, TType.I32)
  93. self._write_int(val)
  94. def write_int64(self, field_id: int, val: int) -> None:
  95. self.write_field_begin(field_id, TType.I64)
  96. self._write_long(val)
  97. def write_list(self, field_id: int, item_type: TType, val: list[Any]) -> None:
  98. self.write_field_begin(field_id, TType.LIST)
  99. if len(val) < 0x0F:
  100. self._write_byte((len(val) << 4) | item_type.value)
  101. else:
  102. self._write_byte(0xF0 | item_type.value)
  103. self._write_varint(len(val))
  104. for item in val:
  105. self.write_val(None, item_type, item)
  106. def write_struct_begin(self, field_id: int) -> None:
  107. self.write_field_begin(field_id, TType.STRUCT)
  108. self._push_stack()
  109. def write_val(self, field_id: int | None, ttype: TType, val: Any) -> None:
  110. if ttype == TType.BOOL:
  111. if field_id is None:
  112. raise ValueError("booleans can only be in structs")
  113. self.write_field_begin(field_id, TType.TRUE if val else TType.FALSE)
  114. return
  115. if field_id is not None:
  116. self.write_field_begin(field_id, ttype)
  117. if ttype == TType.BYTE:
  118. self._write_byte(val)
  119. elif ttype == TType.I16:
  120. self._write_word(val)
  121. elif ttype == TType.I32:
  122. self._write_int(val)
  123. elif ttype == TType.I64:
  124. self._write_long(val)
  125. elif ttype == TType.BINARY:
  126. self.write_string_direct(val)
  127. else:
  128. raise ValueError(f"{ttype} is not supported by write_val()")
  129. def write_struct(self, obj: Any) -> None:
  130. for field_id in iter(obj.thrift_spec):
  131. field_type, field_name, inner_type = obj.thrift_spec[field_id]
  132. val = getattr(obj, field_name, None)
  133. if val is None:
  134. continue
  135. start = len(self.getvalue())
  136. if field_type in (
  137. TType.BOOL,
  138. TType.BYTE,
  139. TType.I16,
  140. TType.I32,
  141. TType.I64,
  142. TType.BINARY,
  143. ):
  144. self.write_val(field_id, field_type, val)
  145. elif field_type in (TType.LIST, TType.SET):
  146. self.write_list(field_id, inner_type, val)
  147. elif field_type == TType.MAP:
  148. (key_type, _), (value_type, _) = inner_type
  149. self.write_map(field_id, key_type, value_type, val)
  150. elif field_type == TType.STRUCT:
  151. self.write_struct_begin(field_id)
  152. self.write_struct(val)
  153. self.write_stop()