From 4a52a71956a8d46fcb7294ac71734504bb09bcc2 Mon Sep 17 00:00:00 2001 From: S. Solomon Darnell Date: Fri, 28 Mar 2025 21:52:21 -0500 Subject: two version of R2R are here --- .../site-packages/shared/abstractions/base.py | 145 +++++++++++++++++++++ 1 file changed, 145 insertions(+) create mode 100644 .venv/lib/python3.12/site-packages/shared/abstractions/base.py (limited to '.venv/lib/python3.12/site-packages/shared/abstractions/base.py') diff --git a/.venv/lib/python3.12/site-packages/shared/abstractions/base.py b/.venv/lib/python3.12/site-packages/shared/abstractions/base.py new file mode 100644 index 00000000..d90ba400 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/shared/abstractions/base.py @@ -0,0 +1,145 @@ +import asyncio +import json +from datetime import datetime +from enum import Enum +from typing import Any, Type, TypeVar +from uuid import UUID + +from pydantic import BaseModel + +T = TypeVar("T", bound="R2RSerializable") + + +class R2RSerializable(BaseModel): + @classmethod + def from_dict(cls: Type[T], data: dict[str, Any] | str) -> T: + if isinstance(data, str): + try: + data_dict = json.loads(data) + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON string: {e}") from e + else: + data_dict = data + return cls(**data_dict) + + def as_dict(self) -> dict[str, Any]: + data = self.model_dump(exclude_unset=True) + return self._serialize_values(data) + + def to_dict(self) -> dict[str, Any]: + data = self.model_dump(exclude_unset=True) + return self._serialize_values(data) + + def to_json(self) -> str: + data = self.to_dict() + return json.dumps(data) + + @classmethod + def from_json(cls: Type[T], json_str: str) -> T: + return cls.model_validate_json(json_str) + + @staticmethod + def _serialize_values(data: Any) -> Any: + if isinstance(data, dict): + return { + k: R2RSerializable._serialize_values(v) + for k, v in data.items() + } + elif isinstance(data, list): + return [R2RSerializable._serialize_values(v) for v in data] + elif isinstance(data, UUID): + return str(data) + elif isinstance(data, Enum): + return data.value + elif isinstance(data, datetime): + return data.isoformat() + else: + return data + + class Config: + arbitrary_types_allowed = True + json_encoders = { + UUID: str, + bytes: lambda v: v.decode("utf-8", errors="ignore"), + } + + +class AsyncSyncMeta(type): + _event_loop = None # Class-level shared event loop + + @classmethod + def get_event_loop(cls): + if cls._event_loop is None or cls._event_loop.is_closed(): + cls._event_loop = asyncio.new_event_loop() + asyncio.set_event_loop(cls._event_loop) + return cls._event_loop + + def __new__(cls, name, bases, dct): + new_cls = super().__new__(cls, name, bases, dct) + for attr_name, attr_value in dct.items(): + if asyncio.iscoroutinefunction(attr_value) and getattr( + attr_value, "_syncable", False + ): + sync_method_name = attr_name[ + 1: + ] # Remove leading 'a' for sync method + async_method = attr_value + + def make_sync_method(async_method): + def sync_wrapper(self, *args, **kwargs): + loop = cls.get_event_loop() + if not loop.is_running(): + # Setup to run the loop in a background thread if necessary + # to prevent blocking the main thread in a synchronous call environment + from threading import Thread + + result = None + exception = None + + def run(): + nonlocal result, exception + try: + asyncio.set_event_loop(loop) + result = loop.run_until_complete( + async_method(self, *args, **kwargs) + ) + except Exception as e: + exception = e + finally: + generation_config = kwargs.get( + "rag_generation_config", None + ) + if ( + not generation_config + or not generation_config.stream + ): + loop.run_until_complete( + loop.shutdown_asyncgens() + ) + loop.close() + + thread = Thread(target=run) + thread.start() + thread.join() + if exception: + raise exception + return result + else: + # If there's already a running loop, schedule and execute the coroutine + future = asyncio.run_coroutine_threadsafe( + async_method(self, *args, **kwargs), loop + ) + return future.result() + + return sync_wrapper + + setattr( + new_cls, sync_method_name, make_sync_method(async_method) + ) + return new_cls + + +def syncable(func): + """Decorator to mark methods for synchronous wrapper creation.""" + func._syncable = True + return func -- cgit v1.2.3