diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/core/agent/rag.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/core/agent/rag.py | 662 |
1 files changed, 662 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/core/agent/rag.py b/.venv/lib/python3.12/site-packages/core/agent/rag.py new file mode 100644 index 00000000..6f3ab630 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/agent/rag.py @@ -0,0 +1,662 @@ +# type: ignore +import logging +from typing import Any, Callable, Optional + +from core.base import ( + format_search_results_for_llm, +) +from core.base.abstractions import ( + AggregateSearchResult, + GenerationConfig, + SearchSettings, + WebPageSearchResult, + WebSearchResult, +) +from core.base.agent import Tool +from core.base.providers import DatabaseProvider +from core.providers import ( + AnthropicCompletionProvider, + LiteLLMCompletionProvider, + OpenAICompletionProvider, + R2RCompletionProvider, +) +from core.utils import ( + SearchResultsCollector, + generate_id, + num_tokens, +) + +from ..base.agent.agent import RAGAgentConfig + +# Import the base classes from the refactored base file +from .base import ( + R2RAgent, + R2RStreamingAgent, + R2RXMLStreamingAgent, + R2RXMLToolsAgent, +) + +logger = logging.getLogger(__name__) + + +class RAGAgentMixin: + """ + A Mixin for adding search_file_knowledge, web_search, and content tools + to your R2R Agents. This allows your agent to: + - call knowledge_search_method (semantic/hybrid search) + - call content_method (fetch entire doc/chunk structures) + - call an external web search API + """ + + def __init__( + self, + *args, + search_settings: SearchSettings, + knowledge_search_method: Callable, + content_method: Callable, + file_search_method: Callable, + max_tool_context_length=10_000, + max_context_window_tokens=512_000, + **kwargs, + ): + # Save references to the retrieval logic + self.search_settings = search_settings + self.knowledge_search_method = knowledge_search_method + self.content_method = content_method + self.file_search_method = file_search_method + self.max_tool_context_length = max_tool_context_length + self.max_context_window_tokens = max_context_window_tokens + self.search_results_collector = SearchResultsCollector() + super().__init__(*args, **kwargs) + + def _register_tools(self): + """ + Called by the base R2RAgent to register all requested tools from self.config.rag_tools. + """ + if not self.config.rag_tools: + return + + for tool_name in set(self.config.rag_tools): + if tool_name == "get_file_content": + self._tools.append(self.content()) + elif tool_name == "web_scrape": + self._tools.append(self.web_scrape()) + elif tool_name == "search_file_knowledge": + self._tools.append(self.search_file_knowledge()) + elif tool_name == "search_file_descriptions": + self._tools.append(self.search_files()) + elif tool_name == "web_search": + self._tools.append(self.web_search()) + else: + raise ValueError(f"Unknown tool requested: {tool_name}") + logger.debug(f"Registered {len(self._tools)} RAG tools.") + + # Local Search Tool + def search_file_knowledge(self) -> Tool: + """ + Tool to do a semantic/hybrid search on the local knowledge base + using self.knowledge_search_method. + """ + return Tool( + name="search_file_knowledge", + description=( + "Search your local knowledge base using the R2R system. " + "Use this when you want relevant text chunks or knowledge graph data." + ), + results_function=self._file_knowledge_search_function, + llm_format_function=self.format_search_results_for_llm, + parameters={ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "User query to search in the local DB.", + }, + }, + "required": ["query"], + }, + ) + + async def _file_knowledge_search_function( + self, + query: str, + *args, + **kwargs, + ) -> AggregateSearchResult: + """ + Calls the passed-in `knowledge_search_method(query, search_settings)`. + Expects either an AggregateSearchResult or a dict with chunk_search_results, etc. + """ + if not self.knowledge_search_method: + raise ValueError( + "No knowledge_search_method provided to RAGAgentMixin." + ) + + raw_response = await self.knowledge_search_method( + query=query, search_settings=self.search_settings + ) + + if isinstance(raw_response, AggregateSearchResult): + agg = raw_response + else: + agg = AggregateSearchResult( + chunk_search_results=raw_response.get( + "chunk_search_results", [] + ), + graph_search_results=raw_response.get( + "graph_search_results", [] + ), + ) + + # 1) Store them so that we can do final citations later + self.search_results_collector.add_aggregate_result(agg) + return agg + + # 2) Local Context + def content(self) -> Tool: + """Tool to fetch entire documents from the local database. + + Typically used if the agent needs deeper or more structured context + from documents, not just chunk-level hits. + """ + if "gemini" in self.rag_generation_config.model: + tool = Tool( + name="get_file_content", + description=( + "Fetches the complete contents of all user documents from the local database. " + "Can be used alongside filter criteria (e.g. doc IDs, collection IDs, etc.) to restrict the query." + "For instance, a single document can be returned with a filter like so:" + "{'document_id': {'$eq': '...'}}." + "Be sure to use the full 32 character hexidecimal document ID, and not the shortened 8 character ID." + ), + results_function=self._content_function, + llm_format_function=self.format_search_results_for_llm, + parameters={ + "type": "object", + "properties": { + "filters": { + "type": "string", + "description": ( + "Dictionary with filter criteria, such as " + '{"$and": [{"document_id": {"$eq": "6c9d1c39..."}, {"collection_ids": {"$overlap": [...]}]}' + ), + }, + }, + "required": ["filters"], + }, + ) + + else: + tool = Tool( + name="get_file_content", + description=( + "Fetches the complete contents of all user documents from the local database. " + "Can be used alongside filter criteria (e.g. doc IDs, collection IDs, etc.) to restrict the query." + "For instance, a single document can be returned with a filter like so:" + "{'document_id': {'$eq': '...'}}." + ), + results_function=self._content_function, + llm_format_function=self.format_search_results_for_llm, + parameters={ + "type": "object", + "properties": { + "filters": { + "type": "object", + "description": ( + "Dictionary with filter criteria, such as " + '{"$and": [{"document_id": {"$eq": "6c9d1c39..."}, {"collection_ids": {"$overlap": [...]}]}' + ), + }, + }, + "required": ["filters"], + }, + ) + return tool + + async def _content_function( + self, + filters: Optional[dict[str, Any]] = None, + options: Optional[dict[str, Any]] = None, + *args, + **kwargs, + ) -> AggregateSearchResult: + """Calls the passed-in `content_method(filters, options)` to fetch + doc+chunk structures. + + Typically returns a list of dicts: + [ + { 'document': {...}, 'chunks': [ {...}, {...}, ... ] }, + ... + ] + We'll store these in a new field `document_search_results` of + AggregateSearchResult so we don't collide with chunk_search_results. + """ + if not self.content_method: + raise ValueError("No content_method provided to RAGAgentMixin.") + + if filters: + if "document_id" in filters: + filters["id"] = filters.pop("document_id") + if self.search_settings.filters != {}: + filters = {"$and": [filters, self.search_settings.filters]} + else: + filters = self.search_settings.filters + + options = options or {} + + # Actually call your data retrieval + content = await self.content_method(filters, options) + # raw_context presumably is a list[dict], each with 'document' + 'chunks'. + + # Return them in the new aggregator field + agg = AggregateSearchResult( + # We won't put them in chunk_search_results: + chunk_search_results=None, + graph_search_results=None, + web_search_results=None, + document_search_results=content, + ) + self.search_results_collector.add_aggregate_result(agg) + return agg + + # Web Search Tool + def web_search(self) -> Tool: + return Tool( + name="web_search", + description=( + "Search for information on the web - use this tool when the user " + "query needs LIVE or recent data from the internet." + ), + results_function=self._web_search_function, + llm_format_function=self.format_search_results_for_llm, + parameters={ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The query to search with an external web API.", + }, + }, + "required": ["query"], + }, + ) + + async def _web_search_function( + self, + query: str, + *args, + **kwargs, + ) -> AggregateSearchResult: + """ + Calls an external search engine (Serper, Google, etc.) asynchronously + and returns results in an AggregateSearchResult. + """ + import asyncio + + from ..utils.serper import SerperClient # adjust your import + + serper_client = SerperClient() + + # If SerperClient.get_raw is not already async, wrap it in run_in_executor + raw_results = await asyncio.get_event_loop().run_in_executor( + None, # Uses the default executor + lambda: serper_client.get_raw(query), + ) + + # If from_serper_results is not already async, wrap it in run_in_executor too + web_response = await asyncio.get_event_loop().run_in_executor( + None, lambda: WebSearchResult.from_serper_results(raw_results) + ) + + agg = AggregateSearchResult( + chunk_search_results=None, + graph_search_results=None, + web_search_results=web_response.organic_results, + ) + self.search_results_collector.add_aggregate_result(agg) + return agg + + def search_files(self) -> Tool: + """ + A tool to search over file-level metadata (titles, doc-level descriptions, etc.) + returning a list of DocumentResponse objects. + """ + return Tool( + name="search_file_descriptions", + description=( + "Semantic search over the stored documents over AI generated summaries of input documents. " + "This does NOT retrieve chunk-level contents or knowledge-graph relationships. " + "Use this when you need a broad overview of which documents (files) might be relevant." + ), + results_function=self._search_files_function, + llm_format_function=self.format_search_results_for_llm, + parameters={ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Query string to semantic search over available files 'list documents about XYZ'.", + } + }, + "required": ["query"], + }, + ) + + async def _search_files_function( + self, query: str, *args, **kwargs + ) -> AggregateSearchResult: + if not self.file_search_method: + raise ValueError( + "No file_search_method provided to RAGAgentMixin." + ) + + # call the doc-level search + """ + FIXME: This is going to fail, as it requires an embedding NOT a query. + I've moved 'search_settings' to 'settings' which had been causing a silent failure + causing null content in the Message object. + """ + doc_results = await self.file_search_method( + query=query, + settings=self.search_settings, + ) + + # Wrap them in an AggregateSearchResult + agg = AggregateSearchResult(document_search_results=doc_results) + + # Add them to the collector + self.search_results_collector.add_aggregate_result(agg) + return agg + + def format_search_results_for_llm( + self, results: AggregateSearchResult + ) -> str: + context = format_search_results_for_llm( + results, self.search_results_collector + ) + context_tokens = num_tokens(context) + 1 + frac_to_return = self.max_tool_context_length / (context_tokens) + + if frac_to_return > 1: + return context + else: + return context[: int(frac_to_return * len(context))] + + def web_scrape(self) -> Tool: + """ + A new Tool that uses Firecrawl to scrape a single URL and return + its contents in an LLM-friendly format (e.g. markdown). + """ + return Tool( + name="web_scrape", + description=( + "Use Firecrawl to scrape a single webpage and retrieve its contents " + "as clean markdown. Useful when you need the entire body of a page, " + "not just a quick snippet or standard web search result." + ), + results_function=self._web_scrape_function, + llm_format_function=self.format_search_results_for_llm, + parameters={ + "type": "object", + "properties": { + "url": { + "type": "string", + "description": ( + "The absolute URL of the webpage you want to scrape. " + "Example: 'https://docs.firecrawl.dev/getting-started'" + ), + } + }, + "required": ["url"], + }, + ) + + async def _web_scrape_function( + self, + url: str, + *args, + **kwargs, + ) -> AggregateSearchResult: + """ + Performs the Firecrawl scrape asynchronously, returning results + as an `AggregateSearchResult` with a single WebPageSearchResult. + """ + import asyncio + + from firecrawl import FirecrawlApp + + app = FirecrawlApp() + logger.debug(f"[Firecrawl] Scraping URL={url}") + + # Create a proper async wrapper for the synchronous scrape_url method + # This offloads the blocking operation to a thread pool + response = await asyncio.get_event_loop().run_in_executor( + None, # Uses the default executor + lambda: app.scrape_url( + url=url, + params={"formats": ["markdown"]}, + ), + ) + + markdown_text = response.get("markdown", "") + metadata = response.get("metadata", {}) + page_title = metadata.get("title", "Untitled page") + + if len(markdown_text) > 100_000: + markdown_text = ( + markdown_text[:100_000] + "...FURTHER CONTENT TRUNCATED..." + ) + + # Create a single WebPageSearchResult HACK - TODO FIX + web_result = WebPageSearchResult( + title=page_title, + link=url, + snippet=markdown_text, + position=0, + id=generate_id(markdown_text), + type="firecrawl", + ) + + agg = AggregateSearchResult(web_search_results=[web_result]) + + # Add results to the collector + if self.search_results_collector: + self.search_results_collector.add_aggregate_result(agg) + + return agg + + +class R2RRAGAgent(RAGAgentMixin, R2RAgent): + """ + Non-streaming RAG Agent that supports search_file_knowledge, content, web_search. + """ + + def __init__( + self, + database_provider: DatabaseProvider, + llm_provider: ( + AnthropicCompletionProvider + | LiteLLMCompletionProvider + | OpenAICompletionProvider + | R2RCompletionProvider + ), + config: RAGAgentConfig, + search_settings: SearchSettings, + rag_generation_config: GenerationConfig, + knowledge_search_method: Callable, + content_method: Callable, + file_search_method: Callable, + max_tool_context_length: int = 20_000, + ): + # Initialize base R2RAgent + R2RAgent.__init__( + self, + database_provider=database_provider, + llm_provider=llm_provider, + config=config, + rag_generation_config=rag_generation_config, + ) + # Initialize the RAGAgentMixin + RAGAgentMixin.__init__( + self, + database_provider=database_provider, + llm_provider=llm_provider, + config=config, + search_settings=search_settings, + rag_generation_config=rag_generation_config, + max_tool_context_length=max_tool_context_length, + knowledge_search_method=knowledge_search_method, + file_search_method=file_search_method, + content_method=content_method, + ) + + +class R2RXMLToolsRAGAgent(RAGAgentMixin, R2RXMLToolsAgent): + """ + Non-streaming RAG Agent that supports search_file_knowledge, content, web_search. + """ + + def __init__( + self, + database_provider: DatabaseProvider, + llm_provider: ( + AnthropicCompletionProvider + | LiteLLMCompletionProvider + | OpenAICompletionProvider + | R2RCompletionProvider + ), + config: RAGAgentConfig, + search_settings: SearchSettings, + rag_generation_config: GenerationConfig, + knowledge_search_method: Callable, + content_method: Callable, + file_search_method: Callable, + max_tool_context_length: int = 20_000, + ): + # Initialize base R2RAgent + R2RXMLToolsAgent.__init__( + self, + database_provider=database_provider, + llm_provider=llm_provider, + config=config, + rag_generation_config=rag_generation_config, + ) + # Initialize the RAGAgentMixin + RAGAgentMixin.__init__( + self, + database_provider=database_provider, + llm_provider=llm_provider, + config=config, + search_settings=search_settings, + rag_generation_config=rag_generation_config, + max_tool_context_length=max_tool_context_length, + knowledge_search_method=knowledge_search_method, + file_search_method=file_search_method, + content_method=content_method, + ) + + +class R2RStreamingRAGAgent(RAGAgentMixin, R2RStreamingAgent): + """ + Streaming-capable RAG Agent that supports search_file_knowledge, content, web_search, + and emits citations as [abc1234] short IDs if the LLM includes them in brackets. + """ + + def __init__( + self, + database_provider: DatabaseProvider, + llm_provider: ( + AnthropicCompletionProvider + | LiteLLMCompletionProvider + | OpenAICompletionProvider + | R2RCompletionProvider + ), + config: RAGAgentConfig, + search_settings: SearchSettings, + rag_generation_config: GenerationConfig, + knowledge_search_method: Callable, + content_method: Callable, + file_search_method: Callable, + max_tool_context_length: int = 10_000, + ): + # Force streaming on + config.stream = True + + # Initialize base R2RStreamingAgent + R2RStreamingAgent.__init__( + self, + database_provider=database_provider, + llm_provider=llm_provider, + config=config, + rag_generation_config=rag_generation_config, + ) + + # Initialize the RAGAgentMixin + RAGAgentMixin.__init__( + self, + database_provider=database_provider, + llm_provider=llm_provider, + config=config, + search_settings=search_settings, + rag_generation_config=rag_generation_config, + max_tool_context_length=max_tool_context_length, + knowledge_search_method=knowledge_search_method, + content_method=content_method, + file_search_method=file_search_method, + ) + + +class R2RXMLToolsStreamingRAGAgent(RAGAgentMixin, R2RXMLStreamingAgent): + """ + A streaming agent that: + - treats <think> or <Thought> blocks as chain-of-thought + and emits them incrementally as SSE "thinking" events. + - accumulates user-visible text outside those tags as SSE "message" events. + - filters out all XML tags related to tool calls and actions. + - upon finishing each iteration, it parses <Action><ToolCalls><ToolCall> blocks, + calls the appropriate tool, and emits SSE "tool_call" / "tool_result". + - properly emits citations when they appear in the text + """ + + def __init__( + self, + database_provider: DatabaseProvider, + llm_provider: ( + AnthropicCompletionProvider + | LiteLLMCompletionProvider + | OpenAICompletionProvider + | R2RCompletionProvider + ), + config: RAGAgentConfig, + search_settings: SearchSettings, + rag_generation_config: GenerationConfig, + knowledge_search_method: Callable, + content_method: Callable, + file_search_method: Callable, + max_tool_context_length: int = 10_000, + ): + # Force streaming on + config.stream = True + + # Initialize base R2RXMLStreamingAgent + R2RXMLStreamingAgent.__init__( + self, + database_provider=database_provider, + llm_provider=llm_provider, + config=config, + rag_generation_config=rag_generation_config, + ) + + # Initialize the RAGAgentMixin + RAGAgentMixin.__init__( + self, + database_provider=database_provider, + llm_provider=llm_provider, + config=config, + search_settings=search_settings, + rag_generation_config=rag_generation_config, + max_tool_context_length=max_tool_context_length, + knowledge_search_method=knowledge_search_method, + content_method=content_method, + file_search_method=file_search_method, + ) |