diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/sdk/sync_methods/retrieval.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/sdk/sync_methods/retrieval.py | 554 |
1 files changed, 554 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/sdk/sync_methods/retrieval.py b/.venv/lib/python3.12/site-packages/sdk/sync_methods/retrieval.py new file mode 100644 index 00000000..4a927014 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/sdk/sync_methods/retrieval.py @@ -0,0 +1,554 @@ +import json +import uuid +from typing import Any, Generator, Optional + +from shared.api.models import ( + WrappedAgentResponse, + WrappedEmbeddingResponse, + WrappedLLMChatCompletion, + WrappedRAGResponse, + WrappedSearchResponse, +) + +from ..models import ( + AgentEvent, + CitationData, + CitationEvent, + Delta, + DeltaPayload, + FinalAnswerData, + FinalAnswerEvent, + GenerationConfig, + Message, + MessageData, + MessageDelta, + MessageEvent, + SearchMode, + SearchResultsData, + SearchResultsEvent, + SearchSettings, + ThinkingData, + ThinkingEvent, + ToolCallData, + ToolCallEvent, + ToolResultData, + ToolResultEvent, + UnknownEvent, +) + + +def parse_retrieval_event(raw: dict) -> Optional[AgentEvent]: + """ + Convert a raw SSE event dict into a typed Pydantic model. + + Example raw dict: + { + "event": "message", + "data": "{\"id\": \"msg_partial\", \"object\": \"agent.message.delta\", \"delta\": {...}}" + } + """ + event_type = raw.get("event", "unknown") + + # If event_type == "done", we usually return None to signal the SSE stream is finished. + if event_type == "done": + return None + + # The SSE "data" is JSON-encoded, so parse it + data_str = raw.get("data", "") + try: + data_obj = json.loads(data_str) + except json.JSONDecodeError as e: + # You can decide whether to raise or return UnknownEvent + raise ValueError(f"Could not parse JSON in SSE event data: {e}") from e + + # Now branch on event_type to build the right Pydantic model + if event_type == "search_results": + return SearchResultsEvent( + event=event_type, + data=SearchResultsData(**data_obj), + ) + elif event_type == "message": + # Parse nested delta structure manually before creating MessageData + if "delta" in data_obj and isinstance(data_obj["delta"], dict): + delta_dict = data_obj["delta"] + + # Convert content items to MessageDelta objects + if "content" in delta_dict and isinstance( + delta_dict["content"], list + ): + parsed_content = [] + for item in delta_dict["content"]: + if isinstance(item, dict): + # Parse payload to DeltaPayload + if "payload" in item and isinstance( + item["payload"], dict + ): + payload_dict = item["payload"] + item["payload"] = DeltaPayload(**payload_dict) + parsed_content.append(MessageDelta(**item)) + + # Replace with parsed content + delta_dict["content"] = parsed_content + + # Create properly typed Delta object + data_obj["delta"] = Delta(**delta_dict) + + return MessageEvent( + event=event_type, + data=MessageData(**data_obj), + ) + elif event_type == "citation": + return CitationEvent(event=event_type, data=CitationData(**data_obj)) + elif event_type == "tool_call": + return ToolCallEvent(event=event_type, data=ToolCallData(**data_obj)) + elif event_type == "tool_result": + return ToolResultEvent( + event=event_type, data=ToolResultData(**data_obj) + ) + elif event_type == "thinking": + # Parse nested delta structure manually before creating ThinkingData + if "delta" in data_obj and isinstance(data_obj["delta"], dict): + delta_dict = data_obj["delta"] + + # Convert content items to MessageDelta objects + if "content" in delta_dict and isinstance( + delta_dict["content"], list + ): + parsed_content = [] + for item in delta_dict["content"]: + if isinstance(item, dict): + # Parse payload to DeltaPayload + if "payload" in item and isinstance( + item["payload"], dict + ): + payload_dict = item["payload"] + item["payload"] = DeltaPayload(**payload_dict) + parsed_content.append(MessageDelta(**item)) + + # Replace with parsed content + delta_dict["content"] = parsed_content + + # Create properly typed Delta object + data_obj["delta"] = Delta(**delta_dict) + + return ThinkingEvent( + event=event_type, + data=ThinkingData(**data_obj), + ) + elif event_type == "final_answer": + return FinalAnswerEvent( + event=event_type, data=FinalAnswerData(**data_obj) + ) + else: + # Fallback if it doesn't match any known event + return UnknownEvent( + event=event_type, + data=data_obj, + ) + + +def search_arg_parser( + query: str, + search_mode: Optional[str | SearchMode] = "custom", + search_settings: Optional[dict | SearchSettings] = None, +) -> dict: + if search_mode and not isinstance(search_mode, str): + search_mode = search_mode.value + + if search_settings and not isinstance(search_settings, dict): + search_settings = search_settings.model_dump() + + data: dict[str, Any] = { + "query": query, + "search_settings": search_settings, + } + if search_mode: + data["search_mode"] = search_mode + + return data + + +def completion_arg_parser( + messages: list[dict | Message], + generation_config: Optional[dict | GenerationConfig] = None, +) -> dict: + # FIXME: Needs a proper return type + cast_messages: list[Message] = [ + Message(**msg) if isinstance(msg, dict) else msg for msg in messages + ] + + if generation_config and not isinstance(generation_config, dict): + generation_config = generation_config.model_dump() + + data: dict[str, Any] = { + "messages": [msg.model_dump() for msg in cast_messages], + "generation_config": generation_config, + } + return data + + +def embedding_arg_parser( + text: str, +) -> dict: + data: dict[str, Any] = { + "text": text, + } + return data + + +def rag_arg_parser( + query: str, + rag_generation_config: Optional[dict | GenerationConfig] = None, + search_mode: Optional[str | SearchMode] = "custom", + search_settings: Optional[dict | SearchSettings] = None, + task_prompt: Optional[str] = None, + include_title_if_available: Optional[bool] = False, + include_web_search: Optional[bool] = False, +) -> dict: + if rag_generation_config and not isinstance(rag_generation_config, dict): + rag_generation_config = rag_generation_config.model_dump() + if search_settings and not isinstance(search_settings, dict): + search_settings = search_settings.model_dump() + + data: dict[str, Any] = { + "query": query, + "rag_generation_config": rag_generation_config, + "search_settings": search_settings, + "task_prompt": task_prompt, + "include_title_if_available": include_title_if_available, + "include_web_search": include_web_search, + } + if search_mode: + data["search_mode"] = search_mode + return data + + +def agent_arg_parser( + message: Optional[dict | Message] = None, + rag_generation_config: Optional[dict | GenerationConfig] = None, + research_generation_config: Optional[dict | GenerationConfig] = None, + search_mode: Optional[str | SearchMode] = "custom", + search_settings: Optional[dict | SearchSettings] = None, + task_prompt: Optional[str] = None, + include_title_if_available: Optional[bool] = True, + conversation_id: Optional[str | uuid.UUID] = None, + max_tool_context_length: Optional[int] = None, + use_system_context: Optional[bool] = True, + rag_tools: Optional[list[str]] = None, + research_tools: Optional[list[str]] = None, + tools: Optional[list[str]] = None, # For backward compatibility + mode: Optional[str] = "rag", + needs_initial_conversation_name: Optional[bool] = None, +) -> dict: + if rag_generation_config and not isinstance(rag_generation_config, dict): + rag_generation_config = rag_generation_config.model_dump() + if research_generation_config and not isinstance( + research_generation_config, dict + ): + research_generation_config = research_generation_config.model_dump() + if search_settings and not isinstance(search_settings, dict): + search_settings = search_settings.model_dump() + + data: dict[str, Any] = { + "rag_generation_config": rag_generation_config or {}, + "search_settings": search_settings, + "task_prompt": task_prompt, + "include_title_if_available": include_title_if_available, + "conversation_id": (str(conversation_id) if conversation_id else None), + "max_tool_context_length": max_tool_context_length, + "use_system_context": use_system_context, + "mode": mode, + } + + # Handle generation configs based on mode + if research_generation_config and mode == "research": + data["research_generation_config"] = research_generation_config + + # Handle tool configurations + if rag_tools: + data["rag_tools"] = rag_tools + if research_tools: + data["research_tools"] = research_tools + if tools: # Backward compatibility + data["tools"] = tools + + if search_mode: + data["search_mode"] = search_mode + + if needs_initial_conversation_name: + data["needs_initial_conversation_name"] = ( + needs_initial_conversation_name + ) + + if message: + cast_message: Message = ( + Message(**message) if isinstance(message, dict) else message + ) + data["message"] = cast_message.model_dump() + return data + + +class RetrievalSDK: + """SDK for interacting with documents in the v3 API.""" + + def __init__(self, client): + self.client = client + + def search( + self, + query: str, + search_mode: Optional[str | SearchMode] = "custom", + search_settings: Optional[dict | SearchSettings] = None, + ) -> WrappedSearchResponse: + """Conduct a vector and/or graph search. + + Args: + query (str): The query to search for. + search_settings (Optional[dict, SearchSettings]]): Vector search settings. + + Returns: + WrappedSearchResponse + """ + + response_dict = self.client._make_request( + "POST", + "retrieval/search", + json=search_arg_parser( + query=query, + search_mode=search_mode, + search_settings=search_settings, + ), + version="v3", + ) + + return WrappedSearchResponse(**response_dict) + + def completion( + self, + messages: list[dict | Message], + generation_config: Optional[dict | GenerationConfig] = None, + ) -> WrappedLLMChatCompletion: + cast_messages: list[Message] = [ + Message(**msg) if isinstance(msg, dict) else msg + for msg in messages + ] + + if generation_config and not isinstance(generation_config, dict): + generation_config = generation_config.model_dump() + + data: dict[str, Any] = { + "messages": [msg.model_dump() for msg in cast_messages], + "generation_config": generation_config, + } + response_dict = self.client._make_request( + "POST", + "retrieval/completion", + json=completion_arg_parser(messages, generation_config), + version="v3", + ) + + return WrappedLLMChatCompletion(**response_dict) + + def embedding( + self, + text: str, + ) -> WrappedEmbeddingResponse: + response_dict = self.client._make_request( + "POST", + "retrieval/embedding", + data=embedding_arg_parser(text), + version="v3", + ) + + return WrappedEmbeddingResponse(**response_dict) + + def rag( + self, + query: str, + rag_generation_config: Optional[dict | GenerationConfig] = None, + search_mode: Optional[str | SearchMode] = "custom", + search_settings: Optional[dict | SearchSettings] = None, + task_prompt: Optional[str] = None, + include_title_if_available: Optional[bool] = False, + include_web_search: Optional[bool] = False, + ) -> ( + WrappedRAGResponse + | Generator[ + ThinkingEvent + | SearchResultsEvent + | MessageEvent + | CitationEvent + | FinalAnswerEvent + | ToolCallEvent + | ToolResultEvent + | UnknownEvent + | None, + None, + None, + ] + ): + """Conducts a Retrieval Augmented Generation (RAG) search with the + given query. + + Args: + query (str): The query to search for. + rag_generation_config (Optional[dict | GenerationConfig]): RAG generation configuration. + search_settings (Optional[dict | SearchSettings]): Vector search settings. + task_prompt (Optional[str]): Task prompt override. + include_title_if_available (Optional[bool]): Include the title if available. + + Returns: + WrappedRAGResponse | AsyncGenerator[RAGResponse, None]: The RAG response + """ + data = rag_arg_parser( + query=query, + rag_generation_config=rag_generation_config, + search_mode=search_mode, + search_settings=search_settings, + task_prompt=task_prompt, + include_title_if_available=include_title_if_available, + include_web_search=include_web_search, + ) + rag_generation_config = data.get("rag_generation_config") + if rag_generation_config and rag_generation_config.get( # type: ignore + "stream", False + ): + raw_stream = self.client._make_streaming_request( + "POST", + "retrieval/rag", + json=data, + version="v3", + ) + # Wrap the raw stream to parse each event + return (parse_retrieval_event(event) for event in raw_stream) + + response_dict = self.client._make_request( + "POST", + "retrieval/rag", + json=data, + version="v3", + ) + + return WrappedRAGResponse(**response_dict) + + def agent( + self, + message: Optional[dict | Message] = None, + rag_generation_config: Optional[dict | GenerationConfig] = None, + research_generation_config: Optional[dict | GenerationConfig] = None, + search_mode: Optional[str | SearchMode] = "custom", + search_settings: Optional[dict | SearchSettings] = None, + task_prompt: Optional[str] = None, + include_title_if_available: Optional[bool] = True, + conversation_id: Optional[str | uuid.UUID] = None, + max_tool_context_length: Optional[int] = None, + use_system_context: Optional[bool] = True, + # Tool configurations + rag_tools: Optional[list[str]] = None, + research_tools: Optional[list[str]] = None, + tools: Optional[list[str]] = None, # For backward compatibility + mode: Optional[str] = "rag", + needs_initial_conversation_name: Optional[bool] = None, + ) -> ( + WrappedAgentResponse + | Generator[ + ThinkingEvent + | SearchResultsEvent + | MessageEvent + | CitationEvent + | FinalAnswerEvent + | ToolCallEvent + | ToolResultEvent + | UnknownEvent + | None, + None, + None, + ] + ): + """Performs a single turn in a conversation with a RAG agent. + + Args: + message (Optional[dict | Message]): The message to send to the agent. + rag_generation_config (Optional[dict | GenerationConfig]): Configuration for RAG generation in 'rag' mode. + research_generation_config (Optional[dict | GenerationConfig]): Configuration for generation in 'research' mode. + search_mode (Optional[str | SearchMode]): Pre-configured search modes: "basic", "advanced", or "custom". + search_settings (Optional[dict | SearchSettings]): Vector search settings. + task_prompt (Optional[str]): Task prompt override. + include_title_if_available (Optional[bool]): Include the title if available. + conversation_id (Optional[str | uuid.UUID]): ID of the conversation for maintaining context. + max_tool_context_length (Optional[int]): Maximum context length for tool replies. + use_system_context (Optional[bool]): Whether to use system context in the prompt. + rag_tools (Optional[list[str]]): List of tools to enable for RAG mode. + Available tools: "search_file_knowledge", "content", "web_search", "web_scrape", "search_file_descriptions". + research_tools (Optional[list[str]]): List of tools to enable for Research mode. + Available tools: "rag", "reasoning", "critique", "python_executor". + tools (Optional[list[str]]): Deprecated. List of tools to execute. + mode (Optional[str]): Mode to use for generation: "rag" for standard retrieval or "research" for deep analysis. + Defaults to "rag". + + Returns: + WrappedAgentResponse | AsyncGenerator[AgentEvent, None]: The agent response. + """ + data = agent_arg_parser( + message=message, + rag_generation_config=rag_generation_config, + research_generation_config=research_generation_config, + search_mode=search_mode, + search_settings=search_settings, + task_prompt=task_prompt, + include_title_if_available=include_title_if_available, + conversation_id=conversation_id, + max_tool_context_length=max_tool_context_length, + use_system_context=use_system_context, + rag_tools=rag_tools, + research_tools=research_tools, + tools=tools, + mode=mode, + needs_initial_conversation_name=needs_initial_conversation_name, + ) + + # Determine if streaming is enabled + if search_mode: + data["search_mode"] = search_mode + + if message: + cast_message: Message = ( + Message(**message) if isinstance(message, dict) else message + ) + data["message"] = cast_message.model_dump() + + is_stream = False + if mode != "research": + if rag_generation_config: + if isinstance(rag_generation_config, dict): + is_stream = rag_generation_config.get( # type: ignore + "stream", False + ) + else: + is_stream = rag_generation_config.stream + else: + if research_generation_config: + if isinstance(research_generation_config, dict): + is_stream = research_generation_config.get( # type: ignore + "stream", False + ) + else: + is_stream = research_generation_config.stream + + if is_stream: + raw_stream = self.client._make_streaming_request( + "POST", + "retrieval/agent", + json=data, + version="v3", + ) + return (parse_retrieval_event(event) for event in raw_stream) + + response_dict = self.client._make_request( + "POST", + "retrieval/agent", + json=data, + version="v3", + ) + + return WrappedAgentResponse(**response_dict) |