import asyncio
import json
from datetime import datetime
from enum import Enum
from typing import Any, Type, TypeVar
from uuid import UUID
from pydantic import BaseModel
T = TypeVar("T", bound="R2RSerializable")
class R2RSerializable(BaseModel):
@classmethod
def from_dict(cls: Type[T], data: dict[str, Any] | str) -> T:
if isinstance(data, str):
try:
data_dict = json.loads(data)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON string: {e}") from e
else:
data_dict = data
return cls(**data_dict)
def as_dict(self) -> dict[str, Any]:
data = self.model_dump(exclude_unset=True)
return self._serialize_values(data)
def to_dict(self) -> dict[str, Any]:
data = self.model_dump(exclude_unset=True)
return self._serialize_values(data)
def to_json(self) -> str:
data = self.to_dict()
return json.dumps(data)
@classmethod
def from_json(cls: Type[T], json_str: str) -> T:
return cls.model_validate_json(json_str)
@staticmethod
def _serialize_values(data: Any) -> Any:
if isinstance(data, dict):
return {
k: R2RSerializable._serialize_values(v)
for k, v in data.items()
}
elif isinstance(data, list):
return [R2RSerializable._serialize_values(v) for v in data]
elif isinstance(data, UUID):
return str(data)
elif isinstance(data, Enum):
return data.value
elif isinstance(data, datetime):
return data.isoformat()
else:
return data
class Config:
arbitrary_types_allowed = True
json_encoders = {
UUID: str,
bytes: lambda v: v.decode("utf-8", errors="ignore"),
}
class AsyncSyncMeta(type):
_event_loop = None # Class-level shared event loop
@classmethod
def get_event_loop(cls):
if cls._event_loop is None or cls._event_loop.is_closed():
cls._event_loop = asyncio.new_event_loop()
asyncio.set_event_loop(cls._event_loop)
return cls._event_loop
def __new__(cls, name, bases, dct):
new_cls = super().__new__(cls, name, bases, dct)
for attr_name, attr_value in dct.items():
if asyncio.iscoroutinefunction(attr_value) and getattr(
attr_value, "_syncable", False
):
sync_method_name = attr_name[
1:
] # Remove leading 'a' for sync method
async_method = attr_value
def make_sync_method(async_method):
def sync_wrapper(self, *args, **kwargs):
loop = cls.get_event_loop()
if not loop.is_running():
# Setup to run the loop in a background thread if necessary
# to prevent blocking the main thread in a synchronous call environment
from threading import Thread
result = None
exception = None
def run():
nonlocal result, exception
try:
asyncio.set_event_loop(loop)
result = loop.run_until_complete(
async_method(self, *args, **kwargs)
)
except Exception as e:
exception = e
finally:
generation_config = kwargs.get(
"rag_generation_config", None
)
if (
not generation_config
or not generation_config.stream
):
loop.run_until_complete(
loop.shutdown_asyncgens()
)
loop.close()
thread = Thread(target=run)
thread.start()
thread.join()
if exception:
raise exception
return result
else:
# If there's already a running loop, schedule and execute the coroutine
future = asyncio.run_coroutine_threadsafe(
async_method(self, *args, **kwargs), loop
)
return future.result()
return sync_wrapper
setattr(
new_cls, sync_method_name, make_sync_method(async_method)
)
return new_cls
def syncable(func):
"""Decorator to mark methods for synchronous wrapper creation."""
func._syncable = True
return func