aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/sdk/asnyc_methods/retrieval.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/sdk/asnyc_methods/retrieval.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-4a52a71956a8d46fcb7294ac71734504bb09bcc2.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to '.venv/lib/python3.12/site-packages/sdk/asnyc_methods/retrieval.py')
-rw-r--r--.venv/lib/python3.12/site-packages/sdk/asnyc_methods/retrieval.py394
1 files changed, 394 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/sdk/asnyc_methods/retrieval.py b/.venv/lib/python3.12/site-packages/sdk/asnyc_methods/retrieval.py
new file mode 100644
index 00000000..d825a91f
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/sdk/asnyc_methods/retrieval.py
@@ -0,0 +1,394 @@
+from typing import Generator
+
+from shared.api.models import (
+ CitationEvent,
+ FinalAnswerEvent,
+ MessageEvent,
+ SearchResultsEvent,
+ ThinkingEvent,
+ ToolCallEvent,
+ ToolResultEvent,
+ UnknownEvent,
+ WrappedAgentResponse,
+ WrappedRAGResponse,
+ WrappedSearchResponse,
+)
+
+from ..models import (
+ Message,
+)
+from ..sync_methods.retrieval import parse_retrieval_event
+
+
+class RetrievalSDK:
+ """
+ SDK for interacting with documents in the v3 API (Asynchronous).
+ """
+
+ def __init__(self, client):
+ self.client = client
+
+ async def search(self, **kwargs) -> WrappedSearchResponse:
+ """
+ Conduct a vector and/or graph search (async).
+
+ Args:
+ query (str): Search query to find relevant documents.
+ search_mode (Optional[str | SearchMode]): Pre-configured search modes: "basic", "advanced", or "custom".
+ search_settings (Optional[dict | SearchSettings]): The search configuration object. If search_mode is "custom",
+ these settings are used as-is. For "basic" or "advanced", these settings
+ will override the default mode configuration.
+
+ Returns:
+ WrappedSearchResponse: The search results.
+ """
+ # Extract the required query parameter
+ query = kwargs.pop("query", None)
+ if query is None:
+ raise ValueError("'query' is a required parameter for search")
+
+ # Process common parameters
+ search_mode = kwargs.pop("search_mode", "custom")
+ search_settings = kwargs.pop("search_settings", None)
+
+ # Handle type conversions
+ if search_mode and not isinstance(search_mode, str):
+ search_mode = search_mode.value
+ if search_settings and not isinstance(search_settings, dict):
+ search_settings = search_settings.model_dump()
+
+ # Build payload
+ payload = {
+ "query": query,
+ "search_mode": search_mode,
+ "search_settings": search_settings,
+ **kwargs, # Include any additional parameters
+ }
+
+ # Filter out None values
+ payload = {k: v for k, v in payload.items() if v is not None}
+
+ response_dict = await self.client._make_request(
+ "POST",
+ "retrieval/search",
+ json=payload,
+ version="v3",
+ )
+ return WrappedSearchResponse(**response_dict)
+
+ async def completion(self, **kwargs):
+ """
+ Get a completion from the model (async).
+
+ Args:
+ messages (list[dict | Message]): List of messages to generate completion for. Each message
+ should have a 'role' and 'content'.
+ generation_config (Optional[dict | GenerationConfig]): Configuration for text generation.
+
+ Returns:
+ The completion response.
+ """
+ # Extract required parameters
+ messages = kwargs.pop("messages", None)
+ if messages is None:
+ raise ValueError(
+ "'messages' is a required parameter for completion"
+ )
+
+ # Process optional parameters
+ generation_config = kwargs.pop("generation_config", None)
+
+ # Handle type conversions
+ cast_messages = [
+ Message(**msg) if isinstance(msg, dict) else msg
+ for msg in messages
+ ]
+ if generation_config and not isinstance(generation_config, dict):
+ generation_config = generation_config.model_dump()
+
+ # Build payload
+ payload = {
+ "messages": [msg.model_dump() for msg in cast_messages],
+ "generation_config": generation_config,
+ **kwargs, # Include any additional parameters
+ }
+
+ # Filter out None values
+ payload = {k: v for k, v in payload.items() if v is not None}
+
+ return await self.client._make_request(
+ "POST",
+ "retrieval/completion",
+ json=payload,
+ version="v3",
+ )
+
+ async def embedding(self, **kwargs):
+ """
+ Generate an embedding for given text (async).
+
+ Args:
+ text (str): Text to generate embeddings for.
+
+ Returns:
+ The embedding vector.
+ """
+ # Extract required parameters
+ text = kwargs.pop("text", None)
+ if text is None:
+ raise ValueError("'text' is a required parameter for embedding")
+
+ # Build payload
+ payload = {"text": text, **kwargs} # Include any additional parameters
+
+ return await self.client._make_request(
+ "POST",
+ "retrieval/embedding",
+ data=payload,
+ version="v3",
+ )
+
+ async def rag(
+ self, **kwargs
+ ) -> (
+ WrappedRAGResponse
+ | Generator[
+ ThinkingEvent
+ | SearchResultsEvent
+ | MessageEvent
+ | CitationEvent
+ | FinalAnswerEvent
+ | ToolCallEvent
+ | ToolResultEvent
+ | UnknownEvent
+ | None,
+ None,
+ None,
+ ]
+ ):
+ """
+ Conducts a Retrieval Augmented Generation (RAG) search (async).
+ May return a `WrappedRAGResponse` or a streaming generator if `stream=True`.
+
+ Args:
+ query (str): The search query.
+ rag_generation_config (Optional[dict | GenerationConfig]): Configuration for RAG generation.
+ search_mode (Optional[str | SearchMode]): Pre-configured search modes: "basic", "advanced", or "custom".
+ search_settings (Optional[dict | SearchSettings]): The search configuration object.
+ task_prompt (Optional[str]): Optional custom prompt to override default.
+ include_title_if_available (Optional[bool]): Include document titles in responses when available.
+ include_web_search (Optional[bool]): Include web search results provided to the LLM.
+
+ Returns:
+ Either a WrappedRAGResponse or an AsyncGenerator for streaming.
+ """
+ # Extract required parameters
+ query = kwargs.pop("query", None)
+ if query is None:
+ raise ValueError("'query' is a required parameter for rag")
+
+ # Process optional parameters
+ rag_generation_config = kwargs.pop("rag_generation_config", None)
+ search_mode = kwargs.pop("search_mode", "custom")
+ search_settings = kwargs.pop("search_settings", None)
+ task_prompt = kwargs.pop("task_prompt", None)
+ include_title_if_available = kwargs.pop(
+ "include_title_if_available", False
+ )
+ include_web_search = kwargs.pop("include_web_search", False)
+
+ # Handle type conversions
+ if rag_generation_config and not isinstance(
+ rag_generation_config, dict
+ ):
+ rag_generation_config = rag_generation_config.model_dump()
+ if search_mode and not isinstance(search_mode, str):
+ search_mode = search_mode.value
+ if search_settings and not isinstance(search_settings, dict):
+ search_settings = search_settings.model_dump()
+
+ # Build payload
+ payload = {
+ "query": query,
+ "rag_generation_config": rag_generation_config,
+ "search_mode": search_mode,
+ "search_settings": search_settings,
+ "task_prompt": task_prompt,
+ "include_title_if_available": include_title_if_available,
+ "include_web_search": include_web_search,
+ **kwargs, # Include any additional parameters
+ }
+
+ # Filter out None values
+ payload = {k: v for k, v in payload.items() if v is not None}
+
+ # Check if streaming is enabled
+ is_stream = False
+ if rag_generation_config and rag_generation_config.get(
+ "stream", False
+ ):
+ is_stream = True
+
+ if is_stream:
+ # Return an async streaming generator
+ raw_stream = self.client._make_streaming_request(
+ "POST",
+ "retrieval/rag",
+ json=payload,
+ version="v3",
+ )
+ # Wrap each raw SSE event with parse_rag_event
+ return (parse_retrieval_event(event) for event in raw_stream)
+
+ # Otherwise, request fully and parse response
+ response_dict = await self.client._make_request(
+ "POST",
+ "retrieval/rag",
+ json=payload,
+ version="v3",
+ )
+ return WrappedRAGResponse(**response_dict)
+
+ async def agent(
+ self, **kwargs
+ ) -> (
+ WrappedAgentResponse
+ | Generator[
+ ThinkingEvent
+ | SearchResultsEvent
+ | MessageEvent
+ | CitationEvent
+ | FinalAnswerEvent
+ | ToolCallEvent
+ | ToolResultEvent
+ | UnknownEvent
+ | None,
+ None,
+ None,
+ ]
+ ):
+ """
+ Performs a single turn in a conversation with a RAG agent (async).
+ May return a `WrappedAgentResponse` or a streaming generator if `stream=True`.
+
+ Args:
+ message (Optional[dict | Message]): Current message to process.
+ messages (Optional[list[dict | Message]]): List of messages (deprecated, use message instead).
+ rag_generation_config (Optional[dict | GenerationConfig]): Configuration for RAG generation in 'rag' mode.
+ research_generation_config (Optional[dict | GenerationConfig]): Configuration for generation in 'research' mode.
+ search_mode (Optional[str | SearchMode]): Pre-configured search modes: "basic", "advanced", or "custom".
+ search_settings (Optional[dict | SearchSettings]): The search configuration object.
+ task_prompt (Optional[str]): Optional custom prompt to override default.
+ include_title_if_available (Optional[bool]): Include document titles from search results.
+ conversation_id (Optional[str | uuid.UUID]): ID of the conversation.
+ tools (Optional[list[str]]): List of tools to execute (deprecated).
+ rag_tools (Optional[list[str]]): List of tools to enable for RAG mode.
+ research_tools (Optional[list[str]]): List of tools to enable for Research mode.
+ max_tool_context_length (Optional[int]): Maximum length of returned tool context.
+ use_system_context (Optional[bool]): Use extended prompt for generation.
+ mode (Optional[Literal["rag", "research"]]): Mode to use for generation: 'rag' or 'research'.
+
+ Returns:
+ Either a WrappedAgentResponse or an AsyncGenerator for streaming.
+ """
+ # Extract parameters
+ message = kwargs.pop("message", None)
+ messages = kwargs.pop("messages", None) # Deprecated
+ rag_generation_config = kwargs.pop("rag_generation_config", None)
+ research_generation_config = kwargs.pop(
+ "research_generation_config", None
+ )
+ search_mode = kwargs.pop("search_mode", "custom")
+ search_settings = kwargs.pop("search_settings", None)
+ task_prompt = kwargs.pop("task_prompt", None)
+ include_title_if_available = kwargs.pop(
+ "include_title_if_available", True
+ )
+ conversation_id = kwargs.pop("conversation_id", None)
+ tools = kwargs.pop("tools", None) # Deprecated
+ rag_tools = kwargs.pop("rag_tools", None)
+ research_tools = kwargs.pop("research_tools", None)
+ max_tool_context_length = kwargs.pop("max_tool_context_length", 32768)
+ use_system_context = kwargs.pop("use_system_context", True)
+ mode = kwargs.pop("mode", "rag")
+
+ # Handle type conversions
+ if message and isinstance(message, dict):
+ message = Message(**message).model_dump()
+ elif message:
+ message = message.model_dump()
+
+ if rag_generation_config and not isinstance(
+ rag_generation_config, dict
+ ):
+ rag_generation_config = rag_generation_config.model_dump()
+
+ if research_generation_config and not isinstance(
+ research_generation_config, dict
+ ):
+ research_generation_config = (
+ research_generation_config.model_dump()
+ )
+
+ if search_mode and not isinstance(search_mode, str):
+ search_mode = search_mode.value
+
+ if search_settings and not isinstance(search_settings, dict):
+ search_settings = search_settings.model_dump()
+
+ # Build payload
+ payload = {
+ "message": message,
+ "messages": messages, # Deprecated but included for backward compatibility
+ "rag_generation_config": rag_generation_config,
+ "research_generation_config": research_generation_config,
+ "search_mode": search_mode,
+ "search_settings": search_settings,
+ "task_prompt": task_prompt,
+ "include_title_if_available": include_title_if_available,
+ "conversation_id": (
+ str(conversation_id) if conversation_id else None
+ ),
+ "tools": tools, # Deprecated but included for backward compatibility
+ "rag_tools": rag_tools,
+ "research_tools": research_tools,
+ "max_tool_context_length": max_tool_context_length,
+ "use_system_context": use_system_context,
+ "mode": mode,
+ **kwargs, # Include any additional parameters
+ }
+
+ # Remove None values
+ payload = {k: v for k, v in payload.items() if v is not None}
+
+ # Check if streaming is enabled
+ is_stream = False
+ if rag_generation_config and rag_generation_config.get(
+ "stream", False
+ ):
+ is_stream = True
+ elif (
+ research_generation_config
+ and mode == "research"
+ and research_generation_config.get("stream", False)
+ ):
+ is_stream = True
+
+ if is_stream:
+ # Return an async streaming generator
+ raw_stream = self.client._make_streaming_request(
+ "POST",
+ "retrieval/agent",
+ json=payload,
+ version="v3",
+ )
+ # Parse each event in the stream
+ return (parse_retrieval_event(event) for event in raw_stream)
+
+ response_dict = await self.client._make_request(
+ "POST",
+ "retrieval/agent",
+ json=payload,
+ version="v3",
+ )
+ return WrappedAgentResponse(**response_dict)