from typing import Generator
from shared.api.models import (
CitationEvent,
FinalAnswerEvent,
MessageEvent,
SearchResultsEvent,
ThinkingEvent,
ToolCallEvent,
ToolResultEvent,
UnknownEvent,
WrappedAgentResponse,
WrappedRAGResponse,
WrappedSearchResponse,
)
from ..models import (
Message,
)
from ..sync_methods.retrieval import parse_retrieval_event
class RetrievalSDK:
"""
SDK for interacting with documents in the v3 API (Asynchronous).
"""
def __init__(self, client):
self.client = client
async def search(self, **kwargs) -> WrappedSearchResponse:
"""
Conduct a vector and/or graph search (async).
Args:
query (str): Search query to find relevant documents.
search_mode (Optional[str | SearchMode]): Pre-configured search modes: "basic", "advanced", or "custom".
search_settings (Optional[dict | SearchSettings]): The search configuration object. If search_mode is "custom",
these settings are used as-is. For "basic" or "advanced", these settings
will override the default mode configuration.
Returns:
WrappedSearchResponse: The search results.
"""
# Extract the required query parameter
query = kwargs.pop("query", None)
if query is None:
raise ValueError("'query' is a required parameter for search")
# Process common parameters
search_mode = kwargs.pop("search_mode", "custom")
search_settings = kwargs.pop("search_settings", None)
# Handle type conversions
if search_mode and not isinstance(search_mode, str):
search_mode = search_mode.value
if search_settings and not isinstance(search_settings, dict):
search_settings = search_settings.model_dump()
# Build payload
payload = {
"query": query,
"search_mode": search_mode,
"search_settings": search_settings,
**kwargs, # Include any additional parameters
}
# Filter out None values
payload = {k: v for k, v in payload.items() if v is not None}
response_dict = await self.client._make_request(
"POST",
"retrieval/search",
json=payload,
version="v3",
)
return WrappedSearchResponse(**response_dict)
async def completion(self, **kwargs):
"""
Get a completion from the model (async).
Args:
messages (list[dict | Message]): List of messages to generate completion for. Each message
should have a 'role' and 'content'.
generation_config (Optional[dict | GenerationConfig]): Configuration for text generation.
Returns:
The completion response.
"""
# Extract required parameters
messages = kwargs.pop("messages", None)
if messages is None:
raise ValueError(
"'messages' is a required parameter for completion"
)
# Process optional parameters
generation_config = kwargs.pop("generation_config", None)
# Handle type conversions
cast_messages = [
Message(**msg) if isinstance(msg, dict) else msg
for msg in messages
]
if generation_config and not isinstance(generation_config, dict):
generation_config = generation_config.model_dump()
# Build payload
payload = {
"messages": [msg.model_dump() for msg in cast_messages],
"generation_config": generation_config,
**kwargs, # Include any additional parameters
}
# Filter out None values
payload = {k: v for k, v in payload.items() if v is not None}
return await self.client._make_request(
"POST",
"retrieval/completion",
json=payload,
version="v3",
)
async def embedding(self, **kwargs):
"""
Generate an embedding for given text (async).
Args:
text (str): Text to generate embeddings for.
Returns:
The embedding vector.
"""
# Extract required parameters
text = kwargs.pop("text", None)
if text is None:
raise ValueError("'text' is a required parameter for embedding")
# Build payload
payload = {"text": text, **kwargs} # Include any additional parameters
return await self.client._make_request(
"POST",
"retrieval/embedding",
data=payload,
version="v3",
)
async def rag(
self, **kwargs
) -> (
WrappedRAGResponse
| Generator[
ThinkingEvent
| SearchResultsEvent
| MessageEvent
| CitationEvent
| FinalAnswerEvent
| ToolCallEvent
| ToolResultEvent
| UnknownEvent
| None,
None,
None,
]
):
"""
Conducts a Retrieval Augmented Generation (RAG) search (async).
May return a `WrappedRAGResponse` or a streaming generator if `stream=True`.
Args:
query (str): The search query.
rag_generation_config (Optional[dict | GenerationConfig]): Configuration for RAG generation.
search_mode (Optional[str | SearchMode]): Pre-configured search modes: "basic", "advanced", or "custom".
search_settings (Optional[dict | SearchSettings]): The search configuration object.
task_prompt (Optional[str]): Optional custom prompt to override default.
include_title_if_available (Optional[bool]): Include document titles in responses when available.
include_web_search (Optional[bool]): Include web search results provided to the LLM.
Returns:
Either a WrappedRAGResponse or an AsyncGenerator for streaming.
"""
# Extract required parameters
query = kwargs.pop("query", None)
if query is None:
raise ValueError("'query' is a required parameter for rag")
# Process optional parameters
rag_generation_config = kwargs.pop("rag_generation_config", None)
search_mode = kwargs.pop("search_mode", "custom")
search_settings = kwargs.pop("search_settings", None)
task_prompt = kwargs.pop("task_prompt", None)
include_title_if_available = kwargs.pop(
"include_title_if_available", False
)
include_web_search = kwargs.pop("include_web_search", False)
# Handle type conversions
if rag_generation_config and not isinstance(
rag_generation_config, dict
):
rag_generation_config = rag_generation_config.model_dump()
if search_mode and not isinstance(search_mode, str):
search_mode = search_mode.value
if search_settings and not isinstance(search_settings, dict):
search_settings = search_settings.model_dump()
# Build payload
payload = {
"query": query,
"rag_generation_config": rag_generation_config,
"search_mode": search_mode,
"search_settings": search_settings,
"task_prompt": task_prompt,
"include_title_if_available": include_title_if_available,
"include_web_search": include_web_search,
**kwargs, # Include any additional parameters
}
# Filter out None values
payload = {k: v for k, v in payload.items() if v is not None}
# Check if streaming is enabled
is_stream = False
if rag_generation_config and rag_generation_config.get(
"stream", False
):
is_stream = True
if is_stream:
# Return an async streaming generator
raw_stream = self.client._make_streaming_request(
"POST",
"retrieval/rag",
json=payload,
version="v3",
)
# Wrap each raw SSE event with parse_rag_event
return (parse_retrieval_event(event) for event in raw_stream)
# Otherwise, request fully and parse response
response_dict = await self.client._make_request(
"POST",
"retrieval/rag",
json=payload,
version="v3",
)
return WrappedRAGResponse(**response_dict)
async def agent(
self, **kwargs
) -> (
WrappedAgentResponse
| Generator[
ThinkingEvent
| SearchResultsEvent
| MessageEvent
| CitationEvent
| FinalAnswerEvent
| ToolCallEvent
| ToolResultEvent
| UnknownEvent
| None,
None,
None,
]
):
"""
Performs a single turn in a conversation with a RAG agent (async).
May return a `WrappedAgentResponse` or a streaming generator if `stream=True`.
Args:
message (Optional[dict | Message]): Current message to process.
messages (Optional[list[dict | Message]]): List of messages (deprecated, use message instead).
rag_generation_config (Optional[dict | GenerationConfig]): Configuration for RAG generation in 'rag' mode.
research_generation_config (Optional[dict | GenerationConfig]): Configuration for generation in 'research' mode.
search_mode (Optional[str | SearchMode]): Pre-configured search modes: "basic", "advanced", or "custom".
search_settings (Optional[dict | SearchSettings]): The search configuration object.
task_prompt (Optional[str]): Optional custom prompt to override default.
include_title_if_available (Optional[bool]): Include document titles from search results.
conversation_id (Optional[str | uuid.UUID]): ID of the conversation.
tools (Optional[list[str]]): List of tools to execute (deprecated).
rag_tools (Optional[list[str]]): List of tools to enable for RAG mode.
research_tools (Optional[list[str]]): List of tools to enable for Research mode.
max_tool_context_length (Optional[int]): Maximum length of returned tool context.
use_system_context (Optional[bool]): Use extended prompt for generation.
mode (Optional[Literal["rag", "research"]]): Mode to use for generation: 'rag' or 'research'.
Returns:
Either a WrappedAgentResponse or an AsyncGenerator for streaming.
"""
# Extract parameters
message = kwargs.pop("message", None)
messages = kwargs.pop("messages", None) # Deprecated
rag_generation_config = kwargs.pop("rag_generation_config", None)
research_generation_config = kwargs.pop(
"research_generation_config", None
)
search_mode = kwargs.pop("search_mode", "custom")
search_settings = kwargs.pop("search_settings", None)
task_prompt = kwargs.pop("task_prompt", None)
include_title_if_available = kwargs.pop(
"include_title_if_available", True
)
conversation_id = kwargs.pop("conversation_id", None)
tools = kwargs.pop("tools", None) # Deprecated
rag_tools = kwargs.pop("rag_tools", None)
research_tools = kwargs.pop("research_tools", None)
max_tool_context_length = kwargs.pop("max_tool_context_length", 32768)
use_system_context = kwargs.pop("use_system_context", True)
mode = kwargs.pop("mode", "rag")
# Handle type conversions
if message and isinstance(message, dict):
message = Message(**message).model_dump()
elif message:
message = message.model_dump()
if rag_generation_config and not isinstance(
rag_generation_config, dict
):
rag_generation_config = rag_generation_config.model_dump()
if research_generation_config and not isinstance(
research_generation_config, dict
):
research_generation_config = (
research_generation_config.model_dump()
)
if search_mode and not isinstance(search_mode, str):
search_mode = search_mode.value
if search_settings and not isinstance(search_settings, dict):
search_settings = search_settings.model_dump()
# Build payload
payload = {
"message": message,
"messages": messages, # Deprecated but included for backward compatibility
"rag_generation_config": rag_generation_config,
"research_generation_config": research_generation_config,
"search_mode": search_mode,
"search_settings": search_settings,
"task_prompt": task_prompt,
"include_title_if_available": include_title_if_available,
"conversation_id": (
str(conversation_id) if conversation_id else None
),
"tools": tools, # Deprecated but included for backward compatibility
"rag_tools": rag_tools,
"research_tools": research_tools,
"max_tool_context_length": max_tool_context_length,
"use_system_context": use_system_context,
"mode": mode,
**kwargs, # Include any additional parameters
}
# Remove None values
payload = {k: v for k, v in payload.items() if v is not None}
# Check if streaming is enabled
is_stream = False
if rag_generation_config and rag_generation_config.get(
"stream", False
):
is_stream = True
elif (
research_generation_config
and mode == "research"
and research_generation_config.get("stream", False)
):
is_stream = True
if is_stream:
# Return an async streaming generator
raw_stream = self.client._make_streaming_request(
"POST",
"retrieval/agent",
json=payload,
version="v3",
)
# Parse each event in the stream
return (parse_retrieval_event(event) for event in raw_stream)
response_dict = await self.client._make_request(
"POST",
"retrieval/agent",
json=payload,
version="v3",
)
return WrappedAgentResponse(**response_dict)