diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/postgrest/base_request_builder.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/postgrest/base_request_builder.py | 685 |
1 files changed, 685 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/postgrest/base_request_builder.py b/.venv/lib/python3.12/site-packages/postgrest/base_request_builder.py new file mode 100644 index 00000000..7b5ab4b7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/postgrest/base_request_builder.py @@ -0,0 +1,685 @@ +from __future__ import annotations + +import json +from json import JSONDecodeError +from re import search +from typing import ( + Any, + Dict, + Generic, + Iterable, + List, + Literal, + NamedTuple, + Optional, + Tuple, + Type, + TypeVar, + Union, +) + +from httpx import Headers, QueryParams +from httpx import Response as RequestResponse +from pydantic import BaseModel + +try: + from typing import Self +except ImportError: + from typing_extensions import Self + +try: + # >= 2.0.0 + from pydantic import field_validator +except ImportError: + # < 2.0.0 + from pydantic import validator as field_validator + +from .types import CountMethod, Filters, RequestMethod, ReturnMethod +from .utils import AsyncClient, SyncClient, get_origin_and_cast, sanitize_param + + +class QueryArgs(NamedTuple): + # groups the method, json, headers and params for a query in a single object + method: RequestMethod + params: QueryParams + headers: Headers + json: Dict[Any, Any] + + +def _unique_columns(json: List[Dict]): + unique_keys = {key for row in json for key in row.keys()} + columns = ",".join([f'"{k}"' for k in unique_keys]) + return columns + + +def _cleaned_columns(columns: Tuple[str, ...]) -> str: + quoted = False + cleaned = [] + + for column in columns: + clean_column = "" + for char in column: + if char.isspace() and not quoted: + continue + if char == '"': + quoted = not quoted + clean_column += char + cleaned.append(clean_column) + + return ",".join(cleaned) + + +def pre_select( + *columns: str, + count: Optional[CountMethod] = None, + head: Optional[bool] = None, +) -> QueryArgs: + method = RequestMethod.HEAD if head else RequestMethod.GET + cleaned_columns = _cleaned_columns(columns or "*") + params = QueryParams({"select": cleaned_columns}) + + headers = Headers({"Prefer": f"count={count}"}) if count else Headers() + return QueryArgs(method, params, headers, {}) + + +def pre_insert( + json: Union[dict, list], + *, + count: Optional[CountMethod], + returning: ReturnMethod, + upsert: bool, + default_to_null: bool = True, +) -> QueryArgs: + prefer_headers = [f"return={returning}"] + if count: + prefer_headers.append(f"count={count}") + if upsert: + prefer_headers.append("resolution=merge-duplicates") + if not default_to_null: + prefer_headers.append("missing=default") + headers = Headers({"Prefer": ",".join(prefer_headers)}) + # Adding 'columns' query parameters + query_params = {} + if isinstance(json, list): + query_params = {"columns": _unique_columns(json)} + return QueryArgs(RequestMethod.POST, QueryParams(query_params), headers, json) + + +def pre_upsert( + json: Union[dict, list], + *, + count: Optional[CountMethod], + returning: ReturnMethod, + ignore_duplicates: bool, + on_conflict: str = "", + default_to_null: bool = True, +) -> QueryArgs: + query_params = {} + prefer_headers = [f"return={returning}"] + if count: + prefer_headers.append(f"count={count}") + resolution = "ignore" if ignore_duplicates else "merge" + prefer_headers.append(f"resolution={resolution}-duplicates") + if not default_to_null: + prefer_headers.append("missing=default") + headers = Headers({"Prefer": ",".join(prefer_headers)}) + if on_conflict: + query_params["on_conflict"] = on_conflict + # Adding 'columns' query parameters + if isinstance(json, list): + query_params["columns"] = _unique_columns(json) + return QueryArgs(RequestMethod.POST, QueryParams(query_params), headers, json) + + +def pre_update( + json: dict, + *, + count: Optional[CountMethod], + returning: ReturnMethod, +) -> QueryArgs: + prefer_headers = [f"return={returning}"] + if count: + prefer_headers.append(f"count={count}") + headers = Headers({"Prefer": ",".join(prefer_headers)}) + return QueryArgs(RequestMethod.PATCH, QueryParams(), headers, json) + + +def pre_delete( + *, + count: Optional[CountMethod], + returning: ReturnMethod, +) -> QueryArgs: + prefer_headers = [f"return={returning}"] + if count: + prefer_headers.append(f"count={count}") + headers = Headers({"Prefer": ",".join(prefer_headers)}) + return QueryArgs(RequestMethod.DELETE, QueryParams(), headers, {}) + + +_ReturnT = TypeVar("_ReturnT") + + +# the APIResponse.data is marked as _ReturnT instead of list[_ReturnT] +# as it is also returned in the case of rpc() calls; and rpc calls do not +# necessarily return lists. +# https://github.com/supabase-community/postgrest-py/issues/200 +class APIResponse(BaseModel, Generic[_ReturnT]): + data: List[_ReturnT] + """The data returned by the query.""" + count: Optional[int] = None + """The number of rows returned.""" + + @field_validator("data") + @classmethod + def raise_when_api_error(cls: Type[Self], value: Any) -> Any: + if isinstance(value, dict) and value.get("message"): + raise ValueError("You are passing an API error to the data field.") + return value + + @staticmethod + def _get_count_from_content_range_header( + content_range_header: str, + ) -> Optional[int]: + content_range = content_range_header.split("/") + return None if len(content_range) < 2 else int(content_range[1]) + + @staticmethod + def _is_count_in_prefer_header(prefer_header: str) -> bool: + pattern = f"count=({'|'.join([cm.value for cm in CountMethod])})" + return bool(search(pattern, prefer_header)) + + @classmethod + def _get_count_from_http_request_response( + cls: Type[Self], + request_response: RequestResponse, + ) -> Optional[int]: + prefer_header: Optional[str] = request_response.request.headers.get("prefer") + if not prefer_header: + return None + is_count_in_prefer_header = cls._is_count_in_prefer_header(prefer_header) + content_range_header: Optional[str] = request_response.headers.get( + "content-range" + ) + return ( + cls._get_count_from_content_range_header(content_range_header) + if (is_count_in_prefer_header and content_range_header) + else None + ) + + @classmethod + def from_http_request_response( + cls: Type[Self], request_response: RequestResponse + ) -> Self: + count = cls._get_count_from_http_request_response(request_response) + try: + data = request_response.json() + except JSONDecodeError: + data = request_response.text if len(request_response.text) > 0 else [] + # the type-ignore here is as pydantic needs us to pass the type parameter + # here explicitly, but pylance already knows that cls is correctly parametrized + return cls[_ReturnT](data=data, count=count) # type: ignore + + @classmethod + def from_dict(cls: Type[Self], dict: Dict[str, Any]) -> Self: + keys = dict.keys() + assert len(keys) == 3 and "data" in keys and "count" in keys and "error" in keys + return cls[_ReturnT]( # type: ignore + data=dict.get("data"), count=dict.get("count"), error=dict.get("error") + ) + + +class SingleAPIResponse(APIResponse[_ReturnT], Generic[_ReturnT]): + data: _ReturnT # type: ignore + """The data returned by the query.""" + + @classmethod + def from_http_request_response( + cls: Type[Self], request_response: RequestResponse + ) -> Self: + count = cls._get_count_from_http_request_response(request_response) + try: + data = request_response.json() + except JSONDecodeError: + data = request_response.text if len(request_response.text) > 0 else [] + return cls[_ReturnT](data=data, count=count) # type: ignore + + @classmethod + def from_dict(cls: Type[Self], dict: Dict[str, Any]) -> Self: + keys = dict.keys() + assert len(keys) == 3 and "data" in keys and "count" in keys and "error" in keys + return cls[_ReturnT]( # type: ignore + data=dict.get("data"), count=dict.get("count"), error=dict.get("error") + ) + + +class BaseFilterRequestBuilder(Generic[_ReturnT]): + def __init__( + self, + session: Union[AsyncClient, SyncClient], + headers: Headers, + params: QueryParams, + ) -> None: + self.session = session + self.headers = headers + self.params = params + self.negate_next = False + + @property + def not_(self: Self) -> Self: + """Whether the filter applied next should be negated.""" + self.negate_next = True + return self + + def filter(self: Self, column: str, operator: str, criteria: str) -> Self: + """Apply filters on a query. + + Args: + column: The name of the column to apply a filter on + operator: The operator to use while filtering + criteria: The value to filter by + """ + if self.negate_next is True: + self.negate_next = False + operator = f"{Filters.NOT}.{operator}" + key, val = sanitize_param(column), f"{operator}.{criteria}" + self.params = self.params.add(key, val) + return self + + def eq(self: Self, column: str, value: Any) -> Self: + """An 'equal to' filter. + + Args: + column: The name of the column to apply a filter on + value: The value to filter by + """ + return self.filter(column, Filters.EQ, value) + + def neq(self: Self, column: str, value: Any) -> Self: + """A 'not equal to' filter + + Args: + column: The name of the column to apply a filter on + value: The value to filter by + """ + return self.filter(column, Filters.NEQ, value) + + def gt(self: Self, column: str, value: Any) -> Self: + """A 'greater than' filter + + Args: + column: The name of the column to apply a filter on + value: The value to filter by + """ + return self.filter(column, Filters.GT, value) + + def gte(self: Self, column: str, value: Any) -> Self: + """A 'greater than or equal to' filter + + Args: + column: The name of the column to apply a filter on + value: The value to filter by + """ + return self.filter(column, Filters.GTE, value) + + def lt(self: Self, column: str, value: Any) -> Self: + """A 'less than' filter + + Args: + column: The name of the column to apply a filter on + value: The value to filter by + """ + return self.filter(column, Filters.LT, value) + + def lte(self: Self, column: str, value: Any) -> Self: + """A 'less than or equal to' filter + + Args: + column: The name of the column to apply a filter on + value: The value to filter by + """ + return self.filter(column, Filters.LTE, value) + + def is_(self: Self, column: str, value: Any) -> Self: + """An 'is' filter + + Args: + column: The name of the column to apply a filter on + value: The value to filter by + """ + if value is None: + value = "null" + return self.filter(column, Filters.IS, value) + + def like(self: Self, column: str, pattern: str) -> Self: + """A 'LIKE' filter, to use for pattern matching. + + Args: + column: The name of the column to apply a filter on + pattern: The pattern to filter by + """ + return self.filter(column, Filters.LIKE, pattern) + + def like_all_of(self: Self, column: str, pattern: str) -> Self: + """A 'LIKE' filter, to use for pattern matching. + + Args: + column: The name of the column to apply a filter on + pattern: The pattern to filter by + """ + + return self.filter(column, Filters.LIKE_ALL, f"{{{pattern}}}") + + def like_any_of(self: Self, column: str, pattern: str) -> Self: + """A 'LIKE' filter, to use for pattern matching. + + Args: + column: The name of the column to apply a filter on + pattern: The pattern to filter by + """ + + return self.filter(column, Filters.LIKE_ANY, f"{{{pattern}}}") + + def ilike_all_of(self: Self, column: str, pattern: str) -> Self: + """A 'ILIKE' filter, to use for pattern matching (case insensitive). + + Args: + column: The name of the column to apply a filter on + pattern: The pattern to filter by + """ + + return self.filter(column, Filters.ILIKE_ALL, f"{{{pattern}}}") + + def ilike_any_of(self: Self, column: str, pattern: str) -> Self: + """A 'ILIKE' filter, to use for pattern matching (case insensitive). + + Args: + column: The name of the column to apply a filter on + pattern: The pattern to filter by + """ + + return self.filter(column, Filters.ILIKE_ANY, f"{{{pattern}}}") + + def ilike(self: Self, column: str, pattern: str) -> Self: + """An 'ILIKE' filter, to use for pattern matching (case insensitive). + + Args: + column: The name of the column to apply a filter on + pattern: The pattern to filter by + """ + return self.filter(column, Filters.ILIKE, pattern) + + def or_(self: Self, filters: str, reference_table: Optional[str] = None) -> Self: + """An 'or' filter + + Args: + filters: The filters to use, following PostgREST syntax + reference_table: Set this to filter on referenced tables instead of the parent table + """ + key = f"{sanitize_param(reference_table)}.or" if reference_table else "or" + self.params = self.params.add(key, f"({filters})") + return self + + def fts(self: Self, column: str, query: Any) -> Self: + return self.filter(column, Filters.FTS, query) + + def plfts(self: Self, column: str, query: Any) -> Self: + return self.filter(column, Filters.PLFTS, query) + + def phfts(self: Self, column: str, query: Any) -> Self: + return self.filter(column, Filters.PHFTS, query) + + def wfts(self: Self, column: str, query: Any) -> Self: + return self.filter(column, Filters.WFTS, query) + + def in_(self: Self, column: str, values: Iterable[Any]) -> Self: + values = map(sanitize_param, values) + values = ",".join(values) + return self.filter(column, Filters.IN, f"({values})") + + def cs(self: Self, column: str, values: Iterable[Any]) -> Self: + values = ",".join(values) + return self.filter(column, Filters.CS, f"{{{values}}}") + + def cd(self: Self, column: str, values: Iterable[Any]) -> Self: + values = ",".join(values) + return self.filter(column, Filters.CD, f"{{{values}}}") + + def contains( + self: Self, column: str, value: Union[Iterable[Any], str, Dict[Any, Any]] + ) -> Self: + if isinstance(value, str): + # range types can be inclusive '[', ']' or exclusive '(', ')' so just + # keep it simple and accept a string + return self.filter(column, Filters.CS, value) + if not isinstance(value, dict) and isinstance(value, Iterable): + # Expected to be some type of iterable + stringified_values = ",".join(value) + return self.filter(column, Filters.CS, f"{{{stringified_values}}}") + + return self.filter(column, Filters.CS, json.dumps(value)) + + def contained_by( + self: Self, column: str, value: Union[Iterable[Any], str, Dict[Any, Any]] + ) -> Self: + if isinstance(value, str): + # range + return self.filter(column, Filters.CD, value) + if not isinstance(value, dict) and isinstance(value, Iterable): + stringified_values = ",".join(value) + return self.filter(column, Filters.CD, f"{{{stringified_values}}}") + return self.filter(column, Filters.CD, json.dumps(value)) + + def ov(self: Self, column: str, value: Iterable[Any]) -> Self: + if isinstance(value, str): + # range types can be inclusive '[', ']' or exclusive '(', ')' so just + # keep it simple and accept a string + return self.filter(column, Filters.OV, value) + if not isinstance(value, dict) and isinstance(value, Iterable): + # Expected to be some type of iterable + stringified_values = ",".join(value) + return self.filter(column, Filters.OV, f"{{{stringified_values}}}") + return self.filter(column, Filters.OV, json.dumps(value)) + + def sl(self: Self, column: str, range: Tuple[int, int]) -> Self: + return self.filter(column, Filters.SL, f"({range[0]},{range[1]})") + + def sr(self: Self, column: str, range: Tuple[int, int]) -> Self: + return self.filter(column, Filters.SR, f"({range[0]},{range[1]})") + + def nxl(self: Self, column: str, range: Tuple[int, int]) -> Self: + return self.filter(column, Filters.NXL, f"({range[0]},{range[1]})") + + def nxr(self: Self, column: str, range: Tuple[int, int]) -> Self: + return self.filter(column, Filters.NXR, f"({range[0]},{range[1]})") + + def adj(self: Self, column: str, range: Tuple[int, int]) -> Self: + return self.filter(column, Filters.ADJ, f"({range[0]},{range[1]})") + + def range_gt(self: Self, column: str, range: Tuple[int, int]) -> Self: + return self.sr(column, range) + + def range_gte(self: Self, column: str, range: Tuple[int, int]) -> Self: + return self.nxl(column, range) + + def range_lt(self: Self, column: str, range: Tuple[int, int]) -> Self: + return self.sl(column, range) + + def range_lte(self: Self, column: str, range: Tuple[int, int]) -> Self: + return self.nxr(column, range) + + def range_adjacent(self: Self, column: str, range: Tuple[int, int]) -> Self: + return self.adj(column, range) + + def overlaps(self: Self, column: str, values: Iterable[Any]) -> Self: + return self.ov(column, values) + + def match(self: Self, query: Dict[str, Any]) -> Self: + updated_query = self + + if not query: + raise ValueError( + "query dictionary should contain at least one key-value pair" + ) + + for key, value in query.items(): + updated_query = self.eq(key, value) + + return updated_query + + +class BaseSelectRequestBuilder(BaseFilterRequestBuilder[_ReturnT]): + def __init__( + self, + session: Union[AsyncClient, SyncClient], + headers: Headers, + params: QueryParams, + ) -> None: + # Generic[T] is an instance of typing._GenericAlias, so doing Generic[T].__init__ + # tries to call _GenericAlias.__init__ - which is the wrong method + # The __origin__ attribute of the _GenericAlias is the actual class + get_origin_and_cast(BaseFilterRequestBuilder[_ReturnT]).__init__( + self, session, headers, params + ) + + def explain( + self: Self, + analyze: bool = False, + verbose: bool = False, + settings: bool = False, + buffers: bool = False, + wal: bool = False, + format: Literal["text", "json"] = "text", + ) -> Self: + options = [ + key + for key, value in locals().items() + if key not in ["self", "format"] and value + ] + options_str = "|".join(options) + self.headers["Accept"] = ( + f"application/vnd.pgrst.plan+{format}; options={options_str}" + ) + return self + + def order( + self: Self, + column: str, + *, + desc: bool = False, + nullsfirst: bool = False, + foreign_table: Optional[str] = None, + ) -> Self: + """Sort the returned rows in some specific order. + + Args: + column: The column to order by + desc: Whether the rows should be ordered in descending order or not. + nullsfirst: nullsfirst + foreign_table: Foreign table name whose results are to be ordered. + .. versionchanged:: 0.10.3 + Allow ordering results for foreign tables with the foreign_table parameter. + """ + + new_order_parameter = ( + f"{foreign_table + '(' if foreign_table else ''}{column}{')' if foreign_table else ''}" + f"{'.desc' if desc else ''}{'.nullsfirst' if nullsfirst else ''}" + ) + + existing_order_parameter = self.params.get("order") + if existing_order_parameter: + self.params = self.params.remove("order") + new_order_parameter = f"{existing_order_parameter},{new_order_parameter}" + + self.params = self.params.add( + "order", + new_order_parameter, + ) + return self + + def limit(self: Self, size: int, *, foreign_table: Optional[str] = None) -> Self: + """Limit the number of rows returned by a query. + + Args: + size: The number of rows to be returned + foreign_table: Foreign table name to limit + .. versionchanged:: 0.10.3 + Allow limiting results returned for foreign tables with the foreign_table parameter. + """ + self.params = self.params.add( + f"{foreign_table}.limit" if foreign_table else "limit", + size, + ) + return self + + def offset(self: _FilterT, size: int) -> _FilterT: + """Set the starting row index returned by a query. + Args: + size: The number of the row to start at + """ + self.params = self.params.add( + "offset", + size, + ) + return self + + def range( + self: Self, start: int, end: int, foreign_table: Optional[str] = None + ) -> Self: + self.params = self.params.add( + f"{foreign_table}.offset" if foreign_table else "offset", start + ) + self.params = self.params.add( + f"{foreign_table}.limit" if foreign_table else "limit", + end - start + 1, + ) + return self + + +class BaseRPCRequestBuilder(BaseSelectRequestBuilder[_ReturnT]): + def __init__( + self, + session: Union[AsyncClient, SyncClient], + headers: Headers, + params: QueryParams, + ) -> None: + # Generic[T] is an instance of typing._GenericAlias, so doing Generic[T].__init__ + # tries to call _GenericAlias.__init__ - which is the wrong method + # The __origin__ attribute of the _GenericAlias is the actual class + get_origin_and_cast(BaseSelectRequestBuilder[_ReturnT]).__init__( + self, session, headers, params + ) + + def select( + self, + *columns: str, + ) -> Self: + """Run a SELECT query. + + Args: + *columns: The names of the columns to fetch. + Returns: + :class:`BaseSelectRequestBuilder` + """ + method, params, headers, json = pre_select(*columns, count=None) + self.params = self.params.add("select", params.get("select")) + self.headers["Prefer"] = "return=representation" + return self + + def single(self) -> Self: + """Specify that the query will only return a single row in response. + + .. caution:: + The API will raise an error if the query returned more than one row. + """ + self.headers["Accept"] = "application/vnd.pgrst.object+json" + return self + + def maybe_single(self) -> Self: + """Retrieves at most one row from the result. Result must be at most one row (e.g. using `eq` on a UNIQUE column), otherwise this will result in an error.""" + self.headers["Accept"] = "application/vnd.pgrst.object+json" + return self + + def csv(self) -> Self: + """Specify that the query must retrieve data as a single CSV string.""" + self.headers["Accept"] = "text/csv" + return self |