about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/sdk/sync_methods/retrieval.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/sdk/sync_methods/retrieval.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/sdk/sync_methods/retrieval.py')
-rw-r--r--.venv/lib/python3.12/site-packages/sdk/sync_methods/retrieval.py554
1 files changed, 554 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/sdk/sync_methods/retrieval.py b/.venv/lib/python3.12/site-packages/sdk/sync_methods/retrieval.py
new file mode 100644
index 00000000..4a927014
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/sdk/sync_methods/retrieval.py
@@ -0,0 +1,554 @@
+import json
+import uuid
+from typing import Any, Generator, Optional
+
+from shared.api.models import (
+    WrappedAgentResponse,
+    WrappedEmbeddingResponse,
+    WrappedLLMChatCompletion,
+    WrappedRAGResponse,
+    WrappedSearchResponse,
+)
+
+from ..models import (
+    AgentEvent,
+    CitationData,
+    CitationEvent,
+    Delta,
+    DeltaPayload,
+    FinalAnswerData,
+    FinalAnswerEvent,
+    GenerationConfig,
+    Message,
+    MessageData,
+    MessageDelta,
+    MessageEvent,
+    SearchMode,
+    SearchResultsData,
+    SearchResultsEvent,
+    SearchSettings,
+    ThinkingData,
+    ThinkingEvent,
+    ToolCallData,
+    ToolCallEvent,
+    ToolResultData,
+    ToolResultEvent,
+    UnknownEvent,
+)
+
+
+def parse_retrieval_event(raw: dict) -> Optional[AgentEvent]:
+    """
+    Convert a raw SSE event dict into a typed Pydantic model.
+
+    Example raw dict:
+        {
+          "event": "message",
+          "data": "{\"id\": \"msg_partial\", \"object\": \"agent.message.delta\", \"delta\": {...}}"
+        }
+    """
+    event_type = raw.get("event", "unknown")
+
+    # If event_type == "done", we usually return None to signal the SSE stream is finished.
+    if event_type == "done":
+        return None
+
+    # The SSE "data" is JSON-encoded, so parse it
+    data_str = raw.get("data", "")
+    try:
+        data_obj = json.loads(data_str)
+    except json.JSONDecodeError as e:
+        # You can decide whether to raise or return UnknownEvent
+        raise ValueError(f"Could not parse JSON in SSE event data: {e}") from e
+
+    # Now branch on event_type to build the right Pydantic model
+    if event_type == "search_results":
+        return SearchResultsEvent(
+            event=event_type,
+            data=SearchResultsData(**data_obj),
+        )
+    elif event_type == "message":
+        # Parse nested delta structure manually before creating MessageData
+        if "delta" in data_obj and isinstance(data_obj["delta"], dict):
+            delta_dict = data_obj["delta"]
+
+            # Convert content items to MessageDelta objects
+            if "content" in delta_dict and isinstance(
+                delta_dict["content"], list
+            ):
+                parsed_content = []
+                for item in delta_dict["content"]:
+                    if isinstance(item, dict):
+                        # Parse payload to DeltaPayload
+                        if "payload" in item and isinstance(
+                            item["payload"], dict
+                        ):
+                            payload_dict = item["payload"]
+                            item["payload"] = DeltaPayload(**payload_dict)
+                        parsed_content.append(MessageDelta(**item))
+
+                # Replace with parsed content
+                delta_dict["content"] = parsed_content
+
+            # Create properly typed Delta object
+            data_obj["delta"] = Delta(**delta_dict)
+
+        return MessageEvent(
+            event=event_type,
+            data=MessageData(**data_obj),
+        )
+    elif event_type == "citation":
+        return CitationEvent(event=event_type, data=CitationData(**data_obj))
+    elif event_type == "tool_call":
+        return ToolCallEvent(event=event_type, data=ToolCallData(**data_obj))
+    elif event_type == "tool_result":
+        return ToolResultEvent(
+            event=event_type, data=ToolResultData(**data_obj)
+        )
+    elif event_type == "thinking":
+        # Parse nested delta structure manually before creating ThinkingData
+        if "delta" in data_obj and isinstance(data_obj["delta"], dict):
+            delta_dict = data_obj["delta"]
+
+            # Convert content items to MessageDelta objects
+            if "content" in delta_dict and isinstance(
+                delta_dict["content"], list
+            ):
+                parsed_content = []
+                for item in delta_dict["content"]:
+                    if isinstance(item, dict):
+                        # Parse payload to DeltaPayload
+                        if "payload" in item and isinstance(
+                            item["payload"], dict
+                        ):
+                            payload_dict = item["payload"]
+                            item["payload"] = DeltaPayload(**payload_dict)
+                        parsed_content.append(MessageDelta(**item))
+
+                # Replace with parsed content
+                delta_dict["content"] = parsed_content
+
+            # Create properly typed Delta object
+            data_obj["delta"] = Delta(**delta_dict)
+
+        return ThinkingEvent(
+            event=event_type,
+            data=ThinkingData(**data_obj),
+        )
+    elif event_type == "final_answer":
+        return FinalAnswerEvent(
+            event=event_type, data=FinalAnswerData(**data_obj)
+        )
+    else:
+        # Fallback if it doesn't match any known event
+        return UnknownEvent(
+            event=event_type,
+            data=data_obj,
+        )
+
+
+def search_arg_parser(
+    query: str,
+    search_mode: Optional[str | SearchMode] = "custom",
+    search_settings: Optional[dict | SearchSettings] = None,
+) -> dict:
+    if search_mode and not isinstance(search_mode, str):
+        search_mode = search_mode.value
+
+    if search_settings and not isinstance(search_settings, dict):
+        search_settings = search_settings.model_dump()
+
+    data: dict[str, Any] = {
+        "query": query,
+        "search_settings": search_settings,
+    }
+    if search_mode:
+        data["search_mode"] = search_mode
+
+    return data
+
+
+def completion_arg_parser(
+    messages: list[dict | Message],
+    generation_config: Optional[dict | GenerationConfig] = None,
+) -> dict:
+    # FIXME: Needs a proper return type
+    cast_messages: list[Message] = [
+        Message(**msg) if isinstance(msg, dict) else msg for msg in messages
+    ]
+
+    if generation_config and not isinstance(generation_config, dict):
+        generation_config = generation_config.model_dump()
+
+    data: dict[str, Any] = {
+        "messages": [msg.model_dump() for msg in cast_messages],
+        "generation_config": generation_config,
+    }
+    return data
+
+
+def embedding_arg_parser(
+    text: str,
+) -> dict:
+    data: dict[str, Any] = {
+        "text": text,
+    }
+    return data
+
+
+def rag_arg_parser(
+    query: str,
+    rag_generation_config: Optional[dict | GenerationConfig] = None,
+    search_mode: Optional[str | SearchMode] = "custom",
+    search_settings: Optional[dict | SearchSettings] = None,
+    task_prompt: Optional[str] = None,
+    include_title_if_available: Optional[bool] = False,
+    include_web_search: Optional[bool] = False,
+) -> dict:
+    if rag_generation_config and not isinstance(rag_generation_config, dict):
+        rag_generation_config = rag_generation_config.model_dump()
+    if search_settings and not isinstance(search_settings, dict):
+        search_settings = search_settings.model_dump()
+
+    data: dict[str, Any] = {
+        "query": query,
+        "rag_generation_config": rag_generation_config,
+        "search_settings": search_settings,
+        "task_prompt": task_prompt,
+        "include_title_if_available": include_title_if_available,
+        "include_web_search": include_web_search,
+    }
+    if search_mode:
+        data["search_mode"] = search_mode
+    return data
+
+
+def agent_arg_parser(
+    message: Optional[dict | Message] = None,
+    rag_generation_config: Optional[dict | GenerationConfig] = None,
+    research_generation_config: Optional[dict | GenerationConfig] = None,
+    search_mode: Optional[str | SearchMode] = "custom",
+    search_settings: Optional[dict | SearchSettings] = None,
+    task_prompt: Optional[str] = None,
+    include_title_if_available: Optional[bool] = True,
+    conversation_id: Optional[str | uuid.UUID] = None,
+    max_tool_context_length: Optional[int] = None,
+    use_system_context: Optional[bool] = True,
+    rag_tools: Optional[list[str]] = None,
+    research_tools: Optional[list[str]] = None,
+    tools: Optional[list[str]] = None,  # For backward compatibility
+    mode: Optional[str] = "rag",
+    needs_initial_conversation_name: Optional[bool] = None,
+) -> dict:
+    if rag_generation_config and not isinstance(rag_generation_config, dict):
+        rag_generation_config = rag_generation_config.model_dump()
+    if research_generation_config and not isinstance(
+        research_generation_config, dict
+    ):
+        research_generation_config = research_generation_config.model_dump()
+    if search_settings and not isinstance(search_settings, dict):
+        search_settings = search_settings.model_dump()
+
+    data: dict[str, Any] = {
+        "rag_generation_config": rag_generation_config or {},
+        "search_settings": search_settings,
+        "task_prompt": task_prompt,
+        "include_title_if_available": include_title_if_available,
+        "conversation_id": (str(conversation_id) if conversation_id else None),
+        "max_tool_context_length": max_tool_context_length,
+        "use_system_context": use_system_context,
+        "mode": mode,
+    }
+
+    # Handle generation configs based on mode
+    if research_generation_config and mode == "research":
+        data["research_generation_config"] = research_generation_config
+
+    # Handle tool configurations
+    if rag_tools:
+        data["rag_tools"] = rag_tools
+    if research_tools:
+        data["research_tools"] = research_tools
+    if tools:  # Backward compatibility
+        data["tools"] = tools
+
+    if search_mode:
+        data["search_mode"] = search_mode
+
+    if needs_initial_conversation_name:
+        data["needs_initial_conversation_name"] = (
+            needs_initial_conversation_name
+        )
+
+    if message:
+        cast_message: Message = (
+            Message(**message) if isinstance(message, dict) else message
+        )
+        data["message"] = cast_message.model_dump()
+    return data
+
+
+class RetrievalSDK:
+    """SDK for interacting with documents in the v3 API."""
+
+    def __init__(self, client):
+        self.client = client
+
+    def search(
+        self,
+        query: str,
+        search_mode: Optional[str | SearchMode] = "custom",
+        search_settings: Optional[dict | SearchSettings] = None,
+    ) -> WrappedSearchResponse:
+        """Conduct a vector and/or graph search.
+
+        Args:
+            query (str): The query to search for.
+            search_settings (Optional[dict, SearchSettings]]): Vector search settings.
+
+        Returns:
+            WrappedSearchResponse
+        """
+
+        response_dict = self.client._make_request(
+            "POST",
+            "retrieval/search",
+            json=search_arg_parser(
+                query=query,
+                search_mode=search_mode,
+                search_settings=search_settings,
+            ),
+            version="v3",
+        )
+
+        return WrappedSearchResponse(**response_dict)
+
+    def completion(
+        self,
+        messages: list[dict | Message],
+        generation_config: Optional[dict | GenerationConfig] = None,
+    ) -> WrappedLLMChatCompletion:
+        cast_messages: list[Message] = [
+            Message(**msg) if isinstance(msg, dict) else msg
+            for msg in messages
+        ]
+
+        if generation_config and not isinstance(generation_config, dict):
+            generation_config = generation_config.model_dump()
+
+        data: dict[str, Any] = {
+            "messages": [msg.model_dump() for msg in cast_messages],
+            "generation_config": generation_config,
+        }
+        response_dict = self.client._make_request(
+            "POST",
+            "retrieval/completion",
+            json=completion_arg_parser(messages, generation_config),
+            version="v3",
+        )
+
+        return WrappedLLMChatCompletion(**response_dict)
+
+    def embedding(
+        self,
+        text: str,
+    ) -> WrappedEmbeddingResponse:
+        response_dict = self.client._make_request(
+            "POST",
+            "retrieval/embedding",
+            data=embedding_arg_parser(text),
+            version="v3",
+        )
+
+        return WrappedEmbeddingResponse(**response_dict)
+
+    def rag(
+        self,
+        query: str,
+        rag_generation_config: Optional[dict | GenerationConfig] = None,
+        search_mode: Optional[str | SearchMode] = "custom",
+        search_settings: Optional[dict | SearchSettings] = None,
+        task_prompt: Optional[str] = None,
+        include_title_if_available: Optional[bool] = False,
+        include_web_search: Optional[bool] = False,
+    ) -> (
+        WrappedRAGResponse
+        | Generator[
+            ThinkingEvent
+            | SearchResultsEvent
+            | MessageEvent
+            | CitationEvent
+            | FinalAnswerEvent
+            | ToolCallEvent
+            | ToolResultEvent
+            | UnknownEvent
+            | None,
+            None,
+            None,
+        ]
+    ):
+        """Conducts a Retrieval Augmented Generation (RAG) search with the
+        given query.
+
+        Args:
+            query (str): The query to search for.
+            rag_generation_config (Optional[dict | GenerationConfig]): RAG generation configuration.
+            search_settings (Optional[dict | SearchSettings]): Vector search settings.
+            task_prompt (Optional[str]): Task prompt override.
+            include_title_if_available (Optional[bool]): Include the title if available.
+
+        Returns:
+            WrappedRAGResponse | AsyncGenerator[RAGResponse, None]: The RAG response
+        """
+        data = rag_arg_parser(
+            query=query,
+            rag_generation_config=rag_generation_config,
+            search_mode=search_mode,
+            search_settings=search_settings,
+            task_prompt=task_prompt,
+            include_title_if_available=include_title_if_available,
+            include_web_search=include_web_search,
+        )
+        rag_generation_config = data.get("rag_generation_config")
+        if rag_generation_config and rag_generation_config.get(  # type: ignore
+            "stream", False
+        ):
+            raw_stream = self.client._make_streaming_request(
+                "POST",
+                "retrieval/rag",
+                json=data,
+                version="v3",
+            )
+            # Wrap the raw stream to parse each event
+            return (parse_retrieval_event(event) for event in raw_stream)
+
+        response_dict = self.client._make_request(
+            "POST",
+            "retrieval/rag",
+            json=data,
+            version="v3",
+        )
+
+        return WrappedRAGResponse(**response_dict)
+
+    def agent(
+        self,
+        message: Optional[dict | Message] = None,
+        rag_generation_config: Optional[dict | GenerationConfig] = None,
+        research_generation_config: Optional[dict | GenerationConfig] = None,
+        search_mode: Optional[str | SearchMode] = "custom",
+        search_settings: Optional[dict | SearchSettings] = None,
+        task_prompt: Optional[str] = None,
+        include_title_if_available: Optional[bool] = True,
+        conversation_id: Optional[str | uuid.UUID] = None,
+        max_tool_context_length: Optional[int] = None,
+        use_system_context: Optional[bool] = True,
+        # Tool configurations
+        rag_tools: Optional[list[str]] = None,
+        research_tools: Optional[list[str]] = None,
+        tools: Optional[list[str]] = None,  # For backward compatibility
+        mode: Optional[str] = "rag",
+        needs_initial_conversation_name: Optional[bool] = None,
+    ) -> (
+        WrappedAgentResponse
+        | Generator[
+            ThinkingEvent
+            | SearchResultsEvent
+            | MessageEvent
+            | CitationEvent
+            | FinalAnswerEvent
+            | ToolCallEvent
+            | ToolResultEvent
+            | UnknownEvent
+            | None,
+            None,
+            None,
+        ]
+    ):
+        """Performs a single turn in a conversation with a RAG agent.
+
+        Args:
+            message (Optional[dict | Message]): The message to send to the agent.
+            rag_generation_config (Optional[dict | GenerationConfig]): Configuration for RAG generation in 'rag' mode.
+            research_generation_config (Optional[dict | GenerationConfig]): Configuration for generation in 'research' mode.
+            search_mode (Optional[str | SearchMode]): Pre-configured search modes: "basic", "advanced", or "custom".
+            search_settings (Optional[dict | SearchSettings]): Vector search settings.
+            task_prompt (Optional[str]): Task prompt override.
+            include_title_if_available (Optional[bool]): Include the title if available.
+            conversation_id (Optional[str | uuid.UUID]): ID of the conversation for maintaining context.
+            max_tool_context_length (Optional[int]): Maximum context length for tool replies.
+            use_system_context (Optional[bool]): Whether to use system context in the prompt.
+            rag_tools (Optional[list[str]]): List of tools to enable for RAG mode.
+                Available tools: "search_file_knowledge", "content", "web_search", "web_scrape", "search_file_descriptions".
+            research_tools (Optional[list[str]]): List of tools to enable for Research mode.
+                Available tools: "rag", "reasoning", "critique", "python_executor".
+            tools (Optional[list[str]]): Deprecated. List of tools to execute.
+            mode (Optional[str]): Mode to use for generation: "rag" for standard retrieval or "research" for deep analysis.
+                Defaults to "rag".
+
+        Returns:
+            WrappedAgentResponse | AsyncGenerator[AgentEvent, None]: The agent response.
+        """
+        data = agent_arg_parser(
+            message=message,
+            rag_generation_config=rag_generation_config,
+            research_generation_config=research_generation_config,
+            search_mode=search_mode,
+            search_settings=search_settings,
+            task_prompt=task_prompt,
+            include_title_if_available=include_title_if_available,
+            conversation_id=conversation_id,
+            max_tool_context_length=max_tool_context_length,
+            use_system_context=use_system_context,
+            rag_tools=rag_tools,
+            research_tools=research_tools,
+            tools=tools,
+            mode=mode,
+            needs_initial_conversation_name=needs_initial_conversation_name,
+        )
+
+        # Determine if streaming is enabled
+        if search_mode:
+            data["search_mode"] = search_mode
+
+        if message:
+            cast_message: Message = (
+                Message(**message) if isinstance(message, dict) else message
+            )
+            data["message"] = cast_message.model_dump()
+
+        is_stream = False
+        if mode != "research":
+            if rag_generation_config:
+                if isinstance(rag_generation_config, dict):
+                    is_stream = rag_generation_config.get(  # type: ignore
+                        "stream", False
+                    )
+                else:
+                    is_stream = rag_generation_config.stream
+        else:
+            if research_generation_config:
+                if isinstance(research_generation_config, dict):
+                    is_stream = research_generation_config.get(  # type: ignore
+                        "stream", False
+                    )
+                else:
+                    is_stream = research_generation_config.stream
+
+        if is_stream:
+            raw_stream = self.client._make_streaming_request(
+                "POST",
+                "retrieval/agent",
+                json=data,
+                version="v3",
+            )
+            return (parse_retrieval_event(event) for event in raw_stream)
+
+        response_dict = self.client._make_request(
+            "POST",
+            "retrieval/agent",
+            json=data,
+            version="v3",
+        )
+
+        return WrappedAgentResponse(**response_dict)