write.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. # mautrix-instagram - A Matrix-Instagram puppeting bridge.
  2. # Copyright (C) 2020 Tulir Asokan
  3. #
  4. # This program is free software: you can redistribute it and/or modify
  5. # it under the terms of the GNU Affero General Public License as published by
  6. # the Free Software Foundation, either version 3 of the License, or
  7. # (at your option) any later version.
  8. #
  9. # This program is distributed in the hope that it will be useful,
  10. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. # GNU Affero General Public License for more details.
  13. #
  14. # You should have received a copy of the GNU Affero General Public License
  15. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. from typing import Any, Union, List, Dict, Optional
  17. import io
  18. from .type import TType
  19. class ThriftWriter(io.BytesIO):
  20. prev_field_id: int
  21. stack: List[int]
  22. def __init__(self, *args, **kwargs) -> None:
  23. super().__init__(*args, **kwargs)
  24. self.prev_field_id = 0
  25. self.stack = []
  26. def _push_stack(self) -> None:
  27. self.stack.append(self.prev_field_id)
  28. self.prev_field_id = 0
  29. def _pop_stack(self) -> None:
  30. if self.stack:
  31. self.prev_field_id = self.stack.pop()
  32. def _write_byte(self, byte: Union[int, TType]) -> None:
  33. self.write(bytes([byte]))
  34. @staticmethod
  35. def _to_zigzag(val: int, bits: int) -> int:
  36. return (val << 1) ^ (val >> (bits - 1))
  37. def _write_varint(self, val: int) -> None:
  38. while True:
  39. byte = val & ~0x7f
  40. if byte == 0:
  41. self._write_byte(val)
  42. break
  43. elif byte == -128:
  44. self._write_byte(0)
  45. break
  46. else:
  47. self._write_byte((val & 0xff) | 0x80)
  48. val = val >> 7
  49. def _write_word(self, val: int) -> None:
  50. self._write_varint(self._to_zigzag(val, 16))
  51. def _write_int(self, val: int) -> None:
  52. self._write_varint(self._to_zigzag(val, 32))
  53. def _write_long(self, val: int) -> None:
  54. self._write_varint(self._to_zigzag(val, 64))
  55. def write_field_begin(self, field_id: int, ttype: TType) -> None:
  56. ttype_val = ttype.value
  57. delta = field_id - self.prev_field_id
  58. if 0 < delta < 16:
  59. self._write_byte((delta << 4) | ttype_val)
  60. else:
  61. self._write_byte(ttype_val)
  62. self._write_word(field_id)
  63. self.prev_field_id = field_id
  64. def write_map(self, field_id: int, key_type: TType, value_type: TType, val: Dict[Any, Any]
  65. ) -> None:
  66. self.write_field_begin(field_id, TType.MAP)
  67. if not map:
  68. self._write_byte(0)
  69. return
  70. self._write_varint(len(val))
  71. self._write_byte(((key_type.value & 0xf) << 4) | (value_type.value & 0xf))
  72. for key, value in val.items():
  73. self.write_val(None, key_type, key)
  74. self.write_val(None, value_type, value)
  75. def write_string_direct(self, val: Union[str, bytes]) -> None:
  76. if isinstance(val, str):
  77. val = val.encode("utf-8")
  78. self._write_varint(len(val))
  79. self.write(val)
  80. def write_stop(self) -> None:
  81. self._write_byte(TType.STOP.value)
  82. self._pop_stack()
  83. def write_int8(self, field_id: int, val: int) -> None:
  84. self.write_field_begin(field_id, TType.BYTE)
  85. self._write_byte(val)
  86. def write_int16(self, field_id: int, val: int) -> None:
  87. self.write_field_begin(field_id, TType.I16)
  88. self._write_word(val)
  89. def write_int32(self, field_id: int, val: int) -> None:
  90. self.write_field_begin(field_id, TType.I32)
  91. self._write_int(val)
  92. def write_int64(self, field_id: int, val: int) -> None:
  93. self.write_field_begin(field_id, TType.I64)
  94. self._write_long(val)
  95. def write_list(self, field_id: int, item_type: TType, val: List[Any]) -> None:
  96. self.write_field_begin(field_id, TType.LIST)
  97. if len(val) < 0x0f:
  98. self._write_byte((len(val) << 4) | item_type.value)
  99. else:
  100. self._write_byte(0xf0 | item_type.value)
  101. self._write_varint(len(val))
  102. for item in val:
  103. self.write_val(None, item_type, item)
  104. def write_struct_begin(self, field_id: int) -> None:
  105. self.write_field_begin(field_id, TType.STRUCT)
  106. self._push_stack()
  107. def write_val(self, field_id: Optional[int], ttype: TType, val: Any) -> None:
  108. if ttype == TType.BOOL:
  109. if field_id is None:
  110. raise ValueError("booleans can only be in structs")
  111. self.write_field_begin(field_id, TType.TRUE if val else TType.FALSE)
  112. return
  113. if field_id is not None:
  114. self.write_field_begin(field_id, ttype)
  115. if ttype == TType.BYTE:
  116. self._write_byte(val)
  117. elif ttype == TType.I16:
  118. self._write_word(val)
  119. elif ttype == TType.I32:
  120. self._write_int(val)
  121. elif ttype == TType.I64:
  122. self._write_long(val)
  123. elif ttype == TType.BINARY:
  124. self.write_string_direct(val)
  125. else:
  126. raise ValueError(f"{ttype} is not supported by write_val()")
  127. def write_struct(self, obj: Any) -> None:
  128. for field_id in iter(obj.thrift_spec):
  129. field_type, field_name, inner_type = obj.thrift_spec[field_id]
  130. val = getattr(obj, field_name, None)
  131. if val is None:
  132. continue
  133. start = len(self.getvalue())
  134. if field_type in (TType.BOOL, TType.BYTE, TType.I16, TType.I32, TType.I64,
  135. TType.BINARY):
  136. self.write_val(field_id, field_type, val)
  137. elif field_type in (TType.LIST, TType.SET):
  138. self.write_list(field_id, inner_type, val)
  139. elif field_type == TType.MAP:
  140. (key_type, _), (value_type, _) = inner_type
  141. self.write_map(field_id, key_type, value_type, val)
  142. elif field_type == TType.STRUCT:
  143. self.write_struct_begin(field_id)
  144. self.write_struct(val)
  145. self.write_stop()