aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/hatchet_sdk/utils/aio_utils.py
blob: 459205f1f370262b6c83f8d806d47f6d14154b1f (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import asyncio
import inspect
from concurrent.futures import Executor
from functools import partial, wraps
from threading import Thread
from typing import Any


## TODO: Stricter typing here
def sync_to_async(func: Any) -> Any:
    """
    A decorator to run a synchronous function or coroutine in an asynchronous context with added
    asyncio loop safety.

    This decorator allows you to safely call synchronous functions or coroutines from an
    asynchronous function by running them in an executor.

    Args:
        func (callable): The synchronous function or coroutine to be run asynchronously.

    Returns:
        callable: An asynchronous wrapper function that runs the given function in an executor.

    Example:
        @sync_to_async
        def sync_function(x, y):
            return x + y

        @sync_to_async
        async def async_function(x, y):
            return x + y


        def undecorated_function(x, y):
            return x + y

        async def main():
            result1 = await sync_function(1, 2)
            result2 = await async_function(3, 4)
            result3 = await sync_to_async(undecorated_function)(5, 6)
            print(result1, result2, result3)

        asyncio.run(main())
    """

    ## TODO: Stricter typing here
    @wraps(func)
    async def run(
        *args: Any,
        loop: asyncio.AbstractEventLoop | None = None,
        executor: Executor | None = None,
        **kwargs: Any
    ) -> Any:
        """
        The asynchronous wrapper function that runs the given function in an executor.

        Args:
            *args: Positional arguments to pass to the function.
            loop (asyncio.AbstractEventLoop, optional): The event loop to use. If None, the current running loop is used.
            executor (concurrent.futures.Executor, optional): The executor to use. If None, the default executor is used.
            **kwargs: Keyword arguments to pass to the function.

        Returns:
            The result of the function call.
        """
        if loop is None:
            loop = asyncio.get_running_loop()

        if inspect.iscoroutinefunction(func):
            # Wrap the coroutine to run it in an executor
            async def wrapper() -> Any:
                return await func(*args, **kwargs)

            pfunc = partial(asyncio.run, wrapper())
            return await loop.run_in_executor(executor, pfunc)
        else:
            # Run the synchronous function in an executor
            pfunc = partial(func, *args, **kwargs)
            return await loop.run_in_executor(executor, pfunc)

    return run


class EventLoopThread:
    """A class that manages an asyncio event loop running in a separate thread."""

    def __init__(self) -> None:
        """
        Initializes the EventLoopThread by creating an event loop
        and setting up a thread to run the loop.
        """
        self.loop = asyncio.new_event_loop()
        self.thread = Thread(target=self.run_loop_in_thread, args=(self.loop,))

    def __enter__(self, *a, **kw) -> asyncio.AbstractEventLoop:
        """
        Starts the thread running the event loop when entering the context.

        Returns:
            asyncio.AbstractEventLoop: The event loop running in the separate thread.
        """
        self.thread.start()
        return self.loop

    def __exit__(self, *a, **kw) -> None:
        """
        Stops the event loop and joins the thread when exiting the context.
        """
        self.loop.call_soon_threadsafe(self.loop.stop)
        self.thread.join()

    def run_loop_in_thread(self, loop: asyncio.AbstractEventLoop) -> None:
        """
        Sets the event loop for the current thread and runs it forever.

        Args:
            loop (asyncio.AbstractEventLoop): The event loop to run.
        """
        asyncio.set_event_loop(loop)
        loop.run_forever()


def get_active_event_loop() -> asyncio.AbstractEventLoop | None:
    """
    Get the active event loop.

    Returns:
        asyncio.AbstractEventLoop: The active event loop, or None if there is no active
        event loop in the current thread.
    """
    try:
        return asyncio.get_event_loop()
    except RuntimeError as e:
        if str(e).startswith("There is no current event loop in thread"):
            return None
        else:
            raise e