diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /R2R/r2r/main/execution.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-4a52a71956a8d46fcb7294ac71734504bb09bcc2.tar.gz |
Diffstat (limited to 'R2R/r2r/main/execution.py')
-rwxr-xr-x | R2R/r2r/main/execution.py | 421 |
1 files changed, 421 insertions, 0 deletions
diff --git a/R2R/r2r/main/execution.py b/R2R/r2r/main/execution.py new file mode 100755 index 00000000..187a2eea --- /dev/null +++ b/R2R/r2r/main/execution.py @@ -0,0 +1,421 @@ +import ast +import asyncio +import json +import os +import uuid +from typing import Optional, Union + +from fastapi import UploadFile + +from r2r.base import ( + AnalysisTypes, + FilterCriteria, + GenerationConfig, + KGSearchSettings, + VectorSearchSettings, + generate_id_from_label, +) + +from .api.client import R2RClient +from .assembly.builder import R2RBuilder +from .assembly.config import R2RConfig +from .r2r import R2R + + +class R2RExecutionWrapper: + """A demo class for the R2R library.""" + + def __init__( + self, + config_path: Optional[str] = None, + config_name: Optional[str] = "default", + client_mode: bool = True, + base_url="http://localhost:8000", + ): + if config_path and config_name: + raise Exception("Cannot specify both config_path and config_name") + + # Handle fire CLI + if isinstance(client_mode, str): + client_mode = client_mode.lower() == "true" + self.client_mode = client_mode + self.base_url = base_url + + if self.client_mode: + self.client = R2RClient(base_url) + self.app = None + else: + config = ( + R2RConfig.from_json(config_path) + if config_path + else R2RConfig.from_json( + R2RBuilder.CONFIG_OPTIONS[config_name or "default"] + ) + ) + + self.client = None + self.app = R2R(config=config) + + def serve(self, host: str = "0.0.0.0", port: int = 8000): + if not self.client_mode: + self.app.serve(host, port) + else: + raise ValueError( + "Serve method is only available when `client_mode=False`." + ) + + def _parse_metadata_string(metadata_string: str) -> list[dict]: + """ + Convert a string representation of metadata into a list of dictionaries. + + The input string can be in one of two formats: + 1. JSON array of objects: '[{"key": "value"}, {"key2": "value2"}]' + 2. Python-like list of dictionaries: "[{'key': 'value'}, {'key2': 'value2'}]" + + Args: + metadata_string (str): The string representation of metadata. + + Returns: + list[dict]: A list of dictionaries representing the metadata. + + Raises: + ValueError: If the string cannot be parsed into a list of dictionaries. + """ + if not metadata_string: + return [] + + try: + # First, try to parse as JSON + return json.loads(metadata_string) + except json.JSONDecodeError as e: + try: + # If JSON parsing fails, try to evaluate as a Python literal + result = ast.literal_eval(metadata_string) + if not isinstance(result, list) or not all( + isinstance(item, dict) for item in result + ): + raise ValueError( + "The string does not represent a list of dictionaries" + ) from e + return result + except (ValueError, SyntaxError) as exc: + raise ValueError( + "Unable to parse the metadata string. " + "Please ensure it's a valid JSON array or Python list of dictionaries." + ) from exc + + def ingest_files( + self, + file_paths: list[str], + metadatas: Optional[list[dict]] = None, + document_ids: Optional[list[Union[uuid.UUID, str]]] = None, + versions: Optional[list[str]] = None, + ): + if isinstance(file_paths, str): + file_paths = list(file_paths.split(",")) + if isinstance(metadatas, str): + metadatas = self._parse_metadata_string(metadatas) + if isinstance(document_ids, str): + document_ids = list(document_ids.split(",")) + if isinstance(versions, str): + versions = list(versions.split(",")) + + all_file_paths = [] + for path in file_paths: + if os.path.isdir(path): + for root, _, files in os.walk(path): + all_file_paths.extend( + os.path.join(root, file) for file in files + ) + else: + all_file_paths.append(path) + + if not document_ids: + document_ids = [ + generate_id_from_label(os.path.basename(file_path)) + for file_path in all_file_paths + ] + + files = [ + UploadFile( + filename=os.path.basename(file_path), + file=open(file_path, "rb"), + ) + for file_path in all_file_paths + ] + + for file in files: + file.file.seek(0, 2) + file.size = file.file.tell() + file.file.seek(0) + + try: + if self.client_mode: + return self.client.ingest_files( + file_paths=all_file_paths, + document_ids=document_ids, + metadatas=metadatas, + versions=versions, + monitor=True, + )["results"] + else: + return self.app.ingest_files( + files=files, + document_ids=document_ids, + metadatas=metadatas, + versions=versions, + ) + finally: + for file in files: + file.file.close() + + def update_files( + self, + file_paths: list[str], + document_ids: list[str], + metadatas: Optional[list[dict]] = None, + ): + if isinstance(file_paths, str): + file_paths = list(file_paths.split(",")) + if isinstance(metadatas, str): + metadatas = self._parse_metadata_string(metadatas) + if isinstance(document_ids, str): + document_ids = list(document_ids.split(",")) + + if self.client_mode: + return self.client.update_files( + file_paths=file_paths, + document_ids=document_ids, + metadatas=metadatas, + monitor=True, + )["results"] + else: + files = [ + UploadFile( + filename=file_path, + file=open(file_path, "rb"), + ) + for file_path in file_paths + ] + return self.app.update_files( + files=files, document_ids=document_ids, metadatas=metadatas + ) + + def search( + self, + query: str, + use_vector_search: bool = True, + search_filters: Optional[dict] = None, + search_limit: int = 10, + do_hybrid_search: bool = False, + use_kg_search: bool = False, + kg_agent_generation_config: Optional[dict] = None, + ): + if self.client_mode: + return self.client.search( + query, + use_vector_search, + search_filters, + search_limit, + do_hybrid_search, + use_kg_search, + kg_agent_generation_config, + )["results"] + else: + return self.app.search( + query, + VectorSearchSettings( + use_vector_search=use_vector_search, + search_filters=search_filters or {}, + search_limit=search_limit, + do_hybrid_search=do_hybrid_search, + ), + KGSearchSettings( + use_kg_search=use_kg_search, + agent_generation_config=GenerationConfig( + **(kg_agent_generation_config or {}) + ), + ), + ) + + def rag( + self, + query: str, + use_vector_search: bool = True, + search_filters: Optional[dict] = None, + search_limit: int = 10, + do_hybrid_search: bool = False, + use_kg_search: bool = False, + kg_agent_generation_config: Optional[dict] = None, + stream: bool = False, + rag_generation_config: Optional[dict] = None, + ): + if self.client_mode: + response = self.client.rag( + query=query, + use_vector_search=use_vector_search, + search_filters=search_filters or {}, + search_limit=search_limit, + do_hybrid_search=do_hybrid_search, + use_kg_search=use_kg_search, + kg_agent_generation_config=kg_agent_generation_config, + rag_generation_config=rag_generation_config, + ) + if not stream: + response = response["results"] + return response + else: + return response + else: + response = self.app.rag( + query, + vector_search_settings=VectorSearchSettings( + use_vector_search=use_vector_search, + search_filters=search_filters or {}, + search_limit=search_limit, + do_hybrid_search=do_hybrid_search, + ), + kg_search_settings=KGSearchSettings( + use_kg_search=use_kg_search, + agent_generation_config=GenerationConfig( + **(kg_agent_generation_config or {}) + ), + ), + rag_generation_config=GenerationConfig( + **(rag_generation_config or {}) + ), + ) + if not stream: + return response + else: + + async def async_generator(): + async for chunk in response: + yield chunk + + def sync_generator(): + try: + loop = asyncio.get_event_loop() + async_gen = async_generator() + while True: + try: + yield loop.run_until_complete( + async_gen.__anext__() + ) + except StopAsyncIteration: + break + except Exception: + pass + + return sync_generator() + + def documents_overview( + self, + document_ids: Optional[list[str]] = None, + user_ids: Optional[list[str]] = None, + ): + if self.client_mode: + return self.client.documents_overview(document_ids, user_ids)[ + "results" + ] + else: + return self.app.documents_overview(document_ids, user_ids) + + def delete( + self, + keys: list[str], + values: list[str], + ): + if self.client_mode: + return self.client.delete(keys, values)["results"] + else: + return self.app.delete(keys, values) + + def logs(self, log_type_filter: Optional[str] = None): + if self.client_mode: + return self.client.logs(log_type_filter)["results"] + else: + return self.app.logs(log_type_filter) + + def document_chunks(self, document_id: str): + doc_uuid = uuid.UUID(document_id) + if self.client_mode: + return self.client.document_chunks(doc_uuid)["results"] + else: + return self.app.document_chunks(doc_uuid) + + def app_settings(self): + if self.client_mode: + return self.client.app_settings() + else: + return self.app.app_settings() + + def users_overview(self, user_ids: Optional[list[uuid.UUID]] = None): + if self.client_mode: + return self.client.users_overview(user_ids)["results"] + else: + return self.app.users_overview(user_ids) + + def analytics( + self, + filters: Optional[str] = None, + analysis_types: Optional[str] = None, + ): + filter_criteria = FilterCriteria(filters=filters) + analysis_types = AnalysisTypes(analysis_types=analysis_types) + + if self.client_mode: + return self.client.analytics( + filter_criteria=filter_criteria.model_dump(), + analysis_types=analysis_types.model_dump(), + )["results"] + else: + return self.app.analytics( + filter_criteria=filter_criteria, analysis_types=analysis_types + ) + + def ingest_sample_file(self, no_media: bool = True, option: int = 0): + from r2r.examples.scripts.sample_data_ingestor import ( + SampleDataIngestor, + ) + + """Ingest the first sample file into R2R.""" + sample_ingestor = SampleDataIngestor(self) + return sample_ingestor.ingest_sample_file( + no_media=no_media, option=option + ) + + def ingest_sample_files(self, no_media: bool = True): + from r2r.examples.scripts.sample_data_ingestor import ( + SampleDataIngestor, + ) + + """Ingest the first sample file into R2R.""" + sample_ingestor = SampleDataIngestor(self) + return sample_ingestor.ingest_sample_files(no_media=no_media) + + def inspect_knowledge_graph(self, limit: int = 100) -> str: + if self.client_mode: + return self.client.inspect_knowledge_graph(limit)["results"] + else: + return self.engine.inspect_knowledge_graph(limit) + + def health(self) -> str: + if self.client_mode: + return self.client.health() + else: + pass + + def get_app(self): + if not self.client_mode: + return self.app.app.app + else: + raise Exception( + "`get_app` method is only available when running with `client_mode=False`." + ) + + +if __name__ == "__main__": + import fire + + fire.Fire(R2RExecutionWrapper) |