diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/utils/capture_logs.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/utils/capture_logs.py | 81 |
1 files changed, 81 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/utils/capture_logs.py b/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/utils/capture_logs.py new file mode 100644 index 00000000..08c57de8 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/utils/capture_logs.py @@ -0,0 +1,81 @@ +import contextvars +import functools +import logging +from concurrent.futures import ThreadPoolExecutor +from io import StringIO +from typing import Any, Coroutine + +from hatchet_sdk import logger +from hatchet_sdk.clients.events import EventClient + +wr: contextvars.ContextVar[str | None] = contextvars.ContextVar( + "workflow_run_id", default=None +) +sr: contextvars.ContextVar[str | None] = contextvars.ContextVar( + "step_run_id", default=None +) + + +def copy_context_vars(ctx_vars, func, *args, **kwargs): + for var, value in ctx_vars: + var.set(value) + return func(*args, **kwargs) + + +class InjectingFilter(logging.Filter): + # For some reason, only the InjectingFilter has access to the contextvars method sr.get(), + # otherwise we would use emit within the CustomLogHandler + def filter(self, record): + record.workflow_run_id = wr.get() + record.step_run_id = sr.get() + return True + + +class CustomLogHandler(logging.StreamHandler): + def __init__(self, event_client: EventClient, stream=None): + super().__init__(stream) + self.logger_thread_pool = ThreadPoolExecutor(max_workers=1) + self.event_client = event_client + + def _log(self, line: str, step_run_id: str | None): + try: + if not step_run_id: + return + + self.event_client.log(message=line, step_run_id=step_run_id) + except Exception as e: + logger.error(f"Error logging: {e}") + + def emit(self, record): + super().emit(record) + + log_entry = self.format(record) + self.logger_thread_pool.submit(self._log, log_entry, record.step_run_id) + + +def capture_logs( + logger: logging.Logger, + event_client: EventClient, + func: Coroutine[Any, Any, Any], +): + @functools.wraps(func) + async def wrapper(*args, **kwargs): + if not logger: + raise Exception("No logger configured on client") + + log_stream = StringIO() + custom_handler = CustomLogHandler(event_client, log_stream) + custom_handler.setLevel(logging.INFO) + custom_handler.addFilter(InjectingFilter()) + logger.addHandler(custom_handler) + + try: + result = await func(*args, **kwargs) + finally: + custom_handler.flush() + logger.removeHandler(custom_handler) + log_stream.close() + + return result + + return wrapper |