read.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. # mautrix-instagram - A Matrix-Instagram puppeting bridge.
  2. # Copyright (C) 2020 Tulir Asokan
  3. #
  4. # This program is free software: you can redistribute it and/or modify
  5. # it under the terms of the GNU Affero General Public License as published by
  6. # the Free Software Foundation, either version 3 of the License, or
  7. # (at your option) any later version.
  8. #
  9. # This program is distributed in the hope that it will be useful,
  10. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. # GNU Affero General Public License for more details.
  13. #
  14. # You should have received a copy of the GNU Affero General Public License
  15. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. from typing import List
  17. import io
  18. from .type import TType
  19. class ThriftReader(io.BytesIO):
  20. prev_field_id: int
  21. stack: List[int]
  22. def __init__(self, *args, **kwargs) -> None:
  23. super().__init__(*args, **kwargs)
  24. self.prev_field_id = 0
  25. self.stack = []
  26. def _push_stack(self) -> None:
  27. self.stack.append(self.prev_field_id)
  28. self.prev_field_id = 0
  29. def _pop_stack(self) -> None:
  30. if self.stack:
  31. self.prev_field_id = self.stack.pop()
  32. def _read_byte(self, signed: bool = False) -> int:
  33. return int.from_bytes(self.read(1), "big", signed=signed)
  34. @staticmethod
  35. def _from_zigzag(val: int) -> int:
  36. return (val >> 1) ^ -(val & 1)
  37. def read_small_int(self) -> int:
  38. return self._from_zigzag(self.read_varint())
  39. def read_varint(self) -> int:
  40. shift = 0
  41. result = 0
  42. while True:
  43. byte = self._read_byte()
  44. result |= (byte & 0x7f) << shift
  45. if (byte & 0x80) == 0:
  46. break
  47. shift += 7
  48. return result
  49. def read_field(self) -> TType:
  50. byte = self._read_byte()
  51. if byte == 0:
  52. return TType.STOP
  53. delta = (byte & 0xf0) >> 4
  54. if delta == 0:
  55. self.prev_field_id = self._from_zigzag(self.read_varint())
  56. else:
  57. self.prev_field_id += delta
  58. return TType(byte & 0x0f)