diff options
Diffstat (limited to 'R2R/r2r/providers')
-rwxr-xr-x | R2R/r2r/providers/__init__.py | 0 | ||||
-rwxr-xr-x | R2R/r2r/providers/embeddings/__init__.py | 11 | ||||
-rwxr-xr-x | R2R/r2r/providers/embeddings/ollama/ollama_base.py | 156 | ||||
-rwxr-xr-x | R2R/r2r/providers/embeddings/openai/openai_base.py | 200 | ||||
-rwxr-xr-x | R2R/r2r/providers/embeddings/sentence_transformer/sentence_transformer_base.py | 160 | ||||
-rwxr-xr-x | R2R/r2r/providers/eval/__init__.py | 3 | ||||
-rwxr-xr-x | R2R/r2r/providers/eval/llm/base_llm_eval.py | 84 | ||||
-rwxr-xr-x | R2R/r2r/providers/kg/__init__.py | 3 | ||||
-rwxr-xr-x | R2R/r2r/providers/kg/neo4j/base_neo4j.py | 983 | ||||
-rwxr-xr-x | R2R/r2r/providers/llms/__init__.py | 7 | ||||
-rwxr-xr-x | R2R/r2r/providers/llms/litellm/base_litellm.py | 142 | ||||
-rwxr-xr-x | R2R/r2r/providers/llms/openai/base_openai.py | 144 | ||||
-rwxr-xr-x | R2R/r2r/providers/vector_dbs/__init__.py | 5 | ||||
-rwxr-xr-x | R2R/r2r/providers/vector_dbs/pgvector/pgvector_db.py | 610 |
14 files changed, 2508 insertions, 0 deletions
diff --git a/R2R/r2r/providers/__init__.py b/R2R/r2r/providers/__init__.py new file mode 100755 index 00000000..e69de29b --- /dev/null +++ b/R2R/r2r/providers/__init__.py diff --git a/R2R/r2r/providers/embeddings/__init__.py b/R2R/r2r/providers/embeddings/__init__.py new file mode 100755 index 00000000..6b0c8b83 --- /dev/null +++ b/R2R/r2r/providers/embeddings/__init__.py @@ -0,0 +1,11 @@ +from .ollama.ollama_base import OllamaEmbeddingProvider +from .openai.openai_base import OpenAIEmbeddingProvider +from .sentence_transformer.sentence_transformer_base import ( + SentenceTransformerEmbeddingProvider, +) + +__all__ = [ + "OllamaEmbeddingProvider", + "OpenAIEmbeddingProvider", + "SentenceTransformerEmbeddingProvider", +] diff --git a/R2R/r2r/providers/embeddings/ollama/ollama_base.py b/R2R/r2r/providers/embeddings/ollama/ollama_base.py new file mode 100755 index 00000000..31a8c717 --- /dev/null +++ b/R2R/r2r/providers/embeddings/ollama/ollama_base.py @@ -0,0 +1,156 @@ +import asyncio +import logging +import os +import random +from typing import Any + +from ollama import AsyncClient, Client + +from r2r.base import EmbeddingConfig, EmbeddingProvider, VectorSearchResult + +logger = logging.getLogger(__name__) + + +class OllamaEmbeddingProvider(EmbeddingProvider): + def __init__(self, config: EmbeddingConfig): + super().__init__(config) + provider = config.provider + if not provider: + raise ValueError( + "Must set provider in order to initialize `OllamaEmbeddingProvider`." + ) + if provider != "ollama": + raise ValueError( + "OllamaEmbeddingProvider must be initialized with provider `ollama`." + ) + if config.rerank_model: + raise ValueError( + "OllamaEmbeddingProvider does not support separate reranking." + ) + + self.base_model = config.base_model + self.base_dimension = config.base_dimension + self.base_url = os.getenv("OLLAMA_API_BASE") + logger.info( + f"Using Ollama API base URL: {self.base_url or 'http://127.0.0.1:11434'}" + ) + self.client = Client(host=self.base_url) + self.aclient = AsyncClient(host=self.base_url) + + self.request_queue = asyncio.Queue() + self.max_retries = 2 + self.initial_backoff = 1 + self.max_backoff = 60 + self.concurrency_limit = 10 + self.semaphore = asyncio.Semaphore(self.concurrency_limit) + + async def process_queue(self): + while True: + task = await self.request_queue.get() + try: + result = await self.execute_task_with_backoff(task) + task["future"].set_result(result) + except Exception as e: + task["future"].set_exception(e) + finally: + self.request_queue.task_done() + + async def execute_task_with_backoff(self, task: dict[str, Any]): + retries = 0 + backoff = self.initial_backoff + while retries < self.max_retries: + try: + async with self.semaphore: + response = await asyncio.wait_for( + self.aclient.embeddings( + prompt=task["text"], model=self.base_model + ), + timeout=30, + ) + return response["embedding"] + except Exception as e: + logger.warning( + f"Request failed (attempt {retries + 1}): {str(e)}" + ) + retries += 1 + if retries == self.max_retries: + raise Exception( + f"Max retries reached. Last error: {str(e)}" + ) + await asyncio.sleep(backoff + random.uniform(0, 1)) + backoff = min(backoff * 2, self.max_backoff) + + def get_embedding( + self, + text: str, + stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE, + ) -> list[float]: + if stage != EmbeddingProvider.PipeStage.BASE: + raise ValueError( + "OllamaEmbeddingProvider only supports search stage." + ) + + try: + response = self.client.embeddings( + prompt=text, model=self.base_model + ) + return response["embedding"] + except Exception as e: + logger.error(f"Error getting embedding: {str(e)}") + raise + + def get_embeddings( + self, + texts: list[str], + stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE, + ) -> list[list[float]]: + return [self.get_embedding(text, stage) for text in texts] + + async def async_get_embeddings( + self, + texts: list[str], + stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE, + ) -> list[list[float]]: + if stage != EmbeddingProvider.PipeStage.BASE: + raise ValueError( + "OllamaEmbeddingProvider only supports search stage." + ) + + queue_processor = asyncio.create_task(self.process_queue()) + futures = [] + for text in texts: + future = asyncio.Future() + await self.request_queue.put({"text": text, "future": future}) + futures.append(future) + + try: + results = await asyncio.gather(*futures, return_exceptions=True) + # Check if any result is an exception and raise it + exceptions = set([r for r in results if isinstance(r, Exception)]) + if exceptions: + raise Exception( + f"Embedding generation failed for one or more embeddings." + ) + return results + except Exception as e: + logger.error(f"Embedding generation failed: {str(e)}") + raise + finally: + await self.request_queue.join() + queue_processor.cancel() + + def rerank( + self, + query: str, + results: list[VectorSearchResult], + stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.RERANK, + limit: int = 10, + ) -> list[VectorSearchResult]: + return results[:limit] + + def tokenize_string( + self, text: str, model: str, stage: EmbeddingProvider.PipeStage + ) -> list[int]: + raise NotImplementedError( + "Tokenization is not supported by OllamaEmbeddingProvider." + ) diff --git a/R2R/r2r/providers/embeddings/openai/openai_base.py b/R2R/r2r/providers/embeddings/openai/openai_base.py new file mode 100755 index 00000000..7e7d32aa --- /dev/null +++ b/R2R/r2r/providers/embeddings/openai/openai_base.py @@ -0,0 +1,200 @@ +import logging +import os + +from openai import AsyncOpenAI, AuthenticationError, OpenAI + +from r2r.base import EmbeddingConfig, EmbeddingProvider, VectorSearchResult + +logger = logging.getLogger(__name__) + + +class OpenAIEmbeddingProvider(EmbeddingProvider): + MODEL_TO_TOKENIZER = { + "text-embedding-ada-002": "cl100k_base", + "text-embedding-3-small": "cl100k_base", + "text-embedding-3-large": "cl100k_base", + } + MODEL_TO_DIMENSIONS = { + "text-embedding-ada-002": [1536], + "text-embedding-3-small": [512, 1536], + "text-embedding-3-large": [256, 1024, 3072], + } + + def __init__(self, config: EmbeddingConfig): + super().__init__(config) + provider = config.provider + if not provider: + raise ValueError( + "Must set provider in order to initialize OpenAIEmbeddingProvider." + ) + + if provider != "openai": + raise ValueError( + "OpenAIEmbeddingProvider must be initialized with provider `openai`." + ) + if not os.getenv("OPENAI_API_KEY"): + raise ValueError( + "Must set OPENAI_API_KEY in order to initialize OpenAIEmbeddingProvider." + ) + self.client = OpenAI() + self.async_client = AsyncOpenAI() + + if config.rerank_model: + raise ValueError( + "OpenAIEmbeddingProvider does not support separate reranking." + ) + self.base_model = config.base_model + self.base_dimension = config.base_dimension + + if self.base_model not in OpenAIEmbeddingProvider.MODEL_TO_TOKENIZER: + raise ValueError( + f"OpenAI embedding model {self.base_model} not supported." + ) + if ( + self.base_dimension + and self.base_dimension + not in OpenAIEmbeddingProvider.MODEL_TO_DIMENSIONS[self.base_model] + ): + raise ValueError( + f"Dimensions {self.dimension} for {self.base_model} are not supported" + ) + + if not self.base_model or not self.base_dimension: + raise ValueError( + "Must set base_model and base_dimension in order to initialize OpenAIEmbeddingProvider." + ) + + if config.rerank_model: + raise ValueError( + "OpenAIEmbeddingProvider does not support separate reranking." + ) + + def get_embedding( + self, + text: str, + stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE, + ) -> list[float]: + if stage != EmbeddingProvider.PipeStage.BASE: + raise ValueError( + "OpenAIEmbeddingProvider only supports search stage." + ) + + try: + return ( + self.client.embeddings.create( + input=[text], + model=self.base_model, + dimensions=self.base_dimension + or OpenAIEmbeddingProvider.MODEL_TO_DIMENSIONS[ + self.base_model + ][-1], + ) + .data[0] + .embedding + ) + except AuthenticationError as e: + raise ValueError( + "Invalid OpenAI API key provided. Please check your OPENAI_API_KEY environment variable." + ) from e + + async def async_get_embedding( + self, + text: str, + stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE, + ) -> list[float]: + if stage != EmbeddingProvider.PipeStage.BASE: + raise ValueError( + "OpenAIEmbeddingProvider only supports search stage." + ) + + try: + response = await self.async_client.embeddings.create( + input=[text], + model=self.base_model, + dimensions=self.base_dimension + or OpenAIEmbeddingProvider.MODEL_TO_DIMENSIONS[ + self.base_model + ][-1], + ) + return response.data[0].embedding + except AuthenticationError as e: + raise ValueError( + "Invalid OpenAI API key provided. Please check your OPENAI_API_KEY environment variable." + ) from e + + def get_embeddings( + self, + texts: list[str], + stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE, + ) -> list[list[float]]: + if stage != EmbeddingProvider.PipeStage.BASE: + raise ValueError( + "OpenAIEmbeddingProvider only supports search stage." + ) + + try: + return [ + ele.embedding + for ele in self.client.embeddings.create( + input=texts, + model=self.base_model, + dimensions=self.base_dimension + or OpenAIEmbeddingProvider.MODEL_TO_DIMENSIONS[ + self.base_model + ][-1], + ).data + ] + except AuthenticationError as e: + raise ValueError( + "Invalid OpenAI API key provided. Please check your OPENAI_API_KEY environment variable." + ) from e + + async def async_get_embeddings( + self, + texts: list[str], + stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE, + ) -> list[list[float]]: + if stage != EmbeddingProvider.PipeStage.BASE: + raise ValueError( + "OpenAIEmbeddingProvider only supports search stage." + ) + + try: + response = await self.async_client.embeddings.create( + input=texts, + model=self.base_model, + dimensions=self.base_dimension + or OpenAIEmbeddingProvider.MODEL_TO_DIMENSIONS[ + self.base_model + ][-1], + ) + return [ele.embedding for ele in response.data] + except AuthenticationError as e: + raise ValueError( + "Invalid OpenAI API key provided. Please check your OPENAI_API_KEY environment variable." + ) from e + + def rerank( + self, + query: str, + results: list[VectorSearchResult], + stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.RERANK, + limit: int = 10, + ): + return results[:limit] + + def tokenize_string(self, text: str, model: str) -> list[int]: + try: + import tiktoken + except ImportError: + raise ValueError( + "Must download tiktoken library to run `tokenize_string`." + ) + # tiktoken encoding - + # cl100k_base - gpt-4, gpt-3.5-turbo, text-embedding-ada-002, text-embedding-3-small, text-embedding-3-large + if model not in OpenAIEmbeddingProvider.MODEL_TO_TOKENIZER: + raise ValueError(f"OpenAI embedding model {model} not supported.") + encoding = tiktoken.get_encoding( + OpenAIEmbeddingProvider.MODEL_TO_TOKENIZER[model] + ) + return encoding.encode(text) diff --git a/R2R/r2r/providers/embeddings/sentence_transformer/sentence_transformer_base.py b/R2R/r2r/providers/embeddings/sentence_transformer/sentence_transformer_base.py new file mode 100755 index 00000000..3316cb60 --- /dev/null +++ b/R2R/r2r/providers/embeddings/sentence_transformer/sentence_transformer_base.py @@ -0,0 +1,160 @@ +import logging + +from r2r.base import EmbeddingConfig, EmbeddingProvider, VectorSearchResult + +logger = logging.getLogger(__name__) + + +class SentenceTransformerEmbeddingProvider(EmbeddingProvider): + def __init__( + self, + config: EmbeddingConfig, + ): + super().__init__(config) + logger.info( + "Initializing `SentenceTransformerEmbeddingProvider` with separate models for search and rerank." + ) + provider = config.provider + if not provider: + raise ValueError( + "Must set provider in order to initialize SentenceTransformerEmbeddingProvider." + ) + if provider != "sentence-transformers": + raise ValueError( + "SentenceTransformerEmbeddingProvider must be initialized with provider `sentence-transformers`." + ) + try: + from sentence_transformers import CrossEncoder, SentenceTransformer + + self.SentenceTransformer = SentenceTransformer + # TODO - Modify this to be configurable, as `bge-reranker-large` is a `SentenceTransformer` model + self.CrossEncoder = CrossEncoder + except ImportError as e: + raise ValueError( + "Must download sentence-transformers library to run `SentenceTransformerEmbeddingProvider`." + ) from e + + # Initialize separate models for search and rerank + self.do_search = False + self.do_rerank = False + + self.search_encoder = self._init_model( + config, EmbeddingProvider.PipeStage.BASE + ) + self.rerank_encoder = self._init_model( + config, EmbeddingProvider.PipeStage.RERANK + ) + + def _init_model(self, config: EmbeddingConfig, stage: str): + stage_name = stage.name.lower() + model = config.dict().get(f"{stage_name}_model", None) + dimension = config.dict().get(f"{stage_name}_dimension", None) + + transformer_type = config.dict().get( + f"{stage_name}_transformer_type", "SentenceTransformer" + ) + + if stage == EmbeddingProvider.PipeStage.BASE: + self.do_search = True + # Check if a model is set for the stage + if not (model and dimension and transformer_type): + raise ValueError( + f"Must set {stage.name.lower()}_model and {stage.name.lower()}_dimension for {stage} stage in order to initialize SentenceTransformerEmbeddingProvider." + ) + + if stage == EmbeddingProvider.PipeStage.RERANK: + # Check if a model is set for the stage + if not (model and dimension and transformer_type): + return None + + self.do_rerank = True + if transformer_type == "SentenceTransformer": + raise ValueError( + f"`SentenceTransformer` models are not yet supported for {stage} stage in SentenceTransformerEmbeddingProvider." + ) + + # Save the model_key and dimension into instance variables + setattr(self, f"{stage_name}_model", model) + setattr(self, f"{stage_name}_dimension", dimension) + setattr(self, f"{stage_name}_transformer_type", transformer_type) + + # Initialize the model + encoder = ( + self.SentenceTransformer( + model, truncate_dim=dimension, trust_remote_code=True + ) + if transformer_type == "SentenceTransformer" + else self.CrossEncoder(model, trust_remote_code=True) + ) + return encoder + + def get_embedding( + self, + text: str, + stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE, + ) -> list[float]: + if stage != EmbeddingProvider.PipeStage.BASE: + raise ValueError("`get_embedding` only supports `SEARCH` stage.") + if not self.do_search: + raise ValueError( + "`get_embedding` can only be called for the search stage if a search model is set." + ) + encoder = self.search_encoder + return encoder.encode([text]).tolist()[0] + + def get_embeddings( + self, + texts: list[str], + stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE, + ) -> list[list[float]]: + if stage != EmbeddingProvider.PipeStage.BASE: + raise ValueError("`get_embeddings` only supports `SEARCH` stage.") + if not self.do_search: + raise ValueError( + "`get_embeddings` can only be called for the search stage if a search model is set." + ) + encoder = ( + self.search_encoder + if stage == EmbeddingProvider.PipeStage.BASE + else self.rerank_encoder + ) + return encoder.encode(texts).tolist() + + def rerank( + self, + query: str, + results: list[VectorSearchResult], + stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.RERANK, + limit: int = 10, + ) -> list[VectorSearchResult]: + if stage != EmbeddingProvider.PipeStage.RERANK: + raise ValueError("`rerank` only supports `RERANK` stage.") + if not self.do_rerank: + return results[:limit] + + from copy import copy + + texts = copy([doc.metadata["text"] for doc in results]) + # Use the rank method from the rerank_encoder, which is a CrossEncoder model + reranked_scores = self.rerank_encoder.rank( + query, texts, return_documents=False, top_k=limit + ) + # Map the reranked scores back to the original documents + reranked_results = [] + for score in reranked_scores: + corpus_id = score["corpus_id"] + new_result = results[corpus_id] + new_result.score = float(score["score"]) + reranked_results.append(new_result) + + # Sort the documents by the new scores in descending order + reranked_results.sort(key=lambda doc: doc.score, reverse=True) + return reranked_results + + def tokenize_string( + self, + stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE, + ) -> list[int]: + raise ValueError( + "SentenceTransformerEmbeddingProvider does not support tokenize_string." + ) diff --git a/R2R/r2r/providers/eval/__init__.py b/R2R/r2r/providers/eval/__init__.py new file mode 100755 index 00000000..3f5e1b51 --- /dev/null +++ b/R2R/r2r/providers/eval/__init__.py @@ -0,0 +1,3 @@ +from .llm.base_llm_eval import LLMEvalProvider + +__all__ = ["LLMEvalProvider"] diff --git a/R2R/r2r/providers/eval/llm/base_llm_eval.py b/R2R/r2r/providers/eval/llm/base_llm_eval.py new file mode 100755 index 00000000..7c573a34 --- /dev/null +++ b/R2R/r2r/providers/eval/llm/base_llm_eval.py @@ -0,0 +1,84 @@ +from fractions import Fraction +from typing import Union + +from r2r import EvalConfig, EvalProvider, LLMProvider, PromptProvider +from r2r.base.abstractions.llm import GenerationConfig + + +class LLMEvalProvider(EvalProvider): + def __init__( + self, + config: EvalConfig, + llm_provider: LLMProvider, + prompt_provider: PromptProvider, + ): + super().__init__(config) + + self.llm_provider = llm_provider + self.prompt_provider = prompt_provider + + def _calc_query_context_relevancy(self, query: str, context: str) -> float: + system_prompt = self.prompt_provider.get_prompt("default_system") + eval_prompt = self.prompt_provider.get_prompt( + "rag_context_eval", {"query": query, "context": context} + ) + response = self.llm_provider.get_completion( + self.prompt_provider._get_message_payload( + system_prompt, eval_prompt + ), + self.eval_generation_config, + ) + response_text = response.choices[0].message.content + fraction = ( + response_text + # Get the fraction in the returned tuple + .split(",")[-1][:-1] + # Remove any quotes and spaces + .replace("'", "") + .replace('"', "") + .strip() + ) + return float(Fraction(fraction)) + + def _calc_answer_grounding( + self, query: str, context: str, answer: str + ) -> float: + system_prompt = self.prompt_provider.get_prompt("default_system") + eval_prompt = self.prompt_provider.get_prompt( + "rag_answer_eval", + {"query": query, "context": context, "answer": answer}, + ) + response = self.llm_provider.get_completion( + self.prompt_provider._get_message_payload( + system_prompt, eval_prompt + ), + self.eval_generation_config, + ) + response_text = response.choices[0].message.content + fraction = ( + response_text + # Get the fraction in the returned tuple + .split(",")[-1][:-1] + # Remove any quotes and spaces + .replace("'", "") + .replace('"', "") + .strip() + ) + return float(Fraction(fraction)) + + def _evaluate( + self, + query: str, + context: str, + answer: str, + eval_generation_config: GenerationConfig, + ) -> dict[str, dict[str, Union[str, float]]]: + self.eval_generation_config = eval_generation_config + query_context_relevancy = self._calc_query_context_relevancy( + query, context + ) + answer_grounding = self._calc_answer_grounding(query, context, answer) + return { + "query_context_relevancy": query_context_relevancy, + "answer_grounding": answer_grounding, + } diff --git a/R2R/r2r/providers/kg/__init__.py b/R2R/r2r/providers/kg/__init__.py new file mode 100755 index 00000000..36bc79a2 --- /dev/null +++ b/R2R/r2r/providers/kg/__init__.py @@ -0,0 +1,3 @@ +from .neo4j.base_neo4j import Neo4jKGProvider + +__all__ = ["Neo4jKGProvider"] diff --git a/R2R/r2r/providers/kg/neo4j/base_neo4j.py b/R2R/r2r/providers/kg/neo4j/base_neo4j.py new file mode 100755 index 00000000..9ede2b85 --- /dev/null +++ b/R2R/r2r/providers/kg/neo4j/base_neo4j.py @@ -0,0 +1,983 @@ +# abstractions are taken from LlamaIndex +# Neo4jKGProvider is almost entirely taken from LlamaIndex Neo4jPropertyGraphStore +# https://github.com/run-llama/llama_index +import json +import os +from typing import Any, Dict, List, Optional, Tuple + +from r2r.base import ( + EntityType, + KGConfig, + KGProvider, + PromptProvider, + format_entity_types, + format_relations, +) +from r2r.base.abstractions.llama_abstractions import ( + LIST_LIMIT, + ChunkNode, + EntityNode, + LabelledNode, + PropertyGraphStore, + Relation, + Triplet, + VectorStoreQuery, + clean_string_values, + value_sanitize, +) + + +def remove_empty_values(input_dict): + """ + Remove entries with empty values from the dictionary. + + Parameters: + input_dict (dict): The dictionary from which empty values need to be removed. + + Returns: + dict: A new dictionary with all empty values removed. + """ + # Create a new dictionary excluding empty values + return {key: value for key, value in input_dict.items() if value} + + +BASE_ENTITY_LABEL = "__Entity__" +EXCLUDED_LABELS = ["_Bloom_Perspective_", "_Bloom_Scene_"] +EXCLUDED_RELS = ["_Bloom_HAS_SCENE_"] +EXHAUSTIVE_SEARCH_LIMIT = 10000 +# Threshold for returning all available prop values in graph schema +DISTINCT_VALUE_LIMIT = 10 + +node_properties_query = """ +CALL apoc.meta.data() +YIELD label, other, elementType, type, property +WHERE NOT type = "RELATIONSHIP" AND elementType = "node" + AND NOT label IN $EXCLUDED_LABELS +WITH label AS nodeLabels, collect({property:property, type:type}) AS properties +RETURN {labels: nodeLabels, properties: properties} AS output + +""" + +rel_properties_query = """ +CALL apoc.meta.data() +YIELD label, other, elementType, type, property +WHERE NOT type = "RELATIONSHIP" AND elementType = "relationship" + AND NOT label in $EXCLUDED_LABELS +WITH label AS nodeLabels, collect({property:property, type:type}) AS properties +RETURN {type: nodeLabels, properties: properties} AS output +""" + +rel_query = """ +CALL apoc.meta.data() +YIELD label, other, elementType, type, property +WHERE type = "RELATIONSHIP" AND elementType = "node" +UNWIND other AS other_node +WITH * WHERE NOT label IN $EXCLUDED_LABELS + AND NOT other_node IN $EXCLUDED_LABELS +RETURN {start: label, type: property, end: toString(other_node)} AS output +""" + + +class Neo4jKGProvider(PropertyGraphStore, KGProvider): + r""" + Neo4j Property Graph Store. + + This class implements a Neo4j property graph store. + + If you are using local Neo4j instead of aura, here's a helpful + command for launching the docker container: + + ```bash + docker run \ + -p 7474:7474 -p 7687:7687 \ + -v $PWD/data:/data -v $PWD/plugins:/plugins \ + --name neo4j-apoc \ + -e NEO4J_apoc_export_file_enabled=true \ + -e NEO4J_apoc_import_file_enabled=true \ + -e NEO4J_apoc_import_file_use__neo4j__config=true \ + -e NEO4JLABS_PLUGINS=\\[\"apoc\"\\] \ + neo4j:latest + ``` + + Args: + username (str): The username for the Neo4j database. + password (str): The password for the Neo4j database. + url (str): The URL for the Neo4j database. + database (Optional[str]): The name of the database to connect to. Defaults to "neo4j". + + Examples: + `pip install llama-index-graph-stores-neo4j` + + ```python + from llama_index.core.indices.property_graph import PropertyGraphIndex + from llama_index.graph_stores.neo4j import Neo4jKGProvider + + # Create a Neo4jKGProvider instance + graph_store = Neo4jKGProvider( + username="neo4j", + password="neo4j", + url="bolt://localhost:7687", + database="neo4j" + ) + + # create the index + index = PropertyGraphIndex.from_documents( + documents, + property_graph_store=graph_store, + ) + ``` + """ + + supports_structured_queries: bool = True + supports_vector_queries: bool = True + + def __init__( + self, + config: KGConfig, + refresh_schema: bool = True, + sanitize_query_output: bool = True, + enhanced_schema: bool = False, + *args: Any, + **kwargs: Any, + ) -> None: + if config.provider != "neo4j": + raise ValueError( + "Neo4jKGProvider must be initialized with config with `neo4j` provider." + ) + + try: + import neo4j + except ImportError: + raise ImportError("Please install neo4j: pip install neo4j") + + username = os.getenv("NEO4J_USER") + password = os.getenv("NEO4J_PASSWORD") + url = os.getenv("NEO4J_URL") + database = os.getenv("NEO4J_DATABASE", "neo4j") + + if not username or not password or not url: + raise ValueError( + "Neo4j configuration values are missing. Please set NEO4J_USER, NEO4J_PASSWORD, and NEO4J_URL environment variables." + ) + + self.sanitize_query_output = sanitize_query_output + self.enhcnaced_schema = enhanced_schema + self._driver = neo4j.GraphDatabase.driver( + url, auth=(username, password), **kwargs + ) + self._async_driver = neo4j.AsyncGraphDatabase.driver( + url, + auth=(username, password), + **kwargs, + ) + self._database = database + self.structured_schema = {} + if refresh_schema: + self.refresh_schema() + self.neo4j = neo4j + self.config = config + + @property + def client(self): + return self._driver + + def refresh_schema(self) -> None: + """Refresh the schema.""" + node_query_results = self.structured_query( + node_properties_query, + param_map={ + "EXCLUDED_LABELS": [*EXCLUDED_LABELS, BASE_ENTITY_LABEL] + }, + ) + node_properties = ( + [el["output"] for el in node_query_results] + if node_query_results + else [] + ) + + rels_query_result = self.structured_query( + rel_properties_query, param_map={"EXCLUDED_LABELS": EXCLUDED_RELS} + ) + rel_properties = ( + [el["output"] for el in rels_query_result] + if rels_query_result + else [] + ) + + rel_objs_query_result = self.structured_query( + rel_query, + param_map={ + "EXCLUDED_LABELS": [*EXCLUDED_LABELS, BASE_ENTITY_LABEL] + }, + ) + relationships = ( + [el["output"] for el in rel_objs_query_result] + if rel_objs_query_result + else [] + ) + + # Get constraints & indexes + try: + constraint = self.structured_query("SHOW CONSTRAINTS") + index = self.structured_query( + "CALL apoc.schema.nodes() YIELD label, properties, type, size, " + "valuesSelectivity WHERE type = 'RANGE' RETURN *, " + "size * valuesSelectivity as distinctValues" + ) + except ( + self.neo4j.exceptions.ClientError + ): # Read-only user might not have access to schema information + constraint = [] + index = [] + + self.structured_schema = { + "node_props": { + el["labels"]: el["properties"] for el in node_properties + }, + "rel_props": { + el["type"]: el["properties"] for el in rel_properties + }, + "relationships": relationships, + "metadata": {"constraint": constraint, "index": index}, + } + schema_counts = self.structured_query( + "CALL apoc.meta.graphSample() YIELD nodes, relationships " + "RETURN nodes, [rel in relationships | {name:apoc.any.property" + "(rel, 'type'), count: apoc.any.property(rel, 'count')}]" + " AS relationships" + ) + # Update node info + for node in schema_counts[0].get("nodes", []): + # Skip bloom labels + if node["name"] in EXCLUDED_LABELS: + continue + node_props = self.structured_schema["node_props"].get(node["name"]) + if not node_props: # The node has no properties + continue + enhanced_cypher = self._enhanced_schema_cypher( + node["name"], + node_props, + node["count"] < EXHAUSTIVE_SEARCH_LIMIT, + ) + enhanced_info = self.structured_query(enhanced_cypher)[0]["output"] + for prop in node_props: + if prop["property"] in enhanced_info: + prop.update(enhanced_info[prop["property"]]) + # Update rel info + for rel in schema_counts[0].get("relationships", []): + # Skip bloom labels + if rel["name"] in EXCLUDED_RELS: + continue + rel_props = self.structured_schema["rel_props"].get(rel["name"]) + if not rel_props: # The rel has no properties + continue + enhanced_cypher = self._enhanced_schema_cypher( + rel["name"], + rel_props, + rel["count"] < EXHAUSTIVE_SEARCH_LIMIT, + is_relationship=True, + ) + try: + enhanced_info = self.structured_query(enhanced_cypher)[0][ + "output" + ] + for prop in rel_props: + if prop["property"] in enhanced_info: + prop.update(enhanced_info[prop["property"]]) + except self.neo4j.exceptions.ClientError: + # Sometimes the types are not consistent in the db + pass + + def upsert_nodes(self, nodes: List[LabelledNode]) -> None: + # Lists to hold separated types + entity_dicts: List[dict] = [] + chunk_dicts: List[dict] = [] + + # Sort by type + for item in nodes: + if isinstance(item, EntityNode): + entity_dicts.append({**item.dict(), "id": item.id}) + elif isinstance(item, ChunkNode): + chunk_dicts.append({**item.dict(), "id": item.id}) + else: + # Log that we do not support these types of nodes + # Or raise an error? + pass + + if chunk_dicts: + self.structured_query( + """ + UNWIND $data AS row + MERGE (c:Chunk {id: row.id}) + SET c.text = row.text + WITH c, row + SET c += row.properties + WITH c, row.embedding AS embedding + WHERE embedding IS NOT NULL + CALL db.create.setNodeVectorProperty(c, 'embedding', embedding) + RETURN count(*) + """, + param_map={"data": chunk_dicts}, + ) + + if entity_dicts: + self.structured_query( + """ + UNWIND $data AS row + MERGE (e:`__Entity__` {id: row.id}) + SET e += apoc.map.clean(row.properties, [], []) + SET e.name = row.name + WITH e, row + CALL apoc.create.addLabels(e, [row.label]) + YIELD node + WITH e, row + CALL { + WITH e, row + WITH e, row + WHERE row.embedding IS NOT NULL + CALL db.create.setNodeVectorProperty(e, 'embedding', row.embedding) + RETURN count(*) AS count + } + WITH e, row WHERE row.properties.triplet_source_id IS NOT NULL + MERGE (c:Chunk {id: row.properties.triplet_source_id}) + MERGE (e)<-[:MENTIONS]-(c) + """, + param_map={"data": entity_dicts}, + ) + + def upsert_relations(self, relations: List[Relation]) -> None: + """Add relations.""" + params = [r.dict() for r in relations] + + self.structured_query( + """ + UNWIND $data AS row + MERGE (source {id: row.source_id}) + MERGE (target {id: row.target_id}) + WITH source, target, row + CALL apoc.merge.relationship(source, row.label, {}, row.properties, target) YIELD rel + RETURN count(*) + """, + param_map={"data": params}, + ) + + def get( + self, + properties: Optional[dict] = None, + ids: Optional[List[str]] = None, + ) -> List[LabelledNode]: + """Get nodes.""" + cypher_statement = "MATCH (e) " + + params = {} + if properties or ids: + cypher_statement += "WHERE " + + if ids: + cypher_statement += "e.id in $ids " + params["ids"] = ids + + if properties: + prop_list = [] + for i, prop in enumerate(properties): + prop_list.append(f"e.`{prop}` = $property_{i}") + params[f"property_{i}"] = properties[prop] + cypher_statement += " AND ".join(prop_list) + + return_statement = """ + WITH e + RETURN e.id AS name, + [l in labels(e) WHERE l <> '__Entity__' | l][0] AS type, + e{.* , embedding: Null, id: Null} AS properties + """ + cypher_statement += return_statement + + response = self.structured_query(cypher_statement, param_map=params) + response = response if response else [] + + nodes = [] + for record in response: + # text indicates a chunk node + # none on the type indicates an implicit node, likely a chunk node + if "text" in record["properties"] or record["type"] is None: + text = record["properties"].pop("text", "") + nodes.append( + ChunkNode( + id_=record["name"], + text=text, + properties=remove_empty_values(record["properties"]), + ) + ) + else: + nodes.append( + EntityNode( + name=record["name"], + label=record["type"], + properties=remove_empty_values(record["properties"]), + ) + ) + + return nodes + + def get_triplets( + self, + entity_names: Optional[List[str]] = None, + relation_names: Optional[List[str]] = None, + properties: Optional[dict] = None, + ids: Optional[List[str]] = None, + ) -> List[Triplet]: + # TODO: handle ids of chunk nodes + cypher_statement = "MATCH (e:`__Entity__`) " + + params = {} + if entity_names or properties or ids: + cypher_statement += "WHERE " + + if entity_names: + cypher_statement += "e.name in $entity_names " + params["entity_names"] = entity_names + + if ids: + cypher_statement += "e.id in $ids " + params["ids"] = ids + + if properties: + prop_list = [] + for i, prop in enumerate(properties): + prop_list.append(f"e.`{prop}` = $property_{i}") + params[f"property_{i}"] = properties[prop] + cypher_statement += " AND ".join(prop_list) + + return_statement = f""" + WITH e + CALL {{ + WITH e + MATCH (e)-[r{':`' + '`|`'.join(relation_names) + '`' if relation_names else ''}]->(t) + RETURN e.name AS source_id, [l in labels(e) WHERE l <> '__Entity__' | l][0] AS source_type, + e{{.* , embedding: Null, name: Null}} AS source_properties, + type(r) AS type, + t.name AS target_id, [l in labels(t) WHERE l <> '__Entity__' | l][0] AS target_type, + t{{.* , embedding: Null, name: Null}} AS target_properties + UNION ALL + WITH e + MATCH (e)<-[r{':`' + '`|`'.join(relation_names) + '`' if relation_names else ''}]-(t) + RETURN t.name AS source_id, [l in labels(t) WHERE l <> '__Entity__' | l][0] AS source_type, + e{{.* , embedding: Null, name: Null}} AS source_properties, + type(r) AS type, + e.name AS target_id, [l in labels(e) WHERE l <> '__Entity__' | l][0] AS target_type, + t{{.* , embedding: Null, name: Null}} AS target_properties + }} + RETURN source_id, source_type, type, target_id, target_type, source_properties, target_properties""" + cypher_statement += return_statement + + data = self.structured_query(cypher_statement, param_map=params) + data = data if data else [] + + triples = [] + for record in data: + source = EntityNode( + name=record["source_id"], + label=record["source_type"], + properties=remove_empty_values(record["source_properties"]), + ) + target = EntityNode( + name=record["target_id"], + label=record["target_type"], + properties=remove_empty_values(record["target_properties"]), + ) + rel = Relation( + source_id=record["source_id"], + target_id=record["target_id"], + label=record["type"], + ) + triples.append([source, rel, target]) + return triples + + def get_rel_map( + self, + graph_nodes: List[LabelledNode], + depth: int = 2, + limit: int = 30, + ignore_rels: Optional[List[str]] = None, + ) -> List[Triplet]: + """Get depth-aware rel map.""" + triples = [] + + ids = [node.id for node in graph_nodes] + # Needs some optimization + response = self.structured_query( + f""" + MATCH (e:`__Entity__`) + WHERE e.id in $ids + MATCH p=(e)-[r*1..{depth}]-(other) + WHERE ALL(rel in relationships(p) WHERE type(rel) <> 'MENTIONS') + UNWIND relationships(p) AS rel + WITH distinct rel + WITH startNode(rel) AS source, + type(rel) AS type, + endNode(rel) AS endNode + RETURN source.id AS source_id, [l in labels(source) WHERE l <> '__Entity__' | l][0] AS source_type, + source{{.* , embedding: Null, id: Null}} AS source_properties, + type, + endNode.id AS target_id, [l in labels(endNode) WHERE l <> '__Entity__' | l][0] AS target_type, + endNode{{.* , embedding: Null, id: Null}} AS target_properties + LIMIT toInteger($limit) + """, + param_map={"ids": ids, "limit": limit}, + ) + response = response if response else [] + + ignore_rels = ignore_rels or [] + for record in response: + if record["type"] in ignore_rels: + continue + + source = EntityNode( + name=record["source_id"], + label=record["source_type"], + properties=remove_empty_values(record["source_properties"]), + ) + target = EntityNode( + name=record["target_id"], + label=record["target_type"], + properties=remove_empty_values(record["target_properties"]), + ) + rel = Relation( + source_id=record["source_id"], + target_id=record["target_id"], + label=record["type"], + ) + triples.append([source, rel, target]) + + return triples + + def structured_query( + self, query: str, param_map: Optional[Dict[str, Any]] = None + ) -> Any: + param_map = param_map or {} + + with self._driver.session(database=self._database) as session: + result = session.run(query, param_map) + full_result = [d.data() for d in result] + + if self.sanitize_query_output: + return value_sanitize(full_result) + + return full_result + + def vector_query( + self, query: VectorStoreQuery, **kwargs: Any + ) -> Tuple[List[LabelledNode], List[float]]: + """Query the graph store with a vector store query.""" + data = self.structured_query( + """MATCH (e:`__Entity__`) + WHERE e.embedding IS NOT NULL AND size(e.embedding) = $dimension + WITH e, vector.similarity.cosine(e.embedding, $embedding) AS score + ORDER BY score DESC LIMIT toInteger($limit) + RETURN e.id AS name, + [l in labels(e) WHERE l <> '__Entity__' | l][0] AS type, + e{.* , embedding: Null, name: Null, id: Null} AS properties, + score""", + param_map={ + "embedding": query.query_embedding, + "dimension": len(query.query_embedding), + "limit": query.similarity_top_k, + }, + ) + data = data if data else [] + + nodes = [] + scores = [] + for record in data: + node = EntityNode( + name=record["name"], + label=record["type"], + properties=remove_empty_values(record["properties"]), + ) + nodes.append(node) + scores.append(record["score"]) + + return (nodes, scores) + + def delete( + self, + entity_names: Optional[List[str]] = None, + relation_names: Optional[List[str]] = None, + properties: Optional[dict] = None, + ids: Optional[List[str]] = None, + ) -> None: + """Delete matching data.""" + if entity_names: + self.structured_query( + "MATCH (n) WHERE n.name IN $entity_names DETACH DELETE n", + param_map={"entity_names": entity_names}, + ) + + if ids: + self.structured_query( + "MATCH (n) WHERE n.id IN $ids DETACH DELETE n", + param_map={"ids": ids}, + ) + + if relation_names: + for rel in relation_names: + self.structured_query(f"MATCH ()-[r:`{rel}`]->() DELETE r") + + if properties: + cypher = "MATCH (e) WHERE " + prop_list = [] + params = {} + for i, prop in enumerate(properties): + prop_list.append(f"e.`{prop}` = $property_{i}") + params[f"property_{i}"] = properties[prop] + cypher += " AND ".join(prop_list) + self.structured_query( + cypher + " DETACH DELETE e", param_map=params + ) + + def _enhanced_schema_cypher( + self, + label_or_type: str, + properties: List[Dict[str, Any]], + exhaustive: bool, + is_relationship: bool = False, + ) -> str: + if is_relationship: + match_clause = f"MATCH ()-[n:`{label_or_type}`]->()" + else: + match_clause = f"MATCH (n:`{label_or_type}`)" + + with_clauses = [] + return_clauses = [] + output_dict = {} + if exhaustive: + for prop in properties: + prop_name = prop["property"] + prop_type = prop["type"] + if prop_type == "STRING": + with_clauses.append( + f"collect(distinct substring(toString(n.`{prop_name}`), 0, 50)) " + f"AS `{prop_name}_values`" + ) + return_clauses.append( + f"values:`{prop_name}_values`[..{DISTINCT_VALUE_LIMIT}]," + f" distinct_count: size(`{prop_name}_values`)" + ) + elif prop_type in [ + "INTEGER", + "FLOAT", + "DATE", + "DATE_TIME", + "LOCAL_DATE_TIME", + ]: + with_clauses.append( + f"min(n.`{prop_name}`) AS `{prop_name}_min`" + ) + with_clauses.append( + f"max(n.`{prop_name}`) AS `{prop_name}_max`" + ) + with_clauses.append( + f"count(distinct n.`{prop_name}`) AS `{prop_name}_distinct`" + ) + return_clauses.append( + f"min: toString(`{prop_name}_min`), " + f"max: toString(`{prop_name}_max`), " + f"distinct_count: `{prop_name}_distinct`" + ) + elif prop_type == "LIST": + with_clauses.append( + f"min(size(n.`{prop_name}`)) AS `{prop_name}_size_min`, " + f"max(size(n.`{prop_name}`)) AS `{prop_name}_size_max`" + ) + return_clauses.append( + f"min_size: `{prop_name}_size_min`, " + f"max_size: `{prop_name}_size_max`" + ) + elif prop_type in ["BOOLEAN", "POINT", "DURATION"]: + continue + output_dict[prop_name] = "{" + return_clauses.pop() + "}" + else: + # Just sample 5 random nodes + match_clause += " WITH n LIMIT 5" + for prop in properties: + prop_name = prop["property"] + prop_type = prop["type"] + + # Check if indexed property, we can still do exhaustive + prop_index = [ + el + for el in self.structured_schema["metadata"]["index"] + if el["label"] == label_or_type + and el["properties"] == [prop_name] + and el["type"] == "RANGE" + ] + if prop_type == "STRING": + if ( + prop_index + and prop_index[0].get("size") > 0 + and prop_index[0].get("distinctValues") + <= DISTINCT_VALUE_LIMIT + ): + distinct_values = self.query( + f"CALL apoc.schema.properties.distinct(" + f"'{label_or_type}', '{prop_name}') YIELD value" + )[0]["value"] + return_clauses.append( + f"values: {distinct_values}," + f" distinct_count: {len(distinct_values)}" + ) + else: + with_clauses.append( + f"collect(distinct substring(n.`{prop_name}`, 0, 50)) " + f"AS `{prop_name}_values`" + ) + return_clauses.append(f"values: `{prop_name}_values`") + elif prop_type in [ + "INTEGER", + "FLOAT", + "DATE", + "DATE_TIME", + "LOCAL_DATE_TIME", + ]: + if not prop_index: + with_clauses.append( + f"collect(distinct toString(n.`{prop_name}`)) " + f"AS `{prop_name}_values`" + ) + return_clauses.append(f"values: `{prop_name}_values`") + else: + with_clauses.append( + f"min(n.`{prop_name}`) AS `{prop_name}_min`" + ) + with_clauses.append( + f"max(n.`{prop_name}`) AS `{prop_name}_max`" + ) + with_clauses.append( + f"count(distinct n.`{prop_name}`) AS `{prop_name}_distinct`" + ) + return_clauses.append( + f"min: toString(`{prop_name}_min`), " + f"max: toString(`{prop_name}_max`), " + f"distinct_count: `{prop_name}_distinct`" + ) + + elif prop_type == "LIST": + with_clauses.append( + f"min(size(n.`{prop_name}`)) AS `{prop_name}_size_min`, " + f"max(size(n.`{prop_name}`)) AS `{prop_name}_size_max`" + ) + return_clauses.append( + f"min_size: `{prop_name}_size_min`, " + f"max_size: `{prop_name}_size_max`" + ) + elif prop_type in ["BOOLEAN", "POINT", "DURATION"]: + continue + + output_dict[prop_name] = "{" + return_clauses.pop() + "}" + + with_clause = "WITH " + ",\n ".join(with_clauses) + return_clause = ( + "RETURN {" + + ", ".join(f"`{k}`: {v}" for k, v in output_dict.items()) + + "} AS output" + ) + + # Combine all parts of the Cypher query + return f"{match_clause}\n{with_clause}\n{return_clause}" + + def get_schema(self, refresh: bool = False) -> Any: + if refresh: + self.refresh_schema() + + return self.structured_schema + + def get_schema_str(self, refresh: bool = False) -> str: + schema = self.get_schema(refresh=refresh) + + formatted_node_props = [] + formatted_rel_props = [] + + if self.enhcnaced_schema: + # Enhanced formatting for nodes + for node_type, properties in schema["node_props"].items(): + formatted_node_props.append(f"- **{node_type}**") + for prop in properties: + example = "" + if prop["type"] == "STRING" and prop.get("values"): + if ( + prop.get("distinct_count", 11) + > DISTINCT_VALUE_LIMIT + ): + example = ( + f'Example: "{clean_string_values(prop["values"][0])}"' + if prop["values"] + else "" + ) + else: # If less than 10 possible values return all + example = ( + ( + "Available options: " + f'{[clean_string_values(el) for el in prop["values"]]}' + ) + if prop["values"] + else "" + ) + + elif prop["type"] in [ + "INTEGER", + "FLOAT", + "DATE", + "DATE_TIME", + "LOCAL_DATE_TIME", + ]: + if prop.get("min") is not None: + example = f'Min: {prop["min"]}, Max: {prop["max"]}' + else: + example = ( + f'Example: "{prop["values"][0]}"' + if prop.get("values") + else "" + ) + elif prop["type"] == "LIST": + # Skip embeddings + if ( + not prop.get("min_size") + or prop["min_size"] > LIST_LIMIT + ): + continue + example = f'Min Size: {prop["min_size"]}, Max Size: {prop["max_size"]}' + formatted_node_props.append( + f" - `{prop['property']}`: {prop['type']} {example}" + ) + + # Enhanced formatting for relationships + for rel_type, properties in schema["rel_props"].items(): + formatted_rel_props.append(f"- **{rel_type}**") + for prop in properties: + example = "" + if prop["type"] == "STRING": + if ( + prop.get("distinct_count", 11) + > DISTINCT_VALUE_LIMIT + ): + example = ( + f'Example: "{clean_string_values(prop["values"][0])}"' + if prop.get("values") + else "" + ) + else: # If less than 10 possible values return all + example = ( + ( + "Available options: " + f'{[clean_string_values(el) for el in prop["values"]]}' + ) + if prop.get("values") + else "" + ) + elif prop["type"] in [ + "INTEGER", + "FLOAT", + "DATE", + "DATE_TIME", + "LOCAL_DATE_TIME", + ]: + if prop.get("min"): # If we have min/max + example = ( + f'Min: {prop["min"]}, Max: {prop["max"]}' + ) + else: # return a single value + example = ( + f'Example: "{prop["values"][0]}"' + if prop.get("values") + else "" + ) + elif prop["type"] == "LIST": + # Skip embeddings + if prop["min_size"] > LIST_LIMIT: + continue + example = f'Min Size: {prop["min_size"]}, Max Size: {prop["max_size"]}' + formatted_rel_props.append( + f" - `{prop['property']}: {prop['type']}` {example}" + ) + else: + # Format node properties + for label, props in schema["node_props"].items(): + props_str = ", ".join( + [f"{prop['property']}: {prop['type']}" for prop in props] + ) + formatted_node_props.append(f"{label} {{{props_str}}}") + + # Format relationship properties using structured_schema + for type, props in schema["rel_props"].items(): + props_str = ", ".join( + [f"{prop['property']}: {prop['type']}" for prop in props] + ) + formatted_rel_props.append(f"{type} {{{props_str}}}") + + # Format relationships + formatted_rels = [ + f"(:{el['start']})-[:{el['type']}]->(:{el['end']})" + for el in schema["relationships"] + ] + + return "\n".join( + [ + "Node properties:", + "\n".join(formatted_node_props), + "Relationship properties:", + "\n".join(formatted_rel_props), + "The relationships:", + "\n".join(formatted_rels), + ] + ) + + def update_extraction_prompt( + self, + prompt_provider: PromptProvider, + entity_types: list[EntityType], + relations: list[Relation], + ): + # Fetch the kg extraction prompt with blank entity types and relations + # Note - Assumes that for given prompt there is a `_with_spec` that can have entities + relations specified + few_shot_ner_kg_extraction_with_spec = prompt_provider.get_prompt( + f"{self.config.kg_extraction_prompt}_with_spec" + ) + + # Format the prompt to include the desired entity types and relations + few_shot_ner_kg_extraction = ( + few_shot_ner_kg_extraction_with_spec.replace( + "{entity_types}", format_entity_types(entity_types) + ).replace("{relations}", format_relations(relations)) + ) + + # Update the "few_shot_ner_kg_extraction" prompt used in downstream KG construction + prompt_provider.update_prompt( + self.config.kg_extraction_prompt, + json.dumps(few_shot_ner_kg_extraction, ensure_ascii=False), + ) + + def update_kg_agent_prompt( + self, + prompt_provider: PromptProvider, + entity_types: list[EntityType], + relations: list[Relation], + ): + # Fetch the kg extraction prompt with blank entity types and relations + # Note - Assumes that for given prompt there is a `_with_spec` that can have entities + relations specified + few_shot_ner_kg_extraction_with_spec = prompt_provider.get_prompt( + f"{self.config.kg_agent_prompt}_with_spec" + ) + + # Format the prompt to include the desired entity types and relations + few_shot_ner_kg_extraction = ( + few_shot_ner_kg_extraction_with_spec.replace( + "{entity_types}", + format_entity_types(entity_types, ignore_subcats=True), + ).replace("{relations}", format_relations(relations)) + ) + + # Update the "few_shot_ner_kg_extraction" prompt used in downstream KG construction + prompt_provider.update_prompt( + self.config.kg_agent_prompt, + json.dumps(few_shot_ner_kg_extraction, ensure_ascii=False), + ) diff --git a/R2R/r2r/providers/llms/__init__.py b/R2R/r2r/providers/llms/__init__.py new file mode 100755 index 00000000..38a1c54a --- /dev/null +++ b/R2R/r2r/providers/llms/__init__.py @@ -0,0 +1,7 @@ +from .litellm.base_litellm import LiteLLM +from .openai.base_openai import OpenAILLM + +__all__ = [ + "LiteLLM", + "OpenAILLM", +] diff --git a/R2R/r2r/providers/llms/litellm/base_litellm.py b/R2R/r2r/providers/llms/litellm/base_litellm.py new file mode 100755 index 00000000..581cce9a --- /dev/null +++ b/R2R/r2r/providers/llms/litellm/base_litellm.py @@ -0,0 +1,142 @@ +import logging +from typing import Any, Generator, Union + +from r2r.base import ( + LLMChatCompletion, + LLMChatCompletionChunk, + LLMConfig, + LLMProvider, +) +from r2r.base.abstractions.llm import GenerationConfig + +logger = logging.getLogger(__name__) + + +class LiteLLM(LLMProvider): + """A concrete class for creating LiteLLM models.""" + + def __init__( + self, + config: LLMConfig, + *args, + **kwargs, + ) -> None: + try: + from litellm import acompletion, completion + + self.litellm_completion = completion + self.litellm_acompletion = acompletion + except ImportError: + raise ImportError( + "Error, `litellm` is required to run a LiteLLM. Please install it using `pip install litellm`." + ) + super().__init__(config) + + def get_completion( + self, + messages: list[dict], + generation_config: GenerationConfig, + **kwargs, + ) -> LLMChatCompletion: + if generation_config.stream: + raise ValueError( + "Stream must be set to False to use the `get_completion` method." + ) + return self._get_completion(messages, generation_config, **kwargs) + + def get_completion_stream( + self, + messages: list[dict], + generation_config: GenerationConfig, + **kwargs, + ) -> Generator[LLMChatCompletionChunk, None, None]: + if not generation_config.stream: + raise ValueError( + "Stream must be set to True to use the `get_completion_stream` method." + ) + return self._get_completion(messages, generation_config, **kwargs) + + def extract_content(self, response: LLMChatCompletion) -> str: + return response.choices[0].message.content + + def _get_completion( + self, + messages: list[dict], + generation_config: GenerationConfig, + **kwargs, + ) -> Union[ + LLMChatCompletion, Generator[LLMChatCompletionChunk, None, None] + ]: + # Create a dictionary with the default arguments + args = self._get_base_args(generation_config) + args["messages"] = messages + + # Conditionally add the 'functions' argument if it's not None + if generation_config.functions is not None: + args["functions"] = generation_config.functions + + args = {**args, **kwargs} + response = self.litellm_completion(**args) + + if not generation_config.stream: + return LLMChatCompletion(**response.dict()) + else: + return self._get_chat_completion(response) + + def _get_chat_completion( + self, + response: Any, + ) -> Generator[LLMChatCompletionChunk, None, None]: + for part in response: + yield LLMChatCompletionChunk(**part.dict()) + + def _get_base_args( + self, + generation_config: GenerationConfig, + prompt=None, + ) -> dict: + """Get the base arguments for the LiteLLM API.""" + args = { + "model": generation_config.model, + "temperature": generation_config.temperature, + "top_p": generation_config.top_p, + "stream": generation_config.stream, + # TODO - We need to cap this to avoid potential errors when exceed max allowable context + "max_tokens": generation_config.max_tokens_to_sample, + } + return args + + async def aget_completion( + self, + messages: list[dict], + generation_config: GenerationConfig, + **kwargs, + ) -> LLMChatCompletion: + if generation_config.stream: + raise ValueError( + "Stream must be set to False to use the `aget_completion` method." + ) + return await self._aget_completion( + messages, generation_config, **kwargs + ) + + async def _aget_completion( + self, + messages: list[dict], + generation_config: GenerationConfig, + **kwargs, + ) -> Union[LLMChatCompletion, LLMChatCompletionChunk]: + """Asynchronously get a completion from the OpenAI API based on the provided messages.""" + + # Create a dictionary with the default arguments + args = self._get_base_args(generation_config) + + args["messages"] = messages + + # Conditionally add the 'functions' argument if it's not None + if generation_config.functions is not None: + args["functions"] = generation_config.functions + + args = {**args, **kwargs} + # Create the chat completion + return await self.litellm_acompletion(**args) diff --git a/R2R/r2r/providers/llms/openai/base_openai.py b/R2R/r2r/providers/llms/openai/base_openai.py new file mode 100755 index 00000000..460c0f0b --- /dev/null +++ b/R2R/r2r/providers/llms/openai/base_openai.py @@ -0,0 +1,144 @@ +"""A module for creating OpenAI model abstractions.""" + +import logging +import os +from typing import Union + +from r2r.base import ( + LLMChatCompletion, + LLMChatCompletionChunk, + LLMConfig, + LLMProvider, +) +from r2r.base.abstractions.llm import GenerationConfig + +logger = logging.getLogger(__name__) + + +class OpenAILLM(LLMProvider): + """A concrete class for creating OpenAI models.""" + + def __init__( + self, + config: LLMConfig, + *args, + **kwargs, + ) -> None: + if not isinstance(config, LLMConfig): + raise ValueError( + "The provided config must be an instance of OpenAIConfig." + ) + try: + from openai import OpenAI # noqa + except ImportError: + raise ImportError( + "Error, `openai` is required to run an OpenAILLM. Please install it using `pip install openai`." + ) + if config.provider != "openai": + raise ValueError( + "OpenAILLM must be initialized with config with `openai` provider." + ) + if not os.getenv("OPENAI_API_KEY"): + raise ValueError( + "OpenAI API key not found. Please set the OPENAI_API_KEY environment variable." + ) + super().__init__(config) + self.config: LLMConfig = config + self.client = OpenAI() + + def get_completion( + self, + messages: list[dict], + generation_config: GenerationConfig, + **kwargs, + ) -> LLMChatCompletion: + if generation_config.stream: + raise ValueError( + "Stream must be set to False to use the `get_completion` method." + ) + return self._get_completion(messages, generation_config, **kwargs) + + def get_completion_stream( + self, + messages: list[dict], + generation_config: GenerationConfig, + **kwargs, + ) -> LLMChatCompletionChunk: + if not generation_config.stream: + raise ValueError( + "Stream must be set to True to use the `get_completion_stream` method." + ) + return self._get_completion(messages, generation_config, **kwargs) + + def _get_completion( + self, + messages: list[dict], + generation_config: GenerationConfig, + **kwargs, + ) -> Union[LLMChatCompletion, LLMChatCompletionChunk]: + """Get a completion from the OpenAI API based on the provided messages.""" + + # Create a dictionary with the default arguments + args = self._get_base_args(generation_config) + + args["messages"] = messages + + # Conditionally add the 'functions' argument if it's not None + if generation_config.functions is not None: + args["functions"] = generation_config.functions + + args = {**args, **kwargs} + # Create the chat completion + return self.client.chat.completions.create(**args) + + def _get_base_args( + self, + generation_config: GenerationConfig, + ) -> dict: + """Get the base arguments for the OpenAI API.""" + + args = { + "model": generation_config.model, + "temperature": generation_config.temperature, + "top_p": generation_config.top_p, + "stream": generation_config.stream, + # TODO - We need to cap this to avoid potential errors when exceed max allowable context + "max_tokens": generation_config.max_tokens_to_sample, + } + + return args + + async def aget_completion( + self, + messages: list[dict], + generation_config: GenerationConfig, + **kwargs, + ) -> LLMChatCompletion: + if generation_config.stream: + raise ValueError( + "Stream must be set to False to use the `aget_completion` method." + ) + return await self._aget_completion( + messages, generation_config, **kwargs + ) + + async def _aget_completion( + self, + messages: list[dict], + generation_config: GenerationConfig, + **kwargs, + ) -> Union[LLMChatCompletion, LLMChatCompletionChunk]: + """Asynchronously get a completion from the OpenAI API based on the provided messages.""" + + # Create a dictionary with the default arguments + args = self._get_base_args(generation_config) + + args["messages"] = messages + + # Conditionally add the 'functions' argument if it's not None + if generation_config.functions is not None: + args["functions"] = generation_config.functions + + args = {**args, **kwargs} + # Create the chat completion + return await self.client.chat.completions.create(**args) diff --git a/R2R/r2r/providers/vector_dbs/__init__.py b/R2R/r2r/providers/vector_dbs/__init__.py new file mode 100755 index 00000000..38ea0890 --- /dev/null +++ b/R2R/r2r/providers/vector_dbs/__init__.py @@ -0,0 +1,5 @@ +from .pgvector.pgvector_db import PGVectorDB + +__all__ = [ + "PGVectorDB", +] diff --git a/R2R/r2r/providers/vector_dbs/pgvector/pgvector_db.py b/R2R/r2r/providers/vector_dbs/pgvector/pgvector_db.py new file mode 100755 index 00000000..8cf728d1 --- /dev/null +++ b/R2R/r2r/providers/vector_dbs/pgvector/pgvector_db.py @@ -0,0 +1,610 @@ +import json +import logging +import os +import time +from typing import Literal, Optional, Union + +from sqlalchemy import exc, text +from sqlalchemy.engine.url import make_url + +from r2r.base import ( + DocumentInfo, + UserStats, + VectorDBConfig, + VectorDBProvider, + VectorEntry, + VectorSearchResult, +) +from r2r.vecs.client import Client +from r2r.vecs.collection import Collection + +logger = logging.getLogger(__name__) + + +class PGVectorDB(VectorDBProvider): + def __init__(self, config: VectorDBConfig) -> None: + super().__init__(config) + try: + import r2r.vecs + except ImportError: + raise ValueError( + f"Error, PGVectorDB requires the vecs library. Please run `pip install vecs`." + ) + + # Check if a complete Postgres URI is provided + postgres_uri = self.config.extra_fields.get( + "postgres_uri" + ) or os.getenv("POSTGRES_URI") + + if postgres_uri: + # Log loudly that Postgres URI is being used + logger.warning("=" * 50) + logger.warning( + "ATTENTION: Using provided Postgres URI for connection" + ) + logger.warning("=" * 50) + + # Validate and use the provided URI + try: + parsed_uri = make_url(postgres_uri) + if not all([parsed_uri.username, parsed_uri.database]): + raise ValueError( + "The provided Postgres URI is missing required components." + ) + DB_CONNECTION = postgres_uri + + # Log the sanitized URI (without password) + sanitized_uri = parsed_uri.set(password="*****") + logger.info(f"Connecting using URI: {sanitized_uri}") + except Exception as e: + raise ValueError(f"Invalid Postgres URI provided: {e}") + else: + # Fall back to existing logic for individual connection parameters + user = self.config.extra_fields.get("user", None) or os.getenv( + "POSTGRES_USER" + ) + password = self.config.extra_fields.get( + "password", None + ) or os.getenv("POSTGRES_PASSWORD") + host = self.config.extra_fields.get("host", None) or os.getenv( + "POSTGRES_HOST" + ) + port = self.config.extra_fields.get("port", None) or os.getenv( + "POSTGRES_PORT" + ) + db_name = self.config.extra_fields.get( + "db_name", None + ) or os.getenv("POSTGRES_DBNAME") + + if not all([user, password, host, db_name]): + raise ValueError( + "Error, please set the POSTGRES_USER, POSTGRES_PASSWORD, POSTGRES_HOST, POSTGRES_DBNAME environment variables or provide them in the config." + ) + + # Check if it's a Unix socket connection + if host.startswith("/") and not port: + DB_CONNECTION = ( + f"postgresql://{user}:{password}@/{db_name}?host={host}" + ) + logger.info("Using Unix socket connection") + else: + DB_CONNECTION = ( + f"postgresql://{user}:{password}@{host}:{port}/{db_name}" + ) + logger.info("Using TCP connection") + + # The rest of the initialization remains the same + try: + self.vx: Client = r2r.vecs.create_client(DB_CONNECTION) + except Exception as e: + raise ValueError( + f"Error {e} occurred while attempting to connect to the pgvector provider with {DB_CONNECTION}." + ) + + self.collection_name = self.config.extra_fields.get( + "vecs_collection" + ) or os.getenv("POSTGRES_VECS_COLLECTION") + if not self.collection_name: + raise ValueError( + "Error, please set a valid POSTGRES_VECS_COLLECTION environment variable or set a 'vecs_collection' in the 'vector_database' settings of your `config.json`." + ) + + self.collection: Optional[Collection] = None + + logger.info( + f"Successfully initialized PGVectorDB with collection: {self.collection_name}" + ) + + def initialize_collection(self, dimension: int) -> None: + self.collection = self.vx.get_or_create_collection( + name=self.collection_name, dimension=dimension + ) + self._create_document_info_table() + self._create_hybrid_search_function() + + def _create_document_info_table(self): + with self.vx.Session() as sess: + with sess.begin(): + try: + # Enable uuid-ossp extension + sess.execute( + text('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";') + ) + except exc.ProgrammingError as e: + logger.error(f"Error enabling uuid-ossp extension: {e}") + raise + + # Create the table if it doesn't exist + create_table_query = f""" + CREATE TABLE IF NOT EXISTS document_info_"{self.collection_name}" ( + document_id UUID PRIMARY KEY, + title TEXT, + user_id UUID NULL, + version TEXT, + size_in_bytes INT, + created_at TIMESTAMPTZ DEFAULT NOW(), + updated_at TIMESTAMPTZ DEFAULT NOW(), + metadata JSONB, + status TEXT + ); + """ + sess.execute(text(create_table_query)) + + # Add the new column if it doesn't exist + add_column_query = f""" + DO $$ + BEGIN + IF NOT EXISTS ( + SELECT 1 + FROM information_schema.columns + WHERE table_name = 'document_info_"{self.collection_name}"' + AND column_name = 'status' + ) THEN + ALTER TABLE "document_info_{self.collection_name}" + ADD COLUMN status TEXT DEFAULT 'processing'; + END IF; + END $$; + """ + sess.execute(text(add_column_query)) + + sess.commit() + + def _create_hybrid_search_function(self): + hybrid_search_function = f""" + CREATE OR REPLACE FUNCTION hybrid_search_{self.collection_name}( + query_text TEXT, + query_embedding VECTOR(512), + match_limit INT, + full_text_weight FLOAT = 1, + semantic_weight FLOAT = 1, + rrf_k INT = 50, + filter_condition JSONB = NULL + ) + RETURNS SETOF vecs."{self.collection_name}" + LANGUAGE sql + AS $$ + WITH full_text AS ( + SELECT + id, + ROW_NUMBER() OVER (ORDER BY ts_rank(to_tsvector('english', metadata->>'text'), websearch_to_tsquery(query_text)) DESC) AS rank_ix + FROM vecs."{self.collection_name}" + WHERE to_tsvector('english', metadata->>'text') @@ websearch_to_tsquery(query_text) + AND (filter_condition IS NULL OR (metadata @> filter_condition)) + ORDER BY rank_ix + LIMIT LEAST(match_limit, 30) * 2 + ), + semantic AS ( + SELECT + id, + ROW_NUMBER() OVER (ORDER BY vec <#> query_embedding) AS rank_ix + FROM vecs."{self.collection_name}" + WHERE filter_condition IS NULL OR (metadata @> filter_condition) + ORDER BY rank_ix + LIMIT LEAST(match_limit, 30) * 2 + ) + SELECT + vecs."{self.collection_name}".* + FROM + full_text + FULL OUTER JOIN semantic + ON full_text.id = semantic.id + JOIN vecs."{self.collection_name}" + ON vecs."{self.collection_name}".id = COALESCE(full_text.id, semantic.id) + ORDER BY + COALESCE(1.0 / (rrf_k + full_text.rank_ix), 0.0) * full_text_weight + + COALESCE(1.0 / (rrf_k + semantic.rank_ix), 0.0) * semantic_weight + DESC + LIMIT + LEAST(match_limit, 30); + $$; + """ + retry_attempts = 5 + for attempt in range(retry_attempts): + try: + with self.vx.Session() as sess: + # Acquire an advisory lock + sess.execute(text("SELECT pg_advisory_lock(123456789)")) + try: + sess.execute(text(hybrid_search_function)) + sess.commit() + finally: + # Release the advisory lock + sess.execute( + text("SELECT pg_advisory_unlock(123456789)") + ) + break # Break the loop if successful + except exc.InternalError as e: + if "tuple concurrently updated" in str(e): + time.sleep(2**attempt) # Exponential backoff + else: + raise # Re-raise the exception if it's not a concurrency issue + else: + raise RuntimeError( + "Failed to create hybrid search function after multiple attempts" + ) + + def copy(self, entry: VectorEntry, commit=True) -> None: + if self.collection is None: + raise ValueError( + "Please call `initialize_collection` before attempting to run `copy`." + ) + + serializeable_entry = entry.to_serializable() + + self.collection.copy( + records=[ + ( + serializeable_entry["id"], + serializeable_entry["vector"], + serializeable_entry["metadata"], + ) + ] + ) + + def copy_entries( + self, entries: list[VectorEntry], commit: bool = True + ) -> None: + if self.collection is None: + raise ValueError( + "Please call `initialize_collection` before attempting to run `copy_entries`." + ) + + self.collection.copy( + records=[ + ( + str(entry.id), + entry.vector.data, + entry.to_serializable()["metadata"], + ) + for entry in entries + ] + ) + + def upsert(self, entry: VectorEntry, commit=True) -> None: + if self.collection is None: + raise ValueError( + "Please call `initialize_collection` before attempting to run `upsert`." + ) + + self.collection.upsert( + records=[ + ( + str(entry.id), + entry.vector.data, + entry.to_serializable()["metadata"], + ) + ] + ) + + def upsert_entries( + self, entries: list[VectorEntry], commit: bool = True + ) -> None: + if self.collection is None: + raise ValueError( + "Please call `initialize_collection` before attempting to run `upsert_entries`." + ) + + self.collection.upsert( + records=[ + ( + str(entry.id), + entry.vector.data, + entry.to_serializable()["metadata"], + ) + for entry in entries + ] + ) + + def search( + self, + query_vector: list[float], + filters: dict[str, Union[bool, int, str]] = {}, + limit: int = 10, + *args, + **kwargs, + ) -> list[VectorSearchResult]: + if self.collection is None: + raise ValueError( + "Please call `initialize_collection` before attempting to run `search`." + ) + measure = kwargs.get("measure", "cosine_distance") + mapped_filters = { + key: {"$eq": value} for key, value in filters.items() + } + + return [ + VectorSearchResult(id=ele[0], score=float(1 - ele[1]), metadata=ele[2]) # type: ignore + for ele in self.collection.query( + data=query_vector, + limit=limit, + filters=mapped_filters, + measure=measure, + include_value=True, + include_metadata=True, + ) + ] + + def hybrid_search( + self, + query_text: str, + query_vector: list[float], + limit: int = 10, + filters: Optional[dict[str, Union[bool, int, str]]] = None, + # Hybrid search parameters + full_text_weight: float = 1.0, + semantic_weight: float = 1.0, + rrf_k: int = 20, # typical value is ~2x the number of results you want + *args, + **kwargs, + ) -> list[VectorSearchResult]: + if self.collection is None: + raise ValueError( + "Please call `initialize_collection` before attempting to run `hybrid_search`." + ) + + # Convert filters to a JSON-compatible format + filter_condition = None + if filters: + filter_condition = json.dumps(filters) + + query = text( + f""" + SELECT * FROM hybrid_search_{self.collection_name}( + cast(:query_text as TEXT), cast(:query_embedding as VECTOR), cast(:match_limit as INT), + cast(:full_text_weight as FLOAT), cast(:semantic_weight as FLOAT), cast(:rrf_k as INT), + cast(:filter_condition as JSONB) + ) + """ + ) + + params = { + "query_text": str(query_text), + "query_embedding": list(query_vector), + "match_limit": limit, + "full_text_weight": full_text_weight, + "semantic_weight": semantic_weight, + "rrf_k": rrf_k, + "filter_condition": filter_condition, + } + + with self.vx.Session() as session: + result = session.execute(query, params).fetchall() + return [ + VectorSearchResult(id=row[0], score=1.0, metadata=row[-1]) + for row in result + ] + + def create_index(self, index_type, column_name, index_options): + pass + + def delete_by_metadata( + self, + metadata_fields: list[str], + metadata_values: list[Union[bool, int, str]], + logic: Literal["AND", "OR"] = "AND", + ) -> list[str]: + if logic == "OR": + raise ValueError( + "OR logic is still being tested before official support for `delete_by_metadata` in pgvector." + ) + if self.collection is None: + raise ValueError( + "Please call `initialize_collection` before attempting to run `delete_by_metadata`." + ) + + if len(metadata_fields) != len(metadata_values): + raise ValueError( + "The number of metadata fields must match the number of metadata values." + ) + + # Construct the filter + if logic == "AND": + filters = { + k: {"$eq": v} for k, v in zip(metadata_fields, metadata_values) + } + else: # OR logic + # TODO - Test 'or' logic and remove check above + filters = { + "$or": [ + {k: {"$eq": v}} + for k, v in zip(metadata_fields, metadata_values) + ] + } + return self.collection.delete(filters=filters) + + def get_metadatas( + self, + metadata_fields: list[str], + filter_field: Optional[str] = None, + filter_value: Optional[Union[bool, int, str]] = None, + ) -> list[dict]: + if self.collection is None: + raise ValueError( + "Please call `initialize_collection` before attempting to run `get_metadatas`." + ) + + results = {tuple(metadata_fields): {}} + for field in metadata_fields: + unique_values = self.collection.get_unique_metadata_values( + field=field, + filter_field=filter_field, + filter_value=filter_value, + ) + for value in unique_values: + if value not in results: + results[value] = {} + results[value][field] = value + + return [ + results[key] for key in results if key != tuple(metadata_fields) + ] + + def upsert_documents_overview( + self, documents_overview: list[DocumentInfo] + ) -> None: + for document_info in documents_overview: + db_entry = document_info.convert_to_db_entry() + + # Convert 'None' string to None type for user_id + if db_entry["user_id"] == "None": + db_entry["user_id"] = None + + query = text( + f""" + INSERT INTO "document_info_{self.collection_name}" (document_id, title, user_id, version, created_at, updated_at, size_in_bytes, metadata, status) + VALUES (:document_id, :title, :user_id, :version, :created_at, :updated_at, :size_in_bytes, :metadata, :status) + ON CONFLICT (document_id) DO UPDATE SET + title = EXCLUDED.title, + user_id = EXCLUDED.user_id, + version = EXCLUDED.version, + updated_at = EXCLUDED.updated_at, + size_in_bytes = EXCLUDED.size_in_bytes, + metadata = EXCLUDED.metadata, + status = EXCLUDED.status; + """ + ) + with self.vx.Session() as sess: + sess.execute(query, db_entry) + sess.commit() + + def delete_from_documents_overview( + self, document_id: str, version: Optional[str] = None + ) -> None: + query = f""" + DELETE FROM "document_info_{self.collection_name}" + WHERE document_id = :document_id + """ + params = {"document_id": document_id} + + if version is not None: + query += " AND version = :version" + params["version"] = version + + with self.vx.Session() as sess: + with sess.begin(): + sess.execute(text(query), params) + sess.commit() + + def get_documents_overview( + self, + filter_document_ids: Optional[list[str]] = None, + filter_user_ids: Optional[list[str]] = None, + ): + conditions = [] + params = {} + + if filter_document_ids: + placeholders = ", ".join( + f":doc_id_{i}" for i in range(len(filter_document_ids)) + ) + conditions.append(f"document_id IN ({placeholders})") + params.update( + { + f"doc_id_{i}": str(document_id) + for i, document_id in enumerate(filter_document_ids) + } + ) + if filter_user_ids: + placeholders = ", ".join( + f":user_id_{i}" for i in range(len(filter_user_ids)) + ) + conditions.append(f"user_id IN ({placeholders})") + params.update( + { + f"user_id_{i}": str(user_id) + for i, user_id in enumerate(filter_user_ids) + } + ) + + query = f""" + SELECT document_id, title, user_id, version, size_in_bytes, created_at, updated_at, metadata, status + FROM "document_info_{self.collection_name}" + """ + if conditions: + query += " WHERE " + " AND ".join(conditions) + + with self.vx.Session() as sess: + results = sess.execute(text(query), params).fetchall() + return [ + DocumentInfo( + document_id=row[0], + title=row[1], + user_id=row[2], + version=row[3], + size_in_bytes=row[4], + created_at=row[5], + updated_at=row[6], + metadata=row[7], + status=row[8], + ) + for row in results + ] + + def get_document_chunks(self, document_id: str) -> list[dict]: + if not self.collection: + raise ValueError("Collection is not initialized.") + + table_name = self.collection.table.name + query = text( + f""" + SELECT metadata + FROM vecs."{table_name}" + WHERE metadata->>'document_id' = :document_id + ORDER BY CAST(metadata->>'chunk_order' AS INTEGER) + """ + ) + + params = {"document_id": document_id} + + with self.vx.Session() as sess: + results = sess.execute(query, params).fetchall() + return [result[0] for result in results] + + def get_users_overview(self, user_ids: Optional[list[str]] = None): + user_ids_condition = "" + params = {} + if user_ids: + user_ids_condition = "WHERE user_id IN :user_ids" + params["user_ids"] = tuple( + map(str, user_ids) + ) # Convert UUIDs to strings + + query = f""" + SELECT user_id, COUNT(document_id) AS num_files, SUM(size_in_bytes) AS total_size_in_bytes, ARRAY_AGG(document_id) AS document_ids + FROM "document_info_{self.collection_name}" + {user_ids_condition} + GROUP BY user_id + """ + + with self.vx.Session() as sess: + results = sess.execute(text(query), params).fetchall() + return [ + UserStats( + user_id=row[0], + num_files=row[1], + total_size_in_bytes=row[2], + document_ids=row[3], + ) + for row in results + if row[0] is not None + ] |