about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/utils
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/utils')
-rw-r--r--.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/utils/capture_logs.py81
-rw-r--r--.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/utils/error_with_traceback.py6
2 files changed, 87 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/utils/capture_logs.py b/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/utils/capture_logs.py
new file mode 100644
index 00000000..08c57de8
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/utils/capture_logs.py
@@ -0,0 +1,81 @@
+import contextvars
+import functools
+import logging
+from concurrent.futures import ThreadPoolExecutor
+from io import StringIO
+from typing import Any, Coroutine
+
+from hatchet_sdk import logger
+from hatchet_sdk.clients.events import EventClient
+
+wr: contextvars.ContextVar[str | None] = contextvars.ContextVar(
+    "workflow_run_id", default=None
+)
+sr: contextvars.ContextVar[str | None] = contextvars.ContextVar(
+    "step_run_id", default=None
+)
+
+
+def copy_context_vars(ctx_vars, func, *args, **kwargs):
+    for var, value in ctx_vars:
+        var.set(value)
+    return func(*args, **kwargs)
+
+
+class InjectingFilter(logging.Filter):
+    # For some reason, only the InjectingFilter has access to the contextvars method sr.get(),
+    # otherwise we would use emit within the CustomLogHandler
+    def filter(self, record):
+        record.workflow_run_id = wr.get()
+        record.step_run_id = sr.get()
+        return True
+
+
+class CustomLogHandler(logging.StreamHandler):
+    def __init__(self, event_client: EventClient, stream=None):
+        super().__init__(stream)
+        self.logger_thread_pool = ThreadPoolExecutor(max_workers=1)
+        self.event_client = event_client
+
+    def _log(self, line: str, step_run_id: str | None):
+        try:
+            if not step_run_id:
+                return
+
+            self.event_client.log(message=line, step_run_id=step_run_id)
+        except Exception as e:
+            logger.error(f"Error logging: {e}")
+
+    def emit(self, record):
+        super().emit(record)
+
+        log_entry = self.format(record)
+        self.logger_thread_pool.submit(self._log, log_entry, record.step_run_id)
+
+
+def capture_logs(
+    logger: logging.Logger,
+    event_client: EventClient,
+    func: Coroutine[Any, Any, Any],
+):
+    @functools.wraps(func)
+    async def wrapper(*args, **kwargs):
+        if not logger:
+            raise Exception("No logger configured on client")
+
+        log_stream = StringIO()
+        custom_handler = CustomLogHandler(event_client, log_stream)
+        custom_handler.setLevel(logging.INFO)
+        custom_handler.addFilter(InjectingFilter())
+        logger.addHandler(custom_handler)
+
+        try:
+            result = await func(*args, **kwargs)
+        finally:
+            custom_handler.flush()
+            logger.removeHandler(custom_handler)
+            log_stream.close()
+
+        return result
+
+    return wrapper
diff --git a/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/utils/error_with_traceback.py b/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/utils/error_with_traceback.py
new file mode 100644
index 00000000..9c09602f
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/utils/error_with_traceback.py
@@ -0,0 +1,6 @@
+import traceback
+
+
+def errorWithTraceback(message: str, e: Exception):
+    trace = "".join(traceback.format_exception(type(e), e, e.__traceback__))
+    return f"{message}\n{trace}"