import json import uuid from typing import Any, Generator, Optional from shared.api.models import ( WrappedAgentResponse, WrappedEmbeddingResponse, WrappedLLMChatCompletion, WrappedRAGResponse, WrappedSearchResponse, ) from ..models import ( AgentEvent, CitationData, CitationEvent, Delta, DeltaPayload, FinalAnswerData, FinalAnswerEvent, GenerationConfig, Message, MessageData, MessageDelta, MessageEvent, SearchMode, SearchResultsData, SearchResultsEvent, SearchSettings, ThinkingData, ThinkingEvent, ToolCallData, ToolCallEvent, ToolResultData, ToolResultEvent, UnknownEvent, ) def parse_retrieval_event(raw: dict) -> Optional[AgentEvent]: """ Convert a raw SSE event dict into a typed Pydantic model. Example raw dict: { "event": "message", "data": "{\"id\": \"msg_partial\", \"object\": \"agent.message.delta\", \"delta\": {...}}" } """ event_type = raw.get("event", "unknown") # If event_type == "done", we usually return None to signal the SSE stream is finished. if event_type == "done": return None # The SSE "data" is JSON-encoded, so parse it data_str = raw.get("data", "") try: data_obj = json.loads(data_str) except json.JSONDecodeError as e: # You can decide whether to raise or return UnknownEvent raise ValueError(f"Could not parse JSON in SSE event data: {e}") from e # Now branch on event_type to build the right Pydantic model if event_type == "search_results": return SearchResultsEvent( event=event_type, data=SearchResultsData(**data_obj), ) elif event_type == "message": # Parse nested delta structure manually before creating MessageData if "delta" in data_obj and isinstance(data_obj["delta"], dict): delta_dict = data_obj["delta"] # Convert content items to MessageDelta objects if "content" in delta_dict and isinstance( delta_dict["content"], list ): parsed_content = [] for item in delta_dict["content"]: if isinstance(item, dict): # Parse payload to DeltaPayload if "payload" in item and isinstance( item["payload"], dict ): payload_dict = item["payload"] item["payload"] = DeltaPayload(**payload_dict) parsed_content.append(MessageDelta(**item)) # Replace with parsed content delta_dict["content"] = parsed_content # Create properly typed Delta object data_obj["delta"] = Delta(**delta_dict) return MessageEvent( event=event_type, data=MessageData(**data_obj), ) elif event_type == "citation": return CitationEvent(event=event_type, data=CitationData(**data_obj)) elif event_type == "tool_call": return ToolCallEvent(event=event_type, data=ToolCallData(**data_obj)) elif event_type == "tool_result": return ToolResultEvent( event=event_type, data=ToolResultData(**data_obj) ) elif event_type == "thinking": # Parse nested delta structure manually before creating ThinkingData if "delta" in data_obj and isinstance(data_obj["delta"], dict): delta_dict = data_obj["delta"] # Convert content items to MessageDelta objects if "content" in delta_dict and isinstance( delta_dict["content"], list ): parsed_content = [] for item in delta_dict["content"]: if isinstance(item, dict): # Parse payload to DeltaPayload if "payload" in item and isinstance( item["payload"], dict ): payload_dict = item["payload"] item["payload"] = DeltaPayload(**payload_dict) parsed_content.append(MessageDelta(**item)) # Replace with parsed content delta_dict["content"] = parsed_content # Create properly typed Delta object data_obj["delta"] = Delta(**delta_dict) return ThinkingEvent( event=event_type, data=ThinkingData(**data_obj), ) elif event_type == "final_answer": return FinalAnswerEvent( event=event_type, data=FinalAnswerData(**data_obj) ) else: # Fallback if it doesn't match any known event return UnknownEvent( event=event_type, data=data_obj, ) def search_arg_parser( query: str, search_mode: Optional[str | SearchMode] = "custom", search_settings: Optional[dict | SearchSettings] = None, ) -> dict: if search_mode and not isinstance(search_mode, str): search_mode = search_mode.value if search_settings and not isinstance(search_settings, dict): search_settings = search_settings.model_dump() data: dict[str, Any] = { "query": query, "search_settings": search_settings, } if search_mode: data["search_mode"] = search_mode return data def completion_arg_parser( messages: list[dict | Message], generation_config: Optional[dict | GenerationConfig] = None, ) -> dict: # FIXME: Needs a proper return type cast_messages: list[Message] = [ Message(**msg) if isinstance(msg, dict) else msg for msg in messages ] if generation_config and not isinstance(generation_config, dict): generation_config = generation_config.model_dump() data: dict[str, Any] = { "messages": [msg.model_dump() for msg in cast_messages], "generation_config": generation_config, } return data def embedding_arg_parser( text: str, ) -> dict: data: dict[str, Any] = { "text": text, } return data def rag_arg_parser( query: str, rag_generation_config: Optional[dict | GenerationConfig] = None, search_mode: Optional[str | SearchMode] = "custom", search_settings: Optional[dict | SearchSettings] = None, task_prompt: Optional[str] = None, include_title_if_available: Optional[bool] = False, include_web_search: Optional[bool] = False, ) -> dict: if rag_generation_config and not isinstance(rag_generation_config, dict): rag_generation_config = rag_generation_config.model_dump() if search_settings and not isinstance(search_settings, dict): search_settings = search_settings.model_dump() data: dict[str, Any] = { "query": query, "rag_generation_config": rag_generation_config, "search_settings": search_settings, "task_prompt": task_prompt, "include_title_if_available": include_title_if_available, "include_web_search": include_web_search, } if search_mode: data["search_mode"] = search_mode return data def agent_arg_parser( message: Optional[dict | Message] = None, rag_generation_config: Optional[dict | GenerationConfig] = None, research_generation_config: Optional[dict | GenerationConfig] = None, search_mode: Optional[str | SearchMode] = "custom", search_settings: Optional[dict | SearchSettings] = None, task_prompt: Optional[str] = None, include_title_if_available: Optional[bool] = True, conversation_id: Optional[str | uuid.UUID] = None, max_tool_context_length: Optional[int] = None, use_system_context: Optional[bool] = True, rag_tools: Optional[list[str]] = None, research_tools: Optional[list[str]] = None, tools: Optional[list[str]] = None, # For backward compatibility mode: Optional[str] = "rag", needs_initial_conversation_name: Optional[bool] = None, ) -> dict: if rag_generation_config and not isinstance(rag_generation_config, dict): rag_generation_config = rag_generation_config.model_dump() if research_generation_config and not isinstance( research_generation_config, dict ): research_generation_config = research_generation_config.model_dump() if search_settings and not isinstance(search_settings, dict): search_settings = search_settings.model_dump() data: dict[str, Any] = { "rag_generation_config": rag_generation_config or {}, "search_settings": search_settings, "task_prompt": task_prompt, "include_title_if_available": include_title_if_available, "conversation_id": (str(conversation_id) if conversation_id else None), "max_tool_context_length": max_tool_context_length, "use_system_context": use_system_context, "mode": mode, } # Handle generation configs based on mode if research_generation_config and mode == "research": data["research_generation_config"] = research_generation_config # Handle tool configurations if rag_tools: data["rag_tools"] = rag_tools if research_tools: data["research_tools"] = research_tools if tools: # Backward compatibility data["tools"] = tools if search_mode: data["search_mode"] = search_mode if needs_initial_conversation_name: data["needs_initial_conversation_name"] = ( needs_initial_conversation_name ) if message: cast_message: Message = ( Message(**message) if isinstance(message, dict) else message ) data["message"] = cast_message.model_dump() return data class RetrievalSDK: """SDK for interacting with documents in the v3 API.""" def __init__(self, client): self.client = client def search( self, query: str, search_mode: Optional[str | SearchMode] = "custom", search_settings: Optional[dict | SearchSettings] = None, ) -> WrappedSearchResponse: """Conduct a vector and/or graph search. Args: query (str): The query to search for. search_settings (Optional[dict, SearchSettings]]): Vector search settings. Returns: WrappedSearchResponse """ response_dict = self.client._make_request( "POST", "retrieval/search", json=search_arg_parser( query=query, search_mode=search_mode, search_settings=search_settings, ), version="v3", ) return WrappedSearchResponse(**response_dict) def completion( self, messages: list[dict | Message], generation_config: Optional[dict | GenerationConfig] = None, ) -> WrappedLLMChatCompletion: cast_messages: list[Message] = [ Message(**msg) if isinstance(msg, dict) else msg for msg in messages ] if generation_config and not isinstance(generation_config, dict): generation_config = generation_config.model_dump() data: dict[str, Any] = { "messages": [msg.model_dump() for msg in cast_messages], "generation_config": generation_config, } response_dict = self.client._make_request( "POST", "retrieval/completion", json=completion_arg_parser(messages, generation_config), version="v3", ) return WrappedLLMChatCompletion(**response_dict) def embedding( self, text: str, ) -> WrappedEmbeddingResponse: response_dict = self.client._make_request( "POST", "retrieval/embedding", data=embedding_arg_parser(text), version="v3", ) return WrappedEmbeddingResponse(**response_dict) def rag( self, query: str, rag_generation_config: Optional[dict | GenerationConfig] = None, search_mode: Optional[str | SearchMode] = "custom", search_settings: Optional[dict | SearchSettings] = None, task_prompt: Optional[str] = None, include_title_if_available: Optional[bool] = False, include_web_search: Optional[bool] = False, ) -> ( WrappedRAGResponse | Generator[ ThinkingEvent | SearchResultsEvent | MessageEvent | CitationEvent | FinalAnswerEvent | ToolCallEvent | ToolResultEvent | UnknownEvent | None, None, None, ] ): """Conducts a Retrieval Augmented Generation (RAG) search with the given query. Args: query (str): The query to search for. rag_generation_config (Optional[dict | GenerationConfig]): RAG generation configuration. search_settings (Optional[dict | SearchSettings]): Vector search settings. task_prompt (Optional[str]): Task prompt override. include_title_if_available (Optional[bool]): Include the title if available. Returns: WrappedRAGResponse | AsyncGenerator[RAGResponse, None]: The RAG response """ data = rag_arg_parser( query=query, rag_generation_config=rag_generation_config, search_mode=search_mode, search_settings=search_settings, task_prompt=task_prompt, include_title_if_available=include_title_if_available, include_web_search=include_web_search, ) rag_generation_config = data.get("rag_generation_config") if rag_generation_config and rag_generation_config.get( # type: ignore "stream", False ): raw_stream = self.client._make_streaming_request( "POST", "retrieval/rag", json=data, version="v3", ) # Wrap the raw stream to parse each event return (parse_retrieval_event(event) for event in raw_stream) response_dict = self.client._make_request( "POST", "retrieval/rag", json=data, version="v3", ) return WrappedRAGResponse(**response_dict) def agent( self, message: Optional[dict | Message] = None, rag_generation_config: Optional[dict | GenerationConfig] = None, research_generation_config: Optional[dict | GenerationConfig] = None, search_mode: Optional[str | SearchMode] = "custom", search_settings: Optional[dict | SearchSettings] = None, task_prompt: Optional[str] = None, include_title_if_available: Optional[bool] = True, conversation_id: Optional[str | uuid.UUID] = None, max_tool_context_length: Optional[int] = None, use_system_context: Optional[bool] = True, # Tool configurations rag_tools: Optional[list[str]] = None, research_tools: Optional[list[str]] = None, tools: Optional[list[str]] = None, # For backward compatibility mode: Optional[str] = "rag", needs_initial_conversation_name: Optional[bool] = None, ) -> ( WrappedAgentResponse | Generator[ ThinkingEvent | SearchResultsEvent | MessageEvent | CitationEvent | FinalAnswerEvent | ToolCallEvent | ToolResultEvent | UnknownEvent | None, None, None, ] ): """Performs a single turn in a conversation with a RAG agent. Args: message (Optional[dict | Message]): The message to send to the agent. rag_generation_config (Optional[dict | GenerationConfig]): Configuration for RAG generation in 'rag' mode. research_generation_config (Optional[dict | GenerationConfig]): Configuration for generation in 'research' mode. search_mode (Optional[str | SearchMode]): Pre-configured search modes: "basic", "advanced", or "custom". search_settings (Optional[dict | SearchSettings]): Vector search settings. task_prompt (Optional[str]): Task prompt override. include_title_if_available (Optional[bool]): Include the title if available. conversation_id (Optional[str | uuid.UUID]): ID of the conversation for maintaining context. max_tool_context_length (Optional[int]): Maximum context length for tool replies. use_system_context (Optional[bool]): Whether to use system context in the prompt. rag_tools (Optional[list[str]]): List of tools to enable for RAG mode. Available tools: "search_file_knowledge", "content", "web_search", "web_scrape", "search_file_descriptions". research_tools (Optional[list[str]]): List of tools to enable for Research mode. Available tools: "rag", "reasoning", "critique", "python_executor". tools (Optional[list[str]]): Deprecated. List of tools to execute. mode (Optional[str]): Mode to use for generation: "rag" for standard retrieval or "research" for deep analysis. Defaults to "rag". Returns: WrappedAgentResponse | AsyncGenerator[AgentEvent, None]: The agent response. """ data = agent_arg_parser( message=message, rag_generation_config=rag_generation_config, research_generation_config=research_generation_config, search_mode=search_mode, search_settings=search_settings, task_prompt=task_prompt, include_title_if_available=include_title_if_available, conversation_id=conversation_id, max_tool_context_length=max_tool_context_length, use_system_context=use_system_context, rag_tools=rag_tools, research_tools=research_tools, tools=tools, mode=mode, needs_initial_conversation_name=needs_initial_conversation_name, ) # Determine if streaming is enabled if search_mode: data["search_mode"] = search_mode if message: cast_message: Message = ( Message(**message) if isinstance(message, dict) else message ) data["message"] = cast_message.model_dump() is_stream = False if mode != "research": if rag_generation_config: if isinstance(rag_generation_config, dict): is_stream = rag_generation_config.get( # type: ignore "stream", False ) else: is_stream = rag_generation_config.stream else: if research_generation_config: if isinstance(research_generation_config, dict): is_stream = research_generation_config.get( # type: ignore "stream", False ) else: is_stream = research_generation_config.stream if is_stream: raw_stream = self.client._make_streaming_request( "POST", "retrieval/agent", json=data, version="v3", ) return (parse_retrieval_event(event) for event in raw_stream) response_dict = self.client._make_request( "POST", "retrieval/agent", json=data, version="v3", ) return WrappedAgentResponse(**response_dict)