import ast import asyncio import json import os import uuid from typing import Optional, Union from fastapi import UploadFile from r2r.base import ( AnalysisTypes, FilterCriteria, GenerationConfig, KGSearchSettings, VectorSearchSettings, generate_id_from_label, ) from .api.client import R2RClient from .assembly.builder import R2RBuilder from .assembly.config import R2RConfig from .r2r import R2R class R2RExecutionWrapper: """A demo class for the R2R library.""" def __init__( self, config_path: Optional[str] = None, config_name: Optional[str] = "default", client_mode: bool = True, base_url="http://localhost:8000", ): if config_path and config_name: raise Exception("Cannot specify both config_path and config_name") # Handle fire CLI if isinstance(client_mode, str): client_mode = client_mode.lower() == "true" self.client_mode = client_mode self.base_url = base_url if self.client_mode: self.client = R2RClient(base_url) self.app = None else: config = ( R2RConfig.from_json(config_path) if config_path else R2RConfig.from_json( R2RBuilder.CONFIG_OPTIONS[config_name or "default"] ) ) self.client = None self.app = R2R(config=config) def serve(self, host: str = "0.0.0.0", port: int = 8000): if not self.client_mode: self.app.serve(host, port) else: raise ValueError( "Serve method is only available when `client_mode=False`." ) def _parse_metadata_string(metadata_string: str) -> list[dict]: """ Convert a string representation of metadata into a list of dictionaries. The input string can be in one of two formats: 1. JSON array of objects: '[{"key": "value"}, {"key2": "value2"}]' 2. Python-like list of dictionaries: "[{'key': 'value'}, {'key2': 'value2'}]" Args: metadata_string (str): The string representation of metadata. Returns: list[dict]: A list of dictionaries representing the metadata. Raises: ValueError: If the string cannot be parsed into a list of dictionaries. """ if not metadata_string: return [] try: # First, try to parse as JSON return json.loads(metadata_string) except json.JSONDecodeError as e: try: # If JSON parsing fails, try to evaluate as a Python literal result = ast.literal_eval(metadata_string) if not isinstance(result, list) or not all( isinstance(item, dict) for item in result ): raise ValueError( "The string does not represent a list of dictionaries" ) from e return result except (ValueError, SyntaxError) as exc: raise ValueError( "Unable to parse the metadata string. " "Please ensure it's a valid JSON array or Python list of dictionaries." ) from exc def ingest_files( self, file_paths: list[str], metadatas: Optional[list[dict]] = None, document_ids: Optional[list[Union[uuid.UUID, str]]] = None, versions: Optional[list[str]] = None, ): if isinstance(file_paths, str): file_paths = list(file_paths.split(",")) if isinstance(metadatas, str): metadatas = self._parse_metadata_string(metadatas) if isinstance(document_ids, str): document_ids = list(document_ids.split(",")) if isinstance(versions, str): versions = list(versions.split(",")) all_file_paths = [] for path in file_paths: if os.path.isdir(path): for root, _, files in os.walk(path): all_file_paths.extend( os.path.join(root, file) for file in files ) else: all_file_paths.append(path) if not document_ids: document_ids = [ generate_id_from_label(os.path.basename(file_path)) for file_path in all_file_paths ] files = [ UploadFile( filename=os.path.basename(file_path), file=open(file_path, "rb"), ) for file_path in all_file_paths ] for file in files: file.file.seek(0, 2) file.size = file.file.tell() file.file.seek(0) try: if self.client_mode: return self.client.ingest_files( file_paths=all_file_paths, document_ids=document_ids, metadatas=metadatas, versions=versions, monitor=True, )["results"] else: return self.app.ingest_files( files=files, document_ids=document_ids, metadatas=metadatas, versions=versions, ) finally: for file in files: file.file.close() def update_files( self, file_paths: list[str], document_ids: list[str], metadatas: Optional[list[dict]] = None, ): if isinstance(file_paths, str): file_paths = list(file_paths.split(",")) if isinstance(metadatas, str): metadatas = self._parse_metadata_string(metadatas) if isinstance(document_ids, str): document_ids = list(document_ids.split(",")) if self.client_mode: return self.client.update_files( file_paths=file_paths, document_ids=document_ids, metadatas=metadatas, monitor=True, )["results"] else: files = [ UploadFile( filename=file_path, file=open(file_path, "rb"), ) for file_path in file_paths ] return self.app.update_files( files=files, document_ids=document_ids, metadatas=metadatas ) def search( self, query: str, use_vector_search: bool = True, search_filters: Optional[dict] = None, search_limit: int = 10, do_hybrid_search: bool = False, use_kg_search: bool = False, kg_agent_generation_config: Optional[dict] = None, ): if self.client_mode: return self.client.search( query, use_vector_search, search_filters, search_limit, do_hybrid_search, use_kg_search, kg_agent_generation_config, )["results"] else: return self.app.search( query, VectorSearchSettings( use_vector_search=use_vector_search, search_filters=search_filters or {}, search_limit=search_limit, do_hybrid_search=do_hybrid_search, ), KGSearchSettings( use_kg_search=use_kg_search, agent_generation_config=GenerationConfig( **(kg_agent_generation_config or {}) ), ), ) def rag( self, query: str, use_vector_search: bool = True, search_filters: Optional[dict] = None, search_limit: int = 10, do_hybrid_search: bool = False, use_kg_search: bool = False, kg_agent_generation_config: Optional[dict] = None, stream: bool = False, rag_generation_config: Optional[dict] = None, ): if self.client_mode: response = self.client.rag( query=query, use_vector_search=use_vector_search, search_filters=search_filters or {}, search_limit=search_limit, do_hybrid_search=do_hybrid_search, use_kg_search=use_kg_search, kg_agent_generation_config=kg_agent_generation_config, rag_generation_config=rag_generation_config, ) if not stream: response = response["results"] return response else: return response else: response = self.app.rag( query, vector_search_settings=VectorSearchSettings( use_vector_search=use_vector_search, search_filters=search_filters or {}, search_limit=search_limit, do_hybrid_search=do_hybrid_search, ), kg_search_settings=KGSearchSettings( use_kg_search=use_kg_search, agent_generation_config=GenerationConfig( **(kg_agent_generation_config or {}) ), ), rag_generation_config=GenerationConfig( **(rag_generation_config or {}) ), ) if not stream: return response else: async def async_generator(): async for chunk in response: yield chunk def sync_generator(): try: loop = asyncio.get_event_loop() async_gen = async_generator() while True: try: yield loop.run_until_complete( async_gen.__anext__() ) except StopAsyncIteration: break except Exception: pass return sync_generator() def documents_overview( self, document_ids: Optional[list[str]] = None, user_ids: Optional[list[str]] = None, ): if self.client_mode: return self.client.documents_overview(document_ids, user_ids)[ "results" ] else: return self.app.documents_overview(document_ids, user_ids) def delete( self, keys: list[str], values: list[str], ): if self.client_mode: return self.client.delete(keys, values)["results"] else: return self.app.delete(keys, values) def logs(self, log_type_filter: Optional[str] = None): if self.client_mode: return self.client.logs(log_type_filter)["results"] else: return self.app.logs(log_type_filter) def document_chunks(self, document_id: str): doc_uuid = uuid.UUID(document_id) if self.client_mode: return self.client.document_chunks(doc_uuid)["results"] else: return self.app.document_chunks(doc_uuid) def app_settings(self): if self.client_mode: return self.client.app_settings() else: return self.app.app_settings() def users_overview(self, user_ids: Optional[list[uuid.UUID]] = None): if self.client_mode: return self.client.users_overview(user_ids)["results"] else: return self.app.users_overview(user_ids) def analytics( self, filters: Optional[str] = None, analysis_types: Optional[str] = None, ): filter_criteria = FilterCriteria(filters=filters) analysis_types = AnalysisTypes(analysis_types=analysis_types) if self.client_mode: return self.client.analytics( filter_criteria=filter_criteria.model_dump(), analysis_types=analysis_types.model_dump(), )["results"] else: return self.app.analytics( filter_criteria=filter_criteria, analysis_types=analysis_types ) def ingest_sample_file(self, no_media: bool = True, option: int = 0): from r2r.examples.scripts.sample_data_ingestor import ( SampleDataIngestor, ) """Ingest the first sample file into R2R.""" sample_ingestor = SampleDataIngestor(self) return sample_ingestor.ingest_sample_file( no_media=no_media, option=option ) def ingest_sample_files(self, no_media: bool = True): from r2r.examples.scripts.sample_data_ingestor import ( SampleDataIngestor, ) """Ingest the first sample file into R2R.""" sample_ingestor = SampleDataIngestor(self) return sample_ingestor.ingest_sample_files(no_media=no_media) def inspect_knowledge_graph(self, limit: int = 100) -> str: if self.client_mode: return self.client.inspect_knowledge_graph(limit)["results"] else: return self.engine.inspect_knowledge_graph(limit) def health(self) -> str: if self.client_mode: return self.client.health() else: pass def get_app(self): if not self.client_mode: return self.app.app.app else: raise Exception( "`get_app` method is only available when running with `client_mode=False`." ) if __name__ == "__main__": import fire fire.Fire(R2RExecutionWrapper)