about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/starlette/schemas.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/starlette/schemas.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/starlette/schemas.py')
-rw-r--r--.venv/lib/python3.12/site-packages/starlette/schemas.py147
1 files changed, 147 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/starlette/schemas.py b/.venv/lib/python3.12/site-packages/starlette/schemas.py
new file mode 100644
index 00000000..bfc40e2a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/starlette/schemas.py
@@ -0,0 +1,147 @@
+from __future__ import annotations
+
+import inspect
+import re
+import typing
+
+from starlette.requests import Request
+from starlette.responses import Response
+from starlette.routing import BaseRoute, Host, Mount, Route
+
+try:
+    import yaml
+except ModuleNotFoundError:  # pragma: no cover
+    yaml = None  # type: ignore[assignment]
+
+
+class OpenAPIResponse(Response):
+    media_type = "application/vnd.oai.openapi"
+
+    def render(self, content: typing.Any) -> bytes:
+        assert yaml is not None, "`pyyaml` must be installed to use OpenAPIResponse."
+        assert isinstance(content, dict), "The schema passed to OpenAPIResponse should be a dictionary."
+        return yaml.dump(content, default_flow_style=False).encode("utf-8")
+
+
+class EndpointInfo(typing.NamedTuple):
+    path: str
+    http_method: str
+    func: typing.Callable[..., typing.Any]
+
+
+_remove_converter_pattern = re.compile(r":\w+}")
+
+
+class BaseSchemaGenerator:
+    def get_schema(self, routes: list[BaseRoute]) -> dict[str, typing.Any]:
+        raise NotImplementedError()  # pragma: no cover
+
+    def get_endpoints(self, routes: list[BaseRoute]) -> list[EndpointInfo]:
+        """
+        Given the routes, yields the following information:
+
+        - path
+            eg: /users/
+        - http_method
+            one of 'get', 'post', 'put', 'patch', 'delete', 'options'
+        - func
+            method ready to extract the docstring
+        """
+        endpoints_info: list[EndpointInfo] = []
+
+        for route in routes:
+            if isinstance(route, (Mount, Host)):
+                routes = route.routes or []
+                if isinstance(route, Mount):
+                    path = self._remove_converter(route.path)
+                else:
+                    path = ""
+                sub_endpoints = [
+                    EndpointInfo(
+                        path="".join((path, sub_endpoint.path)),
+                        http_method=sub_endpoint.http_method,
+                        func=sub_endpoint.func,
+                    )
+                    for sub_endpoint in self.get_endpoints(routes)
+                ]
+                endpoints_info.extend(sub_endpoints)
+
+            elif not isinstance(route, Route) or not route.include_in_schema:
+                continue
+
+            elif inspect.isfunction(route.endpoint) or inspect.ismethod(route.endpoint):
+                path = self._remove_converter(route.path)
+                for method in route.methods or ["GET"]:
+                    if method == "HEAD":
+                        continue
+                    endpoints_info.append(EndpointInfo(path, method.lower(), route.endpoint))
+            else:
+                path = self._remove_converter(route.path)
+                for method in ["get", "post", "put", "patch", "delete", "options"]:
+                    if not hasattr(route.endpoint, method):
+                        continue
+                    func = getattr(route.endpoint, method)
+                    endpoints_info.append(EndpointInfo(path, method.lower(), func))
+
+        return endpoints_info
+
+    def _remove_converter(self, path: str) -> str:
+        """
+        Remove the converter from the path.
+        For example, a route like this:
+            Route("/users/{id:int}", endpoint=get_user, methods=["GET"])
+        Should be represented as `/users/{id}` in the OpenAPI schema.
+        """
+        return _remove_converter_pattern.sub("}", path)
+
+    def parse_docstring(self, func_or_method: typing.Callable[..., typing.Any]) -> dict[str, typing.Any]:
+        """
+        Given a function, parse the docstring as YAML and return a dictionary of info.
+        """
+        docstring = func_or_method.__doc__
+        if not docstring:
+            return {}
+
+        assert yaml is not None, "`pyyaml` must be installed to use parse_docstring."
+
+        # We support having regular docstrings before the schema
+        # definition. Here we return just the schema part from
+        # the docstring.
+        docstring = docstring.split("---")[-1]
+
+        parsed = yaml.safe_load(docstring)
+
+        if not isinstance(parsed, dict):
+            # A regular docstring (not yaml formatted) can return
+            # a simple string here, which wouldn't follow the schema.
+            return {}
+
+        return parsed
+
+    def OpenAPIResponse(self, request: Request) -> Response:
+        routes = request.app.routes
+        schema = self.get_schema(routes=routes)
+        return OpenAPIResponse(schema)
+
+
+class SchemaGenerator(BaseSchemaGenerator):
+    def __init__(self, base_schema: dict[str, typing.Any]) -> None:
+        self.base_schema = base_schema
+
+    def get_schema(self, routes: list[BaseRoute]) -> dict[str, typing.Any]:
+        schema = dict(self.base_schema)
+        schema.setdefault("paths", {})
+        endpoints_info = self.get_endpoints(routes)
+
+        for endpoint in endpoints_info:
+            parsed = self.parse_docstring(endpoint.func)
+
+            if not parsed:
+                continue
+
+            if endpoint.path not in schema["paths"]:
+                schema["paths"][endpoint.path] = {}
+
+            schema["paths"][endpoint.path][endpoint.http_method] = parsed
+
+        return schema