aboutsummaryrefslogtreecommitdiff
path: root/R2R/r2r/main/execution.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /R2R/r2r/main/execution.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-4a52a71956a8d46fcb7294ac71734504bb09bcc2.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to 'R2R/r2r/main/execution.py')
-rwxr-xr-xR2R/r2r/main/execution.py421
1 files changed, 421 insertions, 0 deletions
diff --git a/R2R/r2r/main/execution.py b/R2R/r2r/main/execution.py
new file mode 100755
index 00000000..187a2eea
--- /dev/null
+++ b/R2R/r2r/main/execution.py
@@ -0,0 +1,421 @@
+import ast
+import asyncio
+import json
+import os
+import uuid
+from typing import Optional, Union
+
+from fastapi import UploadFile
+
+from r2r.base import (
+ AnalysisTypes,
+ FilterCriteria,
+ GenerationConfig,
+ KGSearchSettings,
+ VectorSearchSettings,
+ generate_id_from_label,
+)
+
+from .api.client import R2RClient
+from .assembly.builder import R2RBuilder
+from .assembly.config import R2RConfig
+from .r2r import R2R
+
+
+class R2RExecutionWrapper:
+ """A demo class for the R2R library."""
+
+ def __init__(
+ self,
+ config_path: Optional[str] = None,
+ config_name: Optional[str] = "default",
+ client_mode: bool = True,
+ base_url="http://localhost:8000",
+ ):
+ if config_path and config_name:
+ raise Exception("Cannot specify both config_path and config_name")
+
+ # Handle fire CLI
+ if isinstance(client_mode, str):
+ client_mode = client_mode.lower() == "true"
+ self.client_mode = client_mode
+ self.base_url = base_url
+
+ if self.client_mode:
+ self.client = R2RClient(base_url)
+ self.app = None
+ else:
+ config = (
+ R2RConfig.from_json(config_path)
+ if config_path
+ else R2RConfig.from_json(
+ R2RBuilder.CONFIG_OPTIONS[config_name or "default"]
+ )
+ )
+
+ self.client = None
+ self.app = R2R(config=config)
+
+ def serve(self, host: str = "0.0.0.0", port: int = 8000):
+ if not self.client_mode:
+ self.app.serve(host, port)
+ else:
+ raise ValueError(
+ "Serve method is only available when `client_mode=False`."
+ )
+
+ def _parse_metadata_string(metadata_string: str) -> list[dict]:
+ """
+ Convert a string representation of metadata into a list of dictionaries.
+
+ The input string can be in one of two formats:
+ 1. JSON array of objects: '[{"key": "value"}, {"key2": "value2"}]'
+ 2. Python-like list of dictionaries: "[{'key': 'value'}, {'key2': 'value2'}]"
+
+ Args:
+ metadata_string (str): The string representation of metadata.
+
+ Returns:
+ list[dict]: A list of dictionaries representing the metadata.
+
+ Raises:
+ ValueError: If the string cannot be parsed into a list of dictionaries.
+ """
+ if not metadata_string:
+ return []
+
+ try:
+ # First, try to parse as JSON
+ return json.loads(metadata_string)
+ except json.JSONDecodeError as e:
+ try:
+ # If JSON parsing fails, try to evaluate as a Python literal
+ result = ast.literal_eval(metadata_string)
+ if not isinstance(result, list) or not all(
+ isinstance(item, dict) for item in result
+ ):
+ raise ValueError(
+ "The string does not represent a list of dictionaries"
+ ) from e
+ return result
+ except (ValueError, SyntaxError) as exc:
+ raise ValueError(
+ "Unable to parse the metadata string. "
+ "Please ensure it's a valid JSON array or Python list of dictionaries."
+ ) from exc
+
+ def ingest_files(
+ self,
+ file_paths: list[str],
+ metadatas: Optional[list[dict]] = None,
+ document_ids: Optional[list[Union[uuid.UUID, str]]] = None,
+ versions: Optional[list[str]] = None,
+ ):
+ if isinstance(file_paths, str):
+ file_paths = list(file_paths.split(","))
+ if isinstance(metadatas, str):
+ metadatas = self._parse_metadata_string(metadatas)
+ if isinstance(document_ids, str):
+ document_ids = list(document_ids.split(","))
+ if isinstance(versions, str):
+ versions = list(versions.split(","))
+
+ all_file_paths = []
+ for path in file_paths:
+ if os.path.isdir(path):
+ for root, _, files in os.walk(path):
+ all_file_paths.extend(
+ os.path.join(root, file) for file in files
+ )
+ else:
+ all_file_paths.append(path)
+
+ if not document_ids:
+ document_ids = [
+ generate_id_from_label(os.path.basename(file_path))
+ for file_path in all_file_paths
+ ]
+
+ files = [
+ UploadFile(
+ filename=os.path.basename(file_path),
+ file=open(file_path, "rb"),
+ )
+ for file_path in all_file_paths
+ ]
+
+ for file in files:
+ file.file.seek(0, 2)
+ file.size = file.file.tell()
+ file.file.seek(0)
+
+ try:
+ if self.client_mode:
+ return self.client.ingest_files(
+ file_paths=all_file_paths,
+ document_ids=document_ids,
+ metadatas=metadatas,
+ versions=versions,
+ monitor=True,
+ )["results"]
+ else:
+ return self.app.ingest_files(
+ files=files,
+ document_ids=document_ids,
+ metadatas=metadatas,
+ versions=versions,
+ )
+ finally:
+ for file in files:
+ file.file.close()
+
+ def update_files(
+ self,
+ file_paths: list[str],
+ document_ids: list[str],
+ metadatas: Optional[list[dict]] = None,
+ ):
+ if isinstance(file_paths, str):
+ file_paths = list(file_paths.split(","))
+ if isinstance(metadatas, str):
+ metadatas = self._parse_metadata_string(metadatas)
+ if isinstance(document_ids, str):
+ document_ids = list(document_ids.split(","))
+
+ if self.client_mode:
+ return self.client.update_files(
+ file_paths=file_paths,
+ document_ids=document_ids,
+ metadatas=metadatas,
+ monitor=True,
+ )["results"]
+ else:
+ files = [
+ UploadFile(
+ filename=file_path,
+ file=open(file_path, "rb"),
+ )
+ for file_path in file_paths
+ ]
+ return self.app.update_files(
+ files=files, document_ids=document_ids, metadatas=metadatas
+ )
+
+ def search(
+ self,
+ query: str,
+ use_vector_search: bool = True,
+ search_filters: Optional[dict] = None,
+ search_limit: int = 10,
+ do_hybrid_search: bool = False,
+ use_kg_search: bool = False,
+ kg_agent_generation_config: Optional[dict] = None,
+ ):
+ if self.client_mode:
+ return self.client.search(
+ query,
+ use_vector_search,
+ search_filters,
+ search_limit,
+ do_hybrid_search,
+ use_kg_search,
+ kg_agent_generation_config,
+ )["results"]
+ else:
+ return self.app.search(
+ query,
+ VectorSearchSettings(
+ use_vector_search=use_vector_search,
+ search_filters=search_filters or {},
+ search_limit=search_limit,
+ do_hybrid_search=do_hybrid_search,
+ ),
+ KGSearchSettings(
+ use_kg_search=use_kg_search,
+ agent_generation_config=GenerationConfig(
+ **(kg_agent_generation_config or {})
+ ),
+ ),
+ )
+
+ def rag(
+ self,
+ query: str,
+ use_vector_search: bool = True,
+ search_filters: Optional[dict] = None,
+ search_limit: int = 10,
+ do_hybrid_search: bool = False,
+ use_kg_search: bool = False,
+ kg_agent_generation_config: Optional[dict] = None,
+ stream: bool = False,
+ rag_generation_config: Optional[dict] = None,
+ ):
+ if self.client_mode:
+ response = self.client.rag(
+ query=query,
+ use_vector_search=use_vector_search,
+ search_filters=search_filters or {},
+ search_limit=search_limit,
+ do_hybrid_search=do_hybrid_search,
+ use_kg_search=use_kg_search,
+ kg_agent_generation_config=kg_agent_generation_config,
+ rag_generation_config=rag_generation_config,
+ )
+ if not stream:
+ response = response["results"]
+ return response
+ else:
+ return response
+ else:
+ response = self.app.rag(
+ query,
+ vector_search_settings=VectorSearchSettings(
+ use_vector_search=use_vector_search,
+ search_filters=search_filters or {},
+ search_limit=search_limit,
+ do_hybrid_search=do_hybrid_search,
+ ),
+ kg_search_settings=KGSearchSettings(
+ use_kg_search=use_kg_search,
+ agent_generation_config=GenerationConfig(
+ **(kg_agent_generation_config or {})
+ ),
+ ),
+ rag_generation_config=GenerationConfig(
+ **(rag_generation_config or {})
+ ),
+ )
+ if not stream:
+ return response
+ else:
+
+ async def async_generator():
+ async for chunk in response:
+ yield chunk
+
+ def sync_generator():
+ try:
+ loop = asyncio.get_event_loop()
+ async_gen = async_generator()
+ while True:
+ try:
+ yield loop.run_until_complete(
+ async_gen.__anext__()
+ )
+ except StopAsyncIteration:
+ break
+ except Exception:
+ pass
+
+ return sync_generator()
+
+ def documents_overview(
+ self,
+ document_ids: Optional[list[str]] = None,
+ user_ids: Optional[list[str]] = None,
+ ):
+ if self.client_mode:
+ return self.client.documents_overview(document_ids, user_ids)[
+ "results"
+ ]
+ else:
+ return self.app.documents_overview(document_ids, user_ids)
+
+ def delete(
+ self,
+ keys: list[str],
+ values: list[str],
+ ):
+ if self.client_mode:
+ return self.client.delete(keys, values)["results"]
+ else:
+ return self.app.delete(keys, values)
+
+ def logs(self, log_type_filter: Optional[str] = None):
+ if self.client_mode:
+ return self.client.logs(log_type_filter)["results"]
+ else:
+ return self.app.logs(log_type_filter)
+
+ def document_chunks(self, document_id: str):
+ doc_uuid = uuid.UUID(document_id)
+ if self.client_mode:
+ return self.client.document_chunks(doc_uuid)["results"]
+ else:
+ return self.app.document_chunks(doc_uuid)
+
+ def app_settings(self):
+ if self.client_mode:
+ return self.client.app_settings()
+ else:
+ return self.app.app_settings()
+
+ def users_overview(self, user_ids: Optional[list[uuid.UUID]] = None):
+ if self.client_mode:
+ return self.client.users_overview(user_ids)["results"]
+ else:
+ return self.app.users_overview(user_ids)
+
+ def analytics(
+ self,
+ filters: Optional[str] = None,
+ analysis_types: Optional[str] = None,
+ ):
+ filter_criteria = FilterCriteria(filters=filters)
+ analysis_types = AnalysisTypes(analysis_types=analysis_types)
+
+ if self.client_mode:
+ return self.client.analytics(
+ filter_criteria=filter_criteria.model_dump(),
+ analysis_types=analysis_types.model_dump(),
+ )["results"]
+ else:
+ return self.app.analytics(
+ filter_criteria=filter_criteria, analysis_types=analysis_types
+ )
+
+ def ingest_sample_file(self, no_media: bool = True, option: int = 0):
+ from r2r.examples.scripts.sample_data_ingestor import (
+ SampleDataIngestor,
+ )
+
+ """Ingest the first sample file into R2R."""
+ sample_ingestor = SampleDataIngestor(self)
+ return sample_ingestor.ingest_sample_file(
+ no_media=no_media, option=option
+ )
+
+ def ingest_sample_files(self, no_media: bool = True):
+ from r2r.examples.scripts.sample_data_ingestor import (
+ SampleDataIngestor,
+ )
+
+ """Ingest the first sample file into R2R."""
+ sample_ingestor = SampleDataIngestor(self)
+ return sample_ingestor.ingest_sample_files(no_media=no_media)
+
+ def inspect_knowledge_graph(self, limit: int = 100) -> str:
+ if self.client_mode:
+ return self.client.inspect_knowledge_graph(limit)["results"]
+ else:
+ return self.engine.inspect_knowledge_graph(limit)
+
+ def health(self) -> str:
+ if self.client_mode:
+ return self.client.health()
+ else:
+ pass
+
+ def get_app(self):
+ if not self.client_mode:
+ return self.app.app.app
+ else:
+ raise Exception(
+ "`get_app` method is only available when running with `client_mode=False`."
+ )
+
+
+if __name__ == "__main__":
+ import fire
+
+ fire.Fire(R2RExecutionWrapper)