123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171 |
- # mautrix-instagram - A Matrix-Instagram puppeting bridge.
- # Copyright (C) 2020 Tulir Asokan
- #
- # This program is free software: you can redistribute it and/or modify
- # it under the terms of the GNU Affero General Public License as published by
- # the Free Software Foundation, either version 3 of the License, or
- # (at your option) any later version.
- #
- # This program is distributed in the hope that it will be useful,
- # but WITHOUT ANY WARRANTY; without even the implied warranty of
- # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- # GNU Affero General Public License for more details.
- #
- # You should have received a copy of the GNU Affero General Public License
- # along with this program. If not, see <https://www.gnu.org/licenses/>.
- from typing import Any, Union, List, Dict, Optional
- import io
- from .type import TType
- class ThriftWriter(io.BytesIO):
- prev_field_id: int
- stack: List[int]
- def __init__(self, *args, **kwargs) -> None:
- super().__init__(*args, **kwargs)
- self.prev_field_id = 0
- self.stack = []
- def _push_stack(self) -> None:
- self.stack.append(self.prev_field_id)
- self.prev_field_id = 0
- def _pop_stack(self) -> None:
- if self.stack:
- self.prev_field_id = self.stack.pop()
- def _write_byte(self, byte: Union[int, TType]) -> None:
- self.write(bytes([byte]))
- @staticmethod
- def _to_zigzag(val: int, bits: int) -> int:
- return (val << 1) ^ (val >> (bits - 1))
- def _write_varint(self, val: int) -> None:
- while True:
- byte = val & ~0x7f
- if byte == 0:
- self._write_byte(val)
- break
- elif byte == -128:
- self._write_byte(0)
- break
- else:
- self._write_byte((val & 0xff) | 0x80)
- val = val >> 7
- def _write_word(self, val: int) -> None:
- self._write_varint(self._to_zigzag(val, 16))
- def _write_int(self, val: int) -> None:
- self._write_varint(self._to_zigzag(val, 32))
- def _write_long(self, val: int) -> None:
- self._write_varint(self._to_zigzag(val, 64))
- def write_field_begin(self, field_id: int, ttype: TType) -> None:
- ttype_val = ttype.value
- delta = field_id - self.prev_field_id
- if 0 < delta < 16:
- self._write_byte((delta << 4) | ttype_val)
- else:
- self._write_byte(ttype_val)
- self._write_word(field_id)
- self.prev_field_id = field_id
- def write_map(self, field_id: int, key_type: TType, value_type: TType, val: Dict[Any, Any]
- ) -> None:
- self.write_field_begin(field_id, TType.MAP)
- if not map:
- self._write_byte(0)
- return
- self._write_varint(len(val))
- self._write_byte(((key_type.value & 0xf) << 4) | (value_type.value & 0xf))
- for key, value in val.items():
- self.write_val(None, key_type, key)
- self.write_val(None, value_type, value)
- def write_string_direct(self, val: Union[str, bytes]) -> None:
- if isinstance(val, str):
- val = val.encode("utf-8")
- self._write_varint(len(val))
- self.write(val)
- def write_stop(self) -> None:
- self._write_byte(TType.STOP.value)
- self._pop_stack()
- def write_int8(self, field_id: int, val: int) -> None:
- self.write_field_begin(field_id, TType.BYTE)
- self._write_byte(val)
- def write_int16(self, field_id: int, val: int) -> None:
- self.write_field_begin(field_id, TType.I16)
- self._write_word(val)
- def write_int32(self, field_id: int, val: int) -> None:
- self.write_field_begin(field_id, TType.I32)
- self._write_int(val)
- def write_int64(self, field_id: int, val: int) -> None:
- self.write_field_begin(field_id, TType.I64)
- self._write_long(val)
- def write_list(self, field_id: int, item_type: TType, val: List[Any]) -> None:
- self.write_field_begin(field_id, TType.LIST)
- if len(val) < 0x0f:
- self._write_byte((len(val) << 4) | item_type.value)
- else:
- self._write_byte(0xf0 | item_type.value)
- self._write_varint(len(val))
- for item in val:
- self.write_val(None, item_type, item)
- def write_struct_begin(self, field_id: int) -> None:
- self.write_field_begin(field_id, TType.STRUCT)
- self._push_stack()
- def write_val(self, field_id: Optional[int], ttype: TType, val: Any) -> None:
- if ttype == TType.BOOL:
- if field_id is None:
- raise ValueError("booleans can only be in structs")
- self.write_field_begin(field_id, TType.TRUE if val else TType.FALSE)
- return
- if field_id is not None:
- self.write_field_begin(field_id, ttype)
- if ttype == TType.BYTE:
- self._write_byte(val)
- elif ttype == TType.I16:
- self._write_word(val)
- elif ttype == TType.I32:
- self._write_int(val)
- elif ttype == TType.I64:
- self._write_long(val)
- elif ttype == TType.BINARY:
- self.write_string_direct(val)
- else:
- raise ValueError(f"{ttype} is not supported by write_val()")
- def write_struct(self, obj: Any) -> None:
- for field_id in iter(obj.thrift_spec):
- field_type, field_name, inner_type = obj.thrift_spec[field_id]
- val = getattr(obj, field_name, None)
- if val is None:
- continue
- start = len(self.getvalue())
- if field_type in (TType.BOOL, TType.BYTE, TType.I16, TType.I32, TType.I64,
- TType.BINARY):
- self.write_val(field_id, field_type, val)
- elif field_type in (TType.LIST, TType.SET):
- self.write_list(field_id, inner_type, val)
- elif field_type == TType.MAP:
- (key_type, _), (value_type, _) = inner_type
- self.write_map(field_id, key_type, value_type, val)
- elif field_type == TType.STRUCT:
- self.write_struct_begin(field_id)
- self.write_struct(val)
- self.write_stop()
|