from typing import Generator from shared.api.models import ( CitationEvent, FinalAnswerEvent, MessageEvent, SearchResultsEvent, ThinkingEvent, ToolCallEvent, ToolResultEvent, UnknownEvent, WrappedAgentResponse, WrappedRAGResponse, WrappedSearchResponse, ) from ..models import ( Message, ) from ..sync_methods.retrieval import parse_retrieval_event class RetrievalSDK: """ SDK for interacting with documents in the v3 API (Asynchronous). """ def __init__(self, client): self.client = client async def search(self, **kwargs) -> WrappedSearchResponse: """ Conduct a vector and/or graph search (async). Args: query (str): Search query to find relevant documents. search_mode (Optional[str | SearchMode]): Pre-configured search modes: "basic", "advanced", or "custom". search_settings (Optional[dict | SearchSettings]): The search configuration object. If search_mode is "custom", these settings are used as-is. For "basic" or "advanced", these settings will override the default mode configuration. Returns: WrappedSearchResponse: The search results. """ # Extract the required query parameter query = kwargs.pop("query", None) if query is None: raise ValueError("'query' is a required parameter for search") # Process common parameters search_mode = kwargs.pop("search_mode", "custom") search_settings = kwargs.pop("search_settings", None) # Handle type conversions if search_mode and not isinstance(search_mode, str): search_mode = search_mode.value if search_settings and not isinstance(search_settings, dict): search_settings = search_settings.model_dump() # Build payload payload = { "query": query, "search_mode": search_mode, "search_settings": search_settings, **kwargs, # Include any additional parameters } # Filter out None values payload = {k: v for k, v in payload.items() if v is not None} response_dict = await self.client._make_request( "POST", "retrieval/search", json=payload, version="v3", ) return WrappedSearchResponse(**response_dict) async def completion(self, **kwargs): """ Get a completion from the model (async). Args: messages (list[dict | Message]): List of messages to generate completion for. Each message should have a 'role' and 'content'. generation_config (Optional[dict | GenerationConfig]): Configuration for text generation. Returns: The completion response. """ # Extract required parameters messages = kwargs.pop("messages", None) if messages is None: raise ValueError( "'messages' is a required parameter for completion" ) # Process optional parameters generation_config = kwargs.pop("generation_config", None) # Handle type conversions cast_messages = [ Message(**msg) if isinstance(msg, dict) else msg for msg in messages ] if generation_config and not isinstance(generation_config, dict): generation_config = generation_config.model_dump() # Build payload payload = { "messages": [msg.model_dump() for msg in cast_messages], "generation_config": generation_config, **kwargs, # Include any additional parameters } # Filter out None values payload = {k: v for k, v in payload.items() if v is not None} return await self.client._make_request( "POST", "retrieval/completion", json=payload, version="v3", ) async def embedding(self, **kwargs): """ Generate an embedding for given text (async). Args: text (str): Text to generate embeddings for. Returns: The embedding vector. """ # Extract required parameters text = kwargs.pop("text", None) if text is None: raise ValueError("'text' is a required parameter for embedding") # Build payload payload = {"text": text, **kwargs} # Include any additional parameters return await self.client._make_request( "POST", "retrieval/embedding", data=payload, version="v3", ) async def rag( self, **kwargs ) -> ( WrappedRAGResponse | Generator[ ThinkingEvent | SearchResultsEvent | MessageEvent | CitationEvent | FinalAnswerEvent | ToolCallEvent | ToolResultEvent | UnknownEvent | None, None, None, ] ): """ Conducts a Retrieval Augmented Generation (RAG) search (async). May return a `WrappedRAGResponse` or a streaming generator if `stream=True`. Args: query (str): The search query. rag_generation_config (Optional[dict | GenerationConfig]): Configuration for RAG generation. search_mode (Optional[str | SearchMode]): Pre-configured search modes: "basic", "advanced", or "custom". search_settings (Optional[dict | SearchSettings]): The search configuration object. task_prompt (Optional[str]): Optional custom prompt to override default. include_title_if_available (Optional[bool]): Include document titles in responses when available. include_web_search (Optional[bool]): Include web search results provided to the LLM. Returns: Either a WrappedRAGResponse or an AsyncGenerator for streaming. """ # Extract required parameters query = kwargs.pop("query", None) if query is None: raise ValueError("'query' is a required parameter for rag") # Process optional parameters rag_generation_config = kwargs.pop("rag_generation_config", None) search_mode = kwargs.pop("search_mode", "custom") search_settings = kwargs.pop("search_settings", None) task_prompt = kwargs.pop("task_prompt", None) include_title_if_available = kwargs.pop( "include_title_if_available", False ) include_web_search = kwargs.pop("include_web_search", False) # Handle type conversions if rag_generation_config and not isinstance( rag_generation_config, dict ): rag_generation_config = rag_generation_config.model_dump() if search_mode and not isinstance(search_mode, str): search_mode = search_mode.value if search_settings and not isinstance(search_settings, dict): search_settings = search_settings.model_dump() # Build payload payload = { "query": query, "rag_generation_config": rag_generation_config, "search_mode": search_mode, "search_settings": search_settings, "task_prompt": task_prompt, "include_title_if_available": include_title_if_available, "include_web_search": include_web_search, **kwargs, # Include any additional parameters } # Filter out None values payload = {k: v for k, v in payload.items() if v is not None} # Check if streaming is enabled is_stream = False if rag_generation_config and rag_generation_config.get( "stream", False ): is_stream = True if is_stream: # Return an async streaming generator raw_stream = self.client._make_streaming_request( "POST", "retrieval/rag", json=payload, version="v3", ) # Wrap each raw SSE event with parse_rag_event return (parse_retrieval_event(event) for event in raw_stream) # Otherwise, request fully and parse response response_dict = await self.client._make_request( "POST", "retrieval/rag", json=payload, version="v3", ) return WrappedRAGResponse(**response_dict) async def agent( self, **kwargs ) -> ( WrappedAgentResponse | Generator[ ThinkingEvent | SearchResultsEvent | MessageEvent | CitationEvent | FinalAnswerEvent | ToolCallEvent | ToolResultEvent | UnknownEvent | None, None, None, ] ): """ Performs a single turn in a conversation with a RAG agent (async). May return a `WrappedAgentResponse` or a streaming generator if `stream=True`. Args: message (Optional[dict | Message]): Current message to process. messages (Optional[list[dict | Message]]): List of messages (deprecated, use message instead). rag_generation_config (Optional[dict | GenerationConfig]): Configuration for RAG generation in 'rag' mode. research_generation_config (Optional[dict | GenerationConfig]): Configuration for generation in 'research' mode. search_mode (Optional[str | SearchMode]): Pre-configured search modes: "basic", "advanced", or "custom". search_settings (Optional[dict | SearchSettings]): The search configuration object. task_prompt (Optional[str]): Optional custom prompt to override default. include_title_if_available (Optional[bool]): Include document titles from search results. conversation_id (Optional[str | uuid.UUID]): ID of the conversation. tools (Optional[list[str]]): List of tools to execute (deprecated). rag_tools (Optional[list[str]]): List of tools to enable for RAG mode. research_tools (Optional[list[str]]): List of tools to enable for Research mode. max_tool_context_length (Optional[int]): Maximum length of returned tool context. use_system_context (Optional[bool]): Use extended prompt for generation. mode (Optional[Literal["rag", "research"]]): Mode to use for generation: 'rag' or 'research'. Returns: Either a WrappedAgentResponse or an AsyncGenerator for streaming. """ # Extract parameters message = kwargs.pop("message", None) messages = kwargs.pop("messages", None) # Deprecated rag_generation_config = kwargs.pop("rag_generation_config", None) research_generation_config = kwargs.pop( "research_generation_config", None ) search_mode = kwargs.pop("search_mode", "custom") search_settings = kwargs.pop("search_settings", None) task_prompt = kwargs.pop("task_prompt", None) include_title_if_available = kwargs.pop( "include_title_if_available", True ) conversation_id = kwargs.pop("conversation_id", None) tools = kwargs.pop("tools", None) # Deprecated rag_tools = kwargs.pop("rag_tools", None) research_tools = kwargs.pop("research_tools", None) max_tool_context_length = kwargs.pop("max_tool_context_length", 32768) use_system_context = kwargs.pop("use_system_context", True) mode = kwargs.pop("mode", "rag") # Handle type conversions if message and isinstance(message, dict): message = Message(**message).model_dump() elif message: message = message.model_dump() if rag_generation_config and not isinstance( rag_generation_config, dict ): rag_generation_config = rag_generation_config.model_dump() if research_generation_config and not isinstance( research_generation_config, dict ): research_generation_config = ( research_generation_config.model_dump() ) if search_mode and not isinstance(search_mode, str): search_mode = search_mode.value if search_settings and not isinstance(search_settings, dict): search_settings = search_settings.model_dump() # Build payload payload = { "message": message, "messages": messages, # Deprecated but included for backward compatibility "rag_generation_config": rag_generation_config, "research_generation_config": research_generation_config, "search_mode": search_mode, "search_settings": search_settings, "task_prompt": task_prompt, "include_title_if_available": include_title_if_available, "conversation_id": ( str(conversation_id) if conversation_id else None ), "tools": tools, # Deprecated but included for backward compatibility "rag_tools": rag_tools, "research_tools": research_tools, "max_tool_context_length": max_tool_context_length, "use_system_context": use_system_context, "mode": mode, **kwargs, # Include any additional parameters } # Remove None values payload = {k: v for k, v in payload.items() if v is not None} # Check if streaming is enabled is_stream = False if rag_generation_config and rag_generation_config.get( "stream", False ): is_stream = True elif ( research_generation_config and mode == "research" and research_generation_config.get("stream", False) ): is_stream = True if is_stream: # Return an async streaming generator raw_stream = self.client._make_streaming_request( "POST", "retrieval/agent", json=payload, version="v3", ) # Parse each event in the stream return (parse_retrieval_event(event) for event in raw_stream) response_dict = await self.client._make_request( "POST", "retrieval/agent", json=payload, version="v3", ) return WrappedAgentResponse(**response_dict)