Source code for lexrpc.client

"""XRPC client implementation.

TODO:

* asyncio support for subscription websockets
"""
import copy
from io import BytesIO, IOBase
import json
import logging
from urllib.parse import urljoin

import dag_cbor
import requests
import simple_websocket

from .base import Base, NSID_SEGMENT_RE

logger = logging.getLogger(__name__)

DEFAULT_PDS = 'https://bsky.social/'
DEFAULT_HEADERS = {
    'User-Agent': 'lexrpc (https://lexrpc.readthedocs.io/)',
}
LOGIN_NSID = 'com.atproto.server.createSession'
REFRESH_NSID = 'com.atproto.server.refreshSession'
TOKEN_ERRORS = (
    'AccountNotFound',
    'AuthenticationRequired',
    'ExpiredToken',
    'InvalidToken',
    'TokenRequired',
)


class _NsidClient():
    """Internal helper class to implement dynamic attribute-based method calls.

    eg ``client.com.example.my_method(...)``
    """
    client = None
    nsid = None

    def __init__(self, client, nsid):
        assert client and nsid
        self.client = client
        self.nsid = nsid

    def __getattr__(self, attr):
        segment = attr.replace('_', '-')
        if NSID_SEGMENT_RE.match(segment):
            return _NsidClient(self.client, f'{self.nsid}.{segment}')

        return getattr(super(), attr)

    def __call__(self, *args, **kwargs):
        return self.client.call(self.nsid, *args, **kwargs)


[docs] class Client(Base): """XRPC client. Calling ``com.atproto.server.createSession`` will store the returned session and include its acccess token in subsequent requests. If a request fails with ``ExpiredToken`` and we have a session stored, the access token will be refreshed with ``com.atproto.server.refreshSession`` and then the original request will be retried. Attributes: address (str): base URL of XRPC server, eg ``https://bsky.social/`` session (dict): ``createSession`` response with ``accessJwt``, `refreshJwt``, ``handle``, and ``did`` headers (dict): HTTP headers to include in every request """
[docs] def __init__(self, address=DEFAULT_PDS, access_token=None, refresh_token=None, headers=None, session_callback=None, **kwargs): """Constructor. Args: address (str): base URL of XRPC server, eg ``https://bsky.social/`` access_token (str): optional, will be sent in ``Authorization`` header refresh_token (str): optional; used to refresh access token headers (dict): optional, HTTP headers to include in every request session_callback (callable, dict => None): called when a new session is created with new access and refresh tokens. This callable is passed one positional argument, the dict JSON output from ``com.atproto.server.createSession`` or ``com.atproto.server.refreshSession``. kwargs: passed through to :class:`Base` Raises: jsonschema.SchemaError: if any schema is invalid """ super().__init__(**kwargs) logger.debug(f'Using server at {address}') assert address.startswith('http://') or address.startswith('https://'), \ f"{address} doesn't start with http:// or https://" self.address = address self.headers = headers or {} self.session = {} if access_token or refresh_token: self.session.update({ 'accessJwt': access_token, 'refreshJwt': refresh_token, }) self.session_callback = session_callback
def __getattr__(self, attr): if NSID_SEGMENT_RE.match(attr): return _NsidClient(self, attr) return getattr(super(), attr)
[docs] def call(self, nsid, input=None, headers={}, **params): """Makes a remote XRPC method call. Args: nsid (str): method NSID input (dict or bytes): input body, optional for subscriptions headers (dict): HTTP headers to include in this request. Overrides any headers passed to the constructor. params: optional method parameters Returns: dict or generator iterator: for queries and procedures, decoded JSON object, or None if the method has no output. For subscriptions, generator of messages from server. Raises: NotImplementedError: if the given NSID is not found in any of the loaded lexicons jsonschema.ValidationError: if the parameters, input, or returned output don't validate against the method's schemas requests.RequestException: if the connection or HTTP request to the remote server failed """ def loggable(val): return f'{len(val)} bytes' if isinstance(val, bytes) else val logger.debug(f'{nsid}: {params} {loggable(input)}') # strip null params, validate params and input, then encode params params = {k: v for k, v in params.items() if v is not None} params = self._maybe_validate(nsid, 'parameters', params) params_str = self.encode_params(params) type = self._get_def(nsid)['type'] if type == 'subscription': input = self._maybe_validate(nsid, 'input', input) req_headers = { **DEFAULT_HEADERS, 'Content-Type': 'application/json', **self.headers, **headers, } log_headers = copy.copy(req_headers) # auth token = (self.session.get('refreshJwt') if nsid == REFRESH_NSID else self.session.get('accessJwt')) if token: req_headers['Authorization'] = f'Bearer {token}' log_headers['Authorization'] = '...' # run method url = urljoin(self.address, f'/xrpc/{nsid}') if params_str: url += f'?{params_str}' # event stream if type == 'subscription': return self._subscribe(url) # query or procedure fn = requests.get if type == 'query' else requests.post # buffer binary inputs in memory. ideally we'd stream instead, but if we # have to refresh our token below, we need to seek the stream back to the # beginning, and not all streams are seekable, eg requests.Request.raw if isinstance(input, IOBase) or hasattr(input, 'read'): input = input.read() logger.debug(f'Running requests.{fn} {url} {loggable(input)} {params_str} {log_headers}') resp = fn( url, json=input if input and isinstance(input, dict) else None, data=input if input and not isinstance(input, dict) else None, headers=req_headers ) output = None content_type = resp.headers.get('Content-Type', '').split(';')[0] if content_type == 'application/json' and resp.content: output = resp.json() if not resp.ok: logger.debug(f'Got: {resp.text}') if nsid in (LOGIN_NSID, REFRESH_NSID): # auth if resp.ok: logger.info(f'Logged in as {output.get("did")}, storing session') else: logger.info(f'Login failed, nulling out session') output = {} self.session = output if self.session_callback: self.session_callback(output) elif not resp.ok: # token expired, try to refresh it if output and output.get('error') in TOKEN_ERRORS: self.call(REFRESH_NSID) return self.call(nsid, input=input, headers=req_headers, **params) # retry resp.raise_for_status() output = self._maybe_validate(nsid, 'output', output) return output
def _subscribe(self, url): """Connects to a subscription websocket, yields the returned messages.""" ws = simple_websocket.Client(url) try: while True: buf = BytesIO(ws.receive()) header = dag_cbor.decode(buf, allow_concat=True) payload = dag_cbor.decode(buf) yield (header, payload) except simple_websocket.ConnectionClosed as cc: logger.debug(cc)