Source code for aiothrift.connection

import asyncio
import functools

import async_timeout
from thriftpy.thrift import TMessageType

from .protocol import TBinaryProtocol
from .util import args2kwargs
from .errors import ConnectionClosedError, ThriftAppError
from .log import logger


@asyncio.coroutine
[docs]def create_connection(service, address=('127.0.0.1', 6000), *, protocol_cls=TBinaryProtocol, timeout=None, loop=None): """Create a thrift connection. This function is a :ref:`coroutine <coroutine>`. Open a connection to the thrift server by address argument. :param service: a thrift service object :param address: a (host, port) tuple :param protocol_cls: protocol type, default is :class:`TBinaryProtocol` :param timeout: if specified, would raise `asyncio.TimeoutError` if one rpc call is longer than `timeout` :param loop: :class:`Eventloop <asyncio.AbstractEventLoop>` instance, if not specified, default loop is used. :return: newly created :class:`ThriftConnection` instance. """ host, port = address reader, writer = yield from asyncio.open_connection( host, port, loop=loop) iprotocol = protocol_cls(reader) oprotocol = protocol_cls(writer) return ThriftConnection(service, iprot=iprotocol, oprot=oprotocol, address=address, loop=loop, timeout=timeout)
[docs]class ThriftConnection: """ Thrift Connection. """ def __init__(self, service, *, iprot, oprot, address, loop=None, timeout=None): self.service = service self._reader = iprot.trans self._writer = oprot.trans self._loop = loop self.timeout = timeout self.address = address self.closed = False self._oprot = oprot self._iprot = iprot self._seqid = 0 self._init_rpc_apis()
[docs] def _init_rpc_apis(self): """ find out all apis defined in thrift service, and create corresponding method on the connection object, ignore it if some api name is conflicted with an existed attribute of the connection object, which you should call by using the :meth:`execute` method. """ for api in self.service.thrift_services: if not hasattr(self, api): setattr(self, api, functools.partial(self.execute, api)) else: logger.warn( 'api name {0} is conflicted with connection attribute ' '{0}, while you can still call this api by `send_call("{0}")`'.format(api))
def __repr__(self): return '<ThriftConnection {} to>'.format(self.address) @asyncio.coroutine
[docs] def execute(self, api, *args, **kwargs): """ Execute a rpc call by api name. This is function is a :ref:`coroutine <coroutine>`. :param api: api name defined in thrift file :param args: positional arguments passed to api function :param kwargs: keyword arguments passed to api function :return: result of this rpc call :raises: :class:`~asyncio.TimeoutError` if this task has exceeded the `timeout` :raises: :class:`ThriftAppError` if thrift response is an exception defined in thrift. :raises: :class:`ConnectionClosedError`: if server has closed this connection. """ if self.closed: raise ConnectionClosedError('Connection closed') try: with async_timeout.timeout(self.timeout): kw = args2kwargs(getattr(self.service, api + "_args").thrift_spec, *args) kwargs.update(kw) result_cls = getattr(self.service, api + "_result") self._seqid += 1 self._oprot.write_message_begin(api, TMessageType.CALL, self._seqid) args = getattr(self.service, api + '_args')() for k, v in kwargs.items(): setattr(args, k, v) args.write(self._oprot) self._oprot.write_message_end() yield from self._oprot.trans.drain() if not getattr(result_cls, "oneway"): result = yield from self._recv(api) return result except asyncio.TimeoutError: self.close() raise except ConnectionError as e: self.close() logger.debug('connection error {}'.format(str(e))) raise ConnectionClosedError('the server has closed this connection') from e except asyncio.IncompleteReadError as e: self.close() raise ConnectionClosedError('Server connection has closed') from e
@asyncio.coroutine
[docs] def _recv(self, api): """ A :ref:`coroutine <coroutine>` which receive response from the thrift server """ fname, mtype, rseqid = yield from self._iprot.read_message_begin() if rseqid != self._seqid: # transport should be closed if bad seq happened self.close() raise ThriftAppError(ThriftAppError.BAD_SEQUENCE_ID, fname + ' failed: out of sequence response') if mtype == TMessageType.EXCEPTION: x = ThriftAppError() yield from self._iprot.read_struct(x) yield from self._iprot.read_message_end() raise x result = getattr(self.service, api + '_result')() yield from self._iprot.read_struct(result) yield from self._iprot.read_message_end() if hasattr(result, "success") and result.success is not None: return result.success # void api without throws if len(result.thrift_spec) == 0: return # check throws for k, v in result.__dict__.items(): if k != 'success' and v: raise v if hasattr(result, 'success'): raise ThriftAppError(ThriftAppError.MISSING_RESULT)
def close(self): self._writer.close() self.closed = True