aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/shared/abstractions/base.py
blob: d90ba400518c703118308283f170cfa65a124bc7 (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import asyncio
import json
from datetime import datetime
from enum import Enum
from typing import Any, Type, TypeVar
from uuid import UUID

from pydantic import BaseModel

T = TypeVar("T", bound="R2RSerializable")


class R2RSerializable(BaseModel):
    @classmethod
    def from_dict(cls: Type[T], data: dict[str, Any] | str) -> T:
        if isinstance(data, str):
            try:
                data_dict = json.loads(data)
            except json.JSONDecodeError as e:
                raise ValueError(f"Invalid JSON string: {e}") from e
        else:
            data_dict = data
        return cls(**data_dict)

    def as_dict(self) -> dict[str, Any]:
        data = self.model_dump(exclude_unset=True)
        return self._serialize_values(data)

    def to_dict(self) -> dict[str, Any]:
        data = self.model_dump(exclude_unset=True)
        return self._serialize_values(data)

    def to_json(self) -> str:
        data = self.to_dict()
        return json.dumps(data)

    @classmethod
    def from_json(cls: Type[T], json_str: str) -> T:
        return cls.model_validate_json(json_str)

    @staticmethod
    def _serialize_values(data: Any) -> Any:
        if isinstance(data, dict):
            return {
                k: R2RSerializable._serialize_values(v)
                for k, v in data.items()
            }
        elif isinstance(data, list):
            return [R2RSerializable._serialize_values(v) for v in data]
        elif isinstance(data, UUID):
            return str(data)
        elif isinstance(data, Enum):
            return data.value
        elif isinstance(data, datetime):
            return data.isoformat()
        else:
            return data

    class Config:
        arbitrary_types_allowed = True
        json_encoders = {
            UUID: str,
            bytes: lambda v: v.decode("utf-8", errors="ignore"),
        }


class AsyncSyncMeta(type):
    _event_loop = None  # Class-level shared event loop

    @classmethod
    def get_event_loop(cls):
        if cls._event_loop is None or cls._event_loop.is_closed():
            cls._event_loop = asyncio.new_event_loop()
            asyncio.set_event_loop(cls._event_loop)
        return cls._event_loop

    def __new__(cls, name, bases, dct):
        new_cls = super().__new__(cls, name, bases, dct)
        for attr_name, attr_value in dct.items():
            if asyncio.iscoroutinefunction(attr_value) and getattr(
                attr_value, "_syncable", False
            ):
                sync_method_name = attr_name[
                    1:
                ]  # Remove leading 'a' for sync method
                async_method = attr_value

                def make_sync_method(async_method):
                    def sync_wrapper(self, *args, **kwargs):
                        loop = cls.get_event_loop()
                        if not loop.is_running():
                            # Setup to run the loop in a background thread if necessary
                            # to prevent blocking the main thread in a synchronous call environment
                            from threading import Thread

                            result = None
                            exception = None

                            def run():
                                nonlocal result, exception
                                try:
                                    asyncio.set_event_loop(loop)
                                    result = loop.run_until_complete(
                                        async_method(self, *args, **kwargs)
                                    )
                                except Exception as e:
                                    exception = e
                                finally:
                                    generation_config = kwargs.get(
                                        "rag_generation_config", None
                                    )
                                    if (
                                        not generation_config
                                        or not generation_config.stream
                                    ):
                                        loop.run_until_complete(
                                            loop.shutdown_asyncgens()
                                        )
                                        loop.close()

                            thread = Thread(target=run)
                            thread.start()
                            thread.join()
                            if exception:
                                raise exception
                            return result
                        else:
                            # If there's already a running loop, schedule and execute the coroutine
                            future = asyncio.run_coroutine_threadsafe(
                                async_method(self, *args, **kwargs), loop
                            )
                            return future.result()

                    return sync_wrapper

                setattr(
                    new_cls, sync_method_name, make_sync_method(async_method)
                )
        return new_cls


def syncable(func):
    """Decorator to mark methods for synchronous wrapper creation."""
    func._syncable = True
    return func