about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/postgrest/base_request_builder.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/postgrest/base_request_builder.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/postgrest/base_request_builder.py')
-rw-r--r--.venv/lib/python3.12/site-packages/postgrest/base_request_builder.py685
1 files changed, 685 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/postgrest/base_request_builder.py b/.venv/lib/python3.12/site-packages/postgrest/base_request_builder.py
new file mode 100644
index 00000000..7b5ab4b7
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/postgrest/base_request_builder.py
@@ -0,0 +1,685 @@
+from __future__ import annotations
+
+import json
+from json import JSONDecodeError
+from re import search
+from typing import (
+    Any,
+    Dict,
+    Generic,
+    Iterable,
+    List,
+    Literal,
+    NamedTuple,
+    Optional,
+    Tuple,
+    Type,
+    TypeVar,
+    Union,
+)
+
+from httpx import Headers, QueryParams
+from httpx import Response as RequestResponse
+from pydantic import BaseModel
+
+try:
+    from typing import Self
+except ImportError:
+    from typing_extensions import Self
+
+try:
+    # >= 2.0.0
+    from pydantic import field_validator
+except ImportError:
+    # < 2.0.0
+    from pydantic import validator as field_validator
+
+from .types import CountMethod, Filters, RequestMethod, ReturnMethod
+from .utils import AsyncClient, SyncClient, get_origin_and_cast, sanitize_param
+
+
+class QueryArgs(NamedTuple):
+    # groups the method, json, headers and params for a query in a single object
+    method: RequestMethod
+    params: QueryParams
+    headers: Headers
+    json: Dict[Any, Any]
+
+
+def _unique_columns(json: List[Dict]):
+    unique_keys = {key for row in json for key in row.keys()}
+    columns = ",".join([f'"{k}"' for k in unique_keys])
+    return columns
+
+
+def _cleaned_columns(columns: Tuple[str, ...]) -> str:
+    quoted = False
+    cleaned = []
+
+    for column in columns:
+        clean_column = ""
+        for char in column:
+            if char.isspace() and not quoted:
+                continue
+            if char == '"':
+                quoted = not quoted
+            clean_column += char
+        cleaned.append(clean_column)
+
+    return ",".join(cleaned)
+
+
+def pre_select(
+    *columns: str,
+    count: Optional[CountMethod] = None,
+    head: Optional[bool] = None,
+) -> QueryArgs:
+    method = RequestMethod.HEAD if head else RequestMethod.GET
+    cleaned_columns = _cleaned_columns(columns or "*")
+    params = QueryParams({"select": cleaned_columns})
+
+    headers = Headers({"Prefer": f"count={count}"}) if count else Headers()
+    return QueryArgs(method, params, headers, {})
+
+
+def pre_insert(
+    json: Union[dict, list],
+    *,
+    count: Optional[CountMethod],
+    returning: ReturnMethod,
+    upsert: bool,
+    default_to_null: bool = True,
+) -> QueryArgs:
+    prefer_headers = [f"return={returning}"]
+    if count:
+        prefer_headers.append(f"count={count}")
+    if upsert:
+        prefer_headers.append("resolution=merge-duplicates")
+    if not default_to_null:
+        prefer_headers.append("missing=default")
+    headers = Headers({"Prefer": ",".join(prefer_headers)})
+    # Adding 'columns' query parameters
+    query_params = {}
+    if isinstance(json, list):
+        query_params = {"columns": _unique_columns(json)}
+    return QueryArgs(RequestMethod.POST, QueryParams(query_params), headers, json)
+
+
+def pre_upsert(
+    json: Union[dict, list],
+    *,
+    count: Optional[CountMethod],
+    returning: ReturnMethod,
+    ignore_duplicates: bool,
+    on_conflict: str = "",
+    default_to_null: bool = True,
+) -> QueryArgs:
+    query_params = {}
+    prefer_headers = [f"return={returning}"]
+    if count:
+        prefer_headers.append(f"count={count}")
+    resolution = "ignore" if ignore_duplicates else "merge"
+    prefer_headers.append(f"resolution={resolution}-duplicates")
+    if not default_to_null:
+        prefer_headers.append("missing=default")
+    headers = Headers({"Prefer": ",".join(prefer_headers)})
+    if on_conflict:
+        query_params["on_conflict"] = on_conflict
+    # Adding 'columns' query parameters
+    if isinstance(json, list):
+        query_params["columns"] = _unique_columns(json)
+    return QueryArgs(RequestMethod.POST, QueryParams(query_params), headers, json)
+
+
+def pre_update(
+    json: dict,
+    *,
+    count: Optional[CountMethod],
+    returning: ReturnMethod,
+) -> QueryArgs:
+    prefer_headers = [f"return={returning}"]
+    if count:
+        prefer_headers.append(f"count={count}")
+    headers = Headers({"Prefer": ",".join(prefer_headers)})
+    return QueryArgs(RequestMethod.PATCH, QueryParams(), headers, json)
+
+
+def pre_delete(
+    *,
+    count: Optional[CountMethod],
+    returning: ReturnMethod,
+) -> QueryArgs:
+    prefer_headers = [f"return={returning}"]
+    if count:
+        prefer_headers.append(f"count={count}")
+    headers = Headers({"Prefer": ",".join(prefer_headers)})
+    return QueryArgs(RequestMethod.DELETE, QueryParams(), headers, {})
+
+
+_ReturnT = TypeVar("_ReturnT")
+
+
+# the APIResponse.data is marked as _ReturnT instead of list[_ReturnT]
+# as it is also returned in the case of rpc() calls; and rpc calls do not
+# necessarily return lists.
+# https://github.com/supabase-community/postgrest-py/issues/200
+class APIResponse(BaseModel, Generic[_ReturnT]):
+    data: List[_ReturnT]
+    """The data returned by the query."""
+    count: Optional[int] = None
+    """The number of rows returned."""
+
+    @field_validator("data")
+    @classmethod
+    def raise_when_api_error(cls: Type[Self], value: Any) -> Any:
+        if isinstance(value, dict) and value.get("message"):
+            raise ValueError("You are passing an API error to the data field.")
+        return value
+
+    @staticmethod
+    def _get_count_from_content_range_header(
+        content_range_header: str,
+    ) -> Optional[int]:
+        content_range = content_range_header.split("/")
+        return None if len(content_range) < 2 else int(content_range[1])
+
+    @staticmethod
+    def _is_count_in_prefer_header(prefer_header: str) -> bool:
+        pattern = f"count=({'|'.join([cm.value for cm in CountMethod])})"
+        return bool(search(pattern, prefer_header))
+
+    @classmethod
+    def _get_count_from_http_request_response(
+        cls: Type[Self],
+        request_response: RequestResponse,
+    ) -> Optional[int]:
+        prefer_header: Optional[str] = request_response.request.headers.get("prefer")
+        if not prefer_header:
+            return None
+        is_count_in_prefer_header = cls._is_count_in_prefer_header(prefer_header)
+        content_range_header: Optional[str] = request_response.headers.get(
+            "content-range"
+        )
+        return (
+            cls._get_count_from_content_range_header(content_range_header)
+            if (is_count_in_prefer_header and content_range_header)
+            else None
+        )
+
+    @classmethod
+    def from_http_request_response(
+        cls: Type[Self], request_response: RequestResponse
+    ) -> Self:
+        count = cls._get_count_from_http_request_response(request_response)
+        try:
+            data = request_response.json()
+        except JSONDecodeError:
+            data = request_response.text if len(request_response.text) > 0 else []
+        # the type-ignore here is as pydantic needs us to pass the type parameter
+        # here explicitly, but pylance already knows that cls is correctly parametrized
+        return cls[_ReturnT](data=data, count=count)  # type: ignore
+
+    @classmethod
+    def from_dict(cls: Type[Self], dict: Dict[str, Any]) -> Self:
+        keys = dict.keys()
+        assert len(keys) == 3 and "data" in keys and "count" in keys and "error" in keys
+        return cls[_ReturnT](  # type: ignore
+            data=dict.get("data"), count=dict.get("count"), error=dict.get("error")
+        )
+
+
+class SingleAPIResponse(APIResponse[_ReturnT], Generic[_ReturnT]):
+    data: _ReturnT  # type: ignore
+    """The data returned by the query."""
+
+    @classmethod
+    def from_http_request_response(
+        cls: Type[Self], request_response: RequestResponse
+    ) -> Self:
+        count = cls._get_count_from_http_request_response(request_response)
+        try:
+            data = request_response.json()
+        except JSONDecodeError:
+            data = request_response.text if len(request_response.text) > 0 else []
+        return cls[_ReturnT](data=data, count=count)  # type: ignore
+
+    @classmethod
+    def from_dict(cls: Type[Self], dict: Dict[str, Any]) -> Self:
+        keys = dict.keys()
+        assert len(keys) == 3 and "data" in keys and "count" in keys and "error" in keys
+        return cls[_ReturnT](  # type: ignore
+            data=dict.get("data"), count=dict.get("count"), error=dict.get("error")
+        )
+
+
+class BaseFilterRequestBuilder(Generic[_ReturnT]):
+    def __init__(
+        self,
+        session: Union[AsyncClient, SyncClient],
+        headers: Headers,
+        params: QueryParams,
+    ) -> None:
+        self.session = session
+        self.headers = headers
+        self.params = params
+        self.negate_next = False
+
+    @property
+    def not_(self: Self) -> Self:
+        """Whether the filter applied next should be negated."""
+        self.negate_next = True
+        return self
+
+    def filter(self: Self, column: str, operator: str, criteria: str) -> Self:
+        """Apply filters on a query.
+
+        Args:
+            column: The name of the column to apply a filter on
+            operator: The operator to use while filtering
+            criteria: The value to filter by
+        """
+        if self.negate_next is True:
+            self.negate_next = False
+            operator = f"{Filters.NOT}.{operator}"
+        key, val = sanitize_param(column), f"{operator}.{criteria}"
+        self.params = self.params.add(key, val)
+        return self
+
+    def eq(self: Self, column: str, value: Any) -> Self:
+        """An 'equal to' filter.
+
+        Args:
+            column: The name of the column to apply a filter on
+            value: The value to filter by
+        """
+        return self.filter(column, Filters.EQ, value)
+
+    def neq(self: Self, column: str, value: Any) -> Self:
+        """A 'not equal to' filter
+
+        Args:
+            column: The name of the column to apply a filter on
+            value: The value to filter by
+        """
+        return self.filter(column, Filters.NEQ, value)
+
+    def gt(self: Self, column: str, value: Any) -> Self:
+        """A 'greater than' filter
+
+        Args:
+            column: The name of the column to apply a filter on
+            value: The value to filter by
+        """
+        return self.filter(column, Filters.GT, value)
+
+    def gte(self: Self, column: str, value: Any) -> Self:
+        """A 'greater than or equal to' filter
+
+        Args:
+            column: The name of the column to apply a filter on
+            value: The value to filter by
+        """
+        return self.filter(column, Filters.GTE, value)
+
+    def lt(self: Self, column: str, value: Any) -> Self:
+        """A 'less than' filter
+
+        Args:
+            column: The name of the column to apply a filter on
+            value: The value to filter by
+        """
+        return self.filter(column, Filters.LT, value)
+
+    def lte(self: Self, column: str, value: Any) -> Self:
+        """A 'less than or equal to' filter
+
+        Args:
+            column: The name of the column to apply a filter on
+            value: The value to filter by
+        """
+        return self.filter(column, Filters.LTE, value)
+
+    def is_(self: Self, column: str, value: Any) -> Self:
+        """An 'is' filter
+
+        Args:
+            column: The name of the column to apply a filter on
+            value: The value to filter by
+        """
+        if value is None:
+            value = "null"
+        return self.filter(column, Filters.IS, value)
+
+    def like(self: Self, column: str, pattern: str) -> Self:
+        """A 'LIKE' filter, to use for pattern matching.
+
+        Args:
+            column: The name of the column to apply a filter on
+            pattern: The pattern to filter by
+        """
+        return self.filter(column, Filters.LIKE, pattern)
+
+    def like_all_of(self: Self, column: str, pattern: str) -> Self:
+        """A 'LIKE' filter, to use for pattern matching.
+
+        Args:
+            column: The name of the column to apply a filter on
+            pattern: The pattern to filter by
+        """
+
+        return self.filter(column, Filters.LIKE_ALL, f"{{{pattern}}}")
+
+    def like_any_of(self: Self, column: str, pattern: str) -> Self:
+        """A 'LIKE' filter, to use for pattern matching.
+
+        Args:
+            column: The name of the column to apply a filter on
+            pattern: The pattern to filter by
+        """
+
+        return self.filter(column, Filters.LIKE_ANY, f"{{{pattern}}}")
+
+    def ilike_all_of(self: Self, column: str, pattern: str) -> Self:
+        """A 'ILIKE' filter, to use for pattern matching (case insensitive).
+
+        Args:
+            column: The name of the column to apply a filter on
+            pattern: The pattern to filter by
+        """
+
+        return self.filter(column, Filters.ILIKE_ALL, f"{{{pattern}}}")
+
+    def ilike_any_of(self: Self, column: str, pattern: str) -> Self:
+        """A 'ILIKE' filter, to use for pattern matching (case insensitive).
+
+        Args:
+            column: The name of the column to apply a filter on
+            pattern: The pattern to filter by
+        """
+
+        return self.filter(column, Filters.ILIKE_ANY, f"{{{pattern}}}")
+
+    def ilike(self: Self, column: str, pattern: str) -> Self:
+        """An 'ILIKE' filter, to use for pattern matching (case insensitive).
+
+        Args:
+            column: The name of the column to apply a filter on
+            pattern: The pattern to filter by
+        """
+        return self.filter(column, Filters.ILIKE, pattern)
+
+    def or_(self: Self, filters: str, reference_table: Optional[str] = None) -> Self:
+        """An 'or' filter
+
+        Args:
+            filters: The filters to use, following PostgREST syntax
+            reference_table: Set this to filter on referenced tables instead of the parent table
+        """
+        key = f"{sanitize_param(reference_table)}.or" if reference_table else "or"
+        self.params = self.params.add(key, f"({filters})")
+        return self
+
+    def fts(self: Self, column: str, query: Any) -> Self:
+        return self.filter(column, Filters.FTS, query)
+
+    def plfts(self: Self, column: str, query: Any) -> Self:
+        return self.filter(column, Filters.PLFTS, query)
+
+    def phfts(self: Self, column: str, query: Any) -> Self:
+        return self.filter(column, Filters.PHFTS, query)
+
+    def wfts(self: Self, column: str, query: Any) -> Self:
+        return self.filter(column, Filters.WFTS, query)
+
+    def in_(self: Self, column: str, values: Iterable[Any]) -> Self:
+        values = map(sanitize_param, values)
+        values = ",".join(values)
+        return self.filter(column, Filters.IN, f"({values})")
+
+    def cs(self: Self, column: str, values: Iterable[Any]) -> Self:
+        values = ",".join(values)
+        return self.filter(column, Filters.CS, f"{{{values}}}")
+
+    def cd(self: Self, column: str, values: Iterable[Any]) -> Self:
+        values = ",".join(values)
+        return self.filter(column, Filters.CD, f"{{{values}}}")
+
+    def contains(
+        self: Self, column: str, value: Union[Iterable[Any], str, Dict[Any, Any]]
+    ) -> Self:
+        if isinstance(value, str):
+            # range types can be inclusive '[', ']' or exclusive '(', ')' so just
+            # keep it simple and accept a string
+            return self.filter(column, Filters.CS, value)
+        if not isinstance(value, dict) and isinstance(value, Iterable):
+            # Expected to be some type of iterable
+            stringified_values = ",".join(value)
+            return self.filter(column, Filters.CS, f"{{{stringified_values}}}")
+
+        return self.filter(column, Filters.CS, json.dumps(value))
+
+    def contained_by(
+        self: Self, column: str, value: Union[Iterable[Any], str, Dict[Any, Any]]
+    ) -> Self:
+        if isinstance(value, str):
+            # range
+            return self.filter(column, Filters.CD, value)
+        if not isinstance(value, dict) and isinstance(value, Iterable):
+            stringified_values = ",".join(value)
+            return self.filter(column, Filters.CD, f"{{{stringified_values}}}")
+        return self.filter(column, Filters.CD, json.dumps(value))
+
+    def ov(self: Self, column: str, value: Iterable[Any]) -> Self:
+        if isinstance(value, str):
+            # range types can be inclusive '[', ']' or exclusive '(', ')' so just
+            # keep it simple and accept a string
+            return self.filter(column, Filters.OV, value)
+        if not isinstance(value, dict) and isinstance(value, Iterable):
+            # Expected to be some type of iterable
+            stringified_values = ",".join(value)
+            return self.filter(column, Filters.OV, f"{{{stringified_values}}}")
+        return self.filter(column, Filters.OV, json.dumps(value))
+
+    def sl(self: Self, column: str, range: Tuple[int, int]) -> Self:
+        return self.filter(column, Filters.SL, f"({range[0]},{range[1]})")
+
+    def sr(self: Self, column: str, range: Tuple[int, int]) -> Self:
+        return self.filter(column, Filters.SR, f"({range[0]},{range[1]})")
+
+    def nxl(self: Self, column: str, range: Tuple[int, int]) -> Self:
+        return self.filter(column, Filters.NXL, f"({range[0]},{range[1]})")
+
+    def nxr(self: Self, column: str, range: Tuple[int, int]) -> Self:
+        return self.filter(column, Filters.NXR, f"({range[0]},{range[1]})")
+
+    def adj(self: Self, column: str, range: Tuple[int, int]) -> Self:
+        return self.filter(column, Filters.ADJ, f"({range[0]},{range[1]})")
+
+    def range_gt(self: Self, column: str, range: Tuple[int, int]) -> Self:
+        return self.sr(column, range)
+
+    def range_gte(self: Self, column: str, range: Tuple[int, int]) -> Self:
+        return self.nxl(column, range)
+
+    def range_lt(self: Self, column: str, range: Tuple[int, int]) -> Self:
+        return self.sl(column, range)
+
+    def range_lte(self: Self, column: str, range: Tuple[int, int]) -> Self:
+        return self.nxr(column, range)
+
+    def range_adjacent(self: Self, column: str, range: Tuple[int, int]) -> Self:
+        return self.adj(column, range)
+
+    def overlaps(self: Self, column: str, values: Iterable[Any]) -> Self:
+        return self.ov(column, values)
+
+    def match(self: Self, query: Dict[str, Any]) -> Self:
+        updated_query = self
+
+        if not query:
+            raise ValueError(
+                "query dictionary should contain at least one key-value pair"
+            )
+
+        for key, value in query.items():
+            updated_query = self.eq(key, value)
+
+        return updated_query
+
+
+class BaseSelectRequestBuilder(BaseFilterRequestBuilder[_ReturnT]):
+    def __init__(
+        self,
+        session: Union[AsyncClient, SyncClient],
+        headers: Headers,
+        params: QueryParams,
+    ) -> None:
+        # Generic[T] is an instance of typing._GenericAlias, so doing Generic[T].__init__
+        # tries to call _GenericAlias.__init__ - which is the wrong method
+        # The __origin__ attribute of the _GenericAlias is the actual class
+        get_origin_and_cast(BaseFilterRequestBuilder[_ReturnT]).__init__(
+            self, session, headers, params
+        )
+
+    def explain(
+        self: Self,
+        analyze: bool = False,
+        verbose: bool = False,
+        settings: bool = False,
+        buffers: bool = False,
+        wal: bool = False,
+        format: Literal["text", "json"] = "text",
+    ) -> Self:
+        options = [
+            key
+            for key, value in locals().items()
+            if key not in ["self", "format"] and value
+        ]
+        options_str = "|".join(options)
+        self.headers["Accept"] = (
+            f"application/vnd.pgrst.plan+{format}; options={options_str}"
+        )
+        return self
+
+    def order(
+        self: Self,
+        column: str,
+        *,
+        desc: bool = False,
+        nullsfirst: bool = False,
+        foreign_table: Optional[str] = None,
+    ) -> Self:
+        """Sort the returned rows in some specific order.
+
+        Args:
+            column: The column to order by
+            desc: Whether the rows should be ordered in descending order or not.
+            nullsfirst: nullsfirst
+            foreign_table: Foreign table name whose results are to be ordered.
+        .. versionchanged:: 0.10.3
+           Allow ordering results for foreign tables with the foreign_table parameter.
+        """
+
+        new_order_parameter = (
+            f"{foreign_table + '(' if foreign_table else ''}{column}{')' if foreign_table else ''}"
+            f"{'.desc' if desc else ''}{'.nullsfirst' if nullsfirst else ''}"
+        )
+
+        existing_order_parameter = self.params.get("order")
+        if existing_order_parameter:
+            self.params = self.params.remove("order")
+            new_order_parameter = f"{existing_order_parameter},{new_order_parameter}"
+
+        self.params = self.params.add(
+            "order",
+            new_order_parameter,
+        )
+        return self
+
+    def limit(self: Self, size: int, *, foreign_table: Optional[str] = None) -> Self:
+        """Limit the number of rows returned by a query.
+
+        Args:
+            size: The number of rows to be returned
+            foreign_table: Foreign table name to limit
+        .. versionchanged:: 0.10.3
+           Allow limiting results returned for foreign tables with the foreign_table parameter.
+        """
+        self.params = self.params.add(
+            f"{foreign_table}.limit" if foreign_table else "limit",
+            size,
+        )
+        return self
+
+    def offset(self: _FilterT, size: int) -> _FilterT:
+        """Set the starting row index returned by a query.
+        Args:
+            size: The number of the row to start at
+        """
+        self.params = self.params.add(
+            "offset",
+            size,
+        )
+        return self
+
+    def range(
+        self: Self, start: int, end: int, foreign_table: Optional[str] = None
+    ) -> Self:
+        self.params = self.params.add(
+            f"{foreign_table}.offset" if foreign_table else "offset", start
+        )
+        self.params = self.params.add(
+            f"{foreign_table}.limit" if foreign_table else "limit",
+            end - start + 1,
+        )
+        return self
+
+
+class BaseRPCRequestBuilder(BaseSelectRequestBuilder[_ReturnT]):
+    def __init__(
+        self,
+        session: Union[AsyncClient, SyncClient],
+        headers: Headers,
+        params: QueryParams,
+    ) -> None:
+        # Generic[T] is an instance of typing._GenericAlias, so doing Generic[T].__init__
+        # tries to call _GenericAlias.__init__ - which is the wrong method
+        # The __origin__ attribute of the _GenericAlias is the actual class
+        get_origin_and_cast(BaseSelectRequestBuilder[_ReturnT]).__init__(
+            self, session, headers, params
+        )
+
+    def select(
+        self,
+        *columns: str,
+    ) -> Self:
+        """Run a SELECT query.
+
+        Args:
+            *columns: The names of the columns to fetch.
+        Returns:
+            :class:`BaseSelectRequestBuilder`
+        """
+        method, params, headers, json = pre_select(*columns, count=None)
+        self.params = self.params.add("select", params.get("select"))
+        self.headers["Prefer"] = "return=representation"
+        return self
+
+    def single(self) -> Self:
+        """Specify that the query will only return a single row in response.
+
+        .. caution::
+            The API will raise an error if the query returned more than one row.
+        """
+        self.headers["Accept"] = "application/vnd.pgrst.object+json"
+        return self
+
+    def maybe_single(self) -> Self:
+        """Retrieves at most one row from the result. Result must be at most one row (e.g. using `eq` on a UNIQUE column), otherwise this will result in an error."""
+        self.headers["Accept"] = "application/vnd.pgrst.object+json"
+        return self
+
+    def csv(self) -> Self:
+        """Specify that the query must retrieve data as a single CSV string."""
+        self.headers["Accept"] = "text/csv"
+        return self