""" This module contains the Connector class. Every data fetching action should begin with instantiating this Connector class. """ import math import sys from asyncio import as_completed from typing import Any, Awaitable, Dict, Optional, Set, Tuple, Union from warnings import warn import pandas as pd from aiohttp import ClientSession from aiohttp.client_reqrep import ClientResponse from jinja2 import Environment, StrictUndefined, UndefinedError from jsonpath_ng import parse as jparse from .config_manager import initialize_path from .errors import InvalidParameterError, RequestError, UniversalParameterOverridden from .implicit_database import ImplicitDatabase, ImplicitTable from .info import info from .ref import Ref from .schema import ( FieldDef, FieldDefUnion, OffsetPaginationDef, PagePaginationDef, SeekPaginationDef, TokenLocation, TokenPaginationDef, ) from .throttler import OrderedThrottler, ThrottleSession [docs]class Connector: # pylint: disable=too-many-instance-attributes """This is the main class of the connector component. Initialize Connector class as the example code. Parameters ---------- config_path The path to the config. It can be hosted, e.g. "yelp", or from local filesystem, e.g. "./yelp" _auth: Optional[Dict[str, Any]] = None The parameters for authentication, e.g. OAuth2 _concurrency: int = 5 The concurrency setting. By default it is 1 reqs/sec. update: bool = True Force update the config file even if the local version exists. **kwargs Parameters that shared by different queries. Example ------- >>> from dataprep.connector import Connector >>> dc = Connector("yelp", _auth={"access_token": access_token}) """ _impdb: ImplicitDatabase # Varibles that used across different queries, can be overriden by query _vars: Dict[str, Any] _auth: Dict[str, Any] # storage for authorization _storage: Dict[str, Any] _concurrency: int _update: bool _jenv: Environment def __init__( self, config_path: str, *, update: bool = False, _auth: Optional[Dict[str, Any]] = None, _concurrency: int = 1, **kwargs: Any, ) -> None: path = initialize_path(config_path, update) self._impdb = ImplicitDatabase(path) self._vars = kwargs self._auth = _auth or {} self._storage = {} self._concurrency = _concurrency self._update = update self._jenv = Environment(undefined=StrictUndefined) self._throttler = OrderedThrottler(_concurrency) [docs] async def query( # pylint: disable=too-many-locals self, table: str, *, _q: Optional[str] = None, _auth: Optional[Dict[str, Any]] = None, _count: Optional[int] = None, **where: Any, ) -> Union[Awaitable[pd.DataFrame], pd.DataFrame]: """ Query the API to get a table. Parameters ---------- table The table name. _q: Optional[str] = None Search string to be matched in the response. _auth: Optional[Dict[str, Any]] = None The parameters for authentication. Usually the authentication parameters should be defined when instantiating the Connector. In case some tables have different authentication options, a different authentication parameter can be defined here. This parameter will override the one from Connector if passed. _count: Optional[int] = None Count of returned records. **where The additional parameters required for the query. """ allowed_params: Set[str] = set() for key, val in self._impdb.tables[table].config.request.params.items(): if isinstance(val, FieldDef): if isinstance(val.from_key, list): allowed_params.update(val.from_key) elif isinstance(val.from_key, str): allowed_params.add(val.from_key) else: allowed_params.add(key) else: allowed_params.add(key) allowed_params.update(self._impdb.tables[table].config.request.url_path_params()) for key in where: if key not in allowed_params: raise InvalidParameterError(key) return await self._query_imp(table, where, _auth=_auth, _q=_q, _count=_count) [docs] def info(self) -> None: """Show the basic information and provide guidance for users to issue queries.""" info(self._impdb.name) async def _query_imp( # pylint: disable=too-many-locals,too-many-branches,too-many-statements self, table: str, kwargs: Dict[str, Any], *, _auth: Optional[Dict[str, Any]] = None, _count: Optional[int] = None, _q: Optional[str] = None, ) -> pd.DataFrame: if table not in self._impdb.tables: raise ValueError(f"No such table {table} in {self._impdb.name}") itable = self._impdb.tables[table] reqconf = itable.config.request if reqconf.pagination is None and _count is not None: print( f"ignoring _count since {table} has no pagination settings", file=sys.stderr, ) if _count is not None and _count <= 0: raise RuntimeError("_count should be larger than 0") async with ClientSession() as client: throttler = self._throttler.session() if reqconf.pagination is None or _count is None: df = await self._fetch( itable, kwargs, _client=client, _throttler=throttler, _auth=_auth, _q=_q, ) return df pagdef = reqconf.pagination # pagination begins max_per_page = pagdef.max_count total = _count n_page = math.ceil(total / max_per_page) if isinstance(pagdef, SeekPaginationDef): last_id = 0 dfs = [] # No way to parallelize for seek type for i in range(n_page): count = min(total - i * max_per_page, max_per_page) df = await self._fetch( itable, kwargs, _client=client, _throttler=throttler, _page=i, _auth=_auth, _q=_q, _limit=count, _anchor=last_id - 1, ) if df is None: raise NotImplementedError if len(df) == 0: # The API returns empty for this page, maybe we've reached the end break cid = df.columns.get_loc(pagdef.seek_id) # type: ignore last_id = int(df.iloc[-1, cid]) - 1 # type: ignore dfs.append(df) elif isinstance(pagdef, TokenPaginationDef): next_token = None dfs = [] # No way to parallelize for seek type for i in range(n_page): count = min(total - i * max_per_page, max_per_page) df, resp = await self._fetch( # type: ignore itable, kwargs, _client=client, _throttler=throttler, _page=i, _auth=_auth, _q=_q, _limit=count, _anchor=next_token, _raw=True, ) if pagdef.token_location == TokenLocation.Header: next_token = resp.headers[pagdef.token_accessor] elif pagdef.token_location == TokenLocation.Body: # only json body implemented token_expr = jparse(pagdef.token_accessor) (token_elem,) = token_expr.find(await resp.json()) next_token = token_elem.value dfs.append(df) elif isinstance(pagdef, (OffsetPaginationDef, PagePaginationDef)): resps_coros = [] allowed_page = Ref(n_page) for i in range(n_page): count = min(total - i * max_per_page, max_per_page) if pagdef.type == "offset": anchor = i * max_per_page elif pagdef.type == "page": anchor = i + 1 else: raise ValueError(f"Unknown pagination type {pagdef.type}") resps_coros.append( self._fetch( itable, kwargs, _client=client, _throttler=throttler, _page=i, _allowed_page=allowed_page, _auth=_auth, _q=_q, _limit=count, _anchor=anchor, ) ) dfs = [] for resp_coro in as_completed(resps_coros): df = await resp_coro if df is not None: dfs.append(df) else: raise NotImplementedError df = pd.concat(dfs, axis=0).reset_index(drop=True) return df async def _fetch( # pylint: disable=too-many-locals,too-many-branches,too-many-statements self, table: ImplicitTable, kwargs: Dict[str, Any], *, _client: ClientSession, _throttler: ThrottleSession, _page: int = 0, _allowed_page: Optional[Ref[int]] = None, _limit: Optional[int] = None, _anchor: Optional[Any] = None, _auth: Optional[Dict[str, Any]] = None, _q: Optional[str] = None, _raw: bool = False, ) -> Union[Optional[pd.DataFrame], Tuple[Optional[pd.DataFrame], ClientResponse]]: reqdef = table.config.request method = reqdef.method req_data: Dict[str, Dict[str, Any]] = { "headers": {}, "params": {}, "cookies": {}, } merged_vars = {**self._vars, **kwargs} if reqdef.authorization is not None: reqdef.authorization.build(req_data, _auth or self._auth, self._storage) if reqdef.body is not None: # TODO: do we support binary body? instantiated_fields = populate_field(reqdef.body.content, self._jenv, merged_vars) if reqdef.body.ctype == "application/x-www-form-urlencoded": req_data["data"] = instantiated_fields elif reqdef.body.ctype == "application/json": req_data["json"] = instantiated_fields else: raise NotImplementedError(reqdef.body.ctype) if reqdef.pagination is not None and _limit is not None: pagdef = reqdef.pagination limit_key = pagdef.limit_key if isinstance(pagdef, SeekPaginationDef): anchor = pagdef.seek_key elif isinstance(pagdef, OffsetPaginationDef): anchor = pagdef.offset_key elif isinstance(pagdef, PagePaginationDef): anchor = pagdef.page_key elif isinstance(pagdef, TokenPaginationDef): anchor = pagdef.token_key else: raise ValueError(f"Unknown pagination type {pagdef.type}.") if limit_key in req_data["params"]: raise UniversalParameterOverridden(limit_key, "_limit") req_data["params"][limit_key] = _limit if anchor in req_data["params"]: raise UniversalParameterOverridden(anchor, "_offset") if _anchor is not None: req_data["params"][anchor] = _anchor if _q is not None: if reqdef.search is None: raise ValueError("_q specified but the API does not support custom search.") searchdef = reqdef.search search_key = searchdef.key if search_key in req_data["params"]: raise UniversalParameterOverridden(search_key, "_q") req_data["params"][search_key] = _q for key in ["headers", "params", "cookies"]: field_def = getattr(reqdef, key, None) if field_def is not None: instantiated_fields = populate_field( field_def, self._jenv, merged_vars, ) for ikey in instantiated_fields: if ikey in req_data[key]: warn( f"Query parameter {ikey}={req_data[key][ikey]}" " is overriden by {ikey}={instantiated_fields[ikey]}", RuntimeWarning, ) req_data[key].update(**instantiated_fields) for key in ["headers", "params", "cookies"]: field_def = getattr(reqdef, key, None) if field_def is not None: validate_fields(field_def, req_data[key]) url = reqdef.populate_url(merged_vars) await _throttler.acquire(_page) if _allowed_page is not None and int(_allowed_page) <= _page: # cancel current throttler counter since the request is not sent out _throttler.release() return None async with _client.request( method=method, url=url, headers=req_data["headers"], params=req_data["params"], json=req_data.get("json"), data=req_data.get("data"), cookies=req_data["cookies"], ) as resp: if resp.status != 200: raise RequestError(status_code=resp.status, message=await resp.text()) content = await resp.text() df = table.from_response(content) if len(df) == 0 and _allowed_page is not None and _page is not None: _allowed_page.set(_page) df = None if _raw: return df, resp else: return df [docs]def validate_fields(fields: Dict[str, FieldDefUnion], data: Dict[str, Any]) -> None: """Check required fields are provided.""" for key, def_ in fields.items(): to_key = key if isinstance(def_, bool): required = def_ if required and to_key not in data: raise KeyError(f"'{to_key}' is required but not provided") elif isinstance(def_, str): pass else: to_key = def_.to_key or to_key required = def_.required if required and to_key not in data: raise KeyError(f"'{to_key}' is required but not provided") [docs]def populate_field( # pylint: disable=too-many-branches fields: Dict[str, FieldDefUnion], jenv: Environment, params: Dict[str, Any], ) -> Dict[str, str]: """Populate a dict based on the fields definition and provided vars.""" ret: Dict[str, str] = {} for key, def_ in fields.items(): to_key = key if isinstance(def_, bool): value = params.get(to_key) remove_if_empty = False elif isinstance(def_, str): # is a template tmplt = jenv.from_string(def_) value = tmplt.render(**params) remove_if_empty = False else: template = def_.template remove_if_empty = def_.remove_if_empty to_key = def_.to_key or to_key if template is None: value = params.get(to_key) else: tmplt = jenv.from_string(template) try: value = tmplt.render(**params) except UndefinedError: value = "" # This empty string will be removed if `remove_if_empty` is True if value is not None: str_value = str(value) if not remove_if_empty or str_value: if to_key in ret: warn( f"{to_key}={ret[to_key]} overriden by {to_key}={str_value}", RuntimeWarning, ) ret[to_key] = str_value continue return ret