__all__ = ( "ENDPOINT_METHOD", "STR_FMT_RE", "URL_TYPES", "IP_RE", "MSG_RE_BYTES", "MSG_RE_TEXT", "URL", "LRU", "LRUItem", "get_url_args", "get_data_fields", "route", ) from asyncio import sleep from collections import deque from inspect import Parameter, signature from re import compile as re_compile from time import time from typing import Any, Optional, TYPE_CHECKING from logging import getLogger from httpx._urls import URL as _URL from pydantic import BaseModel, Field from orjson import loads from yarl import URL as UR from .models import Entity if TYPE_CHECKING: from .http import SubRouter else: SubRouter = "SubRouter" logger = getLogger() ENDPOINT_METHOD = re_compile(r"^(?=((?:get|set|create|delete)\w+))\1") STR_FMT_RE = re_compile(r"(?=(\{([^:]+)(?::([^}]+))?\}))\1") URL_TYPES = {"str": str, "int": int} IP_RE = re_compile( r"(?=(?:(?<=[^0-9])|^)((?:[a-z0-9]+\.)*[a-z0-9]+\.[a-z]{2,}(?:[0-9]{2,5})?|(?:[0-9]{,3}\.){3}[0-9]{,3}))\1" ) MSG_RE_BYTES = re_compile( rb"(?=((?P^: hi\n\n$)|^id:\s(?P[0-9]+:\d*?)\ndata:(?P[^$]+)\n\n))\1" ) MSG_RE_TEXT = re_compile( r"(?=((?P^: hi\n\n$)|^id:\s(?P[0-9]+:\d*?)\ndata:(?P[^$]+)\n\n))\1" ) def get_url_args(url): kwds = {} match = STR_FMT_RE.finditer(url) for m in match: name = m.group(2) if len(m.groups()) >= 4: type_ = URL_TYPES[m.group(3)] else: type_ = str kwds[name] = type_ return kwds def get_data_fields(fn, data, args, kwargs) -> dict[str, Any]: for param_name, param in signature(fn).parameters.items(): if param_name == "self": continue if param.kind == Parameter.POSITIONAL_ONLY: data[param_name] = param.annotation(str(args[0])) if len(args) > 1: args = args[1:] else: if param_name in fn.__annotations__: anno = fn.__annotations__[param_name] if isinstance(anno, type): type_ = anno elif anno._name == "Optional": if hasattr(anno.__args__[0], "_name"): type_ = str else: type_ = anno.__args__[0] else: type_ = fn.__annotations__[param_name] else: type_ = str if param.default is param.empty: if param_name not in kwargs and param_name not in data: raise TypeError( f"Missing required argument {param_name} for {fn.__name__}" ) if param_name in kwargs: data[param_name] = type_(kwargs.pop(param_name)) else: if v := kwargs.pop(param_name, param.default): data[param_name] = type_(v) return data def route(method, endpoint) -> Any: def wrapped(fn): async def sub_wrap( self: SubRouter, *args, base_uri=None, content: Optional[bytes] = None, data: Optional[dict[str, str]] = None, params: Optional[dict[str, str]] = None, **kwargs, ): params = params or {} data = data or {} json = {} args = set(args) if "headers" in kwargs: headers = self._headers | kwargs.pop("headers") else: headers = self._headers ents = set(filter(lambda x: isinstance(x, Entity), args)) args = tuple(args - ents) ent: BaseModel for ent in ents: json |= loads( ent.json( exclude_unset=True, exclude_none=True, exclude={"id"}, ) ) data = get_data_fields(fn, data | kwargs, args, kwargs) url_args = get_url_args(endpoint) for k, v in url_args.items(): if k not in kwargs and k not in params and k not in data: raise ValueError(f"Missing required argument {k}") for d in (kwargs, params, data): if _v := d.pop(k, None): url_args[k] = v(_v) _match_bridge = IP_RE.search(self._bridge_host) if not _match_bridge: raise ValueError(f"Invalid bridge ip {self._bridge_host}") _url_base = f"https://{_match_bridge.group(1)}/" + f"{base_uri}".lstrip("/") if url_args: new_endpoint = URL(_url_base) / endpoint.format(**url_args) else: new_endpoint = URL(_url_base) / endpoint if headers and headers.get("Accept", "") == "text/event-stream": return self._client.stream( method, new_endpoint, content=content, data=data, params=params, headers=headers, ) else: kwargs = {} if data: kwargs["data"] = data if content: kwargs["content"] = content if params: kwargs["params"] = params if json: kwargs["json"] = json resp = await self._client.request( method, new_endpoint, headers=headers, **kwargs ) if resp.status_code in (200, 207): ret_data = resp.json() elif resp.status_code == 404: return {} else: while resp.status_code == 429: resp = await self._client.request( method, new_endpoint, headers=headers, **kwargs ) await sleep(1) try: ret_data = resp.json() except Exception as e: logger.exception( "Failed to parse response as json <%s> %s", resp, resp.content, e, ) return {} ret_objs = [] if "return" in fn.__annotations__: if "data" in ret_data: ret_data = ret_data["data"] if isinstance(ret_data, list) or isinstance( fn.__annotations__["return"], list ): cls = fn.__annotations__["return"] if hasattr(cls, "__args__"): cls = cls.__args__[0] for r in ret_data: if "metadata" in r and "name" in r["metadata"]: r["metadata"]["name"] = ( r["metadata"]["name"] .replace(" ", "_") .replace("-", "_") .lower() ) ret_objs.append(cls(**r)) else: return ret_objs if "metadata" in ret_data and "name" in ret_data["metadata"]: ret_data["metadata"]["name"] = ( ret_data["metadata"]["name"] .replace(" ", "_") .replace("-", "_") .lower() ) r = fn.__annotations__["return"](**ret_data) return ret_data return sub_wrap return wrapped class LRUItem(BaseModel): access_time: int = Field(default_factory=lambda: int(time())) value: Any = object() def __id__(self): return id(self.value) def __hash__(self): return hash(self.value) class LRU(set): def __init__(self, maxsize, /, *items): super().__init__() self.maxsize = maxsize self.items = deque(maxlen=maxsize) for item in items[:maxsize]: self.add(LRUItem(value=item)) def add(self, item): if len(self) + 1 > self.maxsize: new = self ^ set( sorted(self, key=lambda x: x.access_time)[::-1][ : len(self) + 1 - self.maxsize ] ) old = self - new self -= old super().add(LRUItem(value=item)) def pop(self): super().pop().value def remove(self, item): super().remove(*filter(lambda x: x.value == item, self)) def extend(self, *items): len_new = len(self) + len(items) if len_new > self.maxsize: new = self ^ set( sorted(self, key=lambda x: x.access_time)[::-1][ : len_new - self.maxsize ] ) old = self - new self -= old self |= set([LRUItem(value=item) for item in items]) class URL(_URL): def __truediv__(self, other): # Why am i doing this? good question. try: return URL(str(UR(f"{self}") / other.lstrip("/"))) except NameError: return URL(f"{self}{other.lstrip('/')}") def __floordiv__(self, other): try: return URL(str(UR(f"{self}") / other.lstrip("/"))) except NameError: return URL(f"{self}{other.lstrip('/')}")