1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
|
import contextvars
import functools
import logging
from concurrent.futures import ThreadPoolExecutor
from io import StringIO
from typing import Any, Coroutine
from hatchet_sdk import logger
from hatchet_sdk.clients.events import EventClient
wr: contextvars.ContextVar[str | None] = contextvars.ContextVar(
"workflow_run_id", default=None
)
sr: contextvars.ContextVar[str | None] = contextvars.ContextVar(
"step_run_id", default=None
)
def copy_context_vars(ctx_vars, func, *args, **kwargs):
for var, value in ctx_vars:
var.set(value)
return func(*args, **kwargs)
class InjectingFilter(logging.Filter):
# For some reason, only the InjectingFilter has access to the contextvars method sr.get(),
# otherwise we would use emit within the CustomLogHandler
def filter(self, record):
record.workflow_run_id = wr.get()
record.step_run_id = sr.get()
return True
class CustomLogHandler(logging.StreamHandler):
def __init__(self, event_client: EventClient, stream=None):
super().__init__(stream)
self.logger_thread_pool = ThreadPoolExecutor(max_workers=1)
self.event_client = event_client
def _log(self, line: str, step_run_id: str | None):
try:
if not step_run_id:
return
self.event_client.log(message=line, step_run_id=step_run_id)
except Exception as e:
logger.error(f"Error logging: {e}")
def emit(self, record):
super().emit(record)
log_entry = self.format(record)
self.logger_thread_pool.submit(self._log, log_entry, record.step_run_id)
def capture_logs(
logger: logging.Logger,
event_client: EventClient,
func: Coroutine[Any, Any, Any],
):
@functools.wraps(func)
async def wrapper(*args, **kwargs):
if not logger:
raise Exception("No logger configured on client")
log_stream = StringIO()
custom_handler = CustomLogHandler(event_client, log_stream)
custom_handler.setLevel(logging.INFO)
custom_handler.addFilter(InjectingFilter())
logger.addHandler(custom_handler)
try:
result = await func(*args, **kwargs)
finally:
custom_handler.flush()
logger.removeHandler(custom_handler)
log_stream.close()
return result
return wrapper
|