diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/sdk/asnyc_methods/retrieval.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-4a52a71956a8d46fcb7294ac71734504bb09bcc2.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/sdk/asnyc_methods/retrieval.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/sdk/asnyc_methods/retrieval.py | 394 |
1 files changed, 394 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/sdk/asnyc_methods/retrieval.py b/.venv/lib/python3.12/site-packages/sdk/asnyc_methods/retrieval.py new file mode 100644 index 00000000..d825a91f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/sdk/asnyc_methods/retrieval.py @@ -0,0 +1,394 @@ +from typing import Generator + +from shared.api.models import ( + CitationEvent, + FinalAnswerEvent, + MessageEvent, + SearchResultsEvent, + ThinkingEvent, + ToolCallEvent, + ToolResultEvent, + UnknownEvent, + WrappedAgentResponse, + WrappedRAGResponse, + WrappedSearchResponse, +) + +from ..models import ( + Message, +) +from ..sync_methods.retrieval import parse_retrieval_event + + +class RetrievalSDK: + """ + SDK for interacting with documents in the v3 API (Asynchronous). + """ + + def __init__(self, client): + self.client = client + + async def search(self, **kwargs) -> WrappedSearchResponse: + """ + Conduct a vector and/or graph search (async). + + Args: + query (str): Search query to find relevant documents. + search_mode (Optional[str | SearchMode]): Pre-configured search modes: "basic", "advanced", or "custom". + search_settings (Optional[dict | SearchSettings]): The search configuration object. If search_mode is "custom", + these settings are used as-is. For "basic" or "advanced", these settings + will override the default mode configuration. + + Returns: + WrappedSearchResponse: The search results. + """ + # Extract the required query parameter + query = kwargs.pop("query", None) + if query is None: + raise ValueError("'query' is a required parameter for search") + + # Process common parameters + search_mode = kwargs.pop("search_mode", "custom") + search_settings = kwargs.pop("search_settings", None) + + # Handle type conversions + if search_mode and not isinstance(search_mode, str): + search_mode = search_mode.value + if search_settings and not isinstance(search_settings, dict): + search_settings = search_settings.model_dump() + + # Build payload + payload = { + "query": query, + "search_mode": search_mode, + "search_settings": search_settings, + **kwargs, # Include any additional parameters + } + + # Filter out None values + payload = {k: v for k, v in payload.items() if v is not None} + + response_dict = await self.client._make_request( + "POST", + "retrieval/search", + json=payload, + version="v3", + ) + return WrappedSearchResponse(**response_dict) + + async def completion(self, **kwargs): + """ + Get a completion from the model (async). + + Args: + messages (list[dict | Message]): List of messages to generate completion for. Each message + should have a 'role' and 'content'. + generation_config (Optional[dict | GenerationConfig]): Configuration for text generation. + + Returns: + The completion response. + """ + # Extract required parameters + messages = kwargs.pop("messages", None) + if messages is None: + raise ValueError( + "'messages' is a required parameter for completion" + ) + + # Process optional parameters + generation_config = kwargs.pop("generation_config", None) + + # Handle type conversions + cast_messages = [ + Message(**msg) if isinstance(msg, dict) else msg + for msg in messages + ] + if generation_config and not isinstance(generation_config, dict): + generation_config = generation_config.model_dump() + + # Build payload + payload = { + "messages": [msg.model_dump() for msg in cast_messages], + "generation_config": generation_config, + **kwargs, # Include any additional parameters + } + + # Filter out None values + payload = {k: v for k, v in payload.items() if v is not None} + + return await self.client._make_request( + "POST", + "retrieval/completion", + json=payload, + version="v3", + ) + + async def embedding(self, **kwargs): + """ + Generate an embedding for given text (async). + + Args: + text (str): Text to generate embeddings for. + + Returns: + The embedding vector. + """ + # Extract required parameters + text = kwargs.pop("text", None) + if text is None: + raise ValueError("'text' is a required parameter for embedding") + + # Build payload + payload = {"text": text, **kwargs} # Include any additional parameters + + return await self.client._make_request( + "POST", + "retrieval/embedding", + data=payload, + version="v3", + ) + + async def rag( + self, **kwargs + ) -> ( + WrappedRAGResponse + | Generator[ + ThinkingEvent + | SearchResultsEvent + | MessageEvent + | CitationEvent + | FinalAnswerEvent + | ToolCallEvent + | ToolResultEvent + | UnknownEvent + | None, + None, + None, + ] + ): + """ + Conducts a Retrieval Augmented Generation (RAG) search (async). + May return a `WrappedRAGResponse` or a streaming generator if `stream=True`. + + Args: + query (str): The search query. + rag_generation_config (Optional[dict | GenerationConfig]): Configuration for RAG generation. + search_mode (Optional[str | SearchMode]): Pre-configured search modes: "basic", "advanced", or "custom". + search_settings (Optional[dict | SearchSettings]): The search configuration object. + task_prompt (Optional[str]): Optional custom prompt to override default. + include_title_if_available (Optional[bool]): Include document titles in responses when available. + include_web_search (Optional[bool]): Include web search results provided to the LLM. + + Returns: + Either a WrappedRAGResponse or an AsyncGenerator for streaming. + """ + # Extract required parameters + query = kwargs.pop("query", None) + if query is None: + raise ValueError("'query' is a required parameter for rag") + + # Process optional parameters + rag_generation_config = kwargs.pop("rag_generation_config", None) + search_mode = kwargs.pop("search_mode", "custom") + search_settings = kwargs.pop("search_settings", None) + task_prompt = kwargs.pop("task_prompt", None) + include_title_if_available = kwargs.pop( + "include_title_if_available", False + ) + include_web_search = kwargs.pop("include_web_search", False) + + # Handle type conversions + if rag_generation_config and not isinstance( + rag_generation_config, dict + ): + rag_generation_config = rag_generation_config.model_dump() + if search_mode and not isinstance(search_mode, str): + search_mode = search_mode.value + if search_settings and not isinstance(search_settings, dict): + search_settings = search_settings.model_dump() + + # Build payload + payload = { + "query": query, + "rag_generation_config": rag_generation_config, + "search_mode": search_mode, + "search_settings": search_settings, + "task_prompt": task_prompt, + "include_title_if_available": include_title_if_available, + "include_web_search": include_web_search, + **kwargs, # Include any additional parameters + } + + # Filter out None values + payload = {k: v for k, v in payload.items() if v is not None} + + # Check if streaming is enabled + is_stream = False + if rag_generation_config and rag_generation_config.get( + "stream", False + ): + is_stream = True + + if is_stream: + # Return an async streaming generator + raw_stream = self.client._make_streaming_request( + "POST", + "retrieval/rag", + json=payload, + version="v3", + ) + # Wrap each raw SSE event with parse_rag_event + return (parse_retrieval_event(event) for event in raw_stream) + + # Otherwise, request fully and parse response + response_dict = await self.client._make_request( + "POST", + "retrieval/rag", + json=payload, + version="v3", + ) + return WrappedRAGResponse(**response_dict) + + async def agent( + self, **kwargs + ) -> ( + WrappedAgentResponse + | Generator[ + ThinkingEvent + | SearchResultsEvent + | MessageEvent + | CitationEvent + | FinalAnswerEvent + | ToolCallEvent + | ToolResultEvent + | UnknownEvent + | None, + None, + None, + ] + ): + """ + Performs a single turn in a conversation with a RAG agent (async). + May return a `WrappedAgentResponse` or a streaming generator if `stream=True`. + + Args: + message (Optional[dict | Message]): Current message to process. + messages (Optional[list[dict | Message]]): List of messages (deprecated, use message instead). + rag_generation_config (Optional[dict | GenerationConfig]): Configuration for RAG generation in 'rag' mode. + research_generation_config (Optional[dict | GenerationConfig]): Configuration for generation in 'research' mode. + search_mode (Optional[str | SearchMode]): Pre-configured search modes: "basic", "advanced", or "custom". + search_settings (Optional[dict | SearchSettings]): The search configuration object. + task_prompt (Optional[str]): Optional custom prompt to override default. + include_title_if_available (Optional[bool]): Include document titles from search results. + conversation_id (Optional[str | uuid.UUID]): ID of the conversation. + tools (Optional[list[str]]): List of tools to execute (deprecated). + rag_tools (Optional[list[str]]): List of tools to enable for RAG mode. + research_tools (Optional[list[str]]): List of tools to enable for Research mode. + max_tool_context_length (Optional[int]): Maximum length of returned tool context. + use_system_context (Optional[bool]): Use extended prompt for generation. + mode (Optional[Literal["rag", "research"]]): Mode to use for generation: 'rag' or 'research'. + + Returns: + Either a WrappedAgentResponse or an AsyncGenerator for streaming. + """ + # Extract parameters + message = kwargs.pop("message", None) + messages = kwargs.pop("messages", None) # Deprecated + rag_generation_config = kwargs.pop("rag_generation_config", None) + research_generation_config = kwargs.pop( + "research_generation_config", None + ) + search_mode = kwargs.pop("search_mode", "custom") + search_settings = kwargs.pop("search_settings", None) + task_prompt = kwargs.pop("task_prompt", None) + include_title_if_available = kwargs.pop( + "include_title_if_available", True + ) + conversation_id = kwargs.pop("conversation_id", None) + tools = kwargs.pop("tools", None) # Deprecated + rag_tools = kwargs.pop("rag_tools", None) + research_tools = kwargs.pop("research_tools", None) + max_tool_context_length = kwargs.pop("max_tool_context_length", 32768) + use_system_context = kwargs.pop("use_system_context", True) + mode = kwargs.pop("mode", "rag") + + # Handle type conversions + if message and isinstance(message, dict): + message = Message(**message).model_dump() + elif message: + message = message.model_dump() + + if rag_generation_config and not isinstance( + rag_generation_config, dict + ): + rag_generation_config = rag_generation_config.model_dump() + + if research_generation_config and not isinstance( + research_generation_config, dict + ): + research_generation_config = ( + research_generation_config.model_dump() + ) + + if search_mode and not isinstance(search_mode, str): + search_mode = search_mode.value + + if search_settings and not isinstance(search_settings, dict): + search_settings = search_settings.model_dump() + + # Build payload + payload = { + "message": message, + "messages": messages, # Deprecated but included for backward compatibility + "rag_generation_config": rag_generation_config, + "research_generation_config": research_generation_config, + "search_mode": search_mode, + "search_settings": search_settings, + "task_prompt": task_prompt, + "include_title_if_available": include_title_if_available, + "conversation_id": ( + str(conversation_id) if conversation_id else None + ), + "tools": tools, # Deprecated but included for backward compatibility + "rag_tools": rag_tools, + "research_tools": research_tools, + "max_tool_context_length": max_tool_context_length, + "use_system_context": use_system_context, + "mode": mode, + **kwargs, # Include any additional parameters + } + + # Remove None values + payload = {k: v for k, v in payload.items() if v is not None} + + # Check if streaming is enabled + is_stream = False + if rag_generation_config and rag_generation_config.get( + "stream", False + ): + is_stream = True + elif ( + research_generation_config + and mode == "research" + and research_generation_config.get("stream", False) + ): + is_stream = True + + if is_stream: + # Return an async streaming generator + raw_stream = self.client._make_streaming_request( + "POST", + "retrieval/agent", + json=payload, + version="v3", + ) + # Parse each event in the stream + return (parse_retrieval_event(event) for event in raw_stream) + + response_dict = await self.client._make_request( + "POST", + "retrieval/agent", + json=payload, + version="v3", + ) + return WrappedAgentResponse(**response_dict) |