from __future__ import annotations import json from json import JSONDecodeError from re import search from typing import ( Any, Dict, Generic, Iterable, List, Literal, NamedTuple, Optional, Tuple, Type, TypeVar, Union, ) from httpx import Headers, QueryParams from httpx import Response as RequestResponse from pydantic import BaseModel try: from typing import Self except ImportError: from typing_extensions import Self try: # >= 2.0.0 from pydantic import field_validator except ImportError: # < 2.0.0 from pydantic import validator as field_validator from .types import CountMethod, Filters, RequestMethod, ReturnMethod from .utils import AsyncClient, SyncClient, get_origin_and_cast, sanitize_param class QueryArgs(NamedTuple): # groups the method, json, headers and params for a query in a single object method: RequestMethod params: QueryParams headers: Headers json: Dict[Any, Any] def _unique_columns(json: List[Dict]): unique_keys = {key for row in json for key in row.keys()} columns = ",".join([f'"{k}"' for k in unique_keys]) return columns def _cleaned_columns(columns: Tuple[str, ...]) -> str: quoted = False cleaned = [] for column in columns: clean_column = "" for char in column: if char.isspace() and not quoted: continue if char == '"': quoted = not quoted clean_column += char cleaned.append(clean_column) return ",".join(cleaned) def pre_select( *columns: str, count: Optional[CountMethod] = None, head: Optional[bool] = None, ) -> QueryArgs: method = RequestMethod.HEAD if head else RequestMethod.GET cleaned_columns = _cleaned_columns(columns or "*") params = QueryParams({"select": cleaned_columns}) headers = Headers({"Prefer": f"count={count}"}) if count else Headers() return QueryArgs(method, params, headers, {}) def pre_insert( json: Union[dict, list], *, count: Optional[CountMethod], returning: ReturnMethod, upsert: bool, default_to_null: bool = True, ) -> QueryArgs: prefer_headers = [f"return={returning}"] if count: prefer_headers.append(f"count={count}") if upsert: prefer_headers.append("resolution=merge-duplicates") if not default_to_null: prefer_headers.append("missing=default") headers = Headers({"Prefer": ",".join(prefer_headers)}) # Adding 'columns' query parameters query_params = {} if isinstance(json, list): query_params = {"columns": _unique_columns(json)} return QueryArgs(RequestMethod.POST, QueryParams(query_params), headers, json) def pre_upsert( json: Union[dict, list], *, count: Optional[CountMethod], returning: ReturnMethod, ignore_duplicates: bool, on_conflict: str = "", default_to_null: bool = True, ) -> QueryArgs: query_params = {} prefer_headers = [f"return={returning}"] if count: prefer_headers.append(f"count={count}") resolution = "ignore" if ignore_duplicates else "merge" prefer_headers.append(f"resolution={resolution}-duplicates") if not default_to_null: prefer_headers.append("missing=default") headers = Headers({"Prefer": ",".join(prefer_headers)}) if on_conflict: query_params["on_conflict"] = on_conflict # Adding 'columns' query parameters if isinstance(json, list): query_params["columns"] = _unique_columns(json) return QueryArgs(RequestMethod.POST, QueryParams(query_params), headers, json) def pre_update( json: dict, *, count: Optional[CountMethod], returning: ReturnMethod, ) -> QueryArgs: prefer_headers = [f"return={returning}"] if count: prefer_headers.append(f"count={count}") headers = Headers({"Prefer": ",".join(prefer_headers)}) return QueryArgs(RequestMethod.PATCH, QueryParams(), headers, json) def pre_delete( *, count: Optional[CountMethod], returning: ReturnMethod, ) -> QueryArgs: prefer_headers = [f"return={returning}"] if count: prefer_headers.append(f"count={count}") headers = Headers({"Prefer": ",".join(prefer_headers)}) return QueryArgs(RequestMethod.DELETE, QueryParams(), headers, {}) _ReturnT = TypeVar("_ReturnT") # the APIResponse.data is marked as _ReturnT instead of list[_ReturnT] # as it is also returned in the case of rpc() calls; and rpc calls do not # necessarily return lists. # https://github.com/supabase-community/postgrest-py/issues/200 class APIResponse(BaseModel, Generic[_ReturnT]): data: List[_ReturnT] """The data returned by the query.""" count: Optional[int] = None """The number of rows returned.""" @field_validator("data") @classmethod def raise_when_api_error(cls: Type[Self], value: Any) -> Any: if isinstance(value, dict) and value.get("message"): raise ValueError("You are passing an API error to the data field.") return value @staticmethod def _get_count_from_content_range_header( content_range_header: str, ) -> Optional[int]: content_range = content_range_header.split("/") return None if len(content_range) < 2 else int(content_range[1]) @staticmethod def _is_count_in_prefer_header(prefer_header: str) -> bool: pattern = f"count=({'|'.join([cm.value for cm in CountMethod])})" return bool(search(pattern, prefer_header)) @classmethod def _get_count_from_http_request_response( cls: Type[Self], request_response: RequestResponse, ) -> Optional[int]: prefer_header: Optional[str] = request_response.request.headers.get("prefer") if not prefer_header: return None is_count_in_prefer_header = cls._is_count_in_prefer_header(prefer_header) content_range_header: Optional[str] = request_response.headers.get( "content-range" ) return ( cls._get_count_from_content_range_header(content_range_header) if (is_count_in_prefer_header and content_range_header) else None ) @classmethod def from_http_request_response( cls: Type[Self], request_response: RequestResponse ) -> Self: count = cls._get_count_from_http_request_response(request_response) try: data = request_response.json() except JSONDecodeError: data = request_response.text if len(request_response.text) > 0 else [] # the type-ignore here is as pydantic needs us to pass the type parameter # here explicitly, but pylance already knows that cls is correctly parametrized return cls[_ReturnT](data=data, count=count) # type: ignore @classmethod def from_dict(cls: Type[Self], dict: Dict[str, Any]) -> Self: keys = dict.keys() assert len(keys) == 3 and "data" in keys and "count" in keys and "error" in keys return cls[_ReturnT]( # type: ignore data=dict.get("data"), count=dict.get("count"), error=dict.get("error") ) class SingleAPIResponse(APIResponse[_ReturnT], Generic[_ReturnT]): data: _ReturnT # type: ignore """The data returned by the query.""" @classmethod def from_http_request_response( cls: Type[Self], request_response: RequestResponse ) -> Self: count = cls._get_count_from_http_request_response(request_response) try: data = request_response.json() except JSONDecodeError: data = request_response.text if len(request_response.text) > 0 else [] return cls[_ReturnT](data=data, count=count) # type: ignore @classmethod def from_dict(cls: Type[Self], dict: Dict[str, Any]) -> Self: keys = dict.keys() assert len(keys) == 3 and "data" in keys and "count" in keys and "error" in keys return cls[_ReturnT]( # type: ignore data=dict.get("data"), count=dict.get("count"), error=dict.get("error") ) class BaseFilterRequestBuilder(Generic[_ReturnT]): def __init__( self, session: Union[AsyncClient, SyncClient], headers: Headers, params: QueryParams, ) -> None: self.session = session self.headers = headers self.params = params self.negate_next = False @property def not_(self: Self) -> Self: """Whether the filter applied next should be negated.""" self.negate_next = True return self def filter(self: Self, column: str, operator: str, criteria: str) -> Self: """Apply filters on a query. Args: column: The name of the column to apply a filter on operator: The operator to use while filtering criteria: The value to filter by """ if self.negate_next is True: self.negate_next = False operator = f"{Filters.NOT}.{operator}" key, val = sanitize_param(column), f"{operator}.{criteria}" self.params = self.params.add(key, val) return self def eq(self: Self, column: str, value: Any) -> Self: """An 'equal to' filter. Args: column: The name of the column to apply a filter on value: The value to filter by """ return self.filter(column, Filters.EQ, value) def neq(self: Self, column: str, value: Any) -> Self: """A 'not equal to' filter Args: column: The name of the column to apply a filter on value: The value to filter by """ return self.filter(column, Filters.NEQ, value) def gt(self: Self, column: str, value: Any) -> Self: """A 'greater than' filter Args: column: The name of the column to apply a filter on value: The value to filter by """ return self.filter(column, Filters.GT, value) def gte(self: Self, column: str, value: Any) -> Self: """A 'greater than or equal to' filter Args: column: The name of the column to apply a filter on value: The value to filter by """ return self.filter(column, Filters.GTE, value) def lt(self: Self, column: str, value: Any) -> Self: """A 'less than' filter Args: column: The name of the column to apply a filter on value: The value to filter by """ return self.filter(column, Filters.LT, value) def lte(self: Self, column: str, value: Any) -> Self: """A 'less than or equal to' filter Args: column: The name of the column to apply a filter on value: The value to filter by """ return self.filter(column, Filters.LTE, value) def is_(self: Self, column: str, value: Any) -> Self: """An 'is' filter Args: column: The name of the column to apply a filter on value: The value to filter by """ if value is None: value = "null" return self.filter(column, Filters.IS, value) def like(self: Self, column: str, pattern: str) -> Self: """A 'LIKE' filter, to use for pattern matching. Args: column: The name of the column to apply a filter on pattern: The pattern to filter by """ return self.filter(column, Filters.LIKE, pattern) def like_all_of(self: Self, column: str, pattern: str) -> Self: """A 'LIKE' filter, to use for pattern matching. Args: column: The name of the column to apply a filter on pattern: The pattern to filter by """ return self.filter(column, Filters.LIKE_ALL, f"{{{pattern}}}") def like_any_of(self: Self, column: str, pattern: str) -> Self: """A 'LIKE' filter, to use for pattern matching. Args: column: The name of the column to apply a filter on pattern: The pattern to filter by """ return self.filter(column, Filters.LIKE_ANY, f"{{{pattern}}}") def ilike_all_of(self: Self, column: str, pattern: str) -> Self: """A 'ILIKE' filter, to use for pattern matching (case insensitive). Args: column: The name of the column to apply a filter on pattern: The pattern to filter by """ return self.filter(column, Filters.ILIKE_ALL, f"{{{pattern}}}") def ilike_any_of(self: Self, column: str, pattern: str) -> Self: """A 'ILIKE' filter, to use for pattern matching (case insensitive). Args: column: The name of the column to apply a filter on pattern: The pattern to filter by """ return self.filter(column, Filters.ILIKE_ANY, f"{{{pattern}}}") def ilike(self: Self, column: str, pattern: str) -> Self: """An 'ILIKE' filter, to use for pattern matching (case insensitive). Args: column: The name of the column to apply a filter on pattern: The pattern to filter by """ return self.filter(column, Filters.ILIKE, pattern) def or_(self: Self, filters: str, reference_table: Optional[str] = None) -> Self: """An 'or' filter Args: filters: The filters to use, following PostgREST syntax reference_table: Set this to filter on referenced tables instead of the parent table """ key = f"{sanitize_param(reference_table)}.or" if reference_table else "or" self.params = self.params.add(key, f"({filters})") return self def fts(self: Self, column: str, query: Any) -> Self: return self.filter(column, Filters.FTS, query) def plfts(self: Self, column: str, query: Any) -> Self: return self.filter(column, Filters.PLFTS, query) def phfts(self: Self, column: str, query: Any) -> Self: return self.filter(column, Filters.PHFTS, query) def wfts(self: Self, column: str, query: Any) -> Self: return self.filter(column, Filters.WFTS, query) def in_(self: Self, column: str, values: Iterable[Any]) -> Self: values = map(sanitize_param, values) values = ",".join(values) return self.filter(column, Filters.IN, f"({values})") def cs(self: Self, column: str, values: Iterable[Any]) -> Self: values = ",".join(values) return self.filter(column, Filters.CS, f"{{{values}}}") def cd(self: Self, column: str, values: Iterable[Any]) -> Self: values = ",".join(values) return self.filter(column, Filters.CD, f"{{{values}}}") def contains( self: Self, column: str, value: Union[Iterable[Any], str, Dict[Any, Any]] ) -> Self: if isinstance(value, str): # range types can be inclusive '[', ']' or exclusive '(', ')' so just # keep it simple and accept a string return self.filter(column, Filters.CS, value) if not isinstance(value, dict) and isinstance(value, Iterable): # Expected to be some type of iterable stringified_values = ",".join(value) return self.filter(column, Filters.CS, f"{{{stringified_values}}}") return self.filter(column, Filters.CS, json.dumps(value)) def contained_by( self: Self, column: str, value: Union[Iterable[Any], str, Dict[Any, Any]] ) -> Self: if isinstance(value, str): # range return self.filter(column, Filters.CD, value) if not isinstance(value, dict) and isinstance(value, Iterable): stringified_values = ",".join(value) return self.filter(column, Filters.CD, f"{{{stringified_values}}}") return self.filter(column, Filters.CD, json.dumps(value)) def ov(self: Self, column: str, value: Iterable[Any]) -> Self: if isinstance(value, str): # range types can be inclusive '[', ']' or exclusive '(', ')' so just # keep it simple and accept a string return self.filter(column, Filters.OV, value) if not isinstance(value, dict) and isinstance(value, Iterable): # Expected to be some type of iterable stringified_values = ",".join(value) return self.filter(column, Filters.OV, f"{{{stringified_values}}}") return self.filter(column, Filters.OV, json.dumps(value)) def sl(self: Self, column: str, range: Tuple[int, int]) -> Self: return self.filter(column, Filters.SL, f"({range[0]},{range[1]})") def sr(self: Self, column: str, range: Tuple[int, int]) -> Self: return self.filter(column, Filters.SR, f"({range[0]},{range[1]})") def nxl(self: Self, column: str, range: Tuple[int, int]) -> Self: return self.filter(column, Filters.NXL, f"({range[0]},{range[1]})") def nxr(self: Self, column: str, range: Tuple[int, int]) -> Self: return self.filter(column, Filters.NXR, f"({range[0]},{range[1]})") def adj(self: Self, column: str, range: Tuple[int, int]) -> Self: return self.filter(column, Filters.ADJ, f"({range[0]},{range[1]})") def range_gt(self: Self, column: str, range: Tuple[int, int]) -> Self: return self.sr(column, range) def range_gte(self: Self, column: str, range: Tuple[int, int]) -> Self: return self.nxl(column, range) def range_lt(self: Self, column: str, range: Tuple[int, int]) -> Self: return self.sl(column, range) def range_lte(self: Self, column: str, range: Tuple[int, int]) -> Self: return self.nxr(column, range) def range_adjacent(self: Self, column: str, range: Tuple[int, int]) -> Self: return self.adj(column, range) def overlaps(self: Self, column: str, values: Iterable[Any]) -> Self: return self.ov(column, values) def match(self: Self, query: Dict[str, Any]) -> Self: updated_query = self if not query: raise ValueError( "query dictionary should contain at least one key-value pair" ) for key, value in query.items(): updated_query = self.eq(key, value) return updated_query class BaseSelectRequestBuilder(BaseFilterRequestBuilder[_ReturnT]): def __init__( self, session: Union[AsyncClient, SyncClient], headers: Headers, params: QueryParams, ) -> None: # Generic[T] is an instance of typing._GenericAlias, so doing Generic[T].__init__ # tries to call _GenericAlias.__init__ - which is the wrong method # The __origin__ attribute of the _GenericAlias is the actual class get_origin_and_cast(BaseFilterRequestBuilder[_ReturnT]).__init__( self, session, headers, params ) def explain( self: Self, analyze: bool = False, verbose: bool = False, settings: bool = False, buffers: bool = False, wal: bool = False, format: Literal["text", "json"] = "text", ) -> Self: options = [ key for key, value in locals().items() if key not in ["self", "format"] and value ] options_str = "|".join(options) self.headers["Accept"] = ( f"application/vnd.pgrst.plan+{format}; options={options_str}" ) return self def order( self: Self, column: str, *, desc: bool = False, nullsfirst: bool = False, foreign_table: Optional[str] = None, ) -> Self: """Sort the returned rows in some specific order. Args: column: The column to order by desc: Whether the rows should be ordered in descending order or not. nullsfirst: nullsfirst foreign_table: Foreign table name whose results are to be ordered. .. versionchanged:: 0.10.3 Allow ordering results for foreign tables with the foreign_table parameter. """ new_order_parameter = ( f"{foreign_table + '(' if foreign_table else ''}{column}{')' if foreign_table else ''}" f"{'.desc' if desc else ''}{'.nullsfirst' if nullsfirst else ''}" ) existing_order_parameter = self.params.get("order") if existing_order_parameter: self.params = self.params.remove("order") new_order_parameter = f"{existing_order_parameter},{new_order_parameter}" self.params = self.params.add( "order", new_order_parameter, ) return self def limit(self: Self, size: int, *, foreign_table: Optional[str] = None) -> Self: """Limit the number of rows returned by a query. Args: size: The number of rows to be returned foreign_table: Foreign table name to limit .. versionchanged:: 0.10.3 Allow limiting results returned for foreign tables with the foreign_table parameter. """ self.params = self.params.add( f"{foreign_table}.limit" if foreign_table else "limit", size, ) return self def offset(self: _FilterT, size: int) -> _FilterT: """Set the starting row index returned by a query. Args: size: The number of the row to start at """ self.params = self.params.add( "offset", size, ) return self def range( self: Self, start: int, end: int, foreign_table: Optional[str] = None ) -> Self: self.params = self.params.add( f"{foreign_table}.offset" if foreign_table else "offset", start ) self.params = self.params.add( f"{foreign_table}.limit" if foreign_table else "limit", end - start + 1, ) return self class BaseRPCRequestBuilder(BaseSelectRequestBuilder[_ReturnT]): def __init__( self, session: Union[AsyncClient, SyncClient], headers: Headers, params: QueryParams, ) -> None: # Generic[T] is an instance of typing._GenericAlias, so doing Generic[T].__init__ # tries to call _GenericAlias.__init__ - which is the wrong method # The __origin__ attribute of the _GenericAlias is the actual class get_origin_and_cast(BaseSelectRequestBuilder[_ReturnT]).__init__( self, session, headers, params ) def select( self, *columns: str, ) -> Self: """Run a SELECT query. Args: *columns: The names of the columns to fetch. Returns: :class:`BaseSelectRequestBuilder` """ method, params, headers, json = pre_select(*columns, count=None) self.params = self.params.add("select", params.get("select")) self.headers["Prefer"] = "return=representation" return self def single(self) -> Self: """Specify that the query will only return a single row in response. .. caution:: The API will raise an error if the query returned more than one row. """ self.headers["Accept"] = "application/vnd.pgrst.object+json" return self def maybe_single(self) -> Self: """Retrieves at most one row from the result. Result must be at most one row (e.g. using `eq` on a UNIQUE column), otherwise this will result in an error.""" self.headers["Accept"] = "application/vnd.pgrst.object+json" return self def csv(self) -> Self: """Specify that the query must retrieve data as a single CSV string.""" self.headers["Accept"] = "text/csv" return self