import json
import uuid
from typing import Any, Generator, Optional
from shared.api.models import (
WrappedAgentResponse,
WrappedEmbeddingResponse,
WrappedLLMChatCompletion,
WrappedRAGResponse,
WrappedSearchResponse,
)
from ..models import (
AgentEvent,
CitationData,
CitationEvent,
Delta,
DeltaPayload,
FinalAnswerData,
FinalAnswerEvent,
GenerationConfig,
Message,
MessageData,
MessageDelta,
MessageEvent,
SearchMode,
SearchResultsData,
SearchResultsEvent,
SearchSettings,
ThinkingData,
ThinkingEvent,
ToolCallData,
ToolCallEvent,
ToolResultData,
ToolResultEvent,
UnknownEvent,
)
def parse_retrieval_event(raw: dict) -> Optional[AgentEvent]:
"""
Convert a raw SSE event dict into a typed Pydantic model.
Example raw dict:
{
"event": "message",
"data": "{\"id\": \"msg_partial\", \"object\": \"agent.message.delta\", \"delta\": {...}}"
}
"""
event_type = raw.get("event", "unknown")
# If event_type == "done", we usually return None to signal the SSE stream is finished.
if event_type == "done":
return None
# The SSE "data" is JSON-encoded, so parse it
data_str = raw.get("data", "")
try:
data_obj = json.loads(data_str)
except json.JSONDecodeError as e:
# You can decide whether to raise or return UnknownEvent
raise ValueError(f"Could not parse JSON in SSE event data: {e}") from e
# Now branch on event_type to build the right Pydantic model
if event_type == "search_results":
return SearchResultsEvent(
event=event_type,
data=SearchResultsData(**data_obj),
)
elif event_type == "message":
# Parse nested delta structure manually before creating MessageData
if "delta" in data_obj and isinstance(data_obj["delta"], dict):
delta_dict = data_obj["delta"]
# Convert content items to MessageDelta objects
if "content" in delta_dict and isinstance(
delta_dict["content"], list
):
parsed_content = []
for item in delta_dict["content"]:
if isinstance(item, dict):
# Parse payload to DeltaPayload
if "payload" in item and isinstance(
item["payload"], dict
):
payload_dict = item["payload"]
item["payload"] = DeltaPayload(**payload_dict)
parsed_content.append(MessageDelta(**item))
# Replace with parsed content
delta_dict["content"] = parsed_content
# Create properly typed Delta object
data_obj["delta"] = Delta(**delta_dict)
return MessageEvent(
event=event_type,
data=MessageData(**data_obj),
)
elif event_type == "citation":
return CitationEvent(event=event_type, data=CitationData(**data_obj))
elif event_type == "tool_call":
return ToolCallEvent(event=event_type, data=ToolCallData(**data_obj))
elif event_type == "tool_result":
return ToolResultEvent(
event=event_type, data=ToolResultData(**data_obj)
)
elif event_type == "thinking":
# Parse nested delta structure manually before creating ThinkingData
if "delta" in data_obj and isinstance(data_obj["delta"], dict):
delta_dict = data_obj["delta"]
# Convert content items to MessageDelta objects
if "content" in delta_dict and isinstance(
delta_dict["content"], list
):
parsed_content = []
for item in delta_dict["content"]:
if isinstance(item, dict):
# Parse payload to DeltaPayload
if "payload" in item and isinstance(
item["payload"], dict
):
payload_dict = item["payload"]
item["payload"] = DeltaPayload(**payload_dict)
parsed_content.append(MessageDelta(**item))
# Replace with parsed content
delta_dict["content"] = parsed_content
# Create properly typed Delta object
data_obj["delta"] = Delta(**delta_dict)
return ThinkingEvent(
event=event_type,
data=ThinkingData(**data_obj),
)
elif event_type == "final_answer":
return FinalAnswerEvent(
event=event_type, data=FinalAnswerData(**data_obj)
)
else:
# Fallback if it doesn't match any known event
return UnknownEvent(
event=event_type,
data=data_obj,
)
def search_arg_parser(
query: str,
search_mode: Optional[str | SearchMode] = "custom",
search_settings: Optional[dict | SearchSettings] = None,
) -> dict:
if search_mode and not isinstance(search_mode, str):
search_mode = search_mode.value
if search_settings and not isinstance(search_settings, dict):
search_settings = search_settings.model_dump()
data: dict[str, Any] = {
"query": query,
"search_settings": search_settings,
}
if search_mode:
data["search_mode"] = search_mode
return data
def completion_arg_parser(
messages: list[dict | Message],
generation_config: Optional[dict | GenerationConfig] = None,
) -> dict:
# FIXME: Needs a proper return type
cast_messages: list[Message] = [
Message(**msg) if isinstance(msg, dict) else msg for msg in messages
]
if generation_config and not isinstance(generation_config, dict):
generation_config = generation_config.model_dump()
data: dict[str, Any] = {
"messages": [msg.model_dump() for msg in cast_messages],
"generation_config": generation_config,
}
return data
def embedding_arg_parser(
text: str,
) -> dict:
data: dict[str, Any] = {
"text": text,
}
return data
def rag_arg_parser(
query: str,
rag_generation_config: Optional[dict | GenerationConfig] = None,
search_mode: Optional[str | SearchMode] = "custom",
search_settings: Optional[dict | SearchSettings] = None,
task_prompt: Optional[str] = None,
include_title_if_available: Optional[bool] = False,
include_web_search: Optional[bool] = False,
) -> dict:
if rag_generation_config and not isinstance(rag_generation_config, dict):
rag_generation_config = rag_generation_config.model_dump()
if search_settings and not isinstance(search_settings, dict):
search_settings = search_settings.model_dump()
data: dict[str, Any] = {
"query": query,
"rag_generation_config": rag_generation_config,
"search_settings": search_settings,
"task_prompt": task_prompt,
"include_title_if_available": include_title_if_available,
"include_web_search": include_web_search,
}
if search_mode:
data["search_mode"] = search_mode
return data
def agent_arg_parser(
message: Optional[dict | Message] = None,
rag_generation_config: Optional[dict | GenerationConfig] = None,
research_generation_config: Optional[dict | GenerationConfig] = None,
search_mode: Optional[str | SearchMode] = "custom",
search_settings: Optional[dict | SearchSettings] = None,
task_prompt: Optional[str] = None,
include_title_if_available: Optional[bool] = True,
conversation_id: Optional[str | uuid.UUID] = None,
max_tool_context_length: Optional[int] = None,
use_system_context: Optional[bool] = True,
rag_tools: Optional[list[str]] = None,
research_tools: Optional[list[str]] = None,
tools: Optional[list[str]] = None, # For backward compatibility
mode: Optional[str] = "rag",
needs_initial_conversation_name: Optional[bool] = None,
) -> dict:
if rag_generation_config and not isinstance(rag_generation_config, dict):
rag_generation_config = rag_generation_config.model_dump()
if research_generation_config and not isinstance(
research_generation_config, dict
):
research_generation_config = research_generation_config.model_dump()
if search_settings and not isinstance(search_settings, dict):
search_settings = search_settings.model_dump()
data: dict[str, Any] = {
"rag_generation_config": rag_generation_config or {},
"search_settings": search_settings,
"task_prompt": task_prompt,
"include_title_if_available": include_title_if_available,
"conversation_id": (str(conversation_id) if conversation_id else None),
"max_tool_context_length": max_tool_context_length,
"use_system_context": use_system_context,
"mode": mode,
}
# Handle generation configs based on mode
if research_generation_config and mode == "research":
data["research_generation_config"] = research_generation_config
# Handle tool configurations
if rag_tools:
data["rag_tools"] = rag_tools
if research_tools:
data["research_tools"] = research_tools
if tools: # Backward compatibility
data["tools"] = tools
if search_mode:
data["search_mode"] = search_mode
if needs_initial_conversation_name:
data["needs_initial_conversation_name"] = (
needs_initial_conversation_name
)
if message:
cast_message: Message = (
Message(**message) if isinstance(message, dict) else message
)
data["message"] = cast_message.model_dump()
return data
class RetrievalSDK:
"""SDK for interacting with documents in the v3 API."""
def __init__(self, client):
self.client = client
def search(
self,
query: str,
search_mode: Optional[str | SearchMode] = "custom",
search_settings: Optional[dict | SearchSettings] = None,
) -> WrappedSearchResponse:
"""Conduct a vector and/or graph search.
Args:
query (str): The query to search for.
search_settings (Optional[dict, SearchSettings]]): Vector search settings.
Returns:
WrappedSearchResponse
"""
response_dict = self.client._make_request(
"POST",
"retrieval/search",
json=search_arg_parser(
query=query,
search_mode=search_mode,
search_settings=search_settings,
),
version="v3",
)
return WrappedSearchResponse(**response_dict)
def completion(
self,
messages: list[dict | Message],
generation_config: Optional[dict | GenerationConfig] = None,
) -> WrappedLLMChatCompletion:
cast_messages: list[Message] = [
Message(**msg) if isinstance(msg, dict) else msg
for msg in messages
]
if generation_config and not isinstance(generation_config, dict):
generation_config = generation_config.model_dump()
data: dict[str, Any] = {
"messages": [msg.model_dump() for msg in cast_messages],
"generation_config": generation_config,
}
response_dict = self.client._make_request(
"POST",
"retrieval/completion",
json=completion_arg_parser(messages, generation_config),
version="v3",
)
return WrappedLLMChatCompletion(**response_dict)
def embedding(
self,
text: str,
) -> WrappedEmbeddingResponse:
response_dict = self.client._make_request(
"POST",
"retrieval/embedding",
data=embedding_arg_parser(text),
version="v3",
)
return WrappedEmbeddingResponse(**response_dict)
def rag(
self,
query: str,
rag_generation_config: Optional[dict | GenerationConfig] = None,
search_mode: Optional[str | SearchMode] = "custom",
search_settings: Optional[dict | SearchSettings] = None,
task_prompt: Optional[str] = None,
include_title_if_available: Optional[bool] = False,
include_web_search: Optional[bool] = False,
) -> (
WrappedRAGResponse
| Generator[
ThinkingEvent
| SearchResultsEvent
| MessageEvent
| CitationEvent
| FinalAnswerEvent
| ToolCallEvent
| ToolResultEvent
| UnknownEvent
| None,
None,
None,
]
):
"""Conducts a Retrieval Augmented Generation (RAG) search with the
given query.
Args:
query (str): The query to search for.
rag_generation_config (Optional[dict | GenerationConfig]): RAG generation configuration.
search_settings (Optional[dict | SearchSettings]): Vector search settings.
task_prompt (Optional[str]): Task prompt override.
include_title_if_available (Optional[bool]): Include the title if available.
Returns:
WrappedRAGResponse | AsyncGenerator[RAGResponse, None]: The RAG response
"""
data = rag_arg_parser(
query=query,
rag_generation_config=rag_generation_config,
search_mode=search_mode,
search_settings=search_settings,
task_prompt=task_prompt,
include_title_if_available=include_title_if_available,
include_web_search=include_web_search,
)
rag_generation_config = data.get("rag_generation_config")
if rag_generation_config and rag_generation_config.get( # type: ignore
"stream", False
):
raw_stream = self.client._make_streaming_request(
"POST",
"retrieval/rag",
json=data,
version="v3",
)
# Wrap the raw stream to parse each event
return (parse_retrieval_event(event) for event in raw_stream)
response_dict = self.client._make_request(
"POST",
"retrieval/rag",
json=data,
version="v3",
)
return WrappedRAGResponse(**response_dict)
def agent(
self,
message: Optional[dict | Message] = None,
rag_generation_config: Optional[dict | GenerationConfig] = None,
research_generation_config: Optional[dict | GenerationConfig] = None,
search_mode: Optional[str | SearchMode] = "custom",
search_settings: Optional[dict | SearchSettings] = None,
task_prompt: Optional[str] = None,
include_title_if_available: Optional[bool] = True,
conversation_id: Optional[str | uuid.UUID] = None,
max_tool_context_length: Optional[int] = None,
use_system_context: Optional[bool] = True,
# Tool configurations
rag_tools: Optional[list[str]] = None,
research_tools: Optional[list[str]] = None,
tools: Optional[list[str]] = None, # For backward compatibility
mode: Optional[str] = "rag",
needs_initial_conversation_name: Optional[bool] = None,
) -> (
WrappedAgentResponse
| Generator[
ThinkingEvent
| SearchResultsEvent
| MessageEvent
| CitationEvent
| FinalAnswerEvent
| ToolCallEvent
| ToolResultEvent
| UnknownEvent
| None,
None,
None,
]
):
"""Performs a single turn in a conversation with a RAG agent.
Args:
message (Optional[dict | Message]): The message to send to the agent.
rag_generation_config (Optional[dict | GenerationConfig]): Configuration for RAG generation in 'rag' mode.
research_generation_config (Optional[dict | GenerationConfig]): Configuration for generation in 'research' mode.
search_mode (Optional[str | SearchMode]): Pre-configured search modes: "basic", "advanced", or "custom".
search_settings (Optional[dict | SearchSettings]): Vector search settings.
task_prompt (Optional[str]): Task prompt override.
include_title_if_available (Optional[bool]): Include the title if available.
conversation_id (Optional[str | uuid.UUID]): ID of the conversation for maintaining context.
max_tool_context_length (Optional[int]): Maximum context length for tool replies.
use_system_context (Optional[bool]): Whether to use system context in the prompt.
rag_tools (Optional[list[str]]): List of tools to enable for RAG mode.
Available tools: "search_file_knowledge", "content", "web_search", "web_scrape", "search_file_descriptions".
research_tools (Optional[list[str]]): List of tools to enable for Research mode.
Available tools: "rag", "reasoning", "critique", "python_executor".
tools (Optional[list[str]]): Deprecated. List of tools to execute.
mode (Optional[str]): Mode to use for generation: "rag" for standard retrieval or "research" for deep analysis.
Defaults to "rag".
Returns:
WrappedAgentResponse | AsyncGenerator[AgentEvent, None]: The agent response.
"""
data = agent_arg_parser(
message=message,
rag_generation_config=rag_generation_config,
research_generation_config=research_generation_config,
search_mode=search_mode,
search_settings=search_settings,
task_prompt=task_prompt,
include_title_if_available=include_title_if_available,
conversation_id=conversation_id,
max_tool_context_length=max_tool_context_length,
use_system_context=use_system_context,
rag_tools=rag_tools,
research_tools=research_tools,
tools=tools,
mode=mode,
needs_initial_conversation_name=needs_initial_conversation_name,
)
# Determine if streaming is enabled
if search_mode:
data["search_mode"] = search_mode
if message:
cast_message: Message = (
Message(**message) if isinstance(message, dict) else message
)
data["message"] = cast_message.model_dump()
is_stream = False
if mode != "research":
if rag_generation_config:
if isinstance(rag_generation_config, dict):
is_stream = rag_generation_config.get( # type: ignore
"stream", False
)
else:
is_stream = rag_generation_config.stream
else:
if research_generation_config:
if isinstance(research_generation_config, dict):
is_stream = research_generation_config.get( # type: ignore
"stream", False
)
else:
is_stream = research_generation_config.stream
if is_stream:
raw_stream = self.client._make_streaming_request(
"POST",
"retrieval/agent",
json=data,
version="v3",
)
return (parse_retrieval_event(event) for event in raw_stream)
response_dict = self.client._make_request(
"POST",
"retrieval/agent",
json=data,
version="v3",
)
return WrappedAgentResponse(**response_dict)