aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/shared/abstractions/base.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/shared/abstractions/base.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to '.venv/lib/python3.12/site-packages/shared/abstractions/base.py')
-rw-r--r--.venv/lib/python3.12/site-packages/shared/abstractions/base.py145
1 files changed, 145 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/shared/abstractions/base.py b/.venv/lib/python3.12/site-packages/shared/abstractions/base.py
new file mode 100644
index 00000000..d90ba400
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/shared/abstractions/base.py
@@ -0,0 +1,145 @@
+import asyncio
+import json
+from datetime import datetime
+from enum import Enum
+from typing import Any, Type, TypeVar
+from uuid import UUID
+
+from pydantic import BaseModel
+
+T = TypeVar("T", bound="R2RSerializable")
+
+
+class R2RSerializable(BaseModel):
+ @classmethod
+ def from_dict(cls: Type[T], data: dict[str, Any] | str) -> T:
+ if isinstance(data, str):
+ try:
+ data_dict = json.loads(data)
+ except json.JSONDecodeError as e:
+ raise ValueError(f"Invalid JSON string: {e}") from e
+ else:
+ data_dict = data
+ return cls(**data_dict)
+
+ def as_dict(self) -> dict[str, Any]:
+ data = self.model_dump(exclude_unset=True)
+ return self._serialize_values(data)
+
+ def to_dict(self) -> dict[str, Any]:
+ data = self.model_dump(exclude_unset=True)
+ return self._serialize_values(data)
+
+ def to_json(self) -> str:
+ data = self.to_dict()
+ return json.dumps(data)
+
+ @classmethod
+ def from_json(cls: Type[T], json_str: str) -> T:
+ return cls.model_validate_json(json_str)
+
+ @staticmethod
+ def _serialize_values(data: Any) -> Any:
+ if isinstance(data, dict):
+ return {
+ k: R2RSerializable._serialize_values(v)
+ for k, v in data.items()
+ }
+ elif isinstance(data, list):
+ return [R2RSerializable._serialize_values(v) for v in data]
+ elif isinstance(data, UUID):
+ return str(data)
+ elif isinstance(data, Enum):
+ return data.value
+ elif isinstance(data, datetime):
+ return data.isoformat()
+ else:
+ return data
+
+ class Config:
+ arbitrary_types_allowed = True
+ json_encoders = {
+ UUID: str,
+ bytes: lambda v: v.decode("utf-8", errors="ignore"),
+ }
+
+
+class AsyncSyncMeta(type):
+ _event_loop = None # Class-level shared event loop
+
+ @classmethod
+ def get_event_loop(cls):
+ if cls._event_loop is None or cls._event_loop.is_closed():
+ cls._event_loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(cls._event_loop)
+ return cls._event_loop
+
+ def __new__(cls, name, bases, dct):
+ new_cls = super().__new__(cls, name, bases, dct)
+ for attr_name, attr_value in dct.items():
+ if asyncio.iscoroutinefunction(attr_value) and getattr(
+ attr_value, "_syncable", False
+ ):
+ sync_method_name = attr_name[
+ 1:
+ ] # Remove leading 'a' for sync method
+ async_method = attr_value
+
+ def make_sync_method(async_method):
+ def sync_wrapper(self, *args, **kwargs):
+ loop = cls.get_event_loop()
+ if not loop.is_running():
+ # Setup to run the loop in a background thread if necessary
+ # to prevent blocking the main thread in a synchronous call environment
+ from threading import Thread
+
+ result = None
+ exception = None
+
+ def run():
+ nonlocal result, exception
+ try:
+ asyncio.set_event_loop(loop)
+ result = loop.run_until_complete(
+ async_method(self, *args, **kwargs)
+ )
+ except Exception as e:
+ exception = e
+ finally:
+ generation_config = kwargs.get(
+ "rag_generation_config", None
+ )
+ if (
+ not generation_config
+ or not generation_config.stream
+ ):
+ loop.run_until_complete(
+ loop.shutdown_asyncgens()
+ )
+ loop.close()
+
+ thread = Thread(target=run)
+ thread.start()
+ thread.join()
+ if exception:
+ raise exception
+ return result
+ else:
+ # If there's already a running loop, schedule and execute the coroutine
+ future = asyncio.run_coroutine_threadsafe(
+ async_method(self, *args, **kwargs), loop
+ )
+ return future.result()
+
+ return sync_wrapper
+
+ setattr(
+ new_cls, sync_method_name, make_sync_method(async_method)
+ )
+ return new_cls
+
+
+def syncable(func):
+ """Decorator to mark methods for synchronous wrapper creation."""
+ func._syncable = True
+ return func