aboutsummaryrefslogtreecommitdiff
import asyncio
import logging
from asyncio import Queue
from typing import Any, Optional

from ..base.abstractions.search import (
    AggregateSearchResult,
    KGSearchSettings,
    VectorSearchSettings,
)
from ..base.logging.kv_logger import KVLoggingSingleton
from ..base.logging.run_manager import RunManager, manage_run
from ..base.pipeline.base_pipeline import AsyncPipeline, dequeue_requests
from ..base.pipes.base_pipe import AsyncPipe, AsyncState

logger = logging.getLogger(__name__)


class SearchPipeline(AsyncPipeline):
    """A pipeline for search."""

    pipeline_type: str = "search"

    def __init__(
        self,
        pipe_logger: Optional[KVLoggingSingleton] = None,
        run_manager: Optional[RunManager] = None,
    ):
        super().__init__(pipe_logger, run_manager)
        self._parsing_pipe = None
        self._vector_search_pipeline = None
        self._kg_search_pipeline = None

    async def run(
        self,
        input: Any,
        state: Optional[AsyncState] = None,
        stream: bool = False,
        run_manager: Optional[RunManager] = None,
        log_run_info: bool = True,
        vector_search_settings: VectorSearchSettings = VectorSearchSettings(),
        kg_search_settings: KGSearchSettings = KGSearchSettings(),
        *args: Any,
        **kwargs: Any,
    ):
        self.state = state or AsyncState()
        do_vector_search = (
            self._vector_search_pipeline is not None
            and vector_search_settings.use_vector_search
        )
        do_kg = (
            self._kg_search_pipeline is not None
            and kg_search_settings.use_kg_search
        )
        async with manage_run(run_manager, self.pipeline_type):
            if log_run_info:
                await run_manager.log_run_info(
                    key="pipeline_type",
                    value=self.pipeline_type,
                    is_info_log=True,
                )

            vector_search_queue = Queue()
            kg_queue = Queue()

            async def enqueue_requests():
                async for message in input:
                    if do_vector_search:
                        await vector_search_queue.put(message)
                    if do_kg:
                        await kg_queue.put(message)

                await vector_search_queue.put(None)
                await kg_queue.put(None)

            # Start the document enqueuing process
            enqueue_task = asyncio.create_task(enqueue_requests())

            # Start the embedding and KG pipelines in parallel
            if do_vector_search:
                vector_search_task = asyncio.create_task(
                    self._vector_search_pipeline.run(
                        dequeue_requests(vector_search_queue),
                        state,
                        stream,
                        run_manager,
                        log_run_info=False,
                        vector_search_settings=vector_search_settings,
                    )
                )

            if do_kg:
                kg_task = asyncio.create_task(
                    self._kg_search_pipeline.run(
                        dequeue_requests(kg_queue),
                        state,
                        stream,
                        run_manager,
                        log_run_info=False,
                        kg_search_settings=kg_search_settings,
                    )
                )

        await enqueue_task

        vector_search_results = (
            await vector_search_task if do_vector_search else None
        )
        kg_results = await kg_task if do_kg else None

        return AggregateSearchResult(
            vector_search_results=vector_search_results,
            kg_search_results=kg_results,
        )

    def add_pipe(
        self,
        pipe: AsyncPipe,
        add_upstream_outputs: Optional[list[dict[str, str]]] = None,
        kg_pipe: bool = False,
        vector_search_pipe: bool = False,
        *args,
        **kwargs,
    ) -> None:
        logger.debug(f"Adding pipe {pipe.config.name} to the SearchPipeline")

        if kg_pipe:
            if not self._kg_search_pipeline:
                self._kg_search_pipeline = AsyncPipeline()
            self._kg_search_pipeline.add_pipe(
                pipe, add_upstream_outputs, *args, **kwargs
            )
        elif vector_search_pipe:
            if not self._vector_search_pipeline:
                self._vector_search_pipeline = AsyncPipeline()
            self._vector_search_pipeline.add_pipe(
                pipe, add_upstream_outputs, *args, **kwargs
            )
        else:
            raise ValueError("Pipe must be a vector search or KG pipe")