aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/core/agent/rag.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/core/agent/rag.py')
-rw-r--r--.venv/lib/python3.12/site-packages/core/agent/rag.py662
1 files changed, 662 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/core/agent/rag.py b/.venv/lib/python3.12/site-packages/core/agent/rag.py
new file mode 100644
index 00000000..6f3ab630
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/agent/rag.py
@@ -0,0 +1,662 @@
+# type: ignore
+import logging
+from typing import Any, Callable, Optional
+
+from core.base import (
+ format_search_results_for_llm,
+)
+from core.base.abstractions import (
+ AggregateSearchResult,
+ GenerationConfig,
+ SearchSettings,
+ WebPageSearchResult,
+ WebSearchResult,
+)
+from core.base.agent import Tool
+from core.base.providers import DatabaseProvider
+from core.providers import (
+ AnthropicCompletionProvider,
+ LiteLLMCompletionProvider,
+ OpenAICompletionProvider,
+ R2RCompletionProvider,
+)
+from core.utils import (
+ SearchResultsCollector,
+ generate_id,
+ num_tokens,
+)
+
+from ..base.agent.agent import RAGAgentConfig
+
+# Import the base classes from the refactored base file
+from .base import (
+ R2RAgent,
+ R2RStreamingAgent,
+ R2RXMLStreamingAgent,
+ R2RXMLToolsAgent,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class RAGAgentMixin:
+ """
+ A Mixin for adding search_file_knowledge, web_search, and content tools
+ to your R2R Agents. This allows your agent to:
+ - call knowledge_search_method (semantic/hybrid search)
+ - call content_method (fetch entire doc/chunk structures)
+ - call an external web search API
+ """
+
+ def __init__(
+ self,
+ *args,
+ search_settings: SearchSettings,
+ knowledge_search_method: Callable,
+ content_method: Callable,
+ file_search_method: Callable,
+ max_tool_context_length=10_000,
+ max_context_window_tokens=512_000,
+ **kwargs,
+ ):
+ # Save references to the retrieval logic
+ self.search_settings = search_settings
+ self.knowledge_search_method = knowledge_search_method
+ self.content_method = content_method
+ self.file_search_method = file_search_method
+ self.max_tool_context_length = max_tool_context_length
+ self.max_context_window_tokens = max_context_window_tokens
+ self.search_results_collector = SearchResultsCollector()
+ super().__init__(*args, **kwargs)
+
+ def _register_tools(self):
+ """
+ Called by the base R2RAgent to register all requested tools from self.config.rag_tools.
+ """
+ if not self.config.rag_tools:
+ return
+
+ for tool_name in set(self.config.rag_tools):
+ if tool_name == "get_file_content":
+ self._tools.append(self.content())
+ elif tool_name == "web_scrape":
+ self._tools.append(self.web_scrape())
+ elif tool_name == "search_file_knowledge":
+ self._tools.append(self.search_file_knowledge())
+ elif tool_name == "search_file_descriptions":
+ self._tools.append(self.search_files())
+ elif tool_name == "web_search":
+ self._tools.append(self.web_search())
+ else:
+ raise ValueError(f"Unknown tool requested: {tool_name}")
+ logger.debug(f"Registered {len(self._tools)} RAG tools.")
+
+ # Local Search Tool
+ def search_file_knowledge(self) -> Tool:
+ """
+ Tool to do a semantic/hybrid search on the local knowledge base
+ using self.knowledge_search_method.
+ """
+ return Tool(
+ name="search_file_knowledge",
+ description=(
+ "Search your local knowledge base using the R2R system. "
+ "Use this when you want relevant text chunks or knowledge graph data."
+ ),
+ results_function=self._file_knowledge_search_function,
+ llm_format_function=self.format_search_results_for_llm,
+ parameters={
+ "type": "object",
+ "properties": {
+ "query": {
+ "type": "string",
+ "description": "User query to search in the local DB.",
+ },
+ },
+ "required": ["query"],
+ },
+ )
+
+ async def _file_knowledge_search_function(
+ self,
+ query: str,
+ *args,
+ **kwargs,
+ ) -> AggregateSearchResult:
+ """
+ Calls the passed-in `knowledge_search_method(query, search_settings)`.
+ Expects either an AggregateSearchResult or a dict with chunk_search_results, etc.
+ """
+ if not self.knowledge_search_method:
+ raise ValueError(
+ "No knowledge_search_method provided to RAGAgentMixin."
+ )
+
+ raw_response = await self.knowledge_search_method(
+ query=query, search_settings=self.search_settings
+ )
+
+ if isinstance(raw_response, AggregateSearchResult):
+ agg = raw_response
+ else:
+ agg = AggregateSearchResult(
+ chunk_search_results=raw_response.get(
+ "chunk_search_results", []
+ ),
+ graph_search_results=raw_response.get(
+ "graph_search_results", []
+ ),
+ )
+
+ # 1) Store them so that we can do final citations later
+ self.search_results_collector.add_aggregate_result(agg)
+ return agg
+
+ # 2) Local Context
+ def content(self) -> Tool:
+ """Tool to fetch entire documents from the local database.
+
+ Typically used if the agent needs deeper or more structured context
+ from documents, not just chunk-level hits.
+ """
+ if "gemini" in self.rag_generation_config.model:
+ tool = Tool(
+ name="get_file_content",
+ description=(
+ "Fetches the complete contents of all user documents from the local database. "
+ "Can be used alongside filter criteria (e.g. doc IDs, collection IDs, etc.) to restrict the query."
+ "For instance, a single document can be returned with a filter like so:"
+ "{'document_id': {'$eq': '...'}}."
+ "Be sure to use the full 32 character hexidecimal document ID, and not the shortened 8 character ID."
+ ),
+ results_function=self._content_function,
+ llm_format_function=self.format_search_results_for_llm,
+ parameters={
+ "type": "object",
+ "properties": {
+ "filters": {
+ "type": "string",
+ "description": (
+ "Dictionary with filter criteria, such as "
+ '{"$and": [{"document_id": {"$eq": "6c9d1c39..."}, {"collection_ids": {"$overlap": [...]}]}'
+ ),
+ },
+ },
+ "required": ["filters"],
+ },
+ )
+
+ else:
+ tool = Tool(
+ name="get_file_content",
+ description=(
+ "Fetches the complete contents of all user documents from the local database. "
+ "Can be used alongside filter criteria (e.g. doc IDs, collection IDs, etc.) to restrict the query."
+ "For instance, a single document can be returned with a filter like so:"
+ "{'document_id': {'$eq': '...'}}."
+ ),
+ results_function=self._content_function,
+ llm_format_function=self.format_search_results_for_llm,
+ parameters={
+ "type": "object",
+ "properties": {
+ "filters": {
+ "type": "object",
+ "description": (
+ "Dictionary with filter criteria, such as "
+ '{"$and": [{"document_id": {"$eq": "6c9d1c39..."}, {"collection_ids": {"$overlap": [...]}]}'
+ ),
+ },
+ },
+ "required": ["filters"],
+ },
+ )
+ return tool
+
+ async def _content_function(
+ self,
+ filters: Optional[dict[str, Any]] = None,
+ options: Optional[dict[str, Any]] = None,
+ *args,
+ **kwargs,
+ ) -> AggregateSearchResult:
+ """Calls the passed-in `content_method(filters, options)` to fetch
+ doc+chunk structures.
+
+ Typically returns a list of dicts:
+ [
+ { 'document': {...}, 'chunks': [ {...}, {...}, ... ] },
+ ...
+ ]
+ We'll store these in a new field `document_search_results` of
+ AggregateSearchResult so we don't collide with chunk_search_results.
+ """
+ if not self.content_method:
+ raise ValueError("No content_method provided to RAGAgentMixin.")
+
+ if filters:
+ if "document_id" in filters:
+ filters["id"] = filters.pop("document_id")
+ if self.search_settings.filters != {}:
+ filters = {"$and": [filters, self.search_settings.filters]}
+ else:
+ filters = self.search_settings.filters
+
+ options = options or {}
+
+ # Actually call your data retrieval
+ content = await self.content_method(filters, options)
+ # raw_context presumably is a list[dict], each with 'document' + 'chunks'.
+
+ # Return them in the new aggregator field
+ agg = AggregateSearchResult(
+ # We won't put them in chunk_search_results:
+ chunk_search_results=None,
+ graph_search_results=None,
+ web_search_results=None,
+ document_search_results=content,
+ )
+ self.search_results_collector.add_aggregate_result(agg)
+ return agg
+
+ # Web Search Tool
+ def web_search(self) -> Tool:
+ return Tool(
+ name="web_search",
+ description=(
+ "Search for information on the web - use this tool when the user "
+ "query needs LIVE or recent data from the internet."
+ ),
+ results_function=self._web_search_function,
+ llm_format_function=self.format_search_results_for_llm,
+ parameters={
+ "type": "object",
+ "properties": {
+ "query": {
+ "type": "string",
+ "description": "The query to search with an external web API.",
+ },
+ },
+ "required": ["query"],
+ },
+ )
+
+ async def _web_search_function(
+ self,
+ query: str,
+ *args,
+ **kwargs,
+ ) -> AggregateSearchResult:
+ """
+ Calls an external search engine (Serper, Google, etc.) asynchronously
+ and returns results in an AggregateSearchResult.
+ """
+ import asyncio
+
+ from ..utils.serper import SerperClient # adjust your import
+
+ serper_client = SerperClient()
+
+ # If SerperClient.get_raw is not already async, wrap it in run_in_executor
+ raw_results = await asyncio.get_event_loop().run_in_executor(
+ None, # Uses the default executor
+ lambda: serper_client.get_raw(query),
+ )
+
+ # If from_serper_results is not already async, wrap it in run_in_executor too
+ web_response = await asyncio.get_event_loop().run_in_executor(
+ None, lambda: WebSearchResult.from_serper_results(raw_results)
+ )
+
+ agg = AggregateSearchResult(
+ chunk_search_results=None,
+ graph_search_results=None,
+ web_search_results=web_response.organic_results,
+ )
+ self.search_results_collector.add_aggregate_result(agg)
+ return agg
+
+ def search_files(self) -> Tool:
+ """
+ A tool to search over file-level metadata (titles, doc-level descriptions, etc.)
+ returning a list of DocumentResponse objects.
+ """
+ return Tool(
+ name="search_file_descriptions",
+ description=(
+ "Semantic search over the stored documents over AI generated summaries of input documents. "
+ "This does NOT retrieve chunk-level contents or knowledge-graph relationships. "
+ "Use this when you need a broad overview of which documents (files) might be relevant."
+ ),
+ results_function=self._search_files_function,
+ llm_format_function=self.format_search_results_for_llm,
+ parameters={
+ "type": "object",
+ "properties": {
+ "query": {
+ "type": "string",
+ "description": "Query string to semantic search over available files 'list documents about XYZ'.",
+ }
+ },
+ "required": ["query"],
+ },
+ )
+
+ async def _search_files_function(
+ self, query: str, *args, **kwargs
+ ) -> AggregateSearchResult:
+ if not self.file_search_method:
+ raise ValueError(
+ "No file_search_method provided to RAGAgentMixin."
+ )
+
+ # call the doc-level search
+ """
+ FIXME: This is going to fail, as it requires an embedding NOT a query.
+ I've moved 'search_settings' to 'settings' which had been causing a silent failure
+ causing null content in the Message object.
+ """
+ doc_results = await self.file_search_method(
+ query=query,
+ settings=self.search_settings,
+ )
+
+ # Wrap them in an AggregateSearchResult
+ agg = AggregateSearchResult(document_search_results=doc_results)
+
+ # Add them to the collector
+ self.search_results_collector.add_aggregate_result(agg)
+ return agg
+
+ def format_search_results_for_llm(
+ self, results: AggregateSearchResult
+ ) -> str:
+ context = format_search_results_for_llm(
+ results, self.search_results_collector
+ )
+ context_tokens = num_tokens(context) + 1
+ frac_to_return = self.max_tool_context_length / (context_tokens)
+
+ if frac_to_return > 1:
+ return context
+ else:
+ return context[: int(frac_to_return * len(context))]
+
+ def web_scrape(self) -> Tool:
+ """
+ A new Tool that uses Firecrawl to scrape a single URL and return
+ its contents in an LLM-friendly format (e.g. markdown).
+ """
+ return Tool(
+ name="web_scrape",
+ description=(
+ "Use Firecrawl to scrape a single webpage and retrieve its contents "
+ "as clean markdown. Useful when you need the entire body of a page, "
+ "not just a quick snippet or standard web search result."
+ ),
+ results_function=self._web_scrape_function,
+ llm_format_function=self.format_search_results_for_llm,
+ parameters={
+ "type": "object",
+ "properties": {
+ "url": {
+ "type": "string",
+ "description": (
+ "The absolute URL of the webpage you want to scrape. "
+ "Example: 'https://docs.firecrawl.dev/getting-started'"
+ ),
+ }
+ },
+ "required": ["url"],
+ },
+ )
+
+ async def _web_scrape_function(
+ self,
+ url: str,
+ *args,
+ **kwargs,
+ ) -> AggregateSearchResult:
+ """
+ Performs the Firecrawl scrape asynchronously, returning results
+ as an `AggregateSearchResult` with a single WebPageSearchResult.
+ """
+ import asyncio
+
+ from firecrawl import FirecrawlApp
+
+ app = FirecrawlApp()
+ logger.debug(f"[Firecrawl] Scraping URL={url}")
+
+ # Create a proper async wrapper for the synchronous scrape_url method
+ # This offloads the blocking operation to a thread pool
+ response = await asyncio.get_event_loop().run_in_executor(
+ None, # Uses the default executor
+ lambda: app.scrape_url(
+ url=url,
+ params={"formats": ["markdown"]},
+ ),
+ )
+
+ markdown_text = response.get("markdown", "")
+ metadata = response.get("metadata", {})
+ page_title = metadata.get("title", "Untitled page")
+
+ if len(markdown_text) > 100_000:
+ markdown_text = (
+ markdown_text[:100_000] + "...FURTHER CONTENT TRUNCATED..."
+ )
+
+ # Create a single WebPageSearchResult HACK - TODO FIX
+ web_result = WebPageSearchResult(
+ title=page_title,
+ link=url,
+ snippet=markdown_text,
+ position=0,
+ id=generate_id(markdown_text),
+ type="firecrawl",
+ )
+
+ agg = AggregateSearchResult(web_search_results=[web_result])
+
+ # Add results to the collector
+ if self.search_results_collector:
+ self.search_results_collector.add_aggregate_result(agg)
+
+ return agg
+
+
+class R2RRAGAgent(RAGAgentMixin, R2RAgent):
+ """
+ Non-streaming RAG Agent that supports search_file_knowledge, content, web_search.
+ """
+
+ def __init__(
+ self,
+ database_provider: DatabaseProvider,
+ llm_provider: (
+ AnthropicCompletionProvider
+ | LiteLLMCompletionProvider
+ | OpenAICompletionProvider
+ | R2RCompletionProvider
+ ),
+ config: RAGAgentConfig,
+ search_settings: SearchSettings,
+ rag_generation_config: GenerationConfig,
+ knowledge_search_method: Callable,
+ content_method: Callable,
+ file_search_method: Callable,
+ max_tool_context_length: int = 20_000,
+ ):
+ # Initialize base R2RAgent
+ R2RAgent.__init__(
+ self,
+ database_provider=database_provider,
+ llm_provider=llm_provider,
+ config=config,
+ rag_generation_config=rag_generation_config,
+ )
+ # Initialize the RAGAgentMixin
+ RAGAgentMixin.__init__(
+ self,
+ database_provider=database_provider,
+ llm_provider=llm_provider,
+ config=config,
+ search_settings=search_settings,
+ rag_generation_config=rag_generation_config,
+ max_tool_context_length=max_tool_context_length,
+ knowledge_search_method=knowledge_search_method,
+ file_search_method=file_search_method,
+ content_method=content_method,
+ )
+
+
+class R2RXMLToolsRAGAgent(RAGAgentMixin, R2RXMLToolsAgent):
+ """
+ Non-streaming RAG Agent that supports search_file_knowledge, content, web_search.
+ """
+
+ def __init__(
+ self,
+ database_provider: DatabaseProvider,
+ llm_provider: (
+ AnthropicCompletionProvider
+ | LiteLLMCompletionProvider
+ | OpenAICompletionProvider
+ | R2RCompletionProvider
+ ),
+ config: RAGAgentConfig,
+ search_settings: SearchSettings,
+ rag_generation_config: GenerationConfig,
+ knowledge_search_method: Callable,
+ content_method: Callable,
+ file_search_method: Callable,
+ max_tool_context_length: int = 20_000,
+ ):
+ # Initialize base R2RAgent
+ R2RXMLToolsAgent.__init__(
+ self,
+ database_provider=database_provider,
+ llm_provider=llm_provider,
+ config=config,
+ rag_generation_config=rag_generation_config,
+ )
+ # Initialize the RAGAgentMixin
+ RAGAgentMixin.__init__(
+ self,
+ database_provider=database_provider,
+ llm_provider=llm_provider,
+ config=config,
+ search_settings=search_settings,
+ rag_generation_config=rag_generation_config,
+ max_tool_context_length=max_tool_context_length,
+ knowledge_search_method=knowledge_search_method,
+ file_search_method=file_search_method,
+ content_method=content_method,
+ )
+
+
+class R2RStreamingRAGAgent(RAGAgentMixin, R2RStreamingAgent):
+ """
+ Streaming-capable RAG Agent that supports search_file_knowledge, content, web_search,
+ and emits citations as [abc1234] short IDs if the LLM includes them in brackets.
+ """
+
+ def __init__(
+ self,
+ database_provider: DatabaseProvider,
+ llm_provider: (
+ AnthropicCompletionProvider
+ | LiteLLMCompletionProvider
+ | OpenAICompletionProvider
+ | R2RCompletionProvider
+ ),
+ config: RAGAgentConfig,
+ search_settings: SearchSettings,
+ rag_generation_config: GenerationConfig,
+ knowledge_search_method: Callable,
+ content_method: Callable,
+ file_search_method: Callable,
+ max_tool_context_length: int = 10_000,
+ ):
+ # Force streaming on
+ config.stream = True
+
+ # Initialize base R2RStreamingAgent
+ R2RStreamingAgent.__init__(
+ self,
+ database_provider=database_provider,
+ llm_provider=llm_provider,
+ config=config,
+ rag_generation_config=rag_generation_config,
+ )
+
+ # Initialize the RAGAgentMixin
+ RAGAgentMixin.__init__(
+ self,
+ database_provider=database_provider,
+ llm_provider=llm_provider,
+ config=config,
+ search_settings=search_settings,
+ rag_generation_config=rag_generation_config,
+ max_tool_context_length=max_tool_context_length,
+ knowledge_search_method=knowledge_search_method,
+ content_method=content_method,
+ file_search_method=file_search_method,
+ )
+
+
+class R2RXMLToolsStreamingRAGAgent(RAGAgentMixin, R2RXMLStreamingAgent):
+ """
+ A streaming agent that:
+ - treats <think> or <Thought> blocks as chain-of-thought
+ and emits them incrementally as SSE "thinking" events.
+ - accumulates user-visible text outside those tags as SSE "message" events.
+ - filters out all XML tags related to tool calls and actions.
+ - upon finishing each iteration, it parses <Action><ToolCalls><ToolCall> blocks,
+ calls the appropriate tool, and emits SSE "tool_call" / "tool_result".
+ - properly emits citations when they appear in the text
+ """
+
+ def __init__(
+ self,
+ database_provider: DatabaseProvider,
+ llm_provider: (
+ AnthropicCompletionProvider
+ | LiteLLMCompletionProvider
+ | OpenAICompletionProvider
+ | R2RCompletionProvider
+ ),
+ config: RAGAgentConfig,
+ search_settings: SearchSettings,
+ rag_generation_config: GenerationConfig,
+ knowledge_search_method: Callable,
+ content_method: Callable,
+ file_search_method: Callable,
+ max_tool_context_length: int = 10_000,
+ ):
+ # Force streaming on
+ config.stream = True
+
+ # Initialize base R2RXMLStreamingAgent
+ R2RXMLStreamingAgent.__init__(
+ self,
+ database_provider=database_provider,
+ llm_provider=llm_provider,
+ config=config,
+ rag_generation_config=rag_generation_config,
+ )
+
+ # Initialize the RAGAgentMixin
+ RAGAgentMixin.__init__(
+ self,
+ database_provider=database_provider,
+ llm_provider=llm_provider,
+ config=config,
+ search_settings=search_settings,
+ rag_generation_config=rag_generation_config,
+ max_tool_context_length=max_tool_context_length,
+ knowledge_search_method=knowledge_search_method,
+ content_method=content_method,
+ file_search_method=file_search_method,
+ )