aboutsummaryrefslogtreecommitdiff
path: root/R2R/r2r/providers
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /R2R/r2r/providers
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to 'R2R/r2r/providers')
-rwxr-xr-xR2R/r2r/providers/__init__.py0
-rwxr-xr-xR2R/r2r/providers/embeddings/__init__.py11
-rwxr-xr-xR2R/r2r/providers/embeddings/ollama/ollama_base.py156
-rwxr-xr-xR2R/r2r/providers/embeddings/openai/openai_base.py200
-rwxr-xr-xR2R/r2r/providers/embeddings/sentence_transformer/sentence_transformer_base.py160
-rwxr-xr-xR2R/r2r/providers/eval/__init__.py3
-rwxr-xr-xR2R/r2r/providers/eval/llm/base_llm_eval.py84
-rwxr-xr-xR2R/r2r/providers/kg/__init__.py3
-rwxr-xr-xR2R/r2r/providers/kg/neo4j/base_neo4j.py983
-rwxr-xr-xR2R/r2r/providers/llms/__init__.py7
-rwxr-xr-xR2R/r2r/providers/llms/litellm/base_litellm.py142
-rwxr-xr-xR2R/r2r/providers/llms/openai/base_openai.py144
-rwxr-xr-xR2R/r2r/providers/vector_dbs/__init__.py5
-rwxr-xr-xR2R/r2r/providers/vector_dbs/pgvector/pgvector_db.py610
14 files changed, 2508 insertions, 0 deletions
diff --git a/R2R/r2r/providers/__init__.py b/R2R/r2r/providers/__init__.py
new file mode 100755
index 00000000..e69de29b
--- /dev/null
+++ b/R2R/r2r/providers/__init__.py
diff --git a/R2R/r2r/providers/embeddings/__init__.py b/R2R/r2r/providers/embeddings/__init__.py
new file mode 100755
index 00000000..6b0c8b83
--- /dev/null
+++ b/R2R/r2r/providers/embeddings/__init__.py
@@ -0,0 +1,11 @@
+from .ollama.ollama_base import OllamaEmbeddingProvider
+from .openai.openai_base import OpenAIEmbeddingProvider
+from .sentence_transformer.sentence_transformer_base import (
+ SentenceTransformerEmbeddingProvider,
+)
+
+__all__ = [
+ "OllamaEmbeddingProvider",
+ "OpenAIEmbeddingProvider",
+ "SentenceTransformerEmbeddingProvider",
+]
diff --git a/R2R/r2r/providers/embeddings/ollama/ollama_base.py b/R2R/r2r/providers/embeddings/ollama/ollama_base.py
new file mode 100755
index 00000000..31a8c717
--- /dev/null
+++ b/R2R/r2r/providers/embeddings/ollama/ollama_base.py
@@ -0,0 +1,156 @@
+import asyncio
+import logging
+import os
+import random
+from typing import Any
+
+from ollama import AsyncClient, Client
+
+from r2r.base import EmbeddingConfig, EmbeddingProvider, VectorSearchResult
+
+logger = logging.getLogger(__name__)
+
+
+class OllamaEmbeddingProvider(EmbeddingProvider):
+ def __init__(self, config: EmbeddingConfig):
+ super().__init__(config)
+ provider = config.provider
+ if not provider:
+ raise ValueError(
+ "Must set provider in order to initialize `OllamaEmbeddingProvider`."
+ )
+ if provider != "ollama":
+ raise ValueError(
+ "OllamaEmbeddingProvider must be initialized with provider `ollama`."
+ )
+ if config.rerank_model:
+ raise ValueError(
+ "OllamaEmbeddingProvider does not support separate reranking."
+ )
+
+ self.base_model = config.base_model
+ self.base_dimension = config.base_dimension
+ self.base_url = os.getenv("OLLAMA_API_BASE")
+ logger.info(
+ f"Using Ollama API base URL: {self.base_url or 'http://127.0.0.1:11434'}"
+ )
+ self.client = Client(host=self.base_url)
+ self.aclient = AsyncClient(host=self.base_url)
+
+ self.request_queue = asyncio.Queue()
+ self.max_retries = 2
+ self.initial_backoff = 1
+ self.max_backoff = 60
+ self.concurrency_limit = 10
+ self.semaphore = asyncio.Semaphore(self.concurrency_limit)
+
+ async def process_queue(self):
+ while True:
+ task = await self.request_queue.get()
+ try:
+ result = await self.execute_task_with_backoff(task)
+ task["future"].set_result(result)
+ except Exception as e:
+ task["future"].set_exception(e)
+ finally:
+ self.request_queue.task_done()
+
+ async def execute_task_with_backoff(self, task: dict[str, Any]):
+ retries = 0
+ backoff = self.initial_backoff
+ while retries < self.max_retries:
+ try:
+ async with self.semaphore:
+ response = await asyncio.wait_for(
+ self.aclient.embeddings(
+ prompt=task["text"], model=self.base_model
+ ),
+ timeout=30,
+ )
+ return response["embedding"]
+ except Exception as e:
+ logger.warning(
+ f"Request failed (attempt {retries + 1}): {str(e)}"
+ )
+ retries += 1
+ if retries == self.max_retries:
+ raise Exception(
+ f"Max retries reached. Last error: {str(e)}"
+ )
+ await asyncio.sleep(backoff + random.uniform(0, 1))
+ backoff = min(backoff * 2, self.max_backoff)
+
+ def get_embedding(
+ self,
+ text: str,
+ stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
+ ) -> list[float]:
+ if stage != EmbeddingProvider.PipeStage.BASE:
+ raise ValueError(
+ "OllamaEmbeddingProvider only supports search stage."
+ )
+
+ try:
+ response = self.client.embeddings(
+ prompt=text, model=self.base_model
+ )
+ return response["embedding"]
+ except Exception as e:
+ logger.error(f"Error getting embedding: {str(e)}")
+ raise
+
+ def get_embeddings(
+ self,
+ texts: list[str],
+ stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
+ ) -> list[list[float]]:
+ return [self.get_embedding(text, stage) for text in texts]
+
+ async def async_get_embeddings(
+ self,
+ texts: list[str],
+ stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
+ ) -> list[list[float]]:
+ if stage != EmbeddingProvider.PipeStage.BASE:
+ raise ValueError(
+ "OllamaEmbeddingProvider only supports search stage."
+ )
+
+ queue_processor = asyncio.create_task(self.process_queue())
+ futures = []
+ for text in texts:
+ future = asyncio.Future()
+ await self.request_queue.put({"text": text, "future": future})
+ futures.append(future)
+
+ try:
+ results = await asyncio.gather(*futures, return_exceptions=True)
+ # Check if any result is an exception and raise it
+ exceptions = set([r for r in results if isinstance(r, Exception)])
+ if exceptions:
+ raise Exception(
+ f"Embedding generation failed for one or more embeddings."
+ )
+ return results
+ except Exception as e:
+ logger.error(f"Embedding generation failed: {str(e)}")
+ raise
+ finally:
+ await self.request_queue.join()
+ queue_processor.cancel()
+
+ def rerank(
+ self,
+ query: str,
+ results: list[VectorSearchResult],
+ stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.RERANK,
+ limit: int = 10,
+ ) -> list[VectorSearchResult]:
+ return results[:limit]
+
+ def tokenize_string(
+ self, text: str, model: str, stage: EmbeddingProvider.PipeStage
+ ) -> list[int]:
+ raise NotImplementedError(
+ "Tokenization is not supported by OllamaEmbeddingProvider."
+ )
diff --git a/R2R/r2r/providers/embeddings/openai/openai_base.py b/R2R/r2r/providers/embeddings/openai/openai_base.py
new file mode 100755
index 00000000..7e7d32aa
--- /dev/null
+++ b/R2R/r2r/providers/embeddings/openai/openai_base.py
@@ -0,0 +1,200 @@
+import logging
+import os
+
+from openai import AsyncOpenAI, AuthenticationError, OpenAI
+
+from r2r.base import EmbeddingConfig, EmbeddingProvider, VectorSearchResult
+
+logger = logging.getLogger(__name__)
+
+
+class OpenAIEmbeddingProvider(EmbeddingProvider):
+ MODEL_TO_TOKENIZER = {
+ "text-embedding-ada-002": "cl100k_base",
+ "text-embedding-3-small": "cl100k_base",
+ "text-embedding-3-large": "cl100k_base",
+ }
+ MODEL_TO_DIMENSIONS = {
+ "text-embedding-ada-002": [1536],
+ "text-embedding-3-small": [512, 1536],
+ "text-embedding-3-large": [256, 1024, 3072],
+ }
+
+ def __init__(self, config: EmbeddingConfig):
+ super().__init__(config)
+ provider = config.provider
+ if not provider:
+ raise ValueError(
+ "Must set provider in order to initialize OpenAIEmbeddingProvider."
+ )
+
+ if provider != "openai":
+ raise ValueError(
+ "OpenAIEmbeddingProvider must be initialized with provider `openai`."
+ )
+ if not os.getenv("OPENAI_API_KEY"):
+ raise ValueError(
+ "Must set OPENAI_API_KEY in order to initialize OpenAIEmbeddingProvider."
+ )
+ self.client = OpenAI()
+ self.async_client = AsyncOpenAI()
+
+ if config.rerank_model:
+ raise ValueError(
+ "OpenAIEmbeddingProvider does not support separate reranking."
+ )
+ self.base_model = config.base_model
+ self.base_dimension = config.base_dimension
+
+ if self.base_model not in OpenAIEmbeddingProvider.MODEL_TO_TOKENIZER:
+ raise ValueError(
+ f"OpenAI embedding model {self.base_model} not supported."
+ )
+ if (
+ self.base_dimension
+ and self.base_dimension
+ not in OpenAIEmbeddingProvider.MODEL_TO_DIMENSIONS[self.base_model]
+ ):
+ raise ValueError(
+ f"Dimensions {self.dimension} for {self.base_model} are not supported"
+ )
+
+ if not self.base_model or not self.base_dimension:
+ raise ValueError(
+ "Must set base_model and base_dimension in order to initialize OpenAIEmbeddingProvider."
+ )
+
+ if config.rerank_model:
+ raise ValueError(
+ "OpenAIEmbeddingProvider does not support separate reranking."
+ )
+
+ def get_embedding(
+ self,
+ text: str,
+ stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
+ ) -> list[float]:
+ if stage != EmbeddingProvider.PipeStage.BASE:
+ raise ValueError(
+ "OpenAIEmbeddingProvider only supports search stage."
+ )
+
+ try:
+ return (
+ self.client.embeddings.create(
+ input=[text],
+ model=self.base_model,
+ dimensions=self.base_dimension
+ or OpenAIEmbeddingProvider.MODEL_TO_DIMENSIONS[
+ self.base_model
+ ][-1],
+ )
+ .data[0]
+ .embedding
+ )
+ except AuthenticationError as e:
+ raise ValueError(
+ "Invalid OpenAI API key provided. Please check your OPENAI_API_KEY environment variable."
+ ) from e
+
+ async def async_get_embedding(
+ self,
+ text: str,
+ stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
+ ) -> list[float]:
+ if stage != EmbeddingProvider.PipeStage.BASE:
+ raise ValueError(
+ "OpenAIEmbeddingProvider only supports search stage."
+ )
+
+ try:
+ response = await self.async_client.embeddings.create(
+ input=[text],
+ model=self.base_model,
+ dimensions=self.base_dimension
+ or OpenAIEmbeddingProvider.MODEL_TO_DIMENSIONS[
+ self.base_model
+ ][-1],
+ )
+ return response.data[0].embedding
+ except AuthenticationError as e:
+ raise ValueError(
+ "Invalid OpenAI API key provided. Please check your OPENAI_API_KEY environment variable."
+ ) from e
+
+ def get_embeddings(
+ self,
+ texts: list[str],
+ stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
+ ) -> list[list[float]]:
+ if stage != EmbeddingProvider.PipeStage.BASE:
+ raise ValueError(
+ "OpenAIEmbeddingProvider only supports search stage."
+ )
+
+ try:
+ return [
+ ele.embedding
+ for ele in self.client.embeddings.create(
+ input=texts,
+ model=self.base_model,
+ dimensions=self.base_dimension
+ or OpenAIEmbeddingProvider.MODEL_TO_DIMENSIONS[
+ self.base_model
+ ][-1],
+ ).data
+ ]
+ except AuthenticationError as e:
+ raise ValueError(
+ "Invalid OpenAI API key provided. Please check your OPENAI_API_KEY environment variable."
+ ) from e
+
+ async def async_get_embeddings(
+ self,
+ texts: list[str],
+ stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
+ ) -> list[list[float]]:
+ if stage != EmbeddingProvider.PipeStage.BASE:
+ raise ValueError(
+ "OpenAIEmbeddingProvider only supports search stage."
+ )
+
+ try:
+ response = await self.async_client.embeddings.create(
+ input=texts,
+ model=self.base_model,
+ dimensions=self.base_dimension
+ or OpenAIEmbeddingProvider.MODEL_TO_DIMENSIONS[
+ self.base_model
+ ][-1],
+ )
+ return [ele.embedding for ele in response.data]
+ except AuthenticationError as e:
+ raise ValueError(
+ "Invalid OpenAI API key provided. Please check your OPENAI_API_KEY environment variable."
+ ) from e
+
+ def rerank(
+ self,
+ query: str,
+ results: list[VectorSearchResult],
+ stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.RERANK,
+ limit: int = 10,
+ ):
+ return results[:limit]
+
+ def tokenize_string(self, text: str, model: str) -> list[int]:
+ try:
+ import tiktoken
+ except ImportError:
+ raise ValueError(
+ "Must download tiktoken library to run `tokenize_string`."
+ )
+ # tiktoken encoding -
+ # cl100k_base - gpt-4, gpt-3.5-turbo, text-embedding-ada-002, text-embedding-3-small, text-embedding-3-large
+ if model not in OpenAIEmbeddingProvider.MODEL_TO_TOKENIZER:
+ raise ValueError(f"OpenAI embedding model {model} not supported.")
+ encoding = tiktoken.get_encoding(
+ OpenAIEmbeddingProvider.MODEL_TO_TOKENIZER[model]
+ )
+ return encoding.encode(text)
diff --git a/R2R/r2r/providers/embeddings/sentence_transformer/sentence_transformer_base.py b/R2R/r2r/providers/embeddings/sentence_transformer/sentence_transformer_base.py
new file mode 100755
index 00000000..3316cb60
--- /dev/null
+++ b/R2R/r2r/providers/embeddings/sentence_transformer/sentence_transformer_base.py
@@ -0,0 +1,160 @@
+import logging
+
+from r2r.base import EmbeddingConfig, EmbeddingProvider, VectorSearchResult
+
+logger = logging.getLogger(__name__)
+
+
+class SentenceTransformerEmbeddingProvider(EmbeddingProvider):
+ def __init__(
+ self,
+ config: EmbeddingConfig,
+ ):
+ super().__init__(config)
+ logger.info(
+ "Initializing `SentenceTransformerEmbeddingProvider` with separate models for search and rerank."
+ )
+ provider = config.provider
+ if not provider:
+ raise ValueError(
+ "Must set provider in order to initialize SentenceTransformerEmbeddingProvider."
+ )
+ if provider != "sentence-transformers":
+ raise ValueError(
+ "SentenceTransformerEmbeddingProvider must be initialized with provider `sentence-transformers`."
+ )
+ try:
+ from sentence_transformers import CrossEncoder, SentenceTransformer
+
+ self.SentenceTransformer = SentenceTransformer
+ # TODO - Modify this to be configurable, as `bge-reranker-large` is a `SentenceTransformer` model
+ self.CrossEncoder = CrossEncoder
+ except ImportError as e:
+ raise ValueError(
+ "Must download sentence-transformers library to run `SentenceTransformerEmbeddingProvider`."
+ ) from e
+
+ # Initialize separate models for search and rerank
+ self.do_search = False
+ self.do_rerank = False
+
+ self.search_encoder = self._init_model(
+ config, EmbeddingProvider.PipeStage.BASE
+ )
+ self.rerank_encoder = self._init_model(
+ config, EmbeddingProvider.PipeStage.RERANK
+ )
+
+ def _init_model(self, config: EmbeddingConfig, stage: str):
+ stage_name = stage.name.lower()
+ model = config.dict().get(f"{stage_name}_model", None)
+ dimension = config.dict().get(f"{stage_name}_dimension", None)
+
+ transformer_type = config.dict().get(
+ f"{stage_name}_transformer_type", "SentenceTransformer"
+ )
+
+ if stage == EmbeddingProvider.PipeStage.BASE:
+ self.do_search = True
+ # Check if a model is set for the stage
+ if not (model and dimension and transformer_type):
+ raise ValueError(
+ f"Must set {stage.name.lower()}_model and {stage.name.lower()}_dimension for {stage} stage in order to initialize SentenceTransformerEmbeddingProvider."
+ )
+
+ if stage == EmbeddingProvider.PipeStage.RERANK:
+ # Check if a model is set for the stage
+ if not (model and dimension and transformer_type):
+ return None
+
+ self.do_rerank = True
+ if transformer_type == "SentenceTransformer":
+ raise ValueError(
+ f"`SentenceTransformer` models are not yet supported for {stage} stage in SentenceTransformerEmbeddingProvider."
+ )
+
+ # Save the model_key and dimension into instance variables
+ setattr(self, f"{stage_name}_model", model)
+ setattr(self, f"{stage_name}_dimension", dimension)
+ setattr(self, f"{stage_name}_transformer_type", transformer_type)
+
+ # Initialize the model
+ encoder = (
+ self.SentenceTransformer(
+ model, truncate_dim=dimension, trust_remote_code=True
+ )
+ if transformer_type == "SentenceTransformer"
+ else self.CrossEncoder(model, trust_remote_code=True)
+ )
+ return encoder
+
+ def get_embedding(
+ self,
+ text: str,
+ stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
+ ) -> list[float]:
+ if stage != EmbeddingProvider.PipeStage.BASE:
+ raise ValueError("`get_embedding` only supports `SEARCH` stage.")
+ if not self.do_search:
+ raise ValueError(
+ "`get_embedding` can only be called for the search stage if a search model is set."
+ )
+ encoder = self.search_encoder
+ return encoder.encode([text]).tolist()[0]
+
+ def get_embeddings(
+ self,
+ texts: list[str],
+ stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
+ ) -> list[list[float]]:
+ if stage != EmbeddingProvider.PipeStage.BASE:
+ raise ValueError("`get_embeddings` only supports `SEARCH` stage.")
+ if not self.do_search:
+ raise ValueError(
+ "`get_embeddings` can only be called for the search stage if a search model is set."
+ )
+ encoder = (
+ self.search_encoder
+ if stage == EmbeddingProvider.PipeStage.BASE
+ else self.rerank_encoder
+ )
+ return encoder.encode(texts).tolist()
+
+ def rerank(
+ self,
+ query: str,
+ results: list[VectorSearchResult],
+ stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.RERANK,
+ limit: int = 10,
+ ) -> list[VectorSearchResult]:
+ if stage != EmbeddingProvider.PipeStage.RERANK:
+ raise ValueError("`rerank` only supports `RERANK` stage.")
+ if not self.do_rerank:
+ return results[:limit]
+
+ from copy import copy
+
+ texts = copy([doc.metadata["text"] for doc in results])
+ # Use the rank method from the rerank_encoder, which is a CrossEncoder model
+ reranked_scores = self.rerank_encoder.rank(
+ query, texts, return_documents=False, top_k=limit
+ )
+ # Map the reranked scores back to the original documents
+ reranked_results = []
+ for score in reranked_scores:
+ corpus_id = score["corpus_id"]
+ new_result = results[corpus_id]
+ new_result.score = float(score["score"])
+ reranked_results.append(new_result)
+
+ # Sort the documents by the new scores in descending order
+ reranked_results.sort(key=lambda doc: doc.score, reverse=True)
+ return reranked_results
+
+ def tokenize_string(
+ self,
+ stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
+ ) -> list[int]:
+ raise ValueError(
+ "SentenceTransformerEmbeddingProvider does not support tokenize_string."
+ )
diff --git a/R2R/r2r/providers/eval/__init__.py b/R2R/r2r/providers/eval/__init__.py
new file mode 100755
index 00000000..3f5e1b51
--- /dev/null
+++ b/R2R/r2r/providers/eval/__init__.py
@@ -0,0 +1,3 @@
+from .llm.base_llm_eval import LLMEvalProvider
+
+__all__ = ["LLMEvalProvider"]
diff --git a/R2R/r2r/providers/eval/llm/base_llm_eval.py b/R2R/r2r/providers/eval/llm/base_llm_eval.py
new file mode 100755
index 00000000..7c573a34
--- /dev/null
+++ b/R2R/r2r/providers/eval/llm/base_llm_eval.py
@@ -0,0 +1,84 @@
+from fractions import Fraction
+from typing import Union
+
+from r2r import EvalConfig, EvalProvider, LLMProvider, PromptProvider
+from r2r.base.abstractions.llm import GenerationConfig
+
+
+class LLMEvalProvider(EvalProvider):
+ def __init__(
+ self,
+ config: EvalConfig,
+ llm_provider: LLMProvider,
+ prompt_provider: PromptProvider,
+ ):
+ super().__init__(config)
+
+ self.llm_provider = llm_provider
+ self.prompt_provider = prompt_provider
+
+ def _calc_query_context_relevancy(self, query: str, context: str) -> float:
+ system_prompt = self.prompt_provider.get_prompt("default_system")
+ eval_prompt = self.prompt_provider.get_prompt(
+ "rag_context_eval", {"query": query, "context": context}
+ )
+ response = self.llm_provider.get_completion(
+ self.prompt_provider._get_message_payload(
+ system_prompt, eval_prompt
+ ),
+ self.eval_generation_config,
+ )
+ response_text = response.choices[0].message.content
+ fraction = (
+ response_text
+ # Get the fraction in the returned tuple
+ .split(",")[-1][:-1]
+ # Remove any quotes and spaces
+ .replace("'", "")
+ .replace('"', "")
+ .strip()
+ )
+ return float(Fraction(fraction))
+
+ def _calc_answer_grounding(
+ self, query: str, context: str, answer: str
+ ) -> float:
+ system_prompt = self.prompt_provider.get_prompt("default_system")
+ eval_prompt = self.prompt_provider.get_prompt(
+ "rag_answer_eval",
+ {"query": query, "context": context, "answer": answer},
+ )
+ response = self.llm_provider.get_completion(
+ self.prompt_provider._get_message_payload(
+ system_prompt, eval_prompt
+ ),
+ self.eval_generation_config,
+ )
+ response_text = response.choices[0].message.content
+ fraction = (
+ response_text
+ # Get the fraction in the returned tuple
+ .split(",")[-1][:-1]
+ # Remove any quotes and spaces
+ .replace("'", "")
+ .replace('"', "")
+ .strip()
+ )
+ return float(Fraction(fraction))
+
+ def _evaluate(
+ self,
+ query: str,
+ context: str,
+ answer: str,
+ eval_generation_config: GenerationConfig,
+ ) -> dict[str, dict[str, Union[str, float]]]:
+ self.eval_generation_config = eval_generation_config
+ query_context_relevancy = self._calc_query_context_relevancy(
+ query, context
+ )
+ answer_grounding = self._calc_answer_grounding(query, context, answer)
+ return {
+ "query_context_relevancy": query_context_relevancy,
+ "answer_grounding": answer_grounding,
+ }
diff --git a/R2R/r2r/providers/kg/__init__.py b/R2R/r2r/providers/kg/__init__.py
new file mode 100755
index 00000000..36bc79a2
--- /dev/null
+++ b/R2R/r2r/providers/kg/__init__.py
@@ -0,0 +1,3 @@
+from .neo4j.base_neo4j import Neo4jKGProvider
+
+__all__ = ["Neo4jKGProvider"]
diff --git a/R2R/r2r/providers/kg/neo4j/base_neo4j.py b/R2R/r2r/providers/kg/neo4j/base_neo4j.py
new file mode 100755
index 00000000..9ede2b85
--- /dev/null
+++ b/R2R/r2r/providers/kg/neo4j/base_neo4j.py
@@ -0,0 +1,983 @@
+# abstractions are taken from LlamaIndex
+# Neo4jKGProvider is almost entirely taken from LlamaIndex Neo4jPropertyGraphStore
+# https://github.com/run-llama/llama_index
+import json
+import os
+from typing import Any, Dict, List, Optional, Tuple
+
+from r2r.base import (
+ EntityType,
+ KGConfig,
+ KGProvider,
+ PromptProvider,
+ format_entity_types,
+ format_relations,
+)
+from r2r.base.abstractions.llama_abstractions import (
+ LIST_LIMIT,
+ ChunkNode,
+ EntityNode,
+ LabelledNode,
+ PropertyGraphStore,
+ Relation,
+ Triplet,
+ VectorStoreQuery,
+ clean_string_values,
+ value_sanitize,
+)
+
+
+def remove_empty_values(input_dict):
+ """
+ Remove entries with empty values from the dictionary.
+
+ Parameters:
+ input_dict (dict): The dictionary from which empty values need to be removed.
+
+ Returns:
+ dict: A new dictionary with all empty values removed.
+ """
+ # Create a new dictionary excluding empty values
+ return {key: value for key, value in input_dict.items() if value}
+
+
+BASE_ENTITY_LABEL = "__Entity__"
+EXCLUDED_LABELS = ["_Bloom_Perspective_", "_Bloom_Scene_"]
+EXCLUDED_RELS = ["_Bloom_HAS_SCENE_"]
+EXHAUSTIVE_SEARCH_LIMIT = 10000
+# Threshold for returning all available prop values in graph schema
+DISTINCT_VALUE_LIMIT = 10
+
+node_properties_query = """
+CALL apoc.meta.data()
+YIELD label, other, elementType, type, property
+WHERE NOT type = "RELATIONSHIP" AND elementType = "node"
+ AND NOT label IN $EXCLUDED_LABELS
+WITH label AS nodeLabels, collect({property:property, type:type}) AS properties
+RETURN {labels: nodeLabels, properties: properties} AS output
+
+"""
+
+rel_properties_query = """
+CALL apoc.meta.data()
+YIELD label, other, elementType, type, property
+WHERE NOT type = "RELATIONSHIP" AND elementType = "relationship"
+ AND NOT label in $EXCLUDED_LABELS
+WITH label AS nodeLabels, collect({property:property, type:type}) AS properties
+RETURN {type: nodeLabels, properties: properties} AS output
+"""
+
+rel_query = """
+CALL apoc.meta.data()
+YIELD label, other, elementType, type, property
+WHERE type = "RELATIONSHIP" AND elementType = "node"
+UNWIND other AS other_node
+WITH * WHERE NOT label IN $EXCLUDED_LABELS
+ AND NOT other_node IN $EXCLUDED_LABELS
+RETURN {start: label, type: property, end: toString(other_node)} AS output
+"""
+
+
+class Neo4jKGProvider(PropertyGraphStore, KGProvider):
+ r"""
+ Neo4j Property Graph Store.
+
+ This class implements a Neo4j property graph store.
+
+ If you are using local Neo4j instead of aura, here's a helpful
+ command for launching the docker container:
+
+ ```bash
+ docker run \
+ -p 7474:7474 -p 7687:7687 \
+ -v $PWD/data:/data -v $PWD/plugins:/plugins \
+ --name neo4j-apoc \
+ -e NEO4J_apoc_export_file_enabled=true \
+ -e NEO4J_apoc_import_file_enabled=true \
+ -e NEO4J_apoc_import_file_use__neo4j__config=true \
+ -e NEO4JLABS_PLUGINS=\\[\"apoc\"\\] \
+ neo4j:latest
+ ```
+
+ Args:
+ username (str): The username for the Neo4j database.
+ password (str): The password for the Neo4j database.
+ url (str): The URL for the Neo4j database.
+ database (Optional[str]): The name of the database to connect to. Defaults to "neo4j".
+
+ Examples:
+ `pip install llama-index-graph-stores-neo4j`
+
+ ```python
+ from llama_index.core.indices.property_graph import PropertyGraphIndex
+ from llama_index.graph_stores.neo4j import Neo4jKGProvider
+
+ # Create a Neo4jKGProvider instance
+ graph_store = Neo4jKGProvider(
+ username="neo4j",
+ password="neo4j",
+ url="bolt://localhost:7687",
+ database="neo4j"
+ )
+
+ # create the index
+ index = PropertyGraphIndex.from_documents(
+ documents,
+ property_graph_store=graph_store,
+ )
+ ```
+ """
+
+ supports_structured_queries: bool = True
+ supports_vector_queries: bool = True
+
+ def __init__(
+ self,
+ config: KGConfig,
+ refresh_schema: bool = True,
+ sanitize_query_output: bool = True,
+ enhanced_schema: bool = False,
+ *args: Any,
+ **kwargs: Any,
+ ) -> None:
+ if config.provider != "neo4j":
+ raise ValueError(
+ "Neo4jKGProvider must be initialized with config with `neo4j` provider."
+ )
+
+ try:
+ import neo4j
+ except ImportError:
+ raise ImportError("Please install neo4j: pip install neo4j")
+
+ username = os.getenv("NEO4J_USER")
+ password = os.getenv("NEO4J_PASSWORD")
+ url = os.getenv("NEO4J_URL")
+ database = os.getenv("NEO4J_DATABASE", "neo4j")
+
+ if not username or not password or not url:
+ raise ValueError(
+ "Neo4j configuration values are missing. Please set NEO4J_USER, NEO4J_PASSWORD, and NEO4J_URL environment variables."
+ )
+
+ self.sanitize_query_output = sanitize_query_output
+ self.enhcnaced_schema = enhanced_schema
+ self._driver = neo4j.GraphDatabase.driver(
+ url, auth=(username, password), **kwargs
+ )
+ self._async_driver = neo4j.AsyncGraphDatabase.driver(
+ url,
+ auth=(username, password),
+ **kwargs,
+ )
+ self._database = database
+ self.structured_schema = {}
+ if refresh_schema:
+ self.refresh_schema()
+ self.neo4j = neo4j
+ self.config = config
+
+ @property
+ def client(self):
+ return self._driver
+
+ def refresh_schema(self) -> None:
+ """Refresh the schema."""
+ node_query_results = self.structured_query(
+ node_properties_query,
+ param_map={
+ "EXCLUDED_LABELS": [*EXCLUDED_LABELS, BASE_ENTITY_LABEL]
+ },
+ )
+ node_properties = (
+ [el["output"] for el in node_query_results]
+ if node_query_results
+ else []
+ )
+
+ rels_query_result = self.structured_query(
+ rel_properties_query, param_map={"EXCLUDED_LABELS": EXCLUDED_RELS}
+ )
+ rel_properties = (
+ [el["output"] for el in rels_query_result]
+ if rels_query_result
+ else []
+ )
+
+ rel_objs_query_result = self.structured_query(
+ rel_query,
+ param_map={
+ "EXCLUDED_LABELS": [*EXCLUDED_LABELS, BASE_ENTITY_LABEL]
+ },
+ )
+ relationships = (
+ [el["output"] for el in rel_objs_query_result]
+ if rel_objs_query_result
+ else []
+ )
+
+ # Get constraints & indexes
+ try:
+ constraint = self.structured_query("SHOW CONSTRAINTS")
+ index = self.structured_query(
+ "CALL apoc.schema.nodes() YIELD label, properties, type, size, "
+ "valuesSelectivity WHERE type = 'RANGE' RETURN *, "
+ "size * valuesSelectivity as distinctValues"
+ )
+ except (
+ self.neo4j.exceptions.ClientError
+ ): # Read-only user might not have access to schema information
+ constraint = []
+ index = []
+
+ self.structured_schema = {
+ "node_props": {
+ el["labels"]: el["properties"] for el in node_properties
+ },
+ "rel_props": {
+ el["type"]: el["properties"] for el in rel_properties
+ },
+ "relationships": relationships,
+ "metadata": {"constraint": constraint, "index": index},
+ }
+ schema_counts = self.structured_query(
+ "CALL apoc.meta.graphSample() YIELD nodes, relationships "
+ "RETURN nodes, [rel in relationships | {name:apoc.any.property"
+ "(rel, 'type'), count: apoc.any.property(rel, 'count')}]"
+ " AS relationships"
+ )
+ # Update node info
+ for node in schema_counts[0].get("nodes", []):
+ # Skip bloom labels
+ if node["name"] in EXCLUDED_LABELS:
+ continue
+ node_props = self.structured_schema["node_props"].get(node["name"])
+ if not node_props: # The node has no properties
+ continue
+ enhanced_cypher = self._enhanced_schema_cypher(
+ node["name"],
+ node_props,
+ node["count"] < EXHAUSTIVE_SEARCH_LIMIT,
+ )
+ enhanced_info = self.structured_query(enhanced_cypher)[0]["output"]
+ for prop in node_props:
+ if prop["property"] in enhanced_info:
+ prop.update(enhanced_info[prop["property"]])
+ # Update rel info
+ for rel in schema_counts[0].get("relationships", []):
+ # Skip bloom labels
+ if rel["name"] in EXCLUDED_RELS:
+ continue
+ rel_props = self.structured_schema["rel_props"].get(rel["name"])
+ if not rel_props: # The rel has no properties
+ continue
+ enhanced_cypher = self._enhanced_schema_cypher(
+ rel["name"],
+ rel_props,
+ rel["count"] < EXHAUSTIVE_SEARCH_LIMIT,
+ is_relationship=True,
+ )
+ try:
+ enhanced_info = self.structured_query(enhanced_cypher)[0][
+ "output"
+ ]
+ for prop in rel_props:
+ if prop["property"] in enhanced_info:
+ prop.update(enhanced_info[prop["property"]])
+ except self.neo4j.exceptions.ClientError:
+ # Sometimes the types are not consistent in the db
+ pass
+
+ def upsert_nodes(self, nodes: List[LabelledNode]) -> None:
+ # Lists to hold separated types
+ entity_dicts: List[dict] = []
+ chunk_dicts: List[dict] = []
+
+ # Sort by type
+ for item in nodes:
+ if isinstance(item, EntityNode):
+ entity_dicts.append({**item.dict(), "id": item.id})
+ elif isinstance(item, ChunkNode):
+ chunk_dicts.append({**item.dict(), "id": item.id})
+ else:
+ # Log that we do not support these types of nodes
+ # Or raise an error?
+ pass
+
+ if chunk_dicts:
+ self.structured_query(
+ """
+ UNWIND $data AS row
+ MERGE (c:Chunk {id: row.id})
+ SET c.text = row.text
+ WITH c, row
+ SET c += row.properties
+ WITH c, row.embedding AS embedding
+ WHERE embedding IS NOT NULL
+ CALL db.create.setNodeVectorProperty(c, 'embedding', embedding)
+ RETURN count(*)
+ """,
+ param_map={"data": chunk_dicts},
+ )
+
+ if entity_dicts:
+ self.structured_query(
+ """
+ UNWIND $data AS row
+ MERGE (e:`__Entity__` {id: row.id})
+ SET e += apoc.map.clean(row.properties, [], [])
+ SET e.name = row.name
+ WITH e, row
+ CALL apoc.create.addLabels(e, [row.label])
+ YIELD node
+ WITH e, row
+ CALL {
+ WITH e, row
+ WITH e, row
+ WHERE row.embedding IS NOT NULL
+ CALL db.create.setNodeVectorProperty(e, 'embedding', row.embedding)
+ RETURN count(*) AS count
+ }
+ WITH e, row WHERE row.properties.triplet_source_id IS NOT NULL
+ MERGE (c:Chunk {id: row.properties.triplet_source_id})
+ MERGE (e)<-[:MENTIONS]-(c)
+ """,
+ param_map={"data": entity_dicts},
+ )
+
+ def upsert_relations(self, relations: List[Relation]) -> None:
+ """Add relations."""
+ params = [r.dict() for r in relations]
+
+ self.structured_query(
+ """
+ UNWIND $data AS row
+ MERGE (source {id: row.source_id})
+ MERGE (target {id: row.target_id})
+ WITH source, target, row
+ CALL apoc.merge.relationship(source, row.label, {}, row.properties, target) YIELD rel
+ RETURN count(*)
+ """,
+ param_map={"data": params},
+ )
+
+ def get(
+ self,
+ properties: Optional[dict] = None,
+ ids: Optional[List[str]] = None,
+ ) -> List[LabelledNode]:
+ """Get nodes."""
+ cypher_statement = "MATCH (e) "
+
+ params = {}
+ if properties or ids:
+ cypher_statement += "WHERE "
+
+ if ids:
+ cypher_statement += "e.id in $ids "
+ params["ids"] = ids
+
+ if properties:
+ prop_list = []
+ for i, prop in enumerate(properties):
+ prop_list.append(f"e.`{prop}` = $property_{i}")
+ params[f"property_{i}"] = properties[prop]
+ cypher_statement += " AND ".join(prop_list)
+
+ return_statement = """
+ WITH e
+ RETURN e.id AS name,
+ [l in labels(e) WHERE l <> '__Entity__' | l][0] AS type,
+ e{.* , embedding: Null, id: Null} AS properties
+ """
+ cypher_statement += return_statement
+
+ response = self.structured_query(cypher_statement, param_map=params)
+ response = response if response else []
+
+ nodes = []
+ for record in response:
+ # text indicates a chunk node
+ # none on the type indicates an implicit node, likely a chunk node
+ if "text" in record["properties"] or record["type"] is None:
+ text = record["properties"].pop("text", "")
+ nodes.append(
+ ChunkNode(
+ id_=record["name"],
+ text=text,
+ properties=remove_empty_values(record["properties"]),
+ )
+ )
+ else:
+ nodes.append(
+ EntityNode(
+ name=record["name"],
+ label=record["type"],
+ properties=remove_empty_values(record["properties"]),
+ )
+ )
+
+ return nodes
+
+ def get_triplets(
+ self,
+ entity_names: Optional[List[str]] = None,
+ relation_names: Optional[List[str]] = None,
+ properties: Optional[dict] = None,
+ ids: Optional[List[str]] = None,
+ ) -> List[Triplet]:
+ # TODO: handle ids of chunk nodes
+ cypher_statement = "MATCH (e:`__Entity__`) "
+
+ params = {}
+ if entity_names or properties or ids:
+ cypher_statement += "WHERE "
+
+ if entity_names:
+ cypher_statement += "e.name in $entity_names "
+ params["entity_names"] = entity_names
+
+ if ids:
+ cypher_statement += "e.id in $ids "
+ params["ids"] = ids
+
+ if properties:
+ prop_list = []
+ for i, prop in enumerate(properties):
+ prop_list.append(f"e.`{prop}` = $property_{i}")
+ params[f"property_{i}"] = properties[prop]
+ cypher_statement += " AND ".join(prop_list)
+
+ return_statement = f"""
+ WITH e
+ CALL {{
+ WITH e
+ MATCH (e)-[r{':`' + '`|`'.join(relation_names) + '`' if relation_names else ''}]->(t)
+ RETURN e.name AS source_id, [l in labels(e) WHERE l <> '__Entity__' | l][0] AS source_type,
+ e{{.* , embedding: Null, name: Null}} AS source_properties,
+ type(r) AS type,
+ t.name AS target_id, [l in labels(t) WHERE l <> '__Entity__' | l][0] AS target_type,
+ t{{.* , embedding: Null, name: Null}} AS target_properties
+ UNION ALL
+ WITH e
+ MATCH (e)<-[r{':`' + '`|`'.join(relation_names) + '`' if relation_names else ''}]-(t)
+ RETURN t.name AS source_id, [l in labels(t) WHERE l <> '__Entity__' | l][0] AS source_type,
+ e{{.* , embedding: Null, name: Null}} AS source_properties,
+ type(r) AS type,
+ e.name AS target_id, [l in labels(e) WHERE l <> '__Entity__' | l][0] AS target_type,
+ t{{.* , embedding: Null, name: Null}} AS target_properties
+ }}
+ RETURN source_id, source_type, type, target_id, target_type, source_properties, target_properties"""
+ cypher_statement += return_statement
+
+ data = self.structured_query(cypher_statement, param_map=params)
+ data = data if data else []
+
+ triples = []
+ for record in data:
+ source = EntityNode(
+ name=record["source_id"],
+ label=record["source_type"],
+ properties=remove_empty_values(record["source_properties"]),
+ )
+ target = EntityNode(
+ name=record["target_id"],
+ label=record["target_type"],
+ properties=remove_empty_values(record["target_properties"]),
+ )
+ rel = Relation(
+ source_id=record["source_id"],
+ target_id=record["target_id"],
+ label=record["type"],
+ )
+ triples.append([source, rel, target])
+ return triples
+
+ def get_rel_map(
+ self,
+ graph_nodes: List[LabelledNode],
+ depth: int = 2,
+ limit: int = 30,
+ ignore_rels: Optional[List[str]] = None,
+ ) -> List[Triplet]:
+ """Get depth-aware rel map."""
+ triples = []
+
+ ids = [node.id for node in graph_nodes]
+ # Needs some optimization
+ response = self.structured_query(
+ f"""
+ MATCH (e:`__Entity__`)
+ WHERE e.id in $ids
+ MATCH p=(e)-[r*1..{depth}]-(other)
+ WHERE ALL(rel in relationships(p) WHERE type(rel) <> 'MENTIONS')
+ UNWIND relationships(p) AS rel
+ WITH distinct rel
+ WITH startNode(rel) AS source,
+ type(rel) AS type,
+ endNode(rel) AS endNode
+ RETURN source.id AS source_id, [l in labels(source) WHERE l <> '__Entity__' | l][0] AS source_type,
+ source{{.* , embedding: Null, id: Null}} AS source_properties,
+ type,
+ endNode.id AS target_id, [l in labels(endNode) WHERE l <> '__Entity__' | l][0] AS target_type,
+ endNode{{.* , embedding: Null, id: Null}} AS target_properties
+ LIMIT toInteger($limit)
+ """,
+ param_map={"ids": ids, "limit": limit},
+ )
+ response = response if response else []
+
+ ignore_rels = ignore_rels or []
+ for record in response:
+ if record["type"] in ignore_rels:
+ continue
+
+ source = EntityNode(
+ name=record["source_id"],
+ label=record["source_type"],
+ properties=remove_empty_values(record["source_properties"]),
+ )
+ target = EntityNode(
+ name=record["target_id"],
+ label=record["target_type"],
+ properties=remove_empty_values(record["target_properties"]),
+ )
+ rel = Relation(
+ source_id=record["source_id"],
+ target_id=record["target_id"],
+ label=record["type"],
+ )
+ triples.append([source, rel, target])
+
+ return triples
+
+ def structured_query(
+ self, query: str, param_map: Optional[Dict[str, Any]] = None
+ ) -> Any:
+ param_map = param_map or {}
+
+ with self._driver.session(database=self._database) as session:
+ result = session.run(query, param_map)
+ full_result = [d.data() for d in result]
+
+ if self.sanitize_query_output:
+ return value_sanitize(full_result)
+
+ return full_result
+
+ def vector_query(
+ self, query: VectorStoreQuery, **kwargs: Any
+ ) -> Tuple[List[LabelledNode], List[float]]:
+ """Query the graph store with a vector store query."""
+ data = self.structured_query(
+ """MATCH (e:`__Entity__`)
+ WHERE e.embedding IS NOT NULL AND size(e.embedding) = $dimension
+ WITH e, vector.similarity.cosine(e.embedding, $embedding) AS score
+ ORDER BY score DESC LIMIT toInteger($limit)
+ RETURN e.id AS name,
+ [l in labels(e) WHERE l <> '__Entity__' | l][0] AS type,
+ e{.* , embedding: Null, name: Null, id: Null} AS properties,
+ score""",
+ param_map={
+ "embedding": query.query_embedding,
+ "dimension": len(query.query_embedding),
+ "limit": query.similarity_top_k,
+ },
+ )
+ data = data if data else []
+
+ nodes = []
+ scores = []
+ for record in data:
+ node = EntityNode(
+ name=record["name"],
+ label=record["type"],
+ properties=remove_empty_values(record["properties"]),
+ )
+ nodes.append(node)
+ scores.append(record["score"])
+
+ return (nodes, scores)
+
+ def delete(
+ self,
+ entity_names: Optional[List[str]] = None,
+ relation_names: Optional[List[str]] = None,
+ properties: Optional[dict] = None,
+ ids: Optional[List[str]] = None,
+ ) -> None:
+ """Delete matching data."""
+ if entity_names:
+ self.structured_query(
+ "MATCH (n) WHERE n.name IN $entity_names DETACH DELETE n",
+ param_map={"entity_names": entity_names},
+ )
+
+ if ids:
+ self.structured_query(
+ "MATCH (n) WHERE n.id IN $ids DETACH DELETE n",
+ param_map={"ids": ids},
+ )
+
+ if relation_names:
+ for rel in relation_names:
+ self.structured_query(f"MATCH ()-[r:`{rel}`]->() DELETE r")
+
+ if properties:
+ cypher = "MATCH (e) WHERE "
+ prop_list = []
+ params = {}
+ for i, prop in enumerate(properties):
+ prop_list.append(f"e.`{prop}` = $property_{i}")
+ params[f"property_{i}"] = properties[prop]
+ cypher += " AND ".join(prop_list)
+ self.structured_query(
+ cypher + " DETACH DELETE e", param_map=params
+ )
+
+ def _enhanced_schema_cypher(
+ self,
+ label_or_type: str,
+ properties: List[Dict[str, Any]],
+ exhaustive: bool,
+ is_relationship: bool = False,
+ ) -> str:
+ if is_relationship:
+ match_clause = f"MATCH ()-[n:`{label_or_type}`]->()"
+ else:
+ match_clause = f"MATCH (n:`{label_or_type}`)"
+
+ with_clauses = []
+ return_clauses = []
+ output_dict = {}
+ if exhaustive:
+ for prop in properties:
+ prop_name = prop["property"]
+ prop_type = prop["type"]
+ if prop_type == "STRING":
+ with_clauses.append(
+ f"collect(distinct substring(toString(n.`{prop_name}`), 0, 50)) "
+ f"AS `{prop_name}_values`"
+ )
+ return_clauses.append(
+ f"values:`{prop_name}_values`[..{DISTINCT_VALUE_LIMIT}],"
+ f" distinct_count: size(`{prop_name}_values`)"
+ )
+ elif prop_type in [
+ "INTEGER",
+ "FLOAT",
+ "DATE",
+ "DATE_TIME",
+ "LOCAL_DATE_TIME",
+ ]:
+ with_clauses.append(
+ f"min(n.`{prop_name}`) AS `{prop_name}_min`"
+ )
+ with_clauses.append(
+ f"max(n.`{prop_name}`) AS `{prop_name}_max`"
+ )
+ with_clauses.append(
+ f"count(distinct n.`{prop_name}`) AS `{prop_name}_distinct`"
+ )
+ return_clauses.append(
+ f"min: toString(`{prop_name}_min`), "
+ f"max: toString(`{prop_name}_max`), "
+ f"distinct_count: `{prop_name}_distinct`"
+ )
+ elif prop_type == "LIST":
+ with_clauses.append(
+ f"min(size(n.`{prop_name}`)) AS `{prop_name}_size_min`, "
+ f"max(size(n.`{prop_name}`)) AS `{prop_name}_size_max`"
+ )
+ return_clauses.append(
+ f"min_size: `{prop_name}_size_min`, "
+ f"max_size: `{prop_name}_size_max`"
+ )
+ elif prop_type in ["BOOLEAN", "POINT", "DURATION"]:
+ continue
+ output_dict[prop_name] = "{" + return_clauses.pop() + "}"
+ else:
+ # Just sample 5 random nodes
+ match_clause += " WITH n LIMIT 5"
+ for prop in properties:
+ prop_name = prop["property"]
+ prop_type = prop["type"]
+
+ # Check if indexed property, we can still do exhaustive
+ prop_index = [
+ el
+ for el in self.structured_schema["metadata"]["index"]
+ if el["label"] == label_or_type
+ and el["properties"] == [prop_name]
+ and el["type"] == "RANGE"
+ ]
+ if prop_type == "STRING":
+ if (
+ prop_index
+ and prop_index[0].get("size") > 0
+ and prop_index[0].get("distinctValues")
+ <= DISTINCT_VALUE_LIMIT
+ ):
+ distinct_values = self.query(
+ f"CALL apoc.schema.properties.distinct("
+ f"'{label_or_type}', '{prop_name}') YIELD value"
+ )[0]["value"]
+ return_clauses.append(
+ f"values: {distinct_values},"
+ f" distinct_count: {len(distinct_values)}"
+ )
+ else:
+ with_clauses.append(
+ f"collect(distinct substring(n.`{prop_name}`, 0, 50)) "
+ f"AS `{prop_name}_values`"
+ )
+ return_clauses.append(f"values: `{prop_name}_values`")
+ elif prop_type in [
+ "INTEGER",
+ "FLOAT",
+ "DATE",
+ "DATE_TIME",
+ "LOCAL_DATE_TIME",
+ ]:
+ if not prop_index:
+ with_clauses.append(
+ f"collect(distinct toString(n.`{prop_name}`)) "
+ f"AS `{prop_name}_values`"
+ )
+ return_clauses.append(f"values: `{prop_name}_values`")
+ else:
+ with_clauses.append(
+ f"min(n.`{prop_name}`) AS `{prop_name}_min`"
+ )
+ with_clauses.append(
+ f"max(n.`{prop_name}`) AS `{prop_name}_max`"
+ )
+ with_clauses.append(
+ f"count(distinct n.`{prop_name}`) AS `{prop_name}_distinct`"
+ )
+ return_clauses.append(
+ f"min: toString(`{prop_name}_min`), "
+ f"max: toString(`{prop_name}_max`), "
+ f"distinct_count: `{prop_name}_distinct`"
+ )
+
+ elif prop_type == "LIST":
+ with_clauses.append(
+ f"min(size(n.`{prop_name}`)) AS `{prop_name}_size_min`, "
+ f"max(size(n.`{prop_name}`)) AS `{prop_name}_size_max`"
+ )
+ return_clauses.append(
+ f"min_size: `{prop_name}_size_min`, "
+ f"max_size: `{prop_name}_size_max`"
+ )
+ elif prop_type in ["BOOLEAN", "POINT", "DURATION"]:
+ continue
+
+ output_dict[prop_name] = "{" + return_clauses.pop() + "}"
+
+ with_clause = "WITH " + ",\n ".join(with_clauses)
+ return_clause = (
+ "RETURN {"
+ + ", ".join(f"`{k}`: {v}" for k, v in output_dict.items())
+ + "} AS output"
+ )
+
+ # Combine all parts of the Cypher query
+ return f"{match_clause}\n{with_clause}\n{return_clause}"
+
+ def get_schema(self, refresh: bool = False) -> Any:
+ if refresh:
+ self.refresh_schema()
+
+ return self.structured_schema
+
+ def get_schema_str(self, refresh: bool = False) -> str:
+ schema = self.get_schema(refresh=refresh)
+
+ formatted_node_props = []
+ formatted_rel_props = []
+
+ if self.enhcnaced_schema:
+ # Enhanced formatting for nodes
+ for node_type, properties in schema["node_props"].items():
+ formatted_node_props.append(f"- **{node_type}**")
+ for prop in properties:
+ example = ""
+ if prop["type"] == "STRING" and prop.get("values"):
+ if (
+ prop.get("distinct_count", 11)
+ > DISTINCT_VALUE_LIMIT
+ ):
+ example = (
+ f'Example: "{clean_string_values(prop["values"][0])}"'
+ if prop["values"]
+ else ""
+ )
+ else: # If less than 10 possible values return all
+ example = (
+ (
+ "Available options: "
+ f'{[clean_string_values(el) for el in prop["values"]]}'
+ )
+ if prop["values"]
+ else ""
+ )
+
+ elif prop["type"] in [
+ "INTEGER",
+ "FLOAT",
+ "DATE",
+ "DATE_TIME",
+ "LOCAL_DATE_TIME",
+ ]:
+ if prop.get("min") is not None:
+ example = f'Min: {prop["min"]}, Max: {prop["max"]}'
+ else:
+ example = (
+ f'Example: "{prop["values"][0]}"'
+ if prop.get("values")
+ else ""
+ )
+ elif prop["type"] == "LIST":
+ # Skip embeddings
+ if (
+ not prop.get("min_size")
+ or prop["min_size"] > LIST_LIMIT
+ ):
+ continue
+ example = f'Min Size: {prop["min_size"]}, Max Size: {prop["max_size"]}'
+ formatted_node_props.append(
+ f" - `{prop['property']}`: {prop['type']} {example}"
+ )
+
+ # Enhanced formatting for relationships
+ for rel_type, properties in schema["rel_props"].items():
+ formatted_rel_props.append(f"- **{rel_type}**")
+ for prop in properties:
+ example = ""
+ if prop["type"] == "STRING":
+ if (
+ prop.get("distinct_count", 11)
+ > DISTINCT_VALUE_LIMIT
+ ):
+ example = (
+ f'Example: "{clean_string_values(prop["values"][0])}"'
+ if prop.get("values")
+ else ""
+ )
+ else: # If less than 10 possible values return all
+ example = (
+ (
+ "Available options: "
+ f'{[clean_string_values(el) for el in prop["values"]]}'
+ )
+ if prop.get("values")
+ else ""
+ )
+ elif prop["type"] in [
+ "INTEGER",
+ "FLOAT",
+ "DATE",
+ "DATE_TIME",
+ "LOCAL_DATE_TIME",
+ ]:
+ if prop.get("min"): # If we have min/max
+ example = (
+ f'Min: {prop["min"]}, Max: {prop["max"]}'
+ )
+ else: # return a single value
+ example = (
+ f'Example: "{prop["values"][0]}"'
+ if prop.get("values")
+ else ""
+ )
+ elif prop["type"] == "LIST":
+ # Skip embeddings
+ if prop["min_size"] > LIST_LIMIT:
+ continue
+ example = f'Min Size: {prop["min_size"]}, Max Size: {prop["max_size"]}'
+ formatted_rel_props.append(
+ f" - `{prop['property']}: {prop['type']}` {example}"
+ )
+ else:
+ # Format node properties
+ for label, props in schema["node_props"].items():
+ props_str = ", ".join(
+ [f"{prop['property']}: {prop['type']}" for prop in props]
+ )
+ formatted_node_props.append(f"{label} {{{props_str}}}")
+
+ # Format relationship properties using structured_schema
+ for type, props in schema["rel_props"].items():
+ props_str = ", ".join(
+ [f"{prop['property']}: {prop['type']}" for prop in props]
+ )
+ formatted_rel_props.append(f"{type} {{{props_str}}}")
+
+ # Format relationships
+ formatted_rels = [
+ f"(:{el['start']})-[:{el['type']}]->(:{el['end']})"
+ for el in schema["relationships"]
+ ]
+
+ return "\n".join(
+ [
+ "Node properties:",
+ "\n".join(formatted_node_props),
+ "Relationship properties:",
+ "\n".join(formatted_rel_props),
+ "The relationships:",
+ "\n".join(formatted_rels),
+ ]
+ )
+
+ def update_extraction_prompt(
+ self,
+ prompt_provider: PromptProvider,
+ entity_types: list[EntityType],
+ relations: list[Relation],
+ ):
+ # Fetch the kg extraction prompt with blank entity types and relations
+ # Note - Assumes that for given prompt there is a `_with_spec` that can have entities + relations specified
+ few_shot_ner_kg_extraction_with_spec = prompt_provider.get_prompt(
+ f"{self.config.kg_extraction_prompt}_with_spec"
+ )
+
+ # Format the prompt to include the desired entity types and relations
+ few_shot_ner_kg_extraction = (
+ few_shot_ner_kg_extraction_with_spec.replace(
+ "{entity_types}", format_entity_types(entity_types)
+ ).replace("{relations}", format_relations(relations))
+ )
+
+ # Update the "few_shot_ner_kg_extraction" prompt used in downstream KG construction
+ prompt_provider.update_prompt(
+ self.config.kg_extraction_prompt,
+ json.dumps(few_shot_ner_kg_extraction, ensure_ascii=False),
+ )
+
+ def update_kg_agent_prompt(
+ self,
+ prompt_provider: PromptProvider,
+ entity_types: list[EntityType],
+ relations: list[Relation],
+ ):
+ # Fetch the kg extraction prompt with blank entity types and relations
+ # Note - Assumes that for given prompt there is a `_with_spec` that can have entities + relations specified
+ few_shot_ner_kg_extraction_with_spec = prompt_provider.get_prompt(
+ f"{self.config.kg_agent_prompt}_with_spec"
+ )
+
+ # Format the prompt to include the desired entity types and relations
+ few_shot_ner_kg_extraction = (
+ few_shot_ner_kg_extraction_with_spec.replace(
+ "{entity_types}",
+ format_entity_types(entity_types, ignore_subcats=True),
+ ).replace("{relations}", format_relations(relations))
+ )
+
+ # Update the "few_shot_ner_kg_extraction" prompt used in downstream KG construction
+ prompt_provider.update_prompt(
+ self.config.kg_agent_prompt,
+ json.dumps(few_shot_ner_kg_extraction, ensure_ascii=False),
+ )
diff --git a/R2R/r2r/providers/llms/__init__.py b/R2R/r2r/providers/llms/__init__.py
new file mode 100755
index 00000000..38a1c54a
--- /dev/null
+++ b/R2R/r2r/providers/llms/__init__.py
@@ -0,0 +1,7 @@
+from .litellm.base_litellm import LiteLLM
+from .openai.base_openai import OpenAILLM
+
+__all__ = [
+ "LiteLLM",
+ "OpenAILLM",
+]
diff --git a/R2R/r2r/providers/llms/litellm/base_litellm.py b/R2R/r2r/providers/llms/litellm/base_litellm.py
new file mode 100755
index 00000000..581cce9a
--- /dev/null
+++ b/R2R/r2r/providers/llms/litellm/base_litellm.py
@@ -0,0 +1,142 @@
+import logging
+from typing import Any, Generator, Union
+
+from r2r.base import (
+ LLMChatCompletion,
+ LLMChatCompletionChunk,
+ LLMConfig,
+ LLMProvider,
+)
+from r2r.base.abstractions.llm import GenerationConfig
+
+logger = logging.getLogger(__name__)
+
+
+class LiteLLM(LLMProvider):
+ """A concrete class for creating LiteLLM models."""
+
+ def __init__(
+ self,
+ config: LLMConfig,
+ *args,
+ **kwargs,
+ ) -> None:
+ try:
+ from litellm import acompletion, completion
+
+ self.litellm_completion = completion
+ self.litellm_acompletion = acompletion
+ except ImportError:
+ raise ImportError(
+ "Error, `litellm` is required to run a LiteLLM. Please install it using `pip install litellm`."
+ )
+ super().__init__(config)
+
+ def get_completion(
+ self,
+ messages: list[dict],
+ generation_config: GenerationConfig,
+ **kwargs,
+ ) -> LLMChatCompletion:
+ if generation_config.stream:
+ raise ValueError(
+ "Stream must be set to False to use the `get_completion` method."
+ )
+ return self._get_completion(messages, generation_config, **kwargs)
+
+ def get_completion_stream(
+ self,
+ messages: list[dict],
+ generation_config: GenerationConfig,
+ **kwargs,
+ ) -> Generator[LLMChatCompletionChunk, None, None]:
+ if not generation_config.stream:
+ raise ValueError(
+ "Stream must be set to True to use the `get_completion_stream` method."
+ )
+ return self._get_completion(messages, generation_config, **kwargs)
+
+ def extract_content(self, response: LLMChatCompletion) -> str:
+ return response.choices[0].message.content
+
+ def _get_completion(
+ self,
+ messages: list[dict],
+ generation_config: GenerationConfig,
+ **kwargs,
+ ) -> Union[
+ LLMChatCompletion, Generator[LLMChatCompletionChunk, None, None]
+ ]:
+ # Create a dictionary with the default arguments
+ args = self._get_base_args(generation_config)
+ args["messages"] = messages
+
+ # Conditionally add the 'functions' argument if it's not None
+ if generation_config.functions is not None:
+ args["functions"] = generation_config.functions
+
+ args = {**args, **kwargs}
+ response = self.litellm_completion(**args)
+
+ if not generation_config.stream:
+ return LLMChatCompletion(**response.dict())
+ else:
+ return self._get_chat_completion(response)
+
+ def _get_chat_completion(
+ self,
+ response: Any,
+ ) -> Generator[LLMChatCompletionChunk, None, None]:
+ for part in response:
+ yield LLMChatCompletionChunk(**part.dict())
+
+ def _get_base_args(
+ self,
+ generation_config: GenerationConfig,
+ prompt=None,
+ ) -> dict:
+ """Get the base arguments for the LiteLLM API."""
+ args = {
+ "model": generation_config.model,
+ "temperature": generation_config.temperature,
+ "top_p": generation_config.top_p,
+ "stream": generation_config.stream,
+ # TODO - We need to cap this to avoid potential errors when exceed max allowable context
+ "max_tokens": generation_config.max_tokens_to_sample,
+ }
+ return args
+
+ async def aget_completion(
+ self,
+ messages: list[dict],
+ generation_config: GenerationConfig,
+ **kwargs,
+ ) -> LLMChatCompletion:
+ if generation_config.stream:
+ raise ValueError(
+ "Stream must be set to False to use the `aget_completion` method."
+ )
+ return await self._aget_completion(
+ messages, generation_config, **kwargs
+ )
+
+ async def _aget_completion(
+ self,
+ messages: list[dict],
+ generation_config: GenerationConfig,
+ **kwargs,
+ ) -> Union[LLMChatCompletion, LLMChatCompletionChunk]:
+ """Asynchronously get a completion from the OpenAI API based on the provided messages."""
+
+ # Create a dictionary with the default arguments
+ args = self._get_base_args(generation_config)
+
+ args["messages"] = messages
+
+ # Conditionally add the 'functions' argument if it's not None
+ if generation_config.functions is not None:
+ args["functions"] = generation_config.functions
+
+ args = {**args, **kwargs}
+ # Create the chat completion
+ return await self.litellm_acompletion(**args)
diff --git a/R2R/r2r/providers/llms/openai/base_openai.py b/R2R/r2r/providers/llms/openai/base_openai.py
new file mode 100755
index 00000000..460c0f0b
--- /dev/null
+++ b/R2R/r2r/providers/llms/openai/base_openai.py
@@ -0,0 +1,144 @@
+"""A module for creating OpenAI model abstractions."""
+
+import logging
+import os
+from typing import Union
+
+from r2r.base import (
+ LLMChatCompletion,
+ LLMChatCompletionChunk,
+ LLMConfig,
+ LLMProvider,
+)
+from r2r.base.abstractions.llm import GenerationConfig
+
+logger = logging.getLogger(__name__)
+
+
+class OpenAILLM(LLMProvider):
+ """A concrete class for creating OpenAI models."""
+
+ def __init__(
+ self,
+ config: LLMConfig,
+ *args,
+ **kwargs,
+ ) -> None:
+ if not isinstance(config, LLMConfig):
+ raise ValueError(
+ "The provided config must be an instance of OpenAIConfig."
+ )
+ try:
+ from openai import OpenAI # noqa
+ except ImportError:
+ raise ImportError(
+ "Error, `openai` is required to run an OpenAILLM. Please install it using `pip install openai`."
+ )
+ if config.provider != "openai":
+ raise ValueError(
+ "OpenAILLM must be initialized with config with `openai` provider."
+ )
+ if not os.getenv("OPENAI_API_KEY"):
+ raise ValueError(
+ "OpenAI API key not found. Please set the OPENAI_API_KEY environment variable."
+ )
+ super().__init__(config)
+ self.config: LLMConfig = config
+ self.client = OpenAI()
+
+ def get_completion(
+ self,
+ messages: list[dict],
+ generation_config: GenerationConfig,
+ **kwargs,
+ ) -> LLMChatCompletion:
+ if generation_config.stream:
+ raise ValueError(
+ "Stream must be set to False to use the `get_completion` method."
+ )
+ return self._get_completion(messages, generation_config, **kwargs)
+
+ def get_completion_stream(
+ self,
+ messages: list[dict],
+ generation_config: GenerationConfig,
+ **kwargs,
+ ) -> LLMChatCompletionChunk:
+ if not generation_config.stream:
+ raise ValueError(
+ "Stream must be set to True to use the `get_completion_stream` method."
+ )
+ return self._get_completion(messages, generation_config, **kwargs)
+
+ def _get_completion(
+ self,
+ messages: list[dict],
+ generation_config: GenerationConfig,
+ **kwargs,
+ ) -> Union[LLMChatCompletion, LLMChatCompletionChunk]:
+ """Get a completion from the OpenAI API based on the provided messages."""
+
+ # Create a dictionary with the default arguments
+ args = self._get_base_args(generation_config)
+
+ args["messages"] = messages
+
+ # Conditionally add the 'functions' argument if it's not None
+ if generation_config.functions is not None:
+ args["functions"] = generation_config.functions
+
+ args = {**args, **kwargs}
+ # Create the chat completion
+ return self.client.chat.completions.create(**args)
+
+ def _get_base_args(
+ self,
+ generation_config: GenerationConfig,
+ ) -> dict:
+ """Get the base arguments for the OpenAI API."""
+
+ args = {
+ "model": generation_config.model,
+ "temperature": generation_config.temperature,
+ "top_p": generation_config.top_p,
+ "stream": generation_config.stream,
+ # TODO - We need to cap this to avoid potential errors when exceed max allowable context
+ "max_tokens": generation_config.max_tokens_to_sample,
+ }
+
+ return args
+
+ async def aget_completion(
+ self,
+ messages: list[dict],
+ generation_config: GenerationConfig,
+ **kwargs,
+ ) -> LLMChatCompletion:
+ if generation_config.stream:
+ raise ValueError(
+ "Stream must be set to False to use the `aget_completion` method."
+ )
+ return await self._aget_completion(
+ messages, generation_config, **kwargs
+ )
+
+ async def _aget_completion(
+ self,
+ messages: list[dict],
+ generation_config: GenerationConfig,
+ **kwargs,
+ ) -> Union[LLMChatCompletion, LLMChatCompletionChunk]:
+ """Asynchronously get a completion from the OpenAI API based on the provided messages."""
+
+ # Create a dictionary with the default arguments
+ args = self._get_base_args(generation_config)
+
+ args["messages"] = messages
+
+ # Conditionally add the 'functions' argument if it's not None
+ if generation_config.functions is not None:
+ args["functions"] = generation_config.functions
+
+ args = {**args, **kwargs}
+ # Create the chat completion
+ return await self.client.chat.completions.create(**args)
diff --git a/R2R/r2r/providers/vector_dbs/__init__.py b/R2R/r2r/providers/vector_dbs/__init__.py
new file mode 100755
index 00000000..38ea0890
--- /dev/null
+++ b/R2R/r2r/providers/vector_dbs/__init__.py
@@ -0,0 +1,5 @@
+from .pgvector.pgvector_db import PGVectorDB
+
+__all__ = [
+ "PGVectorDB",
+]
diff --git a/R2R/r2r/providers/vector_dbs/pgvector/pgvector_db.py b/R2R/r2r/providers/vector_dbs/pgvector/pgvector_db.py
new file mode 100755
index 00000000..8cf728d1
--- /dev/null
+++ b/R2R/r2r/providers/vector_dbs/pgvector/pgvector_db.py
@@ -0,0 +1,610 @@
+import json
+import logging
+import os
+import time
+from typing import Literal, Optional, Union
+
+from sqlalchemy import exc, text
+from sqlalchemy.engine.url import make_url
+
+from r2r.base import (
+ DocumentInfo,
+ UserStats,
+ VectorDBConfig,
+ VectorDBProvider,
+ VectorEntry,
+ VectorSearchResult,
+)
+from r2r.vecs.client import Client
+from r2r.vecs.collection import Collection
+
+logger = logging.getLogger(__name__)
+
+
+class PGVectorDB(VectorDBProvider):
+ def __init__(self, config: VectorDBConfig) -> None:
+ super().__init__(config)
+ try:
+ import r2r.vecs
+ except ImportError:
+ raise ValueError(
+ f"Error, PGVectorDB requires the vecs library. Please run `pip install vecs`."
+ )
+
+ # Check if a complete Postgres URI is provided
+ postgres_uri = self.config.extra_fields.get(
+ "postgres_uri"
+ ) or os.getenv("POSTGRES_URI")
+
+ if postgres_uri:
+ # Log loudly that Postgres URI is being used
+ logger.warning("=" * 50)
+ logger.warning(
+ "ATTENTION: Using provided Postgres URI for connection"
+ )
+ logger.warning("=" * 50)
+
+ # Validate and use the provided URI
+ try:
+ parsed_uri = make_url(postgres_uri)
+ if not all([parsed_uri.username, parsed_uri.database]):
+ raise ValueError(
+ "The provided Postgres URI is missing required components."
+ )
+ DB_CONNECTION = postgres_uri
+
+ # Log the sanitized URI (without password)
+ sanitized_uri = parsed_uri.set(password="*****")
+ logger.info(f"Connecting using URI: {sanitized_uri}")
+ except Exception as e:
+ raise ValueError(f"Invalid Postgres URI provided: {e}")
+ else:
+ # Fall back to existing logic for individual connection parameters
+ user = self.config.extra_fields.get("user", None) or os.getenv(
+ "POSTGRES_USER"
+ )
+ password = self.config.extra_fields.get(
+ "password", None
+ ) or os.getenv("POSTGRES_PASSWORD")
+ host = self.config.extra_fields.get("host", None) or os.getenv(
+ "POSTGRES_HOST"
+ )
+ port = self.config.extra_fields.get("port", None) or os.getenv(
+ "POSTGRES_PORT"
+ )
+ db_name = self.config.extra_fields.get(
+ "db_name", None
+ ) or os.getenv("POSTGRES_DBNAME")
+
+ if not all([user, password, host, db_name]):
+ raise ValueError(
+ "Error, please set the POSTGRES_USER, POSTGRES_PASSWORD, POSTGRES_HOST, POSTGRES_DBNAME environment variables or provide them in the config."
+ )
+
+ # Check if it's a Unix socket connection
+ if host.startswith("/") and not port:
+ DB_CONNECTION = (
+ f"postgresql://{user}:{password}@/{db_name}?host={host}"
+ )
+ logger.info("Using Unix socket connection")
+ else:
+ DB_CONNECTION = (
+ f"postgresql://{user}:{password}@{host}:{port}/{db_name}"
+ )
+ logger.info("Using TCP connection")
+
+ # The rest of the initialization remains the same
+ try:
+ self.vx: Client = r2r.vecs.create_client(DB_CONNECTION)
+ except Exception as e:
+ raise ValueError(
+ f"Error {e} occurred while attempting to connect to the pgvector provider with {DB_CONNECTION}."
+ )
+
+ self.collection_name = self.config.extra_fields.get(
+ "vecs_collection"
+ ) or os.getenv("POSTGRES_VECS_COLLECTION")
+ if not self.collection_name:
+ raise ValueError(
+ "Error, please set a valid POSTGRES_VECS_COLLECTION environment variable or set a 'vecs_collection' in the 'vector_database' settings of your `config.json`."
+ )
+
+ self.collection: Optional[Collection] = None
+
+ logger.info(
+ f"Successfully initialized PGVectorDB with collection: {self.collection_name}"
+ )
+
+ def initialize_collection(self, dimension: int) -> None:
+ self.collection = self.vx.get_or_create_collection(
+ name=self.collection_name, dimension=dimension
+ )
+ self._create_document_info_table()
+ self._create_hybrid_search_function()
+
+ def _create_document_info_table(self):
+ with self.vx.Session() as sess:
+ with sess.begin():
+ try:
+ # Enable uuid-ossp extension
+ sess.execute(
+ text('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";')
+ )
+ except exc.ProgrammingError as e:
+ logger.error(f"Error enabling uuid-ossp extension: {e}")
+ raise
+
+ # Create the table if it doesn't exist
+ create_table_query = f"""
+ CREATE TABLE IF NOT EXISTS document_info_"{self.collection_name}" (
+ document_id UUID PRIMARY KEY,
+ title TEXT,
+ user_id UUID NULL,
+ version TEXT,
+ size_in_bytes INT,
+ created_at TIMESTAMPTZ DEFAULT NOW(),
+ updated_at TIMESTAMPTZ DEFAULT NOW(),
+ metadata JSONB,
+ status TEXT
+ );
+ """
+ sess.execute(text(create_table_query))
+
+ # Add the new column if it doesn't exist
+ add_column_query = f"""
+ DO $$
+ BEGIN
+ IF NOT EXISTS (
+ SELECT 1
+ FROM information_schema.columns
+ WHERE table_name = 'document_info_"{self.collection_name}"'
+ AND column_name = 'status'
+ ) THEN
+ ALTER TABLE "document_info_{self.collection_name}"
+ ADD COLUMN status TEXT DEFAULT 'processing';
+ END IF;
+ END $$;
+ """
+ sess.execute(text(add_column_query))
+
+ sess.commit()
+
+ def _create_hybrid_search_function(self):
+ hybrid_search_function = f"""
+ CREATE OR REPLACE FUNCTION hybrid_search_{self.collection_name}(
+ query_text TEXT,
+ query_embedding VECTOR(512),
+ match_limit INT,
+ full_text_weight FLOAT = 1,
+ semantic_weight FLOAT = 1,
+ rrf_k INT = 50,
+ filter_condition JSONB = NULL
+ )
+ RETURNS SETOF vecs."{self.collection_name}"
+ LANGUAGE sql
+ AS $$
+ WITH full_text AS (
+ SELECT
+ id,
+ ROW_NUMBER() OVER (ORDER BY ts_rank(to_tsvector('english', metadata->>'text'), websearch_to_tsquery(query_text)) DESC) AS rank_ix
+ FROM vecs."{self.collection_name}"
+ WHERE to_tsvector('english', metadata->>'text') @@ websearch_to_tsquery(query_text)
+ AND (filter_condition IS NULL OR (metadata @> filter_condition))
+ ORDER BY rank_ix
+ LIMIT LEAST(match_limit, 30) * 2
+ ),
+ semantic AS (
+ SELECT
+ id,
+ ROW_NUMBER() OVER (ORDER BY vec <#> query_embedding) AS rank_ix
+ FROM vecs."{self.collection_name}"
+ WHERE filter_condition IS NULL OR (metadata @> filter_condition)
+ ORDER BY rank_ix
+ LIMIT LEAST(match_limit, 30) * 2
+ )
+ SELECT
+ vecs."{self.collection_name}".*
+ FROM
+ full_text
+ FULL OUTER JOIN semantic
+ ON full_text.id = semantic.id
+ JOIN vecs."{self.collection_name}"
+ ON vecs."{self.collection_name}".id = COALESCE(full_text.id, semantic.id)
+ ORDER BY
+ COALESCE(1.0 / (rrf_k + full_text.rank_ix), 0.0) * full_text_weight +
+ COALESCE(1.0 / (rrf_k + semantic.rank_ix), 0.0) * semantic_weight
+ DESC
+ LIMIT
+ LEAST(match_limit, 30);
+ $$;
+ """
+ retry_attempts = 5
+ for attempt in range(retry_attempts):
+ try:
+ with self.vx.Session() as sess:
+ # Acquire an advisory lock
+ sess.execute(text("SELECT pg_advisory_lock(123456789)"))
+ try:
+ sess.execute(text(hybrid_search_function))
+ sess.commit()
+ finally:
+ # Release the advisory lock
+ sess.execute(
+ text("SELECT pg_advisory_unlock(123456789)")
+ )
+ break # Break the loop if successful
+ except exc.InternalError as e:
+ if "tuple concurrently updated" in str(e):
+ time.sleep(2**attempt) # Exponential backoff
+ else:
+ raise # Re-raise the exception if it's not a concurrency issue
+ else:
+ raise RuntimeError(
+ "Failed to create hybrid search function after multiple attempts"
+ )
+
+ def copy(self, entry: VectorEntry, commit=True) -> None:
+ if self.collection is None:
+ raise ValueError(
+ "Please call `initialize_collection` before attempting to run `copy`."
+ )
+
+ serializeable_entry = entry.to_serializable()
+
+ self.collection.copy(
+ records=[
+ (
+ serializeable_entry["id"],
+ serializeable_entry["vector"],
+ serializeable_entry["metadata"],
+ )
+ ]
+ )
+
+ def copy_entries(
+ self, entries: list[VectorEntry], commit: bool = True
+ ) -> None:
+ if self.collection is None:
+ raise ValueError(
+ "Please call `initialize_collection` before attempting to run `copy_entries`."
+ )
+
+ self.collection.copy(
+ records=[
+ (
+ str(entry.id),
+ entry.vector.data,
+ entry.to_serializable()["metadata"],
+ )
+ for entry in entries
+ ]
+ )
+
+ def upsert(self, entry: VectorEntry, commit=True) -> None:
+ if self.collection is None:
+ raise ValueError(
+ "Please call `initialize_collection` before attempting to run `upsert`."
+ )
+
+ self.collection.upsert(
+ records=[
+ (
+ str(entry.id),
+ entry.vector.data,
+ entry.to_serializable()["metadata"],
+ )
+ ]
+ )
+
+ def upsert_entries(
+ self, entries: list[VectorEntry], commit: bool = True
+ ) -> None:
+ if self.collection is None:
+ raise ValueError(
+ "Please call `initialize_collection` before attempting to run `upsert_entries`."
+ )
+
+ self.collection.upsert(
+ records=[
+ (
+ str(entry.id),
+ entry.vector.data,
+ entry.to_serializable()["metadata"],
+ )
+ for entry in entries
+ ]
+ )
+
+ def search(
+ self,
+ query_vector: list[float],
+ filters: dict[str, Union[bool, int, str]] = {},
+ limit: int = 10,
+ *args,
+ **kwargs,
+ ) -> list[VectorSearchResult]:
+ if self.collection is None:
+ raise ValueError(
+ "Please call `initialize_collection` before attempting to run `search`."
+ )
+ measure = kwargs.get("measure", "cosine_distance")
+ mapped_filters = {
+ key: {"$eq": value} for key, value in filters.items()
+ }
+
+ return [
+ VectorSearchResult(id=ele[0], score=float(1 - ele[1]), metadata=ele[2]) # type: ignore
+ for ele in self.collection.query(
+ data=query_vector,
+ limit=limit,
+ filters=mapped_filters,
+ measure=measure,
+ include_value=True,
+ include_metadata=True,
+ )
+ ]
+
+ def hybrid_search(
+ self,
+ query_text: str,
+ query_vector: list[float],
+ limit: int = 10,
+ filters: Optional[dict[str, Union[bool, int, str]]] = None,
+ # Hybrid search parameters
+ full_text_weight: float = 1.0,
+ semantic_weight: float = 1.0,
+ rrf_k: int = 20, # typical value is ~2x the number of results you want
+ *args,
+ **kwargs,
+ ) -> list[VectorSearchResult]:
+ if self.collection is None:
+ raise ValueError(
+ "Please call `initialize_collection` before attempting to run `hybrid_search`."
+ )
+
+ # Convert filters to a JSON-compatible format
+ filter_condition = None
+ if filters:
+ filter_condition = json.dumps(filters)
+
+ query = text(
+ f"""
+ SELECT * FROM hybrid_search_{self.collection_name}(
+ cast(:query_text as TEXT), cast(:query_embedding as VECTOR), cast(:match_limit as INT),
+ cast(:full_text_weight as FLOAT), cast(:semantic_weight as FLOAT), cast(:rrf_k as INT),
+ cast(:filter_condition as JSONB)
+ )
+ """
+ )
+
+ params = {
+ "query_text": str(query_text),
+ "query_embedding": list(query_vector),
+ "match_limit": limit,
+ "full_text_weight": full_text_weight,
+ "semantic_weight": semantic_weight,
+ "rrf_k": rrf_k,
+ "filter_condition": filter_condition,
+ }
+
+ with self.vx.Session() as session:
+ result = session.execute(query, params).fetchall()
+ return [
+ VectorSearchResult(id=row[0], score=1.0, metadata=row[-1])
+ for row in result
+ ]
+
+ def create_index(self, index_type, column_name, index_options):
+ pass
+
+ def delete_by_metadata(
+ self,
+ metadata_fields: list[str],
+ metadata_values: list[Union[bool, int, str]],
+ logic: Literal["AND", "OR"] = "AND",
+ ) -> list[str]:
+ if logic == "OR":
+ raise ValueError(
+ "OR logic is still being tested before official support for `delete_by_metadata` in pgvector."
+ )
+ if self.collection is None:
+ raise ValueError(
+ "Please call `initialize_collection` before attempting to run `delete_by_metadata`."
+ )
+
+ if len(metadata_fields) != len(metadata_values):
+ raise ValueError(
+ "The number of metadata fields must match the number of metadata values."
+ )
+
+ # Construct the filter
+ if logic == "AND":
+ filters = {
+ k: {"$eq": v} for k, v in zip(metadata_fields, metadata_values)
+ }
+ else: # OR logic
+ # TODO - Test 'or' logic and remove check above
+ filters = {
+ "$or": [
+ {k: {"$eq": v}}
+ for k, v in zip(metadata_fields, metadata_values)
+ ]
+ }
+ return self.collection.delete(filters=filters)
+
+ def get_metadatas(
+ self,
+ metadata_fields: list[str],
+ filter_field: Optional[str] = None,
+ filter_value: Optional[Union[bool, int, str]] = None,
+ ) -> list[dict]:
+ if self.collection is None:
+ raise ValueError(
+ "Please call `initialize_collection` before attempting to run `get_metadatas`."
+ )
+
+ results = {tuple(metadata_fields): {}}
+ for field in metadata_fields:
+ unique_values = self.collection.get_unique_metadata_values(
+ field=field,
+ filter_field=filter_field,
+ filter_value=filter_value,
+ )
+ for value in unique_values:
+ if value not in results:
+ results[value] = {}
+ results[value][field] = value
+
+ return [
+ results[key] for key in results if key != tuple(metadata_fields)
+ ]
+
+ def upsert_documents_overview(
+ self, documents_overview: list[DocumentInfo]
+ ) -> None:
+ for document_info in documents_overview:
+ db_entry = document_info.convert_to_db_entry()
+
+ # Convert 'None' string to None type for user_id
+ if db_entry["user_id"] == "None":
+ db_entry["user_id"] = None
+
+ query = text(
+ f"""
+ INSERT INTO "document_info_{self.collection_name}" (document_id, title, user_id, version, created_at, updated_at, size_in_bytes, metadata, status)
+ VALUES (:document_id, :title, :user_id, :version, :created_at, :updated_at, :size_in_bytes, :metadata, :status)
+ ON CONFLICT (document_id) DO UPDATE SET
+ title = EXCLUDED.title,
+ user_id = EXCLUDED.user_id,
+ version = EXCLUDED.version,
+ updated_at = EXCLUDED.updated_at,
+ size_in_bytes = EXCLUDED.size_in_bytes,
+ metadata = EXCLUDED.metadata,
+ status = EXCLUDED.status;
+ """
+ )
+ with self.vx.Session() as sess:
+ sess.execute(query, db_entry)
+ sess.commit()
+
+ def delete_from_documents_overview(
+ self, document_id: str, version: Optional[str] = None
+ ) -> None:
+ query = f"""
+ DELETE FROM "document_info_{self.collection_name}"
+ WHERE document_id = :document_id
+ """
+ params = {"document_id": document_id}
+
+ if version is not None:
+ query += " AND version = :version"
+ params["version"] = version
+
+ with self.vx.Session() as sess:
+ with sess.begin():
+ sess.execute(text(query), params)
+ sess.commit()
+
+ def get_documents_overview(
+ self,
+ filter_document_ids: Optional[list[str]] = None,
+ filter_user_ids: Optional[list[str]] = None,
+ ):
+ conditions = []
+ params = {}
+
+ if filter_document_ids:
+ placeholders = ", ".join(
+ f":doc_id_{i}" for i in range(len(filter_document_ids))
+ )
+ conditions.append(f"document_id IN ({placeholders})")
+ params.update(
+ {
+ f"doc_id_{i}": str(document_id)
+ for i, document_id in enumerate(filter_document_ids)
+ }
+ )
+ if filter_user_ids:
+ placeholders = ", ".join(
+ f":user_id_{i}" for i in range(len(filter_user_ids))
+ )
+ conditions.append(f"user_id IN ({placeholders})")
+ params.update(
+ {
+ f"user_id_{i}": str(user_id)
+ for i, user_id in enumerate(filter_user_ids)
+ }
+ )
+
+ query = f"""
+ SELECT document_id, title, user_id, version, size_in_bytes, created_at, updated_at, metadata, status
+ FROM "document_info_{self.collection_name}"
+ """
+ if conditions:
+ query += " WHERE " + " AND ".join(conditions)
+
+ with self.vx.Session() as sess:
+ results = sess.execute(text(query), params).fetchall()
+ return [
+ DocumentInfo(
+ document_id=row[0],
+ title=row[1],
+ user_id=row[2],
+ version=row[3],
+ size_in_bytes=row[4],
+ created_at=row[5],
+ updated_at=row[6],
+ metadata=row[7],
+ status=row[8],
+ )
+ for row in results
+ ]
+
+ def get_document_chunks(self, document_id: str) -> list[dict]:
+ if not self.collection:
+ raise ValueError("Collection is not initialized.")
+
+ table_name = self.collection.table.name
+ query = text(
+ f"""
+ SELECT metadata
+ FROM vecs."{table_name}"
+ WHERE metadata->>'document_id' = :document_id
+ ORDER BY CAST(metadata->>'chunk_order' AS INTEGER)
+ """
+ )
+
+ params = {"document_id": document_id}
+
+ with self.vx.Session() as sess:
+ results = sess.execute(query, params).fetchall()
+ return [result[0] for result in results]
+
+ def get_users_overview(self, user_ids: Optional[list[str]] = None):
+ user_ids_condition = ""
+ params = {}
+ if user_ids:
+ user_ids_condition = "WHERE user_id IN :user_ids"
+ params["user_ids"] = tuple(
+ map(str, user_ids)
+ ) # Convert UUIDs to strings
+
+ query = f"""
+ SELECT user_id, COUNT(document_id) AS num_files, SUM(size_in_bytes) AS total_size_in_bytes, ARRAY_AGG(document_id) AS document_ids
+ FROM "document_info_{self.collection_name}"
+ {user_ids_condition}
+ GROUP BY user_id
+ """
+
+ with self.vx.Session() as sess:
+ results = sess.execute(text(query), params).fetchall()
+ return [
+ UserStats(
+ user_id=row[0],
+ num_files=row[1],
+ total_size_in_bytes=row[2],
+ document_ids=row[3],
+ )
+ for row in results
+ if row[0] is not None
+ ]