diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/core/providers/orchestration')
3 files changed, 170 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/core/providers/orchestration/__init__.py b/.venv/lib/python3.12/site-packages/core/providers/orchestration/__init__.py new file mode 100644 index 00000000..b41d79b0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/providers/orchestration/__init__.py @@ -0,0 +1,4 @@ +from .hatchet import HatchetOrchestrationProvider +from .simple import SimpleOrchestrationProvider + +__all__ = ["HatchetOrchestrationProvider", "SimpleOrchestrationProvider"] diff --git a/.venv/lib/python3.12/site-packages/core/providers/orchestration/hatchet.py b/.venv/lib/python3.12/site-packages/core/providers/orchestration/hatchet.py new file mode 100644 index 00000000..941e2048 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/providers/orchestration/hatchet.py @@ -0,0 +1,105 @@ +# FIXME: Once the Hatchet workflows are type annotated, remove the type: ignore comments +import asyncio +import logging +from typing import Any, Callable, Optional + +from core.base import OrchestrationConfig, OrchestrationProvider, Workflow + +logger = logging.getLogger() + + +class HatchetOrchestrationProvider(OrchestrationProvider): + def __init__(self, config: OrchestrationConfig): + super().__init__(config) + try: + from hatchet_sdk import ClientConfig, Hatchet + except ImportError: + raise ImportError( + "Hatchet SDK not installed. Please install it using `pip install hatchet-sdk`." + ) from None + root_logger = logging.getLogger() + + self.orchestrator = Hatchet( + config=ClientConfig( + logger=root_logger, + ), + ) + self.root_logger = root_logger + self.config: OrchestrationConfig = config + self.messages: dict[str, str] = {} + + def workflow(self, *args, **kwargs) -> Callable: + return self.orchestrator.workflow(*args, **kwargs) + + def step(self, *args, **kwargs) -> Callable: + return self.orchestrator.step(*args, **kwargs) + + def failure(self, *args, **kwargs) -> Callable: + return self.orchestrator.on_failure_step(*args, **kwargs) + + def get_worker(self, name: str, max_runs: Optional[int] = None) -> Any: + if not max_runs: + max_runs = self.config.max_runs + self.worker = self.orchestrator.worker(name, max_runs) # type: ignore + return self.worker + + def concurrency(self, *args, **kwargs) -> Callable: + return self.orchestrator.concurrency(*args, **kwargs) + + async def start_worker(self): + if not self.worker: + raise ValueError( + "Worker not initialized. Call get_worker() first." + ) + + asyncio.create_task(self.worker.async_start()) + + async def run_workflow( + self, + workflow_name: str, + parameters: dict, + options: dict, + *args, + **kwargs, + ) -> Any: + task_id = self.orchestrator.admin.run_workflow( # type: ignore + workflow_name, + parameters, + options=options, # type: ignore + *args, + **kwargs, + ) + return { + "task_id": str(task_id), + "message": self.messages.get( + workflow_name, "Workflow queued successfully." + ), # Return message based on workflow name + } + + def register_workflows( + self, workflow: Workflow, service: Any, messages: dict + ) -> None: + self.messages.update(messages) + + logger.info( + f"Registering workflows for {workflow} with messages {messages}." + ) + if workflow == Workflow.INGESTION: + from core.main.orchestration.hatchet.ingestion_workflow import ( # type: ignore + hatchet_ingestion_factory, + ) + + workflows = hatchet_ingestion_factory(self, service) + if self.worker: + for workflow in workflows.values(): + self.worker.register_workflow(workflow) + + elif workflow == Workflow.GRAPH: + from core.main.orchestration.hatchet.graph_workflow import ( # type: ignore + hatchet_graph_search_results_factory, + ) + + workflows = hatchet_graph_search_results_factory(self, service) + if self.worker: + for workflow in workflows.values(): + self.worker.register_workflow(workflow) diff --git a/.venv/lib/python3.12/site-packages/core/providers/orchestration/simple.py b/.venv/lib/python3.12/site-packages/core/providers/orchestration/simple.py new file mode 100644 index 00000000..33028afe --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/providers/orchestration/simple.py @@ -0,0 +1,61 @@ +from typing import Any + +from core.base import OrchestrationConfig, OrchestrationProvider, Workflow + + +class SimpleOrchestrationProvider(OrchestrationProvider): + def __init__(self, config: OrchestrationConfig): + super().__init__(config) + self.config = config + self.messages: dict[str, str] = {} + + async def start_worker(self): + pass + + def get_worker(self, name: str, max_runs: int) -> Any: + pass + + def step(self, *args, **kwargs) -> Any: + pass + + def workflow(self, *args, **kwargs) -> Any: + pass + + def failure(self, *args, **kwargs) -> Any: + pass + + def register_workflows( + self, workflow: Workflow, service: Any, messages: dict + ) -> None: + for key, msg in messages.items(): + self.messages[key] = msg + + if workflow == Workflow.INGESTION: + from core.main.orchestration import simple_ingestion_factory + + self.ingestion_workflows = simple_ingestion_factory(service) + + elif workflow == Workflow.GRAPH: + from core.main.orchestration.simple.graph_workflow import ( + simple_graph_search_results_factory, + ) + + self.graph_search_results_workflows = ( + simple_graph_search_results_factory(service) + ) + + async def run_workflow( + self, workflow_name: str, parameters: dict, options: dict + ) -> dict[str, str]: + if workflow_name in self.ingestion_workflows: + await self.ingestion_workflows[workflow_name]( + parameters.get("request") + ) + return {"message": self.messages[workflow_name]} + elif workflow_name in self.graph_search_results_workflows: + await self.graph_search_results_workflows[workflow_name]( + parameters.get("request") + ) + return {"message": self.messages[workflow_name]} + else: + raise ValueError(f"Workflow '{workflow_name}' not found.") |