aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/sdk/sync_methods/retrieval.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/sdk/sync_methods/retrieval.py')
-rw-r--r--.venv/lib/python3.12/site-packages/sdk/sync_methods/retrieval.py554
1 files changed, 554 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/sdk/sync_methods/retrieval.py b/.venv/lib/python3.12/site-packages/sdk/sync_methods/retrieval.py
new file mode 100644
index 00000000..4a927014
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/sdk/sync_methods/retrieval.py
@@ -0,0 +1,554 @@
+import json
+import uuid
+from typing import Any, Generator, Optional
+
+from shared.api.models import (
+ WrappedAgentResponse,
+ WrappedEmbeddingResponse,
+ WrappedLLMChatCompletion,
+ WrappedRAGResponse,
+ WrappedSearchResponse,
+)
+
+from ..models import (
+ AgentEvent,
+ CitationData,
+ CitationEvent,
+ Delta,
+ DeltaPayload,
+ FinalAnswerData,
+ FinalAnswerEvent,
+ GenerationConfig,
+ Message,
+ MessageData,
+ MessageDelta,
+ MessageEvent,
+ SearchMode,
+ SearchResultsData,
+ SearchResultsEvent,
+ SearchSettings,
+ ThinkingData,
+ ThinkingEvent,
+ ToolCallData,
+ ToolCallEvent,
+ ToolResultData,
+ ToolResultEvent,
+ UnknownEvent,
+)
+
+
+def parse_retrieval_event(raw: dict) -> Optional[AgentEvent]:
+ """
+ Convert a raw SSE event dict into a typed Pydantic model.
+
+ Example raw dict:
+ {
+ "event": "message",
+ "data": "{\"id\": \"msg_partial\", \"object\": \"agent.message.delta\", \"delta\": {...}}"
+ }
+ """
+ event_type = raw.get("event", "unknown")
+
+ # If event_type == "done", we usually return None to signal the SSE stream is finished.
+ if event_type == "done":
+ return None
+
+ # The SSE "data" is JSON-encoded, so parse it
+ data_str = raw.get("data", "")
+ try:
+ data_obj = json.loads(data_str)
+ except json.JSONDecodeError as e:
+ # You can decide whether to raise or return UnknownEvent
+ raise ValueError(f"Could not parse JSON in SSE event data: {e}") from e
+
+ # Now branch on event_type to build the right Pydantic model
+ if event_type == "search_results":
+ return SearchResultsEvent(
+ event=event_type,
+ data=SearchResultsData(**data_obj),
+ )
+ elif event_type == "message":
+ # Parse nested delta structure manually before creating MessageData
+ if "delta" in data_obj and isinstance(data_obj["delta"], dict):
+ delta_dict = data_obj["delta"]
+
+ # Convert content items to MessageDelta objects
+ if "content" in delta_dict and isinstance(
+ delta_dict["content"], list
+ ):
+ parsed_content = []
+ for item in delta_dict["content"]:
+ if isinstance(item, dict):
+ # Parse payload to DeltaPayload
+ if "payload" in item and isinstance(
+ item["payload"], dict
+ ):
+ payload_dict = item["payload"]
+ item["payload"] = DeltaPayload(**payload_dict)
+ parsed_content.append(MessageDelta(**item))
+
+ # Replace with parsed content
+ delta_dict["content"] = parsed_content
+
+ # Create properly typed Delta object
+ data_obj["delta"] = Delta(**delta_dict)
+
+ return MessageEvent(
+ event=event_type,
+ data=MessageData(**data_obj),
+ )
+ elif event_type == "citation":
+ return CitationEvent(event=event_type, data=CitationData(**data_obj))
+ elif event_type == "tool_call":
+ return ToolCallEvent(event=event_type, data=ToolCallData(**data_obj))
+ elif event_type == "tool_result":
+ return ToolResultEvent(
+ event=event_type, data=ToolResultData(**data_obj)
+ )
+ elif event_type == "thinking":
+ # Parse nested delta structure manually before creating ThinkingData
+ if "delta" in data_obj and isinstance(data_obj["delta"], dict):
+ delta_dict = data_obj["delta"]
+
+ # Convert content items to MessageDelta objects
+ if "content" in delta_dict and isinstance(
+ delta_dict["content"], list
+ ):
+ parsed_content = []
+ for item in delta_dict["content"]:
+ if isinstance(item, dict):
+ # Parse payload to DeltaPayload
+ if "payload" in item and isinstance(
+ item["payload"], dict
+ ):
+ payload_dict = item["payload"]
+ item["payload"] = DeltaPayload(**payload_dict)
+ parsed_content.append(MessageDelta(**item))
+
+ # Replace with parsed content
+ delta_dict["content"] = parsed_content
+
+ # Create properly typed Delta object
+ data_obj["delta"] = Delta(**delta_dict)
+
+ return ThinkingEvent(
+ event=event_type,
+ data=ThinkingData(**data_obj),
+ )
+ elif event_type == "final_answer":
+ return FinalAnswerEvent(
+ event=event_type, data=FinalAnswerData(**data_obj)
+ )
+ else:
+ # Fallback if it doesn't match any known event
+ return UnknownEvent(
+ event=event_type,
+ data=data_obj,
+ )
+
+
+def search_arg_parser(
+ query: str,
+ search_mode: Optional[str | SearchMode] = "custom",
+ search_settings: Optional[dict | SearchSettings] = None,
+) -> dict:
+ if search_mode and not isinstance(search_mode, str):
+ search_mode = search_mode.value
+
+ if search_settings and not isinstance(search_settings, dict):
+ search_settings = search_settings.model_dump()
+
+ data: dict[str, Any] = {
+ "query": query,
+ "search_settings": search_settings,
+ }
+ if search_mode:
+ data["search_mode"] = search_mode
+
+ return data
+
+
+def completion_arg_parser(
+ messages: list[dict | Message],
+ generation_config: Optional[dict | GenerationConfig] = None,
+) -> dict:
+ # FIXME: Needs a proper return type
+ cast_messages: list[Message] = [
+ Message(**msg) if isinstance(msg, dict) else msg for msg in messages
+ ]
+
+ if generation_config and not isinstance(generation_config, dict):
+ generation_config = generation_config.model_dump()
+
+ data: dict[str, Any] = {
+ "messages": [msg.model_dump() for msg in cast_messages],
+ "generation_config": generation_config,
+ }
+ return data
+
+
+def embedding_arg_parser(
+ text: str,
+) -> dict:
+ data: dict[str, Any] = {
+ "text": text,
+ }
+ return data
+
+
+def rag_arg_parser(
+ query: str,
+ rag_generation_config: Optional[dict | GenerationConfig] = None,
+ search_mode: Optional[str | SearchMode] = "custom",
+ search_settings: Optional[dict | SearchSettings] = None,
+ task_prompt: Optional[str] = None,
+ include_title_if_available: Optional[bool] = False,
+ include_web_search: Optional[bool] = False,
+) -> dict:
+ if rag_generation_config and not isinstance(rag_generation_config, dict):
+ rag_generation_config = rag_generation_config.model_dump()
+ if search_settings and not isinstance(search_settings, dict):
+ search_settings = search_settings.model_dump()
+
+ data: dict[str, Any] = {
+ "query": query,
+ "rag_generation_config": rag_generation_config,
+ "search_settings": search_settings,
+ "task_prompt": task_prompt,
+ "include_title_if_available": include_title_if_available,
+ "include_web_search": include_web_search,
+ }
+ if search_mode:
+ data["search_mode"] = search_mode
+ return data
+
+
+def agent_arg_parser(
+ message: Optional[dict | Message] = None,
+ rag_generation_config: Optional[dict | GenerationConfig] = None,
+ research_generation_config: Optional[dict | GenerationConfig] = None,
+ search_mode: Optional[str | SearchMode] = "custom",
+ search_settings: Optional[dict | SearchSettings] = None,
+ task_prompt: Optional[str] = None,
+ include_title_if_available: Optional[bool] = True,
+ conversation_id: Optional[str | uuid.UUID] = None,
+ max_tool_context_length: Optional[int] = None,
+ use_system_context: Optional[bool] = True,
+ rag_tools: Optional[list[str]] = None,
+ research_tools: Optional[list[str]] = None,
+ tools: Optional[list[str]] = None, # For backward compatibility
+ mode: Optional[str] = "rag",
+ needs_initial_conversation_name: Optional[bool] = None,
+) -> dict:
+ if rag_generation_config and not isinstance(rag_generation_config, dict):
+ rag_generation_config = rag_generation_config.model_dump()
+ if research_generation_config and not isinstance(
+ research_generation_config, dict
+ ):
+ research_generation_config = research_generation_config.model_dump()
+ if search_settings and not isinstance(search_settings, dict):
+ search_settings = search_settings.model_dump()
+
+ data: dict[str, Any] = {
+ "rag_generation_config": rag_generation_config or {},
+ "search_settings": search_settings,
+ "task_prompt": task_prompt,
+ "include_title_if_available": include_title_if_available,
+ "conversation_id": (str(conversation_id) if conversation_id else None),
+ "max_tool_context_length": max_tool_context_length,
+ "use_system_context": use_system_context,
+ "mode": mode,
+ }
+
+ # Handle generation configs based on mode
+ if research_generation_config and mode == "research":
+ data["research_generation_config"] = research_generation_config
+
+ # Handle tool configurations
+ if rag_tools:
+ data["rag_tools"] = rag_tools
+ if research_tools:
+ data["research_tools"] = research_tools
+ if tools: # Backward compatibility
+ data["tools"] = tools
+
+ if search_mode:
+ data["search_mode"] = search_mode
+
+ if needs_initial_conversation_name:
+ data["needs_initial_conversation_name"] = (
+ needs_initial_conversation_name
+ )
+
+ if message:
+ cast_message: Message = (
+ Message(**message) if isinstance(message, dict) else message
+ )
+ data["message"] = cast_message.model_dump()
+ return data
+
+
+class RetrievalSDK:
+ """SDK for interacting with documents in the v3 API."""
+
+ def __init__(self, client):
+ self.client = client
+
+ def search(
+ self,
+ query: str,
+ search_mode: Optional[str | SearchMode] = "custom",
+ search_settings: Optional[dict | SearchSettings] = None,
+ ) -> WrappedSearchResponse:
+ """Conduct a vector and/or graph search.
+
+ Args:
+ query (str): The query to search for.
+ search_settings (Optional[dict, SearchSettings]]): Vector search settings.
+
+ Returns:
+ WrappedSearchResponse
+ """
+
+ response_dict = self.client._make_request(
+ "POST",
+ "retrieval/search",
+ json=search_arg_parser(
+ query=query,
+ search_mode=search_mode,
+ search_settings=search_settings,
+ ),
+ version="v3",
+ )
+
+ return WrappedSearchResponse(**response_dict)
+
+ def completion(
+ self,
+ messages: list[dict | Message],
+ generation_config: Optional[dict | GenerationConfig] = None,
+ ) -> WrappedLLMChatCompletion:
+ cast_messages: list[Message] = [
+ Message(**msg) if isinstance(msg, dict) else msg
+ for msg in messages
+ ]
+
+ if generation_config and not isinstance(generation_config, dict):
+ generation_config = generation_config.model_dump()
+
+ data: dict[str, Any] = {
+ "messages": [msg.model_dump() for msg in cast_messages],
+ "generation_config": generation_config,
+ }
+ response_dict = self.client._make_request(
+ "POST",
+ "retrieval/completion",
+ json=completion_arg_parser(messages, generation_config),
+ version="v3",
+ )
+
+ return WrappedLLMChatCompletion(**response_dict)
+
+ def embedding(
+ self,
+ text: str,
+ ) -> WrappedEmbeddingResponse:
+ response_dict = self.client._make_request(
+ "POST",
+ "retrieval/embedding",
+ data=embedding_arg_parser(text),
+ version="v3",
+ )
+
+ return WrappedEmbeddingResponse(**response_dict)
+
+ def rag(
+ self,
+ query: str,
+ rag_generation_config: Optional[dict | GenerationConfig] = None,
+ search_mode: Optional[str | SearchMode] = "custom",
+ search_settings: Optional[dict | SearchSettings] = None,
+ task_prompt: Optional[str] = None,
+ include_title_if_available: Optional[bool] = False,
+ include_web_search: Optional[bool] = False,
+ ) -> (
+ WrappedRAGResponse
+ | Generator[
+ ThinkingEvent
+ | SearchResultsEvent
+ | MessageEvent
+ | CitationEvent
+ | FinalAnswerEvent
+ | ToolCallEvent
+ | ToolResultEvent
+ | UnknownEvent
+ | None,
+ None,
+ None,
+ ]
+ ):
+ """Conducts a Retrieval Augmented Generation (RAG) search with the
+ given query.
+
+ Args:
+ query (str): The query to search for.
+ rag_generation_config (Optional[dict | GenerationConfig]): RAG generation configuration.
+ search_settings (Optional[dict | SearchSettings]): Vector search settings.
+ task_prompt (Optional[str]): Task prompt override.
+ include_title_if_available (Optional[bool]): Include the title if available.
+
+ Returns:
+ WrappedRAGResponse | AsyncGenerator[RAGResponse, None]: The RAG response
+ """
+ data = rag_arg_parser(
+ query=query,
+ rag_generation_config=rag_generation_config,
+ search_mode=search_mode,
+ search_settings=search_settings,
+ task_prompt=task_prompt,
+ include_title_if_available=include_title_if_available,
+ include_web_search=include_web_search,
+ )
+ rag_generation_config = data.get("rag_generation_config")
+ if rag_generation_config and rag_generation_config.get( # type: ignore
+ "stream", False
+ ):
+ raw_stream = self.client._make_streaming_request(
+ "POST",
+ "retrieval/rag",
+ json=data,
+ version="v3",
+ )
+ # Wrap the raw stream to parse each event
+ return (parse_retrieval_event(event) for event in raw_stream)
+
+ response_dict = self.client._make_request(
+ "POST",
+ "retrieval/rag",
+ json=data,
+ version="v3",
+ )
+
+ return WrappedRAGResponse(**response_dict)
+
+ def agent(
+ self,
+ message: Optional[dict | Message] = None,
+ rag_generation_config: Optional[dict | GenerationConfig] = None,
+ research_generation_config: Optional[dict | GenerationConfig] = None,
+ search_mode: Optional[str | SearchMode] = "custom",
+ search_settings: Optional[dict | SearchSettings] = None,
+ task_prompt: Optional[str] = None,
+ include_title_if_available: Optional[bool] = True,
+ conversation_id: Optional[str | uuid.UUID] = None,
+ max_tool_context_length: Optional[int] = None,
+ use_system_context: Optional[bool] = True,
+ # Tool configurations
+ rag_tools: Optional[list[str]] = None,
+ research_tools: Optional[list[str]] = None,
+ tools: Optional[list[str]] = None, # For backward compatibility
+ mode: Optional[str] = "rag",
+ needs_initial_conversation_name: Optional[bool] = None,
+ ) -> (
+ WrappedAgentResponse
+ | Generator[
+ ThinkingEvent
+ | SearchResultsEvent
+ | MessageEvent
+ | CitationEvent
+ | FinalAnswerEvent
+ | ToolCallEvent
+ | ToolResultEvent
+ | UnknownEvent
+ | None,
+ None,
+ None,
+ ]
+ ):
+ """Performs a single turn in a conversation with a RAG agent.
+
+ Args:
+ message (Optional[dict | Message]): The message to send to the agent.
+ rag_generation_config (Optional[dict | GenerationConfig]): Configuration for RAG generation in 'rag' mode.
+ research_generation_config (Optional[dict | GenerationConfig]): Configuration for generation in 'research' mode.
+ search_mode (Optional[str | SearchMode]): Pre-configured search modes: "basic", "advanced", or "custom".
+ search_settings (Optional[dict | SearchSettings]): Vector search settings.
+ task_prompt (Optional[str]): Task prompt override.
+ include_title_if_available (Optional[bool]): Include the title if available.
+ conversation_id (Optional[str | uuid.UUID]): ID of the conversation for maintaining context.
+ max_tool_context_length (Optional[int]): Maximum context length for tool replies.
+ use_system_context (Optional[bool]): Whether to use system context in the prompt.
+ rag_tools (Optional[list[str]]): List of tools to enable for RAG mode.
+ Available tools: "search_file_knowledge", "content", "web_search", "web_scrape", "search_file_descriptions".
+ research_tools (Optional[list[str]]): List of tools to enable for Research mode.
+ Available tools: "rag", "reasoning", "critique", "python_executor".
+ tools (Optional[list[str]]): Deprecated. List of tools to execute.
+ mode (Optional[str]): Mode to use for generation: "rag" for standard retrieval or "research" for deep analysis.
+ Defaults to "rag".
+
+ Returns:
+ WrappedAgentResponse | AsyncGenerator[AgentEvent, None]: The agent response.
+ """
+ data = agent_arg_parser(
+ message=message,
+ rag_generation_config=rag_generation_config,
+ research_generation_config=research_generation_config,
+ search_mode=search_mode,
+ search_settings=search_settings,
+ task_prompt=task_prompt,
+ include_title_if_available=include_title_if_available,
+ conversation_id=conversation_id,
+ max_tool_context_length=max_tool_context_length,
+ use_system_context=use_system_context,
+ rag_tools=rag_tools,
+ research_tools=research_tools,
+ tools=tools,
+ mode=mode,
+ needs_initial_conversation_name=needs_initial_conversation_name,
+ )
+
+ # Determine if streaming is enabled
+ if search_mode:
+ data["search_mode"] = search_mode
+
+ if message:
+ cast_message: Message = (
+ Message(**message) if isinstance(message, dict) else message
+ )
+ data["message"] = cast_message.model_dump()
+
+ is_stream = False
+ if mode != "research":
+ if rag_generation_config:
+ if isinstance(rag_generation_config, dict):
+ is_stream = rag_generation_config.get( # type: ignore
+ "stream", False
+ )
+ else:
+ is_stream = rag_generation_config.stream
+ else:
+ if research_generation_config:
+ if isinstance(research_generation_config, dict):
+ is_stream = research_generation_config.get( # type: ignore
+ "stream", False
+ )
+ else:
+ is_stream = research_generation_config.stream
+
+ if is_stream:
+ raw_stream = self.client._make_streaming_request(
+ "POST",
+ "retrieval/agent",
+ json=data,
+ version="v3",
+ )
+ return (parse_retrieval_event(event) for event in raw_stream)
+
+ response_dict = self.client._make_request(
+ "POST",
+ "retrieval/agent",
+ json=data,
+ version="v3",
+ )
+
+ return WrappedAgentResponse(**response_dict)