Source code for aiothrift.protocol

import asyncio
import struct

from thriftpy.protocol.exc import TProtocolException
from thriftpy.thrift import TType

VERSION_MASK = -65536
VERSION_1 = -2147418112
TYPE_MASK = 0x000000ff


def pack_i8(byte):
    return struct.pack("!b", byte)


def pack_i16(i16):
    return struct.pack("!h", i16)


def pack_i32(i32):
    return struct.pack("!i", i32)


def pack_i64(i64):
    return struct.pack("!q", i64)


def pack_double(dub):
    return struct.pack("!d", dub)


def pack_string(string):
    return struct.pack("!i%ds" % len(string), len(string), string)


def unpack_i8(buffer):
    return struct.unpack("!b", buffer)[0]


def unpack_i16(buffer):
    return struct.unpack("!h", buffer)[0]


def unpack_i32(buffer):
    return struct.unpack("!i", buffer)[0]


def unpack_i64(buffer):
    return struct.unpack("!q", buffer)[0]


def unpack_double(buffer):
    return struct.unpack("!d", buffer)[0]


def write_message_begin(writer, name, ttype, seqid, strict=True):
    if strict:
        writer.write(pack_i32(VERSION_1 | ttype))
        writer.write(pack_string(name.encode('utf-8')))
    else:
        writer.write(pack_string(name.encode('utf-8')))
        writer.write(pack_i8(ttype))

    writer.write(pack_i32(seqid))


def write_field_begin(writer, ttype, fid):
    writer.write(pack_i8(ttype) + pack_i16(fid))


def write_field_stop(writer):
    writer.write(pack_i8(TType.STOP))


def write_list_begin(writer, etype, size):
    writer.write(pack_i8(etype) + pack_i32(size))


def write_map_begin(writer, ktype, vtype, size):
    writer.write(pack_i8(ktype) + pack_i8(vtype) + pack_i32(size))


def write_val(writer, ttype, val, spec=None):
    if ttype == TType.BOOL:
        if val:
            writer.write(pack_i8(1))
        else:
            writer.write(pack_i8(0))

    elif ttype == TType.BYTE:
        writer.write(pack_i8(val))

    elif ttype == TType.I16:
        writer.write(pack_i16(val))

    elif ttype == TType.I32:
        writer.write(pack_i32(val))

    elif ttype == TType.I64:
        writer.write(pack_i64(val))

    elif ttype == TType.DOUBLE:
        writer.write(pack_double(val))

    elif ttype == TType.STRING:
        if not isinstance(val, bytes):
            val = val.encode('utf-8')
        writer.write(pack_string(val))

    elif ttype == TType.SET or ttype == TType.LIST:
        if isinstance(spec, tuple):
            e_type, t_spec = spec[0], spec[1]
        else:
            e_type, t_spec = spec, None

        val_len = len(val)
        write_list_begin(writer, e_type, val_len)
        for e_val in val:
            write_val(writer, e_type, e_val, t_spec)

    elif ttype == TType.MAP:
        if isinstance(spec[0], int):
            k_type = spec[0]
            k_spec = None
        else:
            k_type, k_spec = spec[0]

        if isinstance(spec[1], int):
            v_type = spec[1]
            v_spec = None
        else:
            v_type, v_spec = spec[1]

        write_map_begin(writer, k_type, v_type, len(val))
        for k in iter(val):
            write_val(writer, k_type, k, k_spec)
            write_val(writer, v_type, val[k], v_spec)

    elif ttype == TType.STRUCT:
        for fid in iter(val.thrift_spec):
            f_spec = val.thrift_spec[fid]
            if len(f_spec) == 3:
                f_type, f_name, f_req = f_spec
                f_container_spec = None
            else:
                f_type, f_name, f_container_spec, f_req = f_spec

            v = getattr(val, f_name)
            if v is None:
                continue

            write_field_begin(writer, f_type, fid)
            write_val(writer, f_type, v, f_container_spec)
        write_field_stop(writer)


@asyncio.coroutine
def read_message_begin(reader, strict=True):
    data = yield from reader.readexactly(4)
    sz = unpack_i32(data)
    if sz < 0:
        version = sz & VERSION_MASK
        if version != VERSION_1:
            raise TProtocolException(
                type=TProtocolException.BAD_VERSION,
                message='Bad version in read_message_begin: %d' % (sz))

        data = yield from reader.readexactly(4)
        name_sz = unpack_i32(data)
        data = yield from reader.readexactly(name_sz)
        name = data.decode('utf-8')
        type_ = sz & TYPE_MASK
    else:
        if strict:
            raise TProtocolException(type=TProtocolException.BAD_VERSION,
                                     message='No protocol version header')

        data = yield from reader.readexactly(sz)
        name = data.decode('utf-8')
        data = yield from reader.readexactly(1)
        type_ = unpack_i8(data)

    data = yield from reader.readexactly(4)
    seqid = unpack_i32(data)

    return name, type_, seqid


@asyncio.coroutine
def read_field_begin(reader):
    data = yield from reader.readexactly(1)
    f_type = unpack_i8(data)
    if f_type == TType.STOP:
        return f_type, 0

    data = yield from reader.readexactly(2)
    return f_type, unpack_i16(data)


@asyncio.coroutine
def read_list_begin(reader):
    data = yield from reader.readexactly(1)
    e_type = unpack_i8(data)
    data = yield from reader.readexactly(4)
    sz = unpack_i32(data)
    return e_type, sz


@asyncio.coroutine
def read_map_begin(reader):
    k = yield from reader.readexactly(1)
    v = yield from reader.readexactly(1)
    k_type, v_type = unpack_i8(k), unpack_i8(v)
    data = yield from reader.readexactly(4)
    sz = unpack_i32(data)
    return k_type, v_type, sz


@asyncio.coroutine
def read_val(reader, ttype, spec=None, decode_response=True):
    if ttype == TType.BOOL:
        data = yield from reader.readexactly(1)
        return bool(unpack_i8(data))

    elif ttype == TType.BYTE:
        data = yield from reader.readexactly(1)
        return unpack_i8(data)

    elif ttype == TType.I16:
        data = yield from reader.readexactly(2)
        return unpack_i16(data)

    elif ttype == TType.I32:
        data = yield from reader.readexactly(4)
        return unpack_i32(data)

    elif ttype == TType.I64:
        data = yield from reader.readexactly(8)
        return unpack_i64(data)

    elif ttype == TType.DOUBLE:
        data = yield from reader.readexactly(8)
        return unpack_double(data)

    elif ttype == TType.STRING:
        data = yield from reader.readexactly(4)
        sz = unpack_i32(data)
        byte_payload = yield from reader.readexactly(sz)

        # Since we cannot tell if we're getting STRING or BINARY
        # if not asked not to decode, try both
        if decode_response:
            try:
                return byte_payload.decode('utf-8')
            except UnicodeDecodeError:
                pass
        return byte_payload

    elif ttype == TType.SET or ttype == TType.LIST:
        if isinstance(spec, tuple):
            v_type, v_spec = spec[0], spec[1]
        else:
            v_type, v_spec = spec, None

        result = []
        r_type, sz = yield from read_list_begin(reader)
        # the v_type is useless here since we already get it from spec
        if r_type != v_type:
            for _ in range(sz):
                yield from skip(reader, r_type)
            return []

        for i in range(sz):
            data = yield from read_val(reader, v_type, v_spec, decode_response)
            result.append(data)
        return result

    elif ttype == TType.MAP:
        if isinstance(spec[0], int):
            k_type = spec[0]
            k_spec = None
        else:
            k_type, k_spec = spec[0]

        if isinstance(spec[1], int):
            v_type = spec[1]
            v_spec = None
        else:
            v_type, v_spec = spec[1]

        result = {}
        sk_type, sv_type, sz = yield from read_map_begin(reader)
        if sk_type != k_type or sv_type != v_type:
            for _ in range(sz):
                yield from skip(reader, sk_type)
                yield from skip(reader, sv_type)
            return {}

        for i in range(sz):
            k_val = yield from read_val(reader, k_type, k_spec, decode_response)
            v_val = yield from read_val(reader, v_type, v_spec, decode_response)
            result[k_val] = v_val

        return result

    elif ttype == TType.STRUCT:
        obj = spec()
        yield from read_struct(reader, obj, decode_response)
        return obj


@asyncio.coroutine
def read_struct(reader, obj, decode_response=True):
    while True:
        f_type, fid = yield from read_field_begin(reader)
        if f_type == TType.STOP:
            break

        if fid not in obj.thrift_spec:
            yield from skip(reader, f_type)
            continue

        if len(obj.thrift_spec[fid]) == 3:
            sf_type, f_name, f_req = obj.thrift_spec[fid]
            f_container_spec = None
        else:
            sf_type, f_name, f_container_spec, f_req = obj.thrift_spec[fid]

        # it really should equal here. but since we already wasted
        # space storing the duplicate info, let's check it.
        if f_type != sf_type:
            yield from skip(reader, f_type)
            continue

        data = yield from read_val(reader, f_type, f_container_spec, decode_response)
        setattr(obj, f_name, data)


@asyncio.coroutine
def skip(reader, ftype):
    if ftype == TType.BOOL or ftype == TType.BYTE:
        yield from reader.readexactly(1)

    elif ftype == TType.I16:
        yield from reader.readexactly(2)

    elif ftype == TType.I32:
        yield from reader.readexactly(4)

    elif ftype == TType.I64:
        yield from reader.readexactly(8)

    elif ftype == TType.DOUBLE:
        yield from reader.readexactly(8)

    elif ftype == TType.STRING:
        yield from reader.readexactly(unpack_i32(reader.readexactly(4)))

    elif ftype == TType.SET or ftype == TType.LIST:
        v_type, sz = yield from read_list_begin(reader)
        for i in range(sz):
            yield from skip(reader, v_type)

    elif ftype == TType.MAP:
        k_type, v_type, sz = read_map_begin(reader)
        for i in range(sz):
            yield from skip(reader, k_type)
            yield from skip(reader, v_type)

    elif ftype == TType.STRUCT:
        while True:
            f_type, fid = yield from read_field_begin(reader)
            if f_type == TType.STOP:
                break
            yield from skip(reader, f_type)


[docs]class TProtocol: """ Base class for thrift protocols, subclass should implement some of the protocol methods, currently we only have :class:`TBinaryProtocol` implemented for you. """ def __init__(self, trans, strict_read=True, strict_write=True, decode_response=True): self.trans = trans self.strict_read = strict_read self.strict_write = strict_write self.decode_response = decode_response def skip(self, ttype): pass @asyncio.coroutine def read_message_begin(self): pass @asyncio.coroutine def read_message_end(self): pass def write_message_begin(self, name, ttype, seqid): pass def write_message_end(self): pass @asyncio.coroutine def read_struct(self, obj): pass def write_struct(self, obj): pass
[docs]class TBinaryProtocol(TProtocol): """Binary implementation of the Thrift protocol driver.""" def skip(self, ttype): skip(self.trans, ttype) @asyncio.coroutine def read_message_begin(self): api, ttype, seqid = yield from read_message_begin( self.trans, strict=self.strict_read) return api, ttype, seqid def write_message_begin(self, name, ttype, seqid): write_message_begin(self.trans, name, ttype, seqid, strict=self.strict_write) @asyncio.coroutine def read_struct(self, obj): data = yield from read_struct(self.trans, obj, self.decode_response) return data def write_struct(self, obj): write_val(self.trans, TType.STRUCT, obj)