about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/core/providers/orchestration
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/core/providers/orchestration
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/core/providers/orchestration')
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/orchestration/__init__.py4
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/orchestration/hatchet.py105
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/orchestration/simple.py61
3 files changed, 170 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/core/providers/orchestration/__init__.py b/.venv/lib/python3.12/site-packages/core/providers/orchestration/__init__.py
new file mode 100644
index 00000000..b41d79b0
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/orchestration/__init__.py
@@ -0,0 +1,4 @@
+from .hatchet import HatchetOrchestrationProvider
+from .simple import SimpleOrchestrationProvider
+
+__all__ = ["HatchetOrchestrationProvider", "SimpleOrchestrationProvider"]
diff --git a/.venv/lib/python3.12/site-packages/core/providers/orchestration/hatchet.py b/.venv/lib/python3.12/site-packages/core/providers/orchestration/hatchet.py
new file mode 100644
index 00000000..941e2048
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/orchestration/hatchet.py
@@ -0,0 +1,105 @@
+# FIXME: Once the Hatchet workflows are type annotated, remove the type: ignore comments
+import asyncio
+import logging
+from typing import Any, Callable, Optional
+
+from core.base import OrchestrationConfig, OrchestrationProvider, Workflow
+
+logger = logging.getLogger()
+
+
+class HatchetOrchestrationProvider(OrchestrationProvider):
+    def __init__(self, config: OrchestrationConfig):
+        super().__init__(config)
+        try:
+            from hatchet_sdk import ClientConfig, Hatchet
+        except ImportError:
+            raise ImportError(
+                "Hatchet SDK not installed. Please install it using `pip install hatchet-sdk`."
+            ) from None
+        root_logger = logging.getLogger()
+
+        self.orchestrator = Hatchet(
+            config=ClientConfig(
+                logger=root_logger,
+            ),
+        )
+        self.root_logger = root_logger
+        self.config: OrchestrationConfig = config
+        self.messages: dict[str, str] = {}
+
+    def workflow(self, *args, **kwargs) -> Callable:
+        return self.orchestrator.workflow(*args, **kwargs)
+
+    def step(self, *args, **kwargs) -> Callable:
+        return self.orchestrator.step(*args, **kwargs)
+
+    def failure(self, *args, **kwargs) -> Callable:
+        return self.orchestrator.on_failure_step(*args, **kwargs)
+
+    def get_worker(self, name: str, max_runs: Optional[int] = None) -> Any:
+        if not max_runs:
+            max_runs = self.config.max_runs
+        self.worker = self.orchestrator.worker(name, max_runs)  # type: ignore
+        return self.worker
+
+    def concurrency(self, *args, **kwargs) -> Callable:
+        return self.orchestrator.concurrency(*args, **kwargs)
+
+    async def start_worker(self):
+        if not self.worker:
+            raise ValueError(
+                "Worker not initialized. Call get_worker() first."
+            )
+
+        asyncio.create_task(self.worker.async_start())
+
+    async def run_workflow(
+        self,
+        workflow_name: str,
+        parameters: dict,
+        options: dict,
+        *args,
+        **kwargs,
+    ) -> Any:
+        task_id = self.orchestrator.admin.run_workflow(  # type: ignore
+            workflow_name,
+            parameters,
+            options=options,  # type: ignore
+            *args,
+            **kwargs,
+        )
+        return {
+            "task_id": str(task_id),
+            "message": self.messages.get(
+                workflow_name, "Workflow queued successfully."
+            ),  # Return message based on workflow name
+        }
+
+    def register_workflows(
+        self, workflow: Workflow, service: Any, messages: dict
+    ) -> None:
+        self.messages.update(messages)
+
+        logger.info(
+            f"Registering workflows for {workflow} with messages {messages}."
+        )
+        if workflow == Workflow.INGESTION:
+            from core.main.orchestration.hatchet.ingestion_workflow import (  # type: ignore
+                hatchet_ingestion_factory,
+            )
+
+            workflows = hatchet_ingestion_factory(self, service)
+            if self.worker:
+                for workflow in workflows.values():
+                    self.worker.register_workflow(workflow)
+
+        elif workflow == Workflow.GRAPH:
+            from core.main.orchestration.hatchet.graph_workflow import (  # type: ignore
+                hatchet_graph_search_results_factory,
+            )
+
+            workflows = hatchet_graph_search_results_factory(self, service)
+            if self.worker:
+                for workflow in workflows.values():
+                    self.worker.register_workflow(workflow)
diff --git a/.venv/lib/python3.12/site-packages/core/providers/orchestration/simple.py b/.venv/lib/python3.12/site-packages/core/providers/orchestration/simple.py
new file mode 100644
index 00000000..33028afe
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/orchestration/simple.py
@@ -0,0 +1,61 @@
+from typing import Any
+
+from core.base import OrchestrationConfig, OrchestrationProvider, Workflow
+
+
+class SimpleOrchestrationProvider(OrchestrationProvider):
+    def __init__(self, config: OrchestrationConfig):
+        super().__init__(config)
+        self.config = config
+        self.messages: dict[str, str] = {}
+
+    async def start_worker(self):
+        pass
+
+    def get_worker(self, name: str, max_runs: int) -> Any:
+        pass
+
+    def step(self, *args, **kwargs) -> Any:
+        pass
+
+    def workflow(self, *args, **kwargs) -> Any:
+        pass
+
+    def failure(self, *args, **kwargs) -> Any:
+        pass
+
+    def register_workflows(
+        self, workflow: Workflow, service: Any, messages: dict
+    ) -> None:
+        for key, msg in messages.items():
+            self.messages[key] = msg
+
+        if workflow == Workflow.INGESTION:
+            from core.main.orchestration import simple_ingestion_factory
+
+            self.ingestion_workflows = simple_ingestion_factory(service)
+
+        elif workflow == Workflow.GRAPH:
+            from core.main.orchestration.simple.graph_workflow import (
+                simple_graph_search_results_factory,
+            )
+
+            self.graph_search_results_workflows = (
+                simple_graph_search_results_factory(service)
+            )
+
+    async def run_workflow(
+        self, workflow_name: str, parameters: dict, options: dict
+    ) -> dict[str, str]:
+        if workflow_name in self.ingestion_workflows:
+            await self.ingestion_workflows[workflow_name](
+                parameters.get("request")
+            )
+            return {"message": self.messages[workflow_name]}
+        elif workflow_name in self.graph_search_results_workflows:
+            await self.graph_search_results_workflows[workflow_name](
+                parameters.get("request")
+            )
+            return {"message": self.messages[workflow_name]}
+        else:
+            raise ValueError(f"Workflow '{workflow_name}' not found.")