from __future__ import annotations
import json
from json import JSONDecodeError
from re import search
from typing import (
Any,
Dict,
Generic,
Iterable,
List,
Literal,
NamedTuple,
Optional,
Tuple,
Type,
TypeVar,
Union,
)
from httpx import Headers, QueryParams
from httpx import Response as RequestResponse
from pydantic import BaseModel
try:
from typing import Self
except ImportError:
from typing_extensions import Self
try:
# >= 2.0.0
from pydantic import field_validator
except ImportError:
# < 2.0.0
from pydantic import validator as field_validator
from .types import CountMethod, Filters, RequestMethod, ReturnMethod
from .utils import AsyncClient, SyncClient, get_origin_and_cast, sanitize_param
class QueryArgs(NamedTuple):
# groups the method, json, headers and params for a query in a single object
method: RequestMethod
params: QueryParams
headers: Headers
json: Dict[Any, Any]
def _unique_columns(json: List[Dict]):
unique_keys = {key for row in json for key in row.keys()}
columns = ",".join([f'"{k}"' for k in unique_keys])
return columns
def _cleaned_columns(columns: Tuple[str, ...]) -> str:
quoted = False
cleaned = []
for column in columns:
clean_column = ""
for char in column:
if char.isspace() and not quoted:
continue
if char == '"':
quoted = not quoted
clean_column += char
cleaned.append(clean_column)
return ",".join(cleaned)
def pre_select(
*columns: str,
count: Optional[CountMethod] = None,
head: Optional[bool] = None,
) -> QueryArgs:
method = RequestMethod.HEAD if head else RequestMethod.GET
cleaned_columns = _cleaned_columns(columns or "*")
params = QueryParams({"select": cleaned_columns})
headers = Headers({"Prefer": f"count={count}"}) if count else Headers()
return QueryArgs(method, params, headers, {})
def pre_insert(
json: Union[dict, list],
*,
count: Optional[CountMethod],
returning: ReturnMethod,
upsert: bool,
default_to_null: bool = True,
) -> QueryArgs:
prefer_headers = [f"return={returning}"]
if count:
prefer_headers.append(f"count={count}")
if upsert:
prefer_headers.append("resolution=merge-duplicates")
if not default_to_null:
prefer_headers.append("missing=default")
headers = Headers({"Prefer": ",".join(prefer_headers)})
# Adding 'columns' query parameters
query_params = {}
if isinstance(json, list):
query_params = {"columns": _unique_columns(json)}
return QueryArgs(RequestMethod.POST, QueryParams(query_params), headers, json)
def pre_upsert(
json: Union[dict, list],
*,
count: Optional[CountMethod],
returning: ReturnMethod,
ignore_duplicates: bool,
on_conflict: str = "",
default_to_null: bool = True,
) -> QueryArgs:
query_params = {}
prefer_headers = [f"return={returning}"]
if count:
prefer_headers.append(f"count={count}")
resolution = "ignore" if ignore_duplicates else "merge"
prefer_headers.append(f"resolution={resolution}-duplicates")
if not default_to_null:
prefer_headers.append("missing=default")
headers = Headers({"Prefer": ",".join(prefer_headers)})
if on_conflict:
query_params["on_conflict"] = on_conflict
# Adding 'columns' query parameters
if isinstance(json, list):
query_params["columns"] = _unique_columns(json)
return QueryArgs(RequestMethod.POST, QueryParams(query_params), headers, json)
def pre_update(
json: dict,
*,
count: Optional[CountMethod],
returning: ReturnMethod,
) -> QueryArgs:
prefer_headers = [f"return={returning}"]
if count:
prefer_headers.append(f"count={count}")
headers = Headers({"Prefer": ",".join(prefer_headers)})
return QueryArgs(RequestMethod.PATCH, QueryParams(), headers, json)
def pre_delete(
*,
count: Optional[CountMethod],
returning: ReturnMethod,
) -> QueryArgs:
prefer_headers = [f"return={returning}"]
if count:
prefer_headers.append(f"count={count}")
headers = Headers({"Prefer": ",".join(prefer_headers)})
return QueryArgs(RequestMethod.DELETE, QueryParams(), headers, {})
_ReturnT = TypeVar("_ReturnT")
# the APIResponse.data is marked as _ReturnT instead of list[_ReturnT]
# as it is also returned in the case of rpc() calls; and rpc calls do not
# necessarily return lists.
# https://github.com/supabase-community/postgrest-py/issues/200
class APIResponse(BaseModel, Generic[_ReturnT]):
data: List[_ReturnT]
"""The data returned by the query."""
count: Optional[int] = None
"""The number of rows returned."""
@field_validator("data")
@classmethod
def raise_when_api_error(cls: Type[Self], value: Any) -> Any:
if isinstance(value, dict) and value.get("message"):
raise ValueError("You are passing an API error to the data field.")
return value
@staticmethod
def _get_count_from_content_range_header(
content_range_header: str,
) -> Optional[int]:
content_range = content_range_header.split("/")
return None if len(content_range) < 2 else int(content_range[1])
@staticmethod
def _is_count_in_prefer_header(prefer_header: str) -> bool:
pattern = f"count=({'|'.join([cm.value for cm in CountMethod])})"
return bool(search(pattern, prefer_header))
@classmethod
def _get_count_from_http_request_response(
cls: Type[Self],
request_response: RequestResponse,
) -> Optional[int]:
prefer_header: Optional[str] = request_response.request.headers.get("prefer")
if not prefer_header:
return None
is_count_in_prefer_header = cls._is_count_in_prefer_header(prefer_header)
content_range_header: Optional[str] = request_response.headers.get(
"content-range"
)
return (
cls._get_count_from_content_range_header(content_range_header)
if (is_count_in_prefer_header and content_range_header)
else None
)
@classmethod
def from_http_request_response(
cls: Type[Self], request_response: RequestResponse
) -> Self:
count = cls._get_count_from_http_request_response(request_response)
try:
data = request_response.json()
except JSONDecodeError:
data = request_response.text if len(request_response.text) > 0 else []
# the type-ignore here is as pydantic needs us to pass the type parameter
# here explicitly, but pylance already knows that cls is correctly parametrized
return cls[_ReturnT](data=data, count=count) # type: ignore
@classmethod
def from_dict(cls: Type[Self], dict: Dict[str, Any]) -> Self:
keys = dict.keys()
assert len(keys) == 3 and "data" in keys and "count" in keys and "error" in keys
return cls[_ReturnT]( # type: ignore
data=dict.get("data"), count=dict.get("count"), error=dict.get("error")
)
class SingleAPIResponse(APIResponse[_ReturnT], Generic[_ReturnT]):
data: _ReturnT # type: ignore
"""The data returned by the query."""
@classmethod
def from_http_request_response(
cls: Type[Self], request_response: RequestResponse
) -> Self:
count = cls._get_count_from_http_request_response(request_response)
try:
data = request_response.json()
except JSONDecodeError:
data = request_response.text if len(request_response.text) > 0 else []
return cls[_ReturnT](data=data, count=count) # type: ignore
@classmethod
def from_dict(cls: Type[Self], dict: Dict[str, Any]) -> Self:
keys = dict.keys()
assert len(keys) == 3 and "data" in keys and "count" in keys and "error" in keys
return cls[_ReturnT]( # type: ignore
data=dict.get("data"), count=dict.get("count"), error=dict.get("error")
)
class BaseFilterRequestBuilder(Generic[_ReturnT]):
def __init__(
self,
session: Union[AsyncClient, SyncClient],
headers: Headers,
params: QueryParams,
) -> None:
self.session = session
self.headers = headers
self.params = params
self.negate_next = False
@property
def not_(self: Self) -> Self:
"""Whether the filter applied next should be negated."""
self.negate_next = True
return self
def filter(self: Self, column: str, operator: str, criteria: str) -> Self:
"""Apply filters on a query.
Args:
column: The name of the column to apply a filter on
operator: The operator to use while filtering
criteria: The value to filter by
"""
if self.negate_next is True:
self.negate_next = False
operator = f"{Filters.NOT}.{operator}"
key, val = sanitize_param(column), f"{operator}.{criteria}"
self.params = self.params.add(key, val)
return self
def eq(self: Self, column: str, value: Any) -> Self:
"""An 'equal to' filter.
Args:
column: The name of the column to apply a filter on
value: The value to filter by
"""
return self.filter(column, Filters.EQ, value)
def neq(self: Self, column: str, value: Any) -> Self:
"""A 'not equal to' filter
Args:
column: The name of the column to apply a filter on
value: The value to filter by
"""
return self.filter(column, Filters.NEQ, value)
def gt(self: Self, column: str, value: Any) -> Self:
"""A 'greater than' filter
Args:
column: The name of the column to apply a filter on
value: The value to filter by
"""
return self.filter(column, Filters.GT, value)
def gte(self: Self, column: str, value: Any) -> Self:
"""A 'greater than or equal to' filter
Args:
column: The name of the column to apply a filter on
value: The value to filter by
"""
return self.filter(column, Filters.GTE, value)
def lt(self: Self, column: str, value: Any) -> Self:
"""A 'less than' filter
Args:
column: The name of the column to apply a filter on
value: The value to filter by
"""
return self.filter(column, Filters.LT, value)
def lte(self: Self, column: str, value: Any) -> Self:
"""A 'less than or equal to' filter
Args:
column: The name of the column to apply a filter on
value: The value to filter by
"""
return self.filter(column, Filters.LTE, value)
def is_(self: Self, column: str, value: Any) -> Self:
"""An 'is' filter
Args:
column: The name of the column to apply a filter on
value: The value to filter by
"""
if value is None:
value = "null"
return self.filter(column, Filters.IS, value)
def like(self: Self, column: str, pattern: str) -> Self:
"""A 'LIKE' filter, to use for pattern matching.
Args:
column: The name of the column to apply a filter on
pattern: The pattern to filter by
"""
return self.filter(column, Filters.LIKE, pattern)
def like_all_of(self: Self, column: str, pattern: str) -> Self:
"""A 'LIKE' filter, to use for pattern matching.
Args:
column: The name of the column to apply a filter on
pattern: The pattern to filter by
"""
return self.filter(column, Filters.LIKE_ALL, f"{{{pattern}}}")
def like_any_of(self: Self, column: str, pattern: str) -> Self:
"""A 'LIKE' filter, to use for pattern matching.
Args:
column: The name of the column to apply a filter on
pattern: The pattern to filter by
"""
return self.filter(column, Filters.LIKE_ANY, f"{{{pattern}}}")
def ilike_all_of(self: Self, column: str, pattern: str) -> Self:
"""A 'ILIKE' filter, to use for pattern matching (case insensitive).
Args:
column: The name of the column to apply a filter on
pattern: The pattern to filter by
"""
return self.filter(column, Filters.ILIKE_ALL, f"{{{pattern}}}")
def ilike_any_of(self: Self, column: str, pattern: str) -> Self:
"""A 'ILIKE' filter, to use for pattern matching (case insensitive).
Args:
column: The name of the column to apply a filter on
pattern: The pattern to filter by
"""
return self.filter(column, Filters.ILIKE_ANY, f"{{{pattern}}}")
def ilike(self: Self, column: str, pattern: str) -> Self:
"""An 'ILIKE' filter, to use for pattern matching (case insensitive).
Args:
column: The name of the column to apply a filter on
pattern: The pattern to filter by
"""
return self.filter(column, Filters.ILIKE, pattern)
def or_(self: Self, filters: str, reference_table: Optional[str] = None) -> Self:
"""An 'or' filter
Args:
filters: The filters to use, following PostgREST syntax
reference_table: Set this to filter on referenced tables instead of the parent table
"""
key = f"{sanitize_param(reference_table)}.or" if reference_table else "or"
self.params = self.params.add(key, f"({filters})")
return self
def fts(self: Self, column: str, query: Any) -> Self:
return self.filter(column, Filters.FTS, query)
def plfts(self: Self, column: str, query: Any) -> Self:
return self.filter(column, Filters.PLFTS, query)
def phfts(self: Self, column: str, query: Any) -> Self:
return self.filter(column, Filters.PHFTS, query)
def wfts(self: Self, column: str, query: Any) -> Self:
return self.filter(column, Filters.WFTS, query)
def in_(self: Self, column: str, values: Iterable[Any]) -> Self:
values = map(sanitize_param, values)
values = ",".join(values)
return self.filter(column, Filters.IN, f"({values})")
def cs(self: Self, column: str, values: Iterable[Any]) -> Self:
values = ",".join(values)
return self.filter(column, Filters.CS, f"{{{values}}}")
def cd(self: Self, column: str, values: Iterable[Any]) -> Self:
values = ",".join(values)
return self.filter(column, Filters.CD, f"{{{values}}}")
def contains(
self: Self, column: str, value: Union[Iterable[Any], str, Dict[Any, Any]]
) -> Self:
if isinstance(value, str):
# range types can be inclusive '[', ']' or exclusive '(', ')' so just
# keep it simple and accept a string
return self.filter(column, Filters.CS, value)
if not isinstance(value, dict) and isinstance(value, Iterable):
# Expected to be some type of iterable
stringified_values = ",".join(value)
return self.filter(column, Filters.CS, f"{{{stringified_values}}}")
return self.filter(column, Filters.CS, json.dumps(value))
def contained_by(
self: Self, column: str, value: Union[Iterable[Any], str, Dict[Any, Any]]
) -> Self:
if isinstance(value, str):
# range
return self.filter(column, Filters.CD, value)
if not isinstance(value, dict) and isinstance(value, Iterable):
stringified_values = ",".join(value)
return self.filter(column, Filters.CD, f"{{{stringified_values}}}")
return self.filter(column, Filters.CD, json.dumps(value))
def ov(self: Self, column: str, value: Iterable[Any]) -> Self:
if isinstance(value, str):
# range types can be inclusive '[', ']' or exclusive '(', ')' so just
# keep it simple and accept a string
return self.filter(column, Filters.OV, value)
if not isinstance(value, dict) and isinstance(value, Iterable):
# Expected to be some type of iterable
stringified_values = ",".join(value)
return self.filter(column, Filters.OV, f"{{{stringified_values}}}")
return self.filter(column, Filters.OV, json.dumps(value))
def sl(self: Self, column: str, range: Tuple[int, int]) -> Self:
return self.filter(column, Filters.SL, f"({range[0]},{range[1]})")
def sr(self: Self, column: str, range: Tuple[int, int]) -> Self:
return self.filter(column, Filters.SR, f"({range[0]},{range[1]})")
def nxl(self: Self, column: str, range: Tuple[int, int]) -> Self:
return self.filter(column, Filters.NXL, f"({range[0]},{range[1]})")
def nxr(self: Self, column: str, range: Tuple[int, int]) -> Self:
return self.filter(column, Filters.NXR, f"({range[0]},{range[1]})")
def adj(self: Self, column: str, range: Tuple[int, int]) -> Self:
return self.filter(column, Filters.ADJ, f"({range[0]},{range[1]})")
def range_gt(self: Self, column: str, range: Tuple[int, int]) -> Self:
return self.sr(column, range)
def range_gte(self: Self, column: str, range: Tuple[int, int]) -> Self:
return self.nxl(column, range)
def range_lt(self: Self, column: str, range: Tuple[int, int]) -> Self:
return self.sl(column, range)
def range_lte(self: Self, column: str, range: Tuple[int, int]) -> Self:
return self.nxr(column, range)
def range_adjacent(self: Self, column: str, range: Tuple[int, int]) -> Self:
return self.adj(column, range)
def overlaps(self: Self, column: str, values: Iterable[Any]) -> Self:
return self.ov(column, values)
def match(self: Self, query: Dict[str, Any]) -> Self:
updated_query = self
if not query:
raise ValueError(
"query dictionary should contain at least one key-value pair"
)
for key, value in query.items():
updated_query = self.eq(key, value)
return updated_query
class BaseSelectRequestBuilder(BaseFilterRequestBuilder[_ReturnT]):
def __init__(
self,
session: Union[AsyncClient, SyncClient],
headers: Headers,
params: QueryParams,
) -> None:
# Generic[T] is an instance of typing._GenericAlias, so doing Generic[T].__init__
# tries to call _GenericAlias.__init__ - which is the wrong method
# The __origin__ attribute of the _GenericAlias is the actual class
get_origin_and_cast(BaseFilterRequestBuilder[_ReturnT]).__init__(
self, session, headers, params
)
def explain(
self: Self,
analyze: bool = False,
verbose: bool = False,
settings: bool = False,
buffers: bool = False,
wal: bool = False,
format: Literal["text", "json"] = "text",
) -> Self:
options = [
key
for key, value in locals().items()
if key not in ["self", "format"] and value
]
options_str = "|".join(options)
self.headers["Accept"] = (
f"application/vnd.pgrst.plan+{format}; options={options_str}"
)
return self
def order(
self: Self,
column: str,
*,
desc: bool = False,
nullsfirst: bool = False,
foreign_table: Optional[str] = None,
) -> Self:
"""Sort the returned rows in some specific order.
Args:
column: The column to order by
desc: Whether the rows should be ordered in descending order or not.
nullsfirst: nullsfirst
foreign_table: Foreign table name whose results are to be ordered.
.. versionchanged:: 0.10.3
Allow ordering results for foreign tables with the foreign_table parameter.
"""
new_order_parameter = (
f"{foreign_table + '(' if foreign_table else ''}{column}{')' if foreign_table else ''}"
f"{'.desc' if desc else ''}{'.nullsfirst' if nullsfirst else ''}"
)
existing_order_parameter = self.params.get("order")
if existing_order_parameter:
self.params = self.params.remove("order")
new_order_parameter = f"{existing_order_parameter},{new_order_parameter}"
self.params = self.params.add(
"order",
new_order_parameter,
)
return self
def limit(self: Self, size: int, *, foreign_table: Optional[str] = None) -> Self:
"""Limit the number of rows returned by a query.
Args:
size: The number of rows to be returned
foreign_table: Foreign table name to limit
.. versionchanged:: 0.10.3
Allow limiting results returned for foreign tables with the foreign_table parameter.
"""
self.params = self.params.add(
f"{foreign_table}.limit" if foreign_table else "limit",
size,
)
return self
def offset(self: _FilterT, size: int) -> _FilterT:
"""Set the starting row index returned by a query.
Args:
size: The number of the row to start at
"""
self.params = self.params.add(
"offset",
size,
)
return self
def range(
self: Self, start: int, end: int, foreign_table: Optional[str] = None
) -> Self:
self.params = self.params.add(
f"{foreign_table}.offset" if foreign_table else "offset", start
)
self.params = self.params.add(
f"{foreign_table}.limit" if foreign_table else "limit",
end - start + 1,
)
return self
class BaseRPCRequestBuilder(BaseSelectRequestBuilder[_ReturnT]):
def __init__(
self,
session: Union[AsyncClient, SyncClient],
headers: Headers,
params: QueryParams,
) -> None:
# Generic[T] is an instance of typing._GenericAlias, so doing Generic[T].__init__
# tries to call _GenericAlias.__init__ - which is the wrong method
# The __origin__ attribute of the _GenericAlias is the actual class
get_origin_and_cast(BaseSelectRequestBuilder[_ReturnT]).__init__(
self, session, headers, params
)
def select(
self,
*columns: str,
) -> Self:
"""Run a SELECT query.
Args:
*columns: The names of the columns to fetch.
Returns:
:class:`BaseSelectRequestBuilder`
"""
method, params, headers, json = pre_select(*columns, count=None)
self.params = self.params.add("select", params.get("select"))
self.headers["Prefer"] = "return=representation"
return self
def single(self) -> Self:
"""Specify that the query will only return a single row in response.
.. caution::
The API will raise an error if the query returned more than one row.
"""
self.headers["Accept"] = "application/vnd.pgrst.object+json"
return self
def maybe_single(self) -> Self:
"""Retrieves at most one row from the result. Result must be at most one row (e.g. using `eq` on a UNIQUE column), otherwise this will result in an error."""
self.headers["Accept"] = "application/vnd.pgrst.object+json"
return self
def csv(self) -> Self:
"""Specify that the query must retrieve data as a single CSV string."""
self.headers["Accept"] = "text/csv"
return self