aboutsummaryrefslogtreecommitdiff
import asyncio
import json
from datetime import datetime
from enum import Enum
from typing import Any, Type, TypeVar
from uuid import UUID

from pydantic import BaseModel

T = TypeVar("T", bound="R2RSerializable")


class R2RSerializable(BaseModel):
    @classmethod
    def from_dict(cls: Type[T], data: dict[str, Any] | str) -> T:
        if isinstance(data, str):
            try:
                data_dict = json.loads(data)
            except json.JSONDecodeError as e:
                raise ValueError(f"Invalid JSON string: {e}") from e
        else:
            data_dict = data
        return cls(**data_dict)

    def as_dict(self) -> dict[str, Any]:
        data = self.model_dump(exclude_unset=True)
        return self._serialize_values(data)

    def to_dict(self) -> dict[str, Any]:
        data = self.model_dump(exclude_unset=True)
        return self._serialize_values(data)

    def to_json(self) -> str:
        data = self.to_dict()
        return json.dumps(data)

    @classmethod
    def from_json(cls: Type[T], json_str: str) -> T:
        return cls.model_validate_json(json_str)

    @staticmethod
    def _serialize_values(data: Any) -> Any:
        if isinstance(data, dict):
            return {
                k: R2RSerializable._serialize_values(v)
                for k, v in data.items()
            }
        elif isinstance(data, list):
            return [R2RSerializable._serialize_values(v) for v in data]
        elif isinstance(data, UUID):
            return str(data)
        elif isinstance(data, Enum):
            return data.value
        elif isinstance(data, datetime):
            return data.isoformat()
        else:
            return data

    class Config:
        arbitrary_types_allowed = True
        json_encoders = {
            UUID: str,
            bytes: lambda v: v.decode("utf-8", errors="ignore"),
        }


class AsyncSyncMeta(type):
    _event_loop = None  # Class-level shared event loop

    @classmethod
    def get_event_loop(cls):
        if cls._event_loop is None or cls._event_loop.is_closed():
            cls._event_loop = asyncio.new_event_loop()
            asyncio.set_event_loop(cls._event_loop)
        return cls._event_loop

    def __new__(cls, name, bases, dct):
        new_cls = super().__new__(cls, name, bases, dct)
        for attr_name, attr_value in dct.items():
            if asyncio.iscoroutinefunction(attr_value) and getattr(
                attr_value, "_syncable", False
            ):
                sync_method_name = attr_name[
                    1:
                ]  # Remove leading 'a' for sync method
                async_method = attr_value

                def make_sync_method(async_method):
                    def sync_wrapper(self, *args, **kwargs):
                        loop = cls.get_event_loop()
                        if not loop.is_running():
                            # Setup to run the loop in a background thread if necessary
                            # to prevent blocking the main thread in a synchronous call environment
                            from threading import Thread

                            result = None
                            exception = None

                            def run():
                                nonlocal result, exception
                                try:
                                    asyncio.set_event_loop(loop)
                                    result = loop.run_until_complete(
                                        async_method(self, *args, **kwargs)
                                    )
                                except Exception as e:
                                    exception = e
                                finally:
                                    generation_config = kwargs.get(
                                        "rag_generation_config", None
                                    )
                                    if (
                                        not generation_config
                                        or not generation_config.stream
                                    ):
                                        loop.run_until_complete(
                                            loop.shutdown_asyncgens()
                                        )
                                        loop.close()

                            thread = Thread(target=run)
                            thread.start()
                            thread.join()
                            if exception:
                                raise exception
                            return result
                        else:
                            # If there's already a running loop, schedule and execute the coroutine
                            future = asyncio.run_coroutine_threadsafe(
                                async_method(self, *args, **kwargs), loop
                            )
                            return future.result()

                    return sync_wrapper

                setattr(
                    new_cls, sync_method_name, make_sync_method(async_method)
                )
        return new_cls


def syncable(func):
    """Decorator to mark methods for synchronous wrapper creation."""
    func._syncable = True
    return func