about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/shared/abstractions/base.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/shared/abstractions/base.py')
-rw-r--r--.venv/lib/python3.12/site-packages/shared/abstractions/base.py145
1 files changed, 145 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/shared/abstractions/base.py b/.venv/lib/python3.12/site-packages/shared/abstractions/base.py
new file mode 100644
index 00000000..d90ba400
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/shared/abstractions/base.py
@@ -0,0 +1,145 @@
+import asyncio
+import json
+from datetime import datetime
+from enum import Enum
+from typing import Any, Type, TypeVar
+from uuid import UUID
+
+from pydantic import BaseModel
+
+T = TypeVar("T", bound="R2RSerializable")
+
+
+class R2RSerializable(BaseModel):
+    @classmethod
+    def from_dict(cls: Type[T], data: dict[str, Any] | str) -> T:
+        if isinstance(data, str):
+            try:
+                data_dict = json.loads(data)
+            except json.JSONDecodeError as e:
+                raise ValueError(f"Invalid JSON string: {e}") from e
+        else:
+            data_dict = data
+        return cls(**data_dict)
+
+    def as_dict(self) -> dict[str, Any]:
+        data = self.model_dump(exclude_unset=True)
+        return self._serialize_values(data)
+
+    def to_dict(self) -> dict[str, Any]:
+        data = self.model_dump(exclude_unset=True)
+        return self._serialize_values(data)
+
+    def to_json(self) -> str:
+        data = self.to_dict()
+        return json.dumps(data)
+
+    @classmethod
+    def from_json(cls: Type[T], json_str: str) -> T:
+        return cls.model_validate_json(json_str)
+
+    @staticmethod
+    def _serialize_values(data: Any) -> Any:
+        if isinstance(data, dict):
+            return {
+                k: R2RSerializable._serialize_values(v)
+                for k, v in data.items()
+            }
+        elif isinstance(data, list):
+            return [R2RSerializable._serialize_values(v) for v in data]
+        elif isinstance(data, UUID):
+            return str(data)
+        elif isinstance(data, Enum):
+            return data.value
+        elif isinstance(data, datetime):
+            return data.isoformat()
+        else:
+            return data
+
+    class Config:
+        arbitrary_types_allowed = True
+        json_encoders = {
+            UUID: str,
+            bytes: lambda v: v.decode("utf-8", errors="ignore"),
+        }
+
+
+class AsyncSyncMeta(type):
+    _event_loop = None  # Class-level shared event loop
+
+    @classmethod
+    def get_event_loop(cls):
+        if cls._event_loop is None or cls._event_loop.is_closed():
+            cls._event_loop = asyncio.new_event_loop()
+            asyncio.set_event_loop(cls._event_loop)
+        return cls._event_loop
+
+    def __new__(cls, name, bases, dct):
+        new_cls = super().__new__(cls, name, bases, dct)
+        for attr_name, attr_value in dct.items():
+            if asyncio.iscoroutinefunction(attr_value) and getattr(
+                attr_value, "_syncable", False
+            ):
+                sync_method_name = attr_name[
+                    1:
+                ]  # Remove leading 'a' for sync method
+                async_method = attr_value
+
+                def make_sync_method(async_method):
+                    def sync_wrapper(self, *args, **kwargs):
+                        loop = cls.get_event_loop()
+                        if not loop.is_running():
+                            # Setup to run the loop in a background thread if necessary
+                            # to prevent blocking the main thread in a synchronous call environment
+                            from threading import Thread
+
+                            result = None
+                            exception = None
+
+                            def run():
+                                nonlocal result, exception
+                                try:
+                                    asyncio.set_event_loop(loop)
+                                    result = loop.run_until_complete(
+                                        async_method(self, *args, **kwargs)
+                                    )
+                                except Exception as e:
+                                    exception = e
+                                finally:
+                                    generation_config = kwargs.get(
+                                        "rag_generation_config", None
+                                    )
+                                    if (
+                                        not generation_config
+                                        or not generation_config.stream
+                                    ):
+                                        loop.run_until_complete(
+                                            loop.shutdown_asyncgens()
+                                        )
+                                        loop.close()
+
+                            thread = Thread(target=run)
+                            thread.start()
+                            thread.join()
+                            if exception:
+                                raise exception
+                            return result
+                        else:
+                            # If there's already a running loop, schedule and execute the coroutine
+                            future = asyncio.run_coroutine_threadsafe(
+                                async_method(self, *args, **kwargs), loop
+                            )
+                            return future.result()
+
+                    return sync_wrapper
+
+                setattr(
+                    new_cls, sync_method_name, make_sync_method(async_method)
+                )
+        return new_cls
+
+
+def syncable(func):
+    """Decorator to mark methods for synchronous wrapper creation."""
+    func._syncable = True
+    return func