about summary refs log tree commit diff
path: root/R2R/r2r/providers
diff options
context:
space:
mode:
Diffstat (limited to 'R2R/r2r/providers')
-rwxr-xr-xR2R/r2r/providers/__init__.py0
-rwxr-xr-xR2R/r2r/providers/embeddings/__init__.py11
-rwxr-xr-xR2R/r2r/providers/embeddings/ollama/ollama_base.py156
-rwxr-xr-xR2R/r2r/providers/embeddings/openai/openai_base.py200
-rwxr-xr-xR2R/r2r/providers/embeddings/sentence_transformer/sentence_transformer_base.py160
-rwxr-xr-xR2R/r2r/providers/eval/__init__.py3
-rwxr-xr-xR2R/r2r/providers/eval/llm/base_llm_eval.py84
-rwxr-xr-xR2R/r2r/providers/kg/__init__.py3
-rwxr-xr-xR2R/r2r/providers/kg/neo4j/base_neo4j.py983
-rwxr-xr-xR2R/r2r/providers/llms/__init__.py7
-rwxr-xr-xR2R/r2r/providers/llms/litellm/base_litellm.py142
-rwxr-xr-xR2R/r2r/providers/llms/openai/base_openai.py144
-rwxr-xr-xR2R/r2r/providers/vector_dbs/__init__.py5
-rwxr-xr-xR2R/r2r/providers/vector_dbs/pgvector/pgvector_db.py610
14 files changed, 2508 insertions, 0 deletions
diff --git a/R2R/r2r/providers/__init__.py b/R2R/r2r/providers/__init__.py
new file mode 100755
index 00000000..e69de29b
--- /dev/null
+++ b/R2R/r2r/providers/__init__.py
diff --git a/R2R/r2r/providers/embeddings/__init__.py b/R2R/r2r/providers/embeddings/__init__.py
new file mode 100755
index 00000000..6b0c8b83
--- /dev/null
+++ b/R2R/r2r/providers/embeddings/__init__.py
@@ -0,0 +1,11 @@
+from .ollama.ollama_base import OllamaEmbeddingProvider
+from .openai.openai_base import OpenAIEmbeddingProvider
+from .sentence_transformer.sentence_transformer_base import (
+    SentenceTransformerEmbeddingProvider,
+)
+
+__all__ = [
+    "OllamaEmbeddingProvider",
+    "OpenAIEmbeddingProvider",
+    "SentenceTransformerEmbeddingProvider",
+]
diff --git a/R2R/r2r/providers/embeddings/ollama/ollama_base.py b/R2R/r2r/providers/embeddings/ollama/ollama_base.py
new file mode 100755
index 00000000..31a8c717
--- /dev/null
+++ b/R2R/r2r/providers/embeddings/ollama/ollama_base.py
@@ -0,0 +1,156 @@
+import asyncio
+import logging
+import os
+import random
+from typing import Any
+
+from ollama import AsyncClient, Client
+
+from r2r.base import EmbeddingConfig, EmbeddingProvider, VectorSearchResult
+
+logger = logging.getLogger(__name__)
+
+
+class OllamaEmbeddingProvider(EmbeddingProvider):
+    def __init__(self, config: EmbeddingConfig):
+        super().__init__(config)
+        provider = config.provider
+        if not provider:
+            raise ValueError(
+                "Must set provider in order to initialize `OllamaEmbeddingProvider`."
+            )
+        if provider != "ollama":
+            raise ValueError(
+                "OllamaEmbeddingProvider must be initialized with provider `ollama`."
+            )
+        if config.rerank_model:
+            raise ValueError(
+                "OllamaEmbeddingProvider does not support separate reranking."
+            )
+
+        self.base_model = config.base_model
+        self.base_dimension = config.base_dimension
+        self.base_url = os.getenv("OLLAMA_API_BASE")
+        logger.info(
+            f"Using Ollama API base URL: {self.base_url or 'http://127.0.0.1:11434'}"
+        )
+        self.client = Client(host=self.base_url)
+        self.aclient = AsyncClient(host=self.base_url)
+
+        self.request_queue = asyncio.Queue()
+        self.max_retries = 2
+        self.initial_backoff = 1
+        self.max_backoff = 60
+        self.concurrency_limit = 10
+        self.semaphore = asyncio.Semaphore(self.concurrency_limit)
+
+    async def process_queue(self):
+        while True:
+            task = await self.request_queue.get()
+            try:
+                result = await self.execute_task_with_backoff(task)
+                task["future"].set_result(result)
+            except Exception as e:
+                task["future"].set_exception(e)
+            finally:
+                self.request_queue.task_done()
+
+    async def execute_task_with_backoff(self, task: dict[str, Any]):
+        retries = 0
+        backoff = self.initial_backoff
+        while retries < self.max_retries:
+            try:
+                async with self.semaphore:
+                    response = await asyncio.wait_for(
+                        self.aclient.embeddings(
+                            prompt=task["text"], model=self.base_model
+                        ),
+                        timeout=30,
+                    )
+                return response["embedding"]
+            except Exception as e:
+                logger.warning(
+                    f"Request failed (attempt {retries + 1}): {str(e)}"
+                )
+                retries += 1
+                if retries == self.max_retries:
+                    raise Exception(
+                        f"Max retries reached. Last error: {str(e)}"
+                    )
+                await asyncio.sleep(backoff + random.uniform(0, 1))
+                backoff = min(backoff * 2, self.max_backoff)
+
+    def get_embedding(
+        self,
+        text: str,
+        stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
+    ) -> list[float]:
+        if stage != EmbeddingProvider.PipeStage.BASE:
+            raise ValueError(
+                "OllamaEmbeddingProvider only supports search stage."
+            )
+
+        try:
+            response = self.client.embeddings(
+                prompt=text, model=self.base_model
+            )
+            return response["embedding"]
+        except Exception as e:
+            logger.error(f"Error getting embedding: {str(e)}")
+            raise
+
+    def get_embeddings(
+        self,
+        texts: list[str],
+        stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
+    ) -> list[list[float]]:
+        return [self.get_embedding(text, stage) for text in texts]
+
+    async def async_get_embeddings(
+        self,
+        texts: list[str],
+        stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
+    ) -> list[list[float]]:
+        if stage != EmbeddingProvider.PipeStage.BASE:
+            raise ValueError(
+                "OllamaEmbeddingProvider only supports search stage."
+            )
+
+        queue_processor = asyncio.create_task(self.process_queue())
+        futures = []
+        for text in texts:
+            future = asyncio.Future()
+            await self.request_queue.put({"text": text, "future": future})
+            futures.append(future)
+
+        try:
+            results = await asyncio.gather(*futures, return_exceptions=True)
+            # Check if any result is an exception and raise it
+            exceptions = set([r for r in results if isinstance(r, Exception)])
+            if exceptions:
+                raise Exception(
+                    f"Embedding generation failed for one or more embeddings."
+                )
+            return results
+        except Exception as e:
+            logger.error(f"Embedding generation failed: {str(e)}")
+            raise
+        finally:
+            await self.request_queue.join()
+            queue_processor.cancel()
+
+    def rerank(
+        self,
+        query: str,
+        results: list[VectorSearchResult],
+        stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.RERANK,
+        limit: int = 10,
+    ) -> list[VectorSearchResult]:
+        return results[:limit]
+
+    def tokenize_string(
+        self, text: str, model: str, stage: EmbeddingProvider.PipeStage
+    ) -> list[int]:
+        raise NotImplementedError(
+            "Tokenization is not supported by OllamaEmbeddingProvider."
+        )
diff --git a/R2R/r2r/providers/embeddings/openai/openai_base.py b/R2R/r2r/providers/embeddings/openai/openai_base.py
new file mode 100755
index 00000000..7e7d32aa
--- /dev/null
+++ b/R2R/r2r/providers/embeddings/openai/openai_base.py
@@ -0,0 +1,200 @@
+import logging
+import os
+
+from openai import AsyncOpenAI, AuthenticationError, OpenAI
+
+from r2r.base import EmbeddingConfig, EmbeddingProvider, VectorSearchResult
+
+logger = logging.getLogger(__name__)
+
+
+class OpenAIEmbeddingProvider(EmbeddingProvider):
+    MODEL_TO_TOKENIZER = {
+        "text-embedding-ada-002": "cl100k_base",
+        "text-embedding-3-small": "cl100k_base",
+        "text-embedding-3-large": "cl100k_base",
+    }
+    MODEL_TO_DIMENSIONS = {
+        "text-embedding-ada-002": [1536],
+        "text-embedding-3-small": [512, 1536],
+        "text-embedding-3-large": [256, 1024, 3072],
+    }
+
+    def __init__(self, config: EmbeddingConfig):
+        super().__init__(config)
+        provider = config.provider
+        if not provider:
+            raise ValueError(
+                "Must set provider in order to initialize OpenAIEmbeddingProvider."
+            )
+
+        if provider != "openai":
+            raise ValueError(
+                "OpenAIEmbeddingProvider must be initialized with provider `openai`."
+            )
+        if not os.getenv("OPENAI_API_KEY"):
+            raise ValueError(
+                "Must set OPENAI_API_KEY in order to initialize OpenAIEmbeddingProvider."
+            )
+        self.client = OpenAI()
+        self.async_client = AsyncOpenAI()
+
+        if config.rerank_model:
+            raise ValueError(
+                "OpenAIEmbeddingProvider does not support separate reranking."
+            )
+        self.base_model = config.base_model
+        self.base_dimension = config.base_dimension
+
+        if self.base_model not in OpenAIEmbeddingProvider.MODEL_TO_TOKENIZER:
+            raise ValueError(
+                f"OpenAI embedding model {self.base_model} not supported."
+            )
+        if (
+            self.base_dimension
+            and self.base_dimension
+            not in OpenAIEmbeddingProvider.MODEL_TO_DIMENSIONS[self.base_model]
+        ):
+            raise ValueError(
+                f"Dimensions {self.dimension} for {self.base_model} are not supported"
+            )
+
+        if not self.base_model or not self.base_dimension:
+            raise ValueError(
+                "Must set base_model and base_dimension in order to initialize OpenAIEmbeddingProvider."
+            )
+
+        if config.rerank_model:
+            raise ValueError(
+                "OpenAIEmbeddingProvider does not support separate reranking."
+            )
+
+    def get_embedding(
+        self,
+        text: str,
+        stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
+    ) -> list[float]:
+        if stage != EmbeddingProvider.PipeStage.BASE:
+            raise ValueError(
+                "OpenAIEmbeddingProvider only supports search stage."
+            )
+
+        try:
+            return (
+                self.client.embeddings.create(
+                    input=[text],
+                    model=self.base_model,
+                    dimensions=self.base_dimension
+                    or OpenAIEmbeddingProvider.MODEL_TO_DIMENSIONS[
+                        self.base_model
+                    ][-1],
+                )
+                .data[0]
+                .embedding
+            )
+        except AuthenticationError as e:
+            raise ValueError(
+                "Invalid OpenAI API key provided. Please check your OPENAI_API_KEY environment variable."
+            ) from e
+
+    async def async_get_embedding(
+        self,
+        text: str,
+        stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
+    ) -> list[float]:
+        if stage != EmbeddingProvider.PipeStage.BASE:
+            raise ValueError(
+                "OpenAIEmbeddingProvider only supports search stage."
+            )
+
+        try:
+            response = await self.async_client.embeddings.create(
+                input=[text],
+                model=self.base_model,
+                dimensions=self.base_dimension
+                or OpenAIEmbeddingProvider.MODEL_TO_DIMENSIONS[
+                    self.base_model
+                ][-1],
+            )
+            return response.data[0].embedding
+        except AuthenticationError as e:
+            raise ValueError(
+                "Invalid OpenAI API key provided. Please check your OPENAI_API_KEY environment variable."
+            ) from e
+
+    def get_embeddings(
+        self,
+        texts: list[str],
+        stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
+    ) -> list[list[float]]:
+        if stage != EmbeddingProvider.PipeStage.BASE:
+            raise ValueError(
+                "OpenAIEmbeddingProvider only supports search stage."
+            )
+
+        try:
+            return [
+                ele.embedding
+                for ele in self.client.embeddings.create(
+                    input=texts,
+                    model=self.base_model,
+                    dimensions=self.base_dimension
+                    or OpenAIEmbeddingProvider.MODEL_TO_DIMENSIONS[
+                        self.base_model
+                    ][-1],
+                ).data
+            ]
+        except AuthenticationError as e:
+            raise ValueError(
+                "Invalid OpenAI API key provided. Please check your OPENAI_API_KEY environment variable."
+            ) from e
+
+    async def async_get_embeddings(
+        self,
+        texts: list[str],
+        stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
+    ) -> list[list[float]]:
+        if stage != EmbeddingProvider.PipeStage.BASE:
+            raise ValueError(
+                "OpenAIEmbeddingProvider only supports search stage."
+            )
+
+        try:
+            response = await self.async_client.embeddings.create(
+                input=texts,
+                model=self.base_model,
+                dimensions=self.base_dimension
+                or OpenAIEmbeddingProvider.MODEL_TO_DIMENSIONS[
+                    self.base_model
+                ][-1],
+            )
+            return [ele.embedding for ele in response.data]
+        except AuthenticationError as e:
+            raise ValueError(
+                "Invalid OpenAI API key provided. Please check your OPENAI_API_KEY environment variable."
+            ) from e
+
+    def rerank(
+        self,
+        query: str,
+        results: list[VectorSearchResult],
+        stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.RERANK,
+        limit: int = 10,
+    ):
+        return results[:limit]
+
+    def tokenize_string(self, text: str, model: str) -> list[int]:
+        try:
+            import tiktoken
+        except ImportError:
+            raise ValueError(
+                "Must download tiktoken library to run `tokenize_string`."
+            )
+        # tiktoken encoding -
+        # cl100k_base -	gpt-4, gpt-3.5-turbo, text-embedding-ada-002, text-embedding-3-small, text-embedding-3-large
+        if model not in OpenAIEmbeddingProvider.MODEL_TO_TOKENIZER:
+            raise ValueError(f"OpenAI embedding model {model} not supported.")
+        encoding = tiktoken.get_encoding(
+            OpenAIEmbeddingProvider.MODEL_TO_TOKENIZER[model]
+        )
+        return encoding.encode(text)
diff --git a/R2R/r2r/providers/embeddings/sentence_transformer/sentence_transformer_base.py b/R2R/r2r/providers/embeddings/sentence_transformer/sentence_transformer_base.py
new file mode 100755
index 00000000..3316cb60
--- /dev/null
+++ b/R2R/r2r/providers/embeddings/sentence_transformer/sentence_transformer_base.py
@@ -0,0 +1,160 @@
+import logging
+
+from r2r.base import EmbeddingConfig, EmbeddingProvider, VectorSearchResult
+
+logger = logging.getLogger(__name__)
+
+
+class SentenceTransformerEmbeddingProvider(EmbeddingProvider):
+    def __init__(
+        self,
+        config: EmbeddingConfig,
+    ):
+        super().__init__(config)
+        logger.info(
+            "Initializing `SentenceTransformerEmbeddingProvider` with separate models for search and rerank."
+        )
+        provider = config.provider
+        if not provider:
+            raise ValueError(
+                "Must set provider in order to initialize SentenceTransformerEmbeddingProvider."
+            )
+        if provider != "sentence-transformers":
+            raise ValueError(
+                "SentenceTransformerEmbeddingProvider must be initialized with provider `sentence-transformers`."
+            )
+        try:
+            from sentence_transformers import CrossEncoder, SentenceTransformer
+
+            self.SentenceTransformer = SentenceTransformer
+            # TODO - Modify this to be configurable, as `bge-reranker-large` is a `SentenceTransformer` model
+            self.CrossEncoder = CrossEncoder
+        except ImportError as e:
+            raise ValueError(
+                "Must download sentence-transformers library to run `SentenceTransformerEmbeddingProvider`."
+            ) from e
+
+        # Initialize separate models for search and rerank
+        self.do_search = False
+        self.do_rerank = False
+
+        self.search_encoder = self._init_model(
+            config, EmbeddingProvider.PipeStage.BASE
+        )
+        self.rerank_encoder = self._init_model(
+            config, EmbeddingProvider.PipeStage.RERANK
+        )
+
+    def _init_model(self, config: EmbeddingConfig, stage: str):
+        stage_name = stage.name.lower()
+        model = config.dict().get(f"{stage_name}_model", None)
+        dimension = config.dict().get(f"{stage_name}_dimension", None)
+
+        transformer_type = config.dict().get(
+            f"{stage_name}_transformer_type", "SentenceTransformer"
+        )
+
+        if stage == EmbeddingProvider.PipeStage.BASE:
+            self.do_search = True
+            # Check if a model is set for the stage
+            if not (model and dimension and transformer_type):
+                raise ValueError(
+                    f"Must set {stage.name.lower()}_model and {stage.name.lower()}_dimension for {stage} stage in order to initialize SentenceTransformerEmbeddingProvider."
+                )
+
+        if stage == EmbeddingProvider.PipeStage.RERANK:
+            # Check if a model is set for the stage
+            if not (model and dimension and transformer_type):
+                return None
+
+            self.do_rerank = True
+            if transformer_type == "SentenceTransformer":
+                raise ValueError(
+                    f"`SentenceTransformer` models are not yet supported for {stage} stage in SentenceTransformerEmbeddingProvider."
+                )
+
+        # Save the model_key and dimension into instance variables
+        setattr(self, f"{stage_name}_model", model)
+        setattr(self, f"{stage_name}_dimension", dimension)
+        setattr(self, f"{stage_name}_transformer_type", transformer_type)
+
+        # Initialize the model
+        encoder = (
+            self.SentenceTransformer(
+                model, truncate_dim=dimension, trust_remote_code=True
+            )
+            if transformer_type == "SentenceTransformer"
+            else self.CrossEncoder(model, trust_remote_code=True)
+        )
+        return encoder
+
+    def get_embedding(
+        self,
+        text: str,
+        stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
+    ) -> list[float]:
+        if stage != EmbeddingProvider.PipeStage.BASE:
+            raise ValueError("`get_embedding` only supports `SEARCH` stage.")
+        if not self.do_search:
+            raise ValueError(
+                "`get_embedding` can only be called for the search stage if a search model is set."
+            )
+        encoder = self.search_encoder
+        return encoder.encode([text]).tolist()[0]
+
+    def get_embeddings(
+        self,
+        texts: list[str],
+        stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
+    ) -> list[list[float]]:
+        if stage != EmbeddingProvider.PipeStage.BASE:
+            raise ValueError("`get_embeddings` only supports `SEARCH` stage.")
+        if not self.do_search:
+            raise ValueError(
+                "`get_embeddings` can only be called for the search stage if a search model is set."
+            )
+        encoder = (
+            self.search_encoder
+            if stage == EmbeddingProvider.PipeStage.BASE
+            else self.rerank_encoder
+        )
+        return encoder.encode(texts).tolist()
+
+    def rerank(
+        self,
+        query: str,
+        results: list[VectorSearchResult],
+        stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.RERANK,
+        limit: int = 10,
+    ) -> list[VectorSearchResult]:
+        if stage != EmbeddingProvider.PipeStage.RERANK:
+            raise ValueError("`rerank` only supports `RERANK` stage.")
+        if not self.do_rerank:
+            return results[:limit]
+
+        from copy import copy
+
+        texts = copy([doc.metadata["text"] for doc in results])
+        # Use the rank method from the rerank_encoder, which is a CrossEncoder model
+        reranked_scores = self.rerank_encoder.rank(
+            query, texts, return_documents=False, top_k=limit
+        )
+        # Map the reranked scores back to the original documents
+        reranked_results = []
+        for score in reranked_scores:
+            corpus_id = score["corpus_id"]
+            new_result = results[corpus_id]
+            new_result.score = float(score["score"])
+            reranked_results.append(new_result)
+
+        # Sort the documents by the new scores in descending order
+        reranked_results.sort(key=lambda doc: doc.score, reverse=True)
+        return reranked_results
+
+    def tokenize_string(
+        self,
+        stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
+    ) -> list[int]:
+        raise ValueError(
+            "SentenceTransformerEmbeddingProvider does not support tokenize_string."
+        )
diff --git a/R2R/r2r/providers/eval/__init__.py b/R2R/r2r/providers/eval/__init__.py
new file mode 100755
index 00000000..3f5e1b51
--- /dev/null
+++ b/R2R/r2r/providers/eval/__init__.py
@@ -0,0 +1,3 @@
+from .llm.base_llm_eval import LLMEvalProvider
+
+__all__ = ["LLMEvalProvider"]
diff --git a/R2R/r2r/providers/eval/llm/base_llm_eval.py b/R2R/r2r/providers/eval/llm/base_llm_eval.py
new file mode 100755
index 00000000..7c573a34
--- /dev/null
+++ b/R2R/r2r/providers/eval/llm/base_llm_eval.py
@@ -0,0 +1,84 @@
+from fractions import Fraction
+from typing import Union
+
+from r2r import EvalConfig, EvalProvider, LLMProvider, PromptProvider
+from r2r.base.abstractions.llm import GenerationConfig
+
+
+class LLMEvalProvider(EvalProvider):
+    def __init__(
+        self,
+        config: EvalConfig,
+        llm_provider: LLMProvider,
+        prompt_provider: PromptProvider,
+    ):
+        super().__init__(config)
+
+        self.llm_provider = llm_provider
+        self.prompt_provider = prompt_provider
+
+    def _calc_query_context_relevancy(self, query: str, context: str) -> float:
+        system_prompt = self.prompt_provider.get_prompt("default_system")
+        eval_prompt = self.prompt_provider.get_prompt(
+            "rag_context_eval", {"query": query, "context": context}
+        )
+        response = self.llm_provider.get_completion(
+            self.prompt_provider._get_message_payload(
+                system_prompt, eval_prompt
+            ),
+            self.eval_generation_config,
+        )
+        response_text = response.choices[0].message.content
+        fraction = (
+            response_text
+            # Get the fraction in the returned tuple
+            .split(",")[-1][:-1]
+            # Remove any quotes and spaces
+            .replace("'", "")
+            .replace('"', "")
+            .strip()
+        )
+        return float(Fraction(fraction))
+
+    def _calc_answer_grounding(
+        self, query: str, context: str, answer: str
+    ) -> float:
+        system_prompt = self.prompt_provider.get_prompt("default_system")
+        eval_prompt = self.prompt_provider.get_prompt(
+            "rag_answer_eval",
+            {"query": query, "context": context, "answer": answer},
+        )
+        response = self.llm_provider.get_completion(
+            self.prompt_provider._get_message_payload(
+                system_prompt, eval_prompt
+            ),
+            self.eval_generation_config,
+        )
+        response_text = response.choices[0].message.content
+        fraction = (
+            response_text
+            # Get the fraction in the returned tuple
+            .split(",")[-1][:-1]
+            # Remove any quotes and spaces
+            .replace("'", "")
+            .replace('"', "")
+            .strip()
+        )
+        return float(Fraction(fraction))
+
+    def _evaluate(
+        self,
+        query: str,
+        context: str,
+        answer: str,
+        eval_generation_config: GenerationConfig,
+    ) -> dict[str, dict[str, Union[str, float]]]:
+        self.eval_generation_config = eval_generation_config
+        query_context_relevancy = self._calc_query_context_relevancy(
+            query, context
+        )
+        answer_grounding = self._calc_answer_grounding(query, context, answer)
+        return {
+            "query_context_relevancy": query_context_relevancy,
+            "answer_grounding": answer_grounding,
+        }
diff --git a/R2R/r2r/providers/kg/__init__.py b/R2R/r2r/providers/kg/__init__.py
new file mode 100755
index 00000000..36bc79a2
--- /dev/null
+++ b/R2R/r2r/providers/kg/__init__.py
@@ -0,0 +1,3 @@
+from .neo4j.base_neo4j import Neo4jKGProvider
+
+__all__ = ["Neo4jKGProvider"]
diff --git a/R2R/r2r/providers/kg/neo4j/base_neo4j.py b/R2R/r2r/providers/kg/neo4j/base_neo4j.py
new file mode 100755
index 00000000..9ede2b85
--- /dev/null
+++ b/R2R/r2r/providers/kg/neo4j/base_neo4j.py
@@ -0,0 +1,983 @@
+# abstractions are taken from LlamaIndex
+# Neo4jKGProvider is almost entirely taken from LlamaIndex Neo4jPropertyGraphStore
+# https://github.com/run-llama/llama_index
+import json
+import os
+from typing import Any, Dict, List, Optional, Tuple
+
+from r2r.base import (
+    EntityType,
+    KGConfig,
+    KGProvider,
+    PromptProvider,
+    format_entity_types,
+    format_relations,
+)
+from r2r.base.abstractions.llama_abstractions import (
+    LIST_LIMIT,
+    ChunkNode,
+    EntityNode,
+    LabelledNode,
+    PropertyGraphStore,
+    Relation,
+    Triplet,
+    VectorStoreQuery,
+    clean_string_values,
+    value_sanitize,
+)
+
+
+def remove_empty_values(input_dict):
+    """
+    Remove entries with empty values from the dictionary.
+
+    Parameters:
+    input_dict (dict): The dictionary from which empty values need to be removed.
+
+    Returns:
+    dict: A new dictionary with all empty values removed.
+    """
+    # Create a new dictionary excluding empty values
+    return {key: value for key, value in input_dict.items() if value}
+
+
+BASE_ENTITY_LABEL = "__Entity__"
+EXCLUDED_LABELS = ["_Bloom_Perspective_", "_Bloom_Scene_"]
+EXCLUDED_RELS = ["_Bloom_HAS_SCENE_"]
+EXHAUSTIVE_SEARCH_LIMIT = 10000
+# Threshold for returning all available prop values in graph schema
+DISTINCT_VALUE_LIMIT = 10
+
+node_properties_query = """
+CALL apoc.meta.data()
+YIELD label, other, elementType, type, property
+WHERE NOT type = "RELATIONSHIP" AND elementType = "node"
+  AND NOT label IN $EXCLUDED_LABELS
+WITH label AS nodeLabels, collect({property:property, type:type}) AS properties
+RETURN {labels: nodeLabels, properties: properties} AS output
+
+"""
+
+rel_properties_query = """
+CALL apoc.meta.data()
+YIELD label, other, elementType, type, property
+WHERE NOT type = "RELATIONSHIP" AND elementType = "relationship"
+      AND NOT label in $EXCLUDED_LABELS
+WITH label AS nodeLabels, collect({property:property, type:type}) AS properties
+RETURN {type: nodeLabels, properties: properties} AS output
+"""
+
+rel_query = """
+CALL apoc.meta.data()
+YIELD label, other, elementType, type, property
+WHERE type = "RELATIONSHIP" AND elementType = "node"
+UNWIND other AS other_node
+WITH * WHERE NOT label IN $EXCLUDED_LABELS
+    AND NOT other_node IN $EXCLUDED_LABELS
+RETURN {start: label, type: property, end: toString(other_node)} AS output
+"""
+
+
+class Neo4jKGProvider(PropertyGraphStore, KGProvider):
+    r"""
+    Neo4j Property Graph Store.
+
+    This class implements a Neo4j property graph store.
+
+    If you are using local Neo4j instead of aura, here's a helpful
+    command for launching the docker container:
+
+    ```bash
+    docker run \
+        -p 7474:7474 -p 7687:7687 \
+        -v $PWD/data:/data -v $PWD/plugins:/plugins \
+        --name neo4j-apoc \
+        -e NEO4J_apoc_export_file_enabled=true \
+        -e NEO4J_apoc_import_file_enabled=true \
+        -e NEO4J_apoc_import_file_use__neo4j__config=true \
+        -e NEO4JLABS_PLUGINS=\\[\"apoc\"\\] \
+        neo4j:latest
+    ```
+
+    Args:
+        username (str): The username for the Neo4j database.
+        password (str): The password for the Neo4j database.
+        url (str): The URL for the Neo4j database.
+        database (Optional[str]): The name of the database to connect to. Defaults to "neo4j".
+
+    Examples:
+        `pip install llama-index-graph-stores-neo4j`
+
+        ```python
+        from llama_index.core.indices.property_graph import PropertyGraphIndex
+        from llama_index.graph_stores.neo4j import Neo4jKGProvider
+
+        # Create a Neo4jKGProvider instance
+        graph_store = Neo4jKGProvider(
+            username="neo4j",
+            password="neo4j",
+            url="bolt://localhost:7687",
+            database="neo4j"
+        )
+
+        # create the index
+        index = PropertyGraphIndex.from_documents(
+            documents,
+            property_graph_store=graph_store,
+        )
+        ```
+    """
+
+    supports_structured_queries: bool = True
+    supports_vector_queries: bool = True
+
+    def __init__(
+        self,
+        config: KGConfig,
+        refresh_schema: bool = True,
+        sanitize_query_output: bool = True,
+        enhanced_schema: bool = False,
+        *args: Any,
+        **kwargs: Any,
+    ) -> None:
+        if config.provider != "neo4j":
+            raise ValueError(
+                "Neo4jKGProvider must be initialized with config with `neo4j` provider."
+            )
+
+        try:
+            import neo4j
+        except ImportError:
+            raise ImportError("Please install neo4j: pip install neo4j")
+
+        username = os.getenv("NEO4J_USER")
+        password = os.getenv("NEO4J_PASSWORD")
+        url = os.getenv("NEO4J_URL")
+        database = os.getenv("NEO4J_DATABASE", "neo4j")
+
+        if not username or not password or not url:
+            raise ValueError(
+                "Neo4j configuration values are missing. Please set NEO4J_USER, NEO4J_PASSWORD, and NEO4J_URL environment variables."
+            )
+
+        self.sanitize_query_output = sanitize_query_output
+        self.enhcnaced_schema = enhanced_schema
+        self._driver = neo4j.GraphDatabase.driver(
+            url, auth=(username, password), **kwargs
+        )
+        self._async_driver = neo4j.AsyncGraphDatabase.driver(
+            url,
+            auth=(username, password),
+            **kwargs,
+        )
+        self._database = database
+        self.structured_schema = {}
+        if refresh_schema:
+            self.refresh_schema()
+        self.neo4j = neo4j
+        self.config = config
+
+    @property
+    def client(self):
+        return self._driver
+
+    def refresh_schema(self) -> None:
+        """Refresh the schema."""
+        node_query_results = self.structured_query(
+            node_properties_query,
+            param_map={
+                "EXCLUDED_LABELS": [*EXCLUDED_LABELS, BASE_ENTITY_LABEL]
+            },
+        )
+        node_properties = (
+            [el["output"] for el in node_query_results]
+            if node_query_results
+            else []
+        )
+
+        rels_query_result = self.structured_query(
+            rel_properties_query, param_map={"EXCLUDED_LABELS": EXCLUDED_RELS}
+        )
+        rel_properties = (
+            [el["output"] for el in rels_query_result]
+            if rels_query_result
+            else []
+        )
+
+        rel_objs_query_result = self.structured_query(
+            rel_query,
+            param_map={
+                "EXCLUDED_LABELS": [*EXCLUDED_LABELS, BASE_ENTITY_LABEL]
+            },
+        )
+        relationships = (
+            [el["output"] for el in rel_objs_query_result]
+            if rel_objs_query_result
+            else []
+        )
+
+        # Get constraints & indexes
+        try:
+            constraint = self.structured_query("SHOW CONSTRAINTS")
+            index = self.structured_query(
+                "CALL apoc.schema.nodes() YIELD label, properties, type, size, "
+                "valuesSelectivity WHERE type = 'RANGE' RETURN *, "
+                "size * valuesSelectivity as distinctValues"
+            )
+        except (
+            self.neo4j.exceptions.ClientError
+        ):  # Read-only user might not have access to schema information
+            constraint = []
+            index = []
+
+        self.structured_schema = {
+            "node_props": {
+                el["labels"]: el["properties"] for el in node_properties
+            },
+            "rel_props": {
+                el["type"]: el["properties"] for el in rel_properties
+            },
+            "relationships": relationships,
+            "metadata": {"constraint": constraint, "index": index},
+        }
+        schema_counts = self.structured_query(
+            "CALL apoc.meta.graphSample() YIELD nodes, relationships "
+            "RETURN nodes, [rel in relationships | {name:apoc.any.property"
+            "(rel, 'type'), count: apoc.any.property(rel, 'count')}]"
+            " AS relationships"
+        )
+        # Update node info
+        for node in schema_counts[0].get("nodes", []):
+            # Skip bloom labels
+            if node["name"] in EXCLUDED_LABELS:
+                continue
+            node_props = self.structured_schema["node_props"].get(node["name"])
+            if not node_props:  # The node has no properties
+                continue
+            enhanced_cypher = self._enhanced_schema_cypher(
+                node["name"],
+                node_props,
+                node["count"] < EXHAUSTIVE_SEARCH_LIMIT,
+            )
+            enhanced_info = self.structured_query(enhanced_cypher)[0]["output"]
+            for prop in node_props:
+                if prop["property"] in enhanced_info:
+                    prop.update(enhanced_info[prop["property"]])
+        # Update rel info
+        for rel in schema_counts[0].get("relationships", []):
+            # Skip bloom labels
+            if rel["name"] in EXCLUDED_RELS:
+                continue
+            rel_props = self.structured_schema["rel_props"].get(rel["name"])
+            if not rel_props:  # The rel has no properties
+                continue
+            enhanced_cypher = self._enhanced_schema_cypher(
+                rel["name"],
+                rel_props,
+                rel["count"] < EXHAUSTIVE_SEARCH_LIMIT,
+                is_relationship=True,
+            )
+            try:
+                enhanced_info = self.structured_query(enhanced_cypher)[0][
+                    "output"
+                ]
+                for prop in rel_props:
+                    if prop["property"] in enhanced_info:
+                        prop.update(enhanced_info[prop["property"]])
+            except self.neo4j.exceptions.ClientError:
+                # Sometimes the types are not consistent in the db
+                pass
+
+    def upsert_nodes(self, nodes: List[LabelledNode]) -> None:
+        # Lists to hold separated types
+        entity_dicts: List[dict] = []
+        chunk_dicts: List[dict] = []
+
+        # Sort by type
+        for item in nodes:
+            if isinstance(item, EntityNode):
+                entity_dicts.append({**item.dict(), "id": item.id})
+            elif isinstance(item, ChunkNode):
+                chunk_dicts.append({**item.dict(), "id": item.id})
+            else:
+                # Log that we do not support these types of nodes
+                # Or raise an error?
+                pass
+
+        if chunk_dicts:
+            self.structured_query(
+                """
+                UNWIND $data AS row
+                MERGE (c:Chunk {id: row.id})
+                SET c.text = row.text
+                WITH c, row
+                SET c += row.properties
+                WITH c, row.embedding AS embedding
+                WHERE embedding IS NOT NULL
+                CALL db.create.setNodeVectorProperty(c, 'embedding', embedding)
+                RETURN count(*)
+                """,
+                param_map={"data": chunk_dicts},
+            )
+
+        if entity_dicts:
+            self.structured_query(
+                """
+                UNWIND $data AS row
+                MERGE (e:`__Entity__` {id: row.id})
+                SET e += apoc.map.clean(row.properties, [], [])
+                SET e.name = row.name
+                WITH e, row
+                CALL apoc.create.addLabels(e, [row.label])
+                YIELD node
+                WITH e, row
+                CALL {
+                    WITH e, row
+                    WITH e, row
+                    WHERE row.embedding IS NOT NULL
+                    CALL db.create.setNodeVectorProperty(e, 'embedding', row.embedding)
+                    RETURN count(*) AS count
+                }
+                WITH e, row WHERE row.properties.triplet_source_id IS NOT NULL
+                MERGE (c:Chunk {id: row.properties.triplet_source_id})
+                MERGE (e)<-[:MENTIONS]-(c)
+                """,
+                param_map={"data": entity_dicts},
+            )
+
+    def upsert_relations(self, relations: List[Relation]) -> None:
+        """Add relations."""
+        params = [r.dict() for r in relations]
+
+        self.structured_query(
+            """
+            UNWIND $data AS row
+            MERGE (source {id: row.source_id})
+            MERGE (target {id: row.target_id})
+            WITH source, target, row
+            CALL apoc.merge.relationship(source, row.label, {}, row.properties, target) YIELD rel
+            RETURN count(*)
+            """,
+            param_map={"data": params},
+        )
+
+    def get(
+        self,
+        properties: Optional[dict] = None,
+        ids: Optional[List[str]] = None,
+    ) -> List[LabelledNode]:
+        """Get nodes."""
+        cypher_statement = "MATCH (e) "
+
+        params = {}
+        if properties or ids:
+            cypher_statement += "WHERE "
+
+        if ids:
+            cypher_statement += "e.id in $ids "
+            params["ids"] = ids
+
+        if properties:
+            prop_list = []
+            for i, prop in enumerate(properties):
+                prop_list.append(f"e.`{prop}` = $property_{i}")
+                params[f"property_{i}"] = properties[prop]
+            cypher_statement += " AND ".join(prop_list)
+
+        return_statement = """
+        WITH e
+        RETURN e.id AS name,
+               [l in labels(e) WHERE l <> '__Entity__' | l][0] AS type,
+               e{.* , embedding: Null, id: Null} AS properties
+        """
+        cypher_statement += return_statement
+
+        response = self.structured_query(cypher_statement, param_map=params)
+        response = response if response else []
+
+        nodes = []
+        for record in response:
+            # text indicates a chunk node
+            # none on the type indicates an implicit node, likely a chunk node
+            if "text" in record["properties"] or record["type"] is None:
+                text = record["properties"].pop("text", "")
+                nodes.append(
+                    ChunkNode(
+                        id_=record["name"],
+                        text=text,
+                        properties=remove_empty_values(record["properties"]),
+                    )
+                )
+            else:
+                nodes.append(
+                    EntityNode(
+                        name=record["name"],
+                        label=record["type"],
+                        properties=remove_empty_values(record["properties"]),
+                    )
+                )
+
+        return nodes
+
+    def get_triplets(
+        self,
+        entity_names: Optional[List[str]] = None,
+        relation_names: Optional[List[str]] = None,
+        properties: Optional[dict] = None,
+        ids: Optional[List[str]] = None,
+    ) -> List[Triplet]:
+        # TODO: handle ids of chunk nodes
+        cypher_statement = "MATCH (e:`__Entity__`) "
+
+        params = {}
+        if entity_names or properties or ids:
+            cypher_statement += "WHERE "
+
+        if entity_names:
+            cypher_statement += "e.name in $entity_names "
+            params["entity_names"] = entity_names
+
+        if ids:
+            cypher_statement += "e.id in $ids "
+            params["ids"] = ids
+
+        if properties:
+            prop_list = []
+            for i, prop in enumerate(properties):
+                prop_list.append(f"e.`{prop}` = $property_{i}")
+                params[f"property_{i}"] = properties[prop]
+            cypher_statement += " AND ".join(prop_list)
+
+        return_statement = f"""
+        WITH e
+        CALL {{
+            WITH e
+            MATCH (e)-[r{':`' + '`|`'.join(relation_names) + '`' if relation_names else ''}]->(t)
+            RETURN e.name AS source_id, [l in labels(e) WHERE l <> '__Entity__' | l][0] AS source_type,
+                   e{{.* , embedding: Null, name: Null}} AS source_properties,
+                   type(r) AS type,
+                   t.name AS target_id, [l in labels(t) WHERE l <> '__Entity__' | l][0] AS target_type,
+                   t{{.* , embedding: Null, name: Null}} AS target_properties
+            UNION ALL
+            WITH e
+            MATCH (e)<-[r{':`' + '`|`'.join(relation_names) + '`' if relation_names else ''}]-(t)
+            RETURN t.name AS source_id, [l in labels(t) WHERE l <> '__Entity__' | l][0] AS source_type,
+                   e{{.* , embedding: Null, name: Null}} AS source_properties,
+                   type(r) AS type,
+                   e.name AS target_id, [l in labels(e) WHERE l <> '__Entity__' | l][0] AS target_type,
+                   t{{.* , embedding: Null, name: Null}} AS target_properties
+        }}
+        RETURN source_id, source_type, type, target_id, target_type, source_properties, target_properties"""
+        cypher_statement += return_statement
+
+        data = self.structured_query(cypher_statement, param_map=params)
+        data = data if data else []
+
+        triples = []
+        for record in data:
+            source = EntityNode(
+                name=record["source_id"],
+                label=record["source_type"],
+                properties=remove_empty_values(record["source_properties"]),
+            )
+            target = EntityNode(
+                name=record["target_id"],
+                label=record["target_type"],
+                properties=remove_empty_values(record["target_properties"]),
+            )
+            rel = Relation(
+                source_id=record["source_id"],
+                target_id=record["target_id"],
+                label=record["type"],
+            )
+            triples.append([source, rel, target])
+        return triples
+
+    def get_rel_map(
+        self,
+        graph_nodes: List[LabelledNode],
+        depth: int = 2,
+        limit: int = 30,
+        ignore_rels: Optional[List[str]] = None,
+    ) -> List[Triplet]:
+        """Get depth-aware rel map."""
+        triples = []
+
+        ids = [node.id for node in graph_nodes]
+        # Needs some optimization
+        response = self.structured_query(
+            f"""
+            MATCH (e:`__Entity__`)
+            WHERE e.id in $ids
+            MATCH p=(e)-[r*1..{depth}]-(other)
+            WHERE ALL(rel in relationships(p) WHERE type(rel) <> 'MENTIONS')
+            UNWIND relationships(p) AS rel
+            WITH distinct rel
+            WITH startNode(rel) AS source,
+                type(rel) AS type,
+                endNode(rel) AS endNode
+            RETURN source.id AS source_id, [l in labels(source) WHERE l <> '__Entity__' | l][0] AS source_type,
+                    source{{.* , embedding: Null, id: Null}} AS source_properties,
+                    type,
+                    endNode.id AS target_id, [l in labels(endNode) WHERE l <> '__Entity__' | l][0] AS target_type,
+                    endNode{{.* , embedding: Null, id: Null}} AS target_properties
+            LIMIT toInteger($limit)
+            """,
+            param_map={"ids": ids, "limit": limit},
+        )
+        response = response if response else []
+
+        ignore_rels = ignore_rels or []
+        for record in response:
+            if record["type"] in ignore_rels:
+                continue
+
+            source = EntityNode(
+                name=record["source_id"],
+                label=record["source_type"],
+                properties=remove_empty_values(record["source_properties"]),
+            )
+            target = EntityNode(
+                name=record["target_id"],
+                label=record["target_type"],
+                properties=remove_empty_values(record["target_properties"]),
+            )
+            rel = Relation(
+                source_id=record["source_id"],
+                target_id=record["target_id"],
+                label=record["type"],
+            )
+            triples.append([source, rel, target])
+
+        return triples
+
+    def structured_query(
+        self, query: str, param_map: Optional[Dict[str, Any]] = None
+    ) -> Any:
+        param_map = param_map or {}
+
+        with self._driver.session(database=self._database) as session:
+            result = session.run(query, param_map)
+            full_result = [d.data() for d in result]
+
+        if self.sanitize_query_output:
+            return value_sanitize(full_result)
+
+        return full_result
+
+    def vector_query(
+        self, query: VectorStoreQuery, **kwargs: Any
+    ) -> Tuple[List[LabelledNode], List[float]]:
+        """Query the graph store with a vector store query."""
+        data = self.structured_query(
+            """MATCH (e:`__Entity__`)
+            WHERE e.embedding IS NOT NULL AND size(e.embedding) = $dimension
+            WITH e, vector.similarity.cosine(e.embedding, $embedding) AS score
+            ORDER BY score DESC LIMIT toInteger($limit)
+            RETURN e.id AS name,
+               [l in labels(e) WHERE l <> '__Entity__' | l][0] AS type,
+               e{.* , embedding: Null, name: Null, id: Null} AS properties,
+               score""",
+            param_map={
+                "embedding": query.query_embedding,
+                "dimension": len(query.query_embedding),
+                "limit": query.similarity_top_k,
+            },
+        )
+        data = data if data else []
+
+        nodes = []
+        scores = []
+        for record in data:
+            node = EntityNode(
+                name=record["name"],
+                label=record["type"],
+                properties=remove_empty_values(record["properties"]),
+            )
+            nodes.append(node)
+            scores.append(record["score"])
+
+        return (nodes, scores)
+
+    def delete(
+        self,
+        entity_names: Optional[List[str]] = None,
+        relation_names: Optional[List[str]] = None,
+        properties: Optional[dict] = None,
+        ids: Optional[List[str]] = None,
+    ) -> None:
+        """Delete matching data."""
+        if entity_names:
+            self.structured_query(
+                "MATCH (n) WHERE n.name IN $entity_names DETACH DELETE n",
+                param_map={"entity_names": entity_names},
+            )
+
+        if ids:
+            self.structured_query(
+                "MATCH (n) WHERE n.id IN $ids DETACH DELETE n",
+                param_map={"ids": ids},
+            )
+
+        if relation_names:
+            for rel in relation_names:
+                self.structured_query(f"MATCH ()-[r:`{rel}`]->() DELETE r")
+
+        if properties:
+            cypher = "MATCH (e) WHERE "
+            prop_list = []
+            params = {}
+            for i, prop in enumerate(properties):
+                prop_list.append(f"e.`{prop}` = $property_{i}")
+                params[f"property_{i}"] = properties[prop]
+            cypher += " AND ".join(prop_list)
+            self.structured_query(
+                cypher + " DETACH DELETE e", param_map=params
+            )
+
+    def _enhanced_schema_cypher(
+        self,
+        label_or_type: str,
+        properties: List[Dict[str, Any]],
+        exhaustive: bool,
+        is_relationship: bool = False,
+    ) -> str:
+        if is_relationship:
+            match_clause = f"MATCH ()-[n:`{label_or_type}`]->()"
+        else:
+            match_clause = f"MATCH (n:`{label_or_type}`)"
+
+        with_clauses = []
+        return_clauses = []
+        output_dict = {}
+        if exhaustive:
+            for prop in properties:
+                prop_name = prop["property"]
+                prop_type = prop["type"]
+                if prop_type == "STRING":
+                    with_clauses.append(
+                        f"collect(distinct substring(toString(n.`{prop_name}`), 0, 50)) "
+                        f"AS `{prop_name}_values`"
+                    )
+                    return_clauses.append(
+                        f"values:`{prop_name}_values`[..{DISTINCT_VALUE_LIMIT}],"
+                        f" distinct_count: size(`{prop_name}_values`)"
+                    )
+                elif prop_type in [
+                    "INTEGER",
+                    "FLOAT",
+                    "DATE",
+                    "DATE_TIME",
+                    "LOCAL_DATE_TIME",
+                ]:
+                    with_clauses.append(
+                        f"min(n.`{prop_name}`) AS `{prop_name}_min`"
+                    )
+                    with_clauses.append(
+                        f"max(n.`{prop_name}`) AS `{prop_name}_max`"
+                    )
+                    with_clauses.append(
+                        f"count(distinct n.`{prop_name}`) AS `{prop_name}_distinct`"
+                    )
+                    return_clauses.append(
+                        f"min: toString(`{prop_name}_min`), "
+                        f"max: toString(`{prop_name}_max`), "
+                        f"distinct_count: `{prop_name}_distinct`"
+                    )
+                elif prop_type == "LIST":
+                    with_clauses.append(
+                        f"min(size(n.`{prop_name}`)) AS `{prop_name}_size_min`, "
+                        f"max(size(n.`{prop_name}`)) AS `{prop_name}_size_max`"
+                    )
+                    return_clauses.append(
+                        f"min_size: `{prop_name}_size_min`, "
+                        f"max_size: `{prop_name}_size_max`"
+                    )
+                elif prop_type in ["BOOLEAN", "POINT", "DURATION"]:
+                    continue
+                output_dict[prop_name] = "{" + return_clauses.pop() + "}"
+        else:
+            # Just sample 5 random nodes
+            match_clause += " WITH n LIMIT 5"
+            for prop in properties:
+                prop_name = prop["property"]
+                prop_type = prop["type"]
+
+                # Check if indexed property, we can still do exhaustive
+                prop_index = [
+                    el
+                    for el in self.structured_schema["metadata"]["index"]
+                    if el["label"] == label_or_type
+                    and el["properties"] == [prop_name]
+                    and el["type"] == "RANGE"
+                ]
+                if prop_type == "STRING":
+                    if (
+                        prop_index
+                        and prop_index[0].get("size") > 0
+                        and prop_index[0].get("distinctValues")
+                        <= DISTINCT_VALUE_LIMIT
+                    ):
+                        distinct_values = self.query(
+                            f"CALL apoc.schema.properties.distinct("
+                            f"'{label_or_type}', '{prop_name}') YIELD value"
+                        )[0]["value"]
+                        return_clauses.append(
+                            f"values: {distinct_values},"
+                            f" distinct_count: {len(distinct_values)}"
+                        )
+                    else:
+                        with_clauses.append(
+                            f"collect(distinct substring(n.`{prop_name}`, 0, 50)) "
+                            f"AS `{prop_name}_values`"
+                        )
+                        return_clauses.append(f"values: `{prop_name}_values`")
+                elif prop_type in [
+                    "INTEGER",
+                    "FLOAT",
+                    "DATE",
+                    "DATE_TIME",
+                    "LOCAL_DATE_TIME",
+                ]:
+                    if not prop_index:
+                        with_clauses.append(
+                            f"collect(distinct toString(n.`{prop_name}`)) "
+                            f"AS `{prop_name}_values`"
+                        )
+                        return_clauses.append(f"values: `{prop_name}_values`")
+                    else:
+                        with_clauses.append(
+                            f"min(n.`{prop_name}`) AS `{prop_name}_min`"
+                        )
+                        with_clauses.append(
+                            f"max(n.`{prop_name}`) AS `{prop_name}_max`"
+                        )
+                        with_clauses.append(
+                            f"count(distinct n.`{prop_name}`) AS `{prop_name}_distinct`"
+                        )
+                        return_clauses.append(
+                            f"min: toString(`{prop_name}_min`), "
+                            f"max: toString(`{prop_name}_max`), "
+                            f"distinct_count: `{prop_name}_distinct`"
+                        )
+
+                elif prop_type == "LIST":
+                    with_clauses.append(
+                        f"min(size(n.`{prop_name}`)) AS `{prop_name}_size_min`, "
+                        f"max(size(n.`{prop_name}`)) AS `{prop_name}_size_max`"
+                    )
+                    return_clauses.append(
+                        f"min_size: `{prop_name}_size_min`, "
+                        f"max_size: `{prop_name}_size_max`"
+                    )
+                elif prop_type in ["BOOLEAN", "POINT", "DURATION"]:
+                    continue
+
+                output_dict[prop_name] = "{" + return_clauses.pop() + "}"
+
+        with_clause = "WITH " + ",\n     ".join(with_clauses)
+        return_clause = (
+            "RETURN {"
+            + ", ".join(f"`{k}`: {v}" for k, v in output_dict.items())
+            + "} AS output"
+        )
+
+        # Combine all parts of the Cypher query
+        return f"{match_clause}\n{with_clause}\n{return_clause}"
+
+    def get_schema(self, refresh: bool = False) -> Any:
+        if refresh:
+            self.refresh_schema()
+
+        return self.structured_schema
+
+    def get_schema_str(self, refresh: bool = False) -> str:
+        schema = self.get_schema(refresh=refresh)
+
+        formatted_node_props = []
+        formatted_rel_props = []
+
+        if self.enhcnaced_schema:
+            # Enhanced formatting for nodes
+            for node_type, properties in schema["node_props"].items():
+                formatted_node_props.append(f"- **{node_type}**")
+                for prop in properties:
+                    example = ""
+                    if prop["type"] == "STRING" and prop.get("values"):
+                        if (
+                            prop.get("distinct_count", 11)
+                            > DISTINCT_VALUE_LIMIT
+                        ):
+                            example = (
+                                f'Example: "{clean_string_values(prop["values"][0])}"'
+                                if prop["values"]
+                                else ""
+                            )
+                        else:  # If less than 10 possible values return all
+                            example = (
+                                (
+                                    "Available options: "
+                                    f'{[clean_string_values(el) for el in prop["values"]]}'
+                                )
+                                if prop["values"]
+                                else ""
+                            )
+
+                    elif prop["type"] in [
+                        "INTEGER",
+                        "FLOAT",
+                        "DATE",
+                        "DATE_TIME",
+                        "LOCAL_DATE_TIME",
+                    ]:
+                        if prop.get("min") is not None:
+                            example = f'Min: {prop["min"]}, Max: {prop["max"]}'
+                        else:
+                            example = (
+                                f'Example: "{prop["values"][0]}"'
+                                if prop.get("values")
+                                else ""
+                            )
+                    elif prop["type"] == "LIST":
+                        # Skip embeddings
+                        if (
+                            not prop.get("min_size")
+                            or prop["min_size"] > LIST_LIMIT
+                        ):
+                            continue
+                        example = f'Min Size: {prop["min_size"]}, Max Size: {prop["max_size"]}'
+                    formatted_node_props.append(
+                        f"  - `{prop['property']}`: {prop['type']} {example}"
+                    )
+
+            # Enhanced formatting for relationships
+            for rel_type, properties in schema["rel_props"].items():
+                formatted_rel_props.append(f"- **{rel_type}**")
+                for prop in properties:
+                    example = ""
+                    if prop["type"] == "STRING":
+                        if (
+                            prop.get("distinct_count", 11)
+                            > DISTINCT_VALUE_LIMIT
+                        ):
+                            example = (
+                                f'Example: "{clean_string_values(prop["values"][0])}"'
+                                if prop.get("values")
+                                else ""
+                            )
+                        else:  # If less than 10 possible values return all
+                            example = (
+                                (
+                                    "Available options: "
+                                    f'{[clean_string_values(el) for el in prop["values"]]}'
+                                )
+                                if prop.get("values")
+                                else ""
+                            )
+                    elif prop["type"] in [
+                        "INTEGER",
+                        "FLOAT",
+                        "DATE",
+                        "DATE_TIME",
+                        "LOCAL_DATE_TIME",
+                    ]:
+                        if prop.get("min"):  # If we have min/max
+                            example = (
+                                f'Min: {prop["min"]}, Max:  {prop["max"]}'
+                            )
+                        else:  # return a single value
+                            example = (
+                                f'Example: "{prop["values"][0]}"'
+                                if prop.get("values")
+                                else ""
+                            )
+                    elif prop["type"] == "LIST":
+                        # Skip embeddings
+                        if prop["min_size"] > LIST_LIMIT:
+                            continue
+                        example = f'Min Size: {prop["min_size"]}, Max Size: {prop["max_size"]}'
+                    formatted_rel_props.append(
+                        f"  - `{prop['property']}: {prop['type']}` {example}"
+                    )
+        else:
+            # Format node properties
+            for label, props in schema["node_props"].items():
+                props_str = ", ".join(
+                    [f"{prop['property']}: {prop['type']}" for prop in props]
+                )
+                formatted_node_props.append(f"{label} {{{props_str}}}")
+
+            # Format relationship properties using structured_schema
+            for type, props in schema["rel_props"].items():
+                props_str = ", ".join(
+                    [f"{prop['property']}: {prop['type']}" for prop in props]
+                )
+                formatted_rel_props.append(f"{type} {{{props_str}}}")
+
+        # Format relationships
+        formatted_rels = [
+            f"(:{el['start']})-[:{el['type']}]->(:{el['end']})"
+            for el in schema["relationships"]
+        ]
+
+        return "\n".join(
+            [
+                "Node properties:",
+                "\n".join(formatted_node_props),
+                "Relationship properties:",
+                "\n".join(formatted_rel_props),
+                "The relationships:",
+                "\n".join(formatted_rels),
+            ]
+        )
+
+    def update_extraction_prompt(
+        self,
+        prompt_provider: PromptProvider,
+        entity_types: list[EntityType],
+        relations: list[Relation],
+    ):
+        # Fetch the kg extraction prompt with blank entity types and relations
+        # Note - Assumes that for given prompt there is a `_with_spec` that can have entities + relations specified
+        few_shot_ner_kg_extraction_with_spec = prompt_provider.get_prompt(
+            f"{self.config.kg_extraction_prompt}_with_spec"
+        )
+
+        # Format the prompt to include the desired entity types and relations
+        few_shot_ner_kg_extraction = (
+            few_shot_ner_kg_extraction_with_spec.replace(
+                "{entity_types}", format_entity_types(entity_types)
+            ).replace("{relations}", format_relations(relations))
+        )
+
+        # Update the "few_shot_ner_kg_extraction" prompt used in downstream KG construction
+        prompt_provider.update_prompt(
+            self.config.kg_extraction_prompt,
+            json.dumps(few_shot_ner_kg_extraction, ensure_ascii=False),
+        )
+
+    def update_kg_agent_prompt(
+        self,
+        prompt_provider: PromptProvider,
+        entity_types: list[EntityType],
+        relations: list[Relation],
+    ):
+        # Fetch the kg extraction prompt with blank entity types and relations
+        # Note - Assumes that for given prompt there is a `_with_spec` that can have entities + relations specified
+        few_shot_ner_kg_extraction_with_spec = prompt_provider.get_prompt(
+            f"{self.config.kg_agent_prompt}_with_spec"
+        )
+
+        # Format the prompt to include the desired entity types and relations
+        few_shot_ner_kg_extraction = (
+            few_shot_ner_kg_extraction_with_spec.replace(
+                "{entity_types}",
+                format_entity_types(entity_types, ignore_subcats=True),
+            ).replace("{relations}", format_relations(relations))
+        )
+
+        # Update the "few_shot_ner_kg_extraction" prompt used in downstream KG construction
+        prompt_provider.update_prompt(
+            self.config.kg_agent_prompt,
+            json.dumps(few_shot_ner_kg_extraction, ensure_ascii=False),
+        )
diff --git a/R2R/r2r/providers/llms/__init__.py b/R2R/r2r/providers/llms/__init__.py
new file mode 100755
index 00000000..38a1c54a
--- /dev/null
+++ b/R2R/r2r/providers/llms/__init__.py
@@ -0,0 +1,7 @@
+from .litellm.base_litellm import LiteLLM
+from .openai.base_openai import OpenAILLM
+
+__all__ = [
+    "LiteLLM",
+    "OpenAILLM",
+]
diff --git a/R2R/r2r/providers/llms/litellm/base_litellm.py b/R2R/r2r/providers/llms/litellm/base_litellm.py
new file mode 100755
index 00000000..581cce9a
--- /dev/null
+++ b/R2R/r2r/providers/llms/litellm/base_litellm.py
@@ -0,0 +1,142 @@
+import logging
+from typing import Any, Generator, Union
+
+from r2r.base import (
+    LLMChatCompletion,
+    LLMChatCompletionChunk,
+    LLMConfig,
+    LLMProvider,
+)
+from r2r.base.abstractions.llm import GenerationConfig
+
+logger = logging.getLogger(__name__)
+
+
+class LiteLLM(LLMProvider):
+    """A concrete class for creating LiteLLM models."""
+
+    def __init__(
+        self,
+        config: LLMConfig,
+        *args,
+        **kwargs,
+    ) -> None:
+        try:
+            from litellm import acompletion, completion
+
+            self.litellm_completion = completion
+            self.litellm_acompletion = acompletion
+        except ImportError:
+            raise ImportError(
+                "Error, `litellm` is required to run a LiteLLM. Please install it using `pip install litellm`."
+            )
+        super().__init__(config)
+
+    def get_completion(
+        self,
+        messages: list[dict],
+        generation_config: GenerationConfig,
+        **kwargs,
+    ) -> LLMChatCompletion:
+        if generation_config.stream:
+            raise ValueError(
+                "Stream must be set to False to use the `get_completion` method."
+            )
+        return self._get_completion(messages, generation_config, **kwargs)
+
+    def get_completion_stream(
+        self,
+        messages: list[dict],
+        generation_config: GenerationConfig,
+        **kwargs,
+    ) -> Generator[LLMChatCompletionChunk, None, None]:
+        if not generation_config.stream:
+            raise ValueError(
+                "Stream must be set to True to use the `get_completion_stream` method."
+            )
+        return self._get_completion(messages, generation_config, **kwargs)
+
+    def extract_content(self, response: LLMChatCompletion) -> str:
+        return response.choices[0].message.content
+
+    def _get_completion(
+        self,
+        messages: list[dict],
+        generation_config: GenerationConfig,
+        **kwargs,
+    ) -> Union[
+        LLMChatCompletion, Generator[LLMChatCompletionChunk, None, None]
+    ]:
+        # Create a dictionary with the default arguments
+        args = self._get_base_args(generation_config)
+        args["messages"] = messages
+
+        # Conditionally add the 'functions' argument if it's not None
+        if generation_config.functions is not None:
+            args["functions"] = generation_config.functions
+
+        args = {**args, **kwargs}
+        response = self.litellm_completion(**args)
+
+        if not generation_config.stream:
+            return LLMChatCompletion(**response.dict())
+        else:
+            return self._get_chat_completion(response)
+
+    def _get_chat_completion(
+        self,
+        response: Any,
+    ) -> Generator[LLMChatCompletionChunk, None, None]:
+        for part in response:
+            yield LLMChatCompletionChunk(**part.dict())
+
+    def _get_base_args(
+        self,
+        generation_config: GenerationConfig,
+        prompt=None,
+    ) -> dict:
+        """Get the base arguments for the LiteLLM API."""
+        args = {
+            "model": generation_config.model,
+            "temperature": generation_config.temperature,
+            "top_p": generation_config.top_p,
+            "stream": generation_config.stream,
+            # TODO - We need to cap this to avoid potential errors when exceed max allowable context
+            "max_tokens": generation_config.max_tokens_to_sample,
+        }
+        return args
+
+    async def aget_completion(
+        self,
+        messages: list[dict],
+        generation_config: GenerationConfig,
+        **kwargs,
+    ) -> LLMChatCompletion:
+        if generation_config.stream:
+            raise ValueError(
+                "Stream must be set to False to use the `aget_completion` method."
+            )
+        return await self._aget_completion(
+            messages, generation_config, **kwargs
+        )
+
+    async def _aget_completion(
+        self,
+        messages: list[dict],
+        generation_config: GenerationConfig,
+        **kwargs,
+    ) -> Union[LLMChatCompletion, LLMChatCompletionChunk]:
+        """Asynchronously get a completion from the OpenAI API based on the provided messages."""
+
+        # Create a dictionary with the default arguments
+        args = self._get_base_args(generation_config)
+
+        args["messages"] = messages
+
+        # Conditionally add the 'functions' argument if it's not None
+        if generation_config.functions is not None:
+            args["functions"] = generation_config.functions
+
+        args = {**args, **kwargs}
+        # Create the chat completion
+        return await self.litellm_acompletion(**args)
diff --git a/R2R/r2r/providers/llms/openai/base_openai.py b/R2R/r2r/providers/llms/openai/base_openai.py
new file mode 100755
index 00000000..460c0f0b
--- /dev/null
+++ b/R2R/r2r/providers/llms/openai/base_openai.py
@@ -0,0 +1,144 @@
+"""A module for creating OpenAI model abstractions."""
+
+import logging
+import os
+from typing import Union
+
+from r2r.base import (
+    LLMChatCompletion,
+    LLMChatCompletionChunk,
+    LLMConfig,
+    LLMProvider,
+)
+from r2r.base.abstractions.llm import GenerationConfig
+
+logger = logging.getLogger(__name__)
+
+
+class OpenAILLM(LLMProvider):
+    """A concrete class for creating OpenAI models."""
+
+    def __init__(
+        self,
+        config: LLMConfig,
+        *args,
+        **kwargs,
+    ) -> None:
+        if not isinstance(config, LLMConfig):
+            raise ValueError(
+                "The provided config must be an instance of OpenAIConfig."
+            )
+        try:
+            from openai import OpenAI  # noqa
+        except ImportError:
+            raise ImportError(
+                "Error, `openai` is required to run an OpenAILLM. Please install it using `pip install openai`."
+            )
+        if config.provider != "openai":
+            raise ValueError(
+                "OpenAILLM must be initialized with config with `openai` provider."
+            )
+        if not os.getenv("OPENAI_API_KEY"):
+            raise ValueError(
+                "OpenAI API key not found. Please set the OPENAI_API_KEY environment variable."
+            )
+        super().__init__(config)
+        self.config: LLMConfig = config
+        self.client = OpenAI()
+
+    def get_completion(
+        self,
+        messages: list[dict],
+        generation_config: GenerationConfig,
+        **kwargs,
+    ) -> LLMChatCompletion:
+        if generation_config.stream:
+            raise ValueError(
+                "Stream must be set to False to use the `get_completion` method."
+            )
+        return self._get_completion(messages, generation_config, **kwargs)
+
+    def get_completion_stream(
+        self,
+        messages: list[dict],
+        generation_config: GenerationConfig,
+        **kwargs,
+    ) -> LLMChatCompletionChunk:
+        if not generation_config.stream:
+            raise ValueError(
+                "Stream must be set to True to use the `get_completion_stream` method."
+            )
+        return self._get_completion(messages, generation_config, **kwargs)
+
+    def _get_completion(
+        self,
+        messages: list[dict],
+        generation_config: GenerationConfig,
+        **kwargs,
+    ) -> Union[LLMChatCompletion, LLMChatCompletionChunk]:
+        """Get a completion from the OpenAI API based on the provided messages."""
+
+        # Create a dictionary with the default arguments
+        args = self._get_base_args(generation_config)
+
+        args["messages"] = messages
+
+        # Conditionally add the 'functions' argument if it's not None
+        if generation_config.functions is not None:
+            args["functions"] = generation_config.functions
+
+        args = {**args, **kwargs}
+        # Create the chat completion
+        return self.client.chat.completions.create(**args)
+
+    def _get_base_args(
+        self,
+        generation_config: GenerationConfig,
+    ) -> dict:
+        """Get the base arguments for the OpenAI API."""
+
+        args = {
+            "model": generation_config.model,
+            "temperature": generation_config.temperature,
+            "top_p": generation_config.top_p,
+            "stream": generation_config.stream,
+            # TODO - We need to cap this to avoid potential errors when exceed max allowable context
+            "max_tokens": generation_config.max_tokens_to_sample,
+        }
+
+        return args
+
+    async def aget_completion(
+        self,
+        messages: list[dict],
+        generation_config: GenerationConfig,
+        **kwargs,
+    ) -> LLMChatCompletion:
+        if generation_config.stream:
+            raise ValueError(
+                "Stream must be set to False to use the `aget_completion` method."
+            )
+        return await self._aget_completion(
+            messages, generation_config, **kwargs
+        )
+
+    async def _aget_completion(
+        self,
+        messages: list[dict],
+        generation_config: GenerationConfig,
+        **kwargs,
+    ) -> Union[LLMChatCompletion, LLMChatCompletionChunk]:
+        """Asynchronously get a completion from the OpenAI API based on the provided messages."""
+
+        # Create a dictionary with the default arguments
+        args = self._get_base_args(generation_config)
+
+        args["messages"] = messages
+
+        # Conditionally add the 'functions' argument if it's not None
+        if generation_config.functions is not None:
+            args["functions"] = generation_config.functions
+
+        args = {**args, **kwargs}
+        # Create the chat completion
+        return await self.client.chat.completions.create(**args)
diff --git a/R2R/r2r/providers/vector_dbs/__init__.py b/R2R/r2r/providers/vector_dbs/__init__.py
new file mode 100755
index 00000000..38ea0890
--- /dev/null
+++ b/R2R/r2r/providers/vector_dbs/__init__.py
@@ -0,0 +1,5 @@
+from .pgvector.pgvector_db import PGVectorDB
+
+__all__ = [
+    "PGVectorDB",
+]
diff --git a/R2R/r2r/providers/vector_dbs/pgvector/pgvector_db.py b/R2R/r2r/providers/vector_dbs/pgvector/pgvector_db.py
new file mode 100755
index 00000000..8cf728d1
--- /dev/null
+++ b/R2R/r2r/providers/vector_dbs/pgvector/pgvector_db.py
@@ -0,0 +1,610 @@
+import json
+import logging
+import os
+import time
+from typing import Literal, Optional, Union
+
+from sqlalchemy import exc, text
+from sqlalchemy.engine.url import make_url
+
+from r2r.base import (
+    DocumentInfo,
+    UserStats,
+    VectorDBConfig,
+    VectorDBProvider,
+    VectorEntry,
+    VectorSearchResult,
+)
+from r2r.vecs.client import Client
+from r2r.vecs.collection import Collection
+
+logger = logging.getLogger(__name__)
+
+
+class PGVectorDB(VectorDBProvider):
+    def __init__(self, config: VectorDBConfig) -> None:
+        super().__init__(config)
+        try:
+            import r2r.vecs
+        except ImportError:
+            raise ValueError(
+                f"Error, PGVectorDB requires the vecs library. Please run `pip install vecs`."
+            )
+
+        # Check if a complete Postgres URI is provided
+        postgres_uri = self.config.extra_fields.get(
+            "postgres_uri"
+        ) or os.getenv("POSTGRES_URI")
+
+        if postgres_uri:
+            # Log loudly that Postgres URI is being used
+            logger.warning("=" * 50)
+            logger.warning(
+                "ATTENTION: Using provided Postgres URI for connection"
+            )
+            logger.warning("=" * 50)
+
+            # Validate and use the provided URI
+            try:
+                parsed_uri = make_url(postgres_uri)
+                if not all([parsed_uri.username, parsed_uri.database]):
+                    raise ValueError(
+                        "The provided Postgres URI is missing required components."
+                    )
+                DB_CONNECTION = postgres_uri
+
+                # Log the sanitized URI (without password)
+                sanitized_uri = parsed_uri.set(password="*****")
+                logger.info(f"Connecting using URI: {sanitized_uri}")
+            except Exception as e:
+                raise ValueError(f"Invalid Postgres URI provided: {e}")
+        else:
+            # Fall back to existing logic for individual connection parameters
+            user = self.config.extra_fields.get("user", None) or os.getenv(
+                "POSTGRES_USER"
+            )
+            password = self.config.extra_fields.get(
+                "password", None
+            ) or os.getenv("POSTGRES_PASSWORD")
+            host = self.config.extra_fields.get("host", None) or os.getenv(
+                "POSTGRES_HOST"
+            )
+            port = self.config.extra_fields.get("port", None) or os.getenv(
+                "POSTGRES_PORT"
+            )
+            db_name = self.config.extra_fields.get(
+                "db_name", None
+            ) or os.getenv("POSTGRES_DBNAME")
+
+            if not all([user, password, host, db_name]):
+                raise ValueError(
+                    "Error, please set the POSTGRES_USER, POSTGRES_PASSWORD, POSTGRES_HOST, POSTGRES_DBNAME environment variables or provide them in the config."
+                )
+
+            # Check if it's a Unix socket connection
+            if host.startswith("/") and not port:
+                DB_CONNECTION = (
+                    f"postgresql://{user}:{password}@/{db_name}?host={host}"
+                )
+                logger.info("Using Unix socket connection")
+            else:
+                DB_CONNECTION = (
+                    f"postgresql://{user}:{password}@{host}:{port}/{db_name}"
+                )
+                logger.info("Using TCP connection")
+
+        # The rest of the initialization remains the same
+        try:
+            self.vx: Client = r2r.vecs.create_client(DB_CONNECTION)
+        except Exception as e:
+            raise ValueError(
+                f"Error {e} occurred while attempting to connect to the pgvector provider with {DB_CONNECTION}."
+            )
+
+        self.collection_name = self.config.extra_fields.get(
+            "vecs_collection"
+        ) or os.getenv("POSTGRES_VECS_COLLECTION")
+        if not self.collection_name:
+            raise ValueError(
+                "Error, please set a valid POSTGRES_VECS_COLLECTION environment variable or set a 'vecs_collection' in the 'vector_database' settings of your `config.json`."
+            )
+
+        self.collection: Optional[Collection] = None
+
+        logger.info(
+            f"Successfully initialized PGVectorDB with collection: {self.collection_name}"
+        )
+
+    def initialize_collection(self, dimension: int) -> None:
+        self.collection = self.vx.get_or_create_collection(
+            name=self.collection_name, dimension=dimension
+        )
+        self._create_document_info_table()
+        self._create_hybrid_search_function()
+
+    def _create_document_info_table(self):
+        with self.vx.Session() as sess:
+            with sess.begin():
+                try:
+                    # Enable uuid-ossp extension
+                    sess.execute(
+                        text('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";')
+                    )
+                except exc.ProgrammingError as e:
+                    logger.error(f"Error enabling uuid-ossp extension: {e}")
+                    raise
+
+                # Create the table if it doesn't exist
+                create_table_query = f"""
+                CREATE TABLE IF NOT EXISTS document_info_"{self.collection_name}" (
+                    document_id UUID PRIMARY KEY,
+                    title TEXT,
+                    user_id UUID NULL,
+                    version TEXT,
+                    size_in_bytes INT,
+                    created_at TIMESTAMPTZ DEFAULT NOW(),
+                    updated_at TIMESTAMPTZ DEFAULT NOW(),
+                    metadata JSONB,
+                    status TEXT
+                );
+                """
+                sess.execute(text(create_table_query))
+
+                # Add the new column if it doesn't exist
+                add_column_query = f"""
+                DO $$
+                BEGIN
+                    IF NOT EXISTS (
+                        SELECT 1
+                        FROM information_schema.columns
+                        WHERE table_name = 'document_info_"{self.collection_name}"'
+                        AND column_name = 'status'
+                    ) THEN
+                        ALTER TABLE "document_info_{self.collection_name}"
+                        ADD COLUMN status TEXT DEFAULT 'processing';
+                    END IF;
+                END $$;
+                """
+                sess.execute(text(add_column_query))
+
+                sess.commit()
+
+    def _create_hybrid_search_function(self):
+        hybrid_search_function = f"""
+        CREATE OR REPLACE FUNCTION hybrid_search_{self.collection_name}(
+            query_text TEXT,
+            query_embedding VECTOR(512),
+            match_limit INT,
+            full_text_weight FLOAT = 1,
+            semantic_weight FLOAT = 1,
+            rrf_k INT = 50,
+            filter_condition JSONB = NULL
+        )
+        RETURNS SETOF vecs."{self.collection_name}"
+        LANGUAGE sql
+        AS $$
+        WITH full_text AS (
+            SELECT
+                id,
+                ROW_NUMBER() OVER (ORDER BY ts_rank(to_tsvector('english', metadata->>'text'), websearch_to_tsquery(query_text)) DESC) AS rank_ix
+            FROM vecs."{self.collection_name}"
+            WHERE to_tsvector('english', metadata->>'text') @@ websearch_to_tsquery(query_text)
+            AND (filter_condition IS NULL OR (metadata @> filter_condition))
+            ORDER BY rank_ix
+            LIMIT LEAST(match_limit, 30) * 2
+        ),
+        semantic AS (
+            SELECT
+                id,
+                ROW_NUMBER() OVER (ORDER BY vec <#> query_embedding) AS rank_ix
+            FROM vecs."{self.collection_name}"
+            WHERE filter_condition IS NULL OR (metadata @> filter_condition)
+            ORDER BY rank_ix
+            LIMIT LEAST(match_limit, 30) * 2
+        )
+        SELECT
+            vecs."{self.collection_name}".*
+        FROM
+            full_text
+            FULL OUTER JOIN semantic
+                ON full_text.id = semantic.id
+            JOIN vecs."{self.collection_name}"
+                ON vecs."{self.collection_name}".id = COALESCE(full_text.id, semantic.id)
+        ORDER BY
+            COALESCE(1.0 / (rrf_k + full_text.rank_ix), 0.0) * full_text_weight +
+            COALESCE(1.0 / (rrf_k + semantic.rank_ix), 0.0) * semantic_weight
+            DESC
+        LIMIT
+            LEAST(match_limit, 30);
+        $$;
+        """
+        retry_attempts = 5
+        for attempt in range(retry_attempts):
+            try:
+                with self.vx.Session() as sess:
+                    # Acquire an advisory lock
+                    sess.execute(text("SELECT pg_advisory_lock(123456789)"))
+                    try:
+                        sess.execute(text(hybrid_search_function))
+                        sess.commit()
+                    finally:
+                        # Release the advisory lock
+                        sess.execute(
+                            text("SELECT pg_advisory_unlock(123456789)")
+                        )
+                break  # Break the loop if successful
+            except exc.InternalError as e:
+                if "tuple concurrently updated" in str(e):
+                    time.sleep(2**attempt)  # Exponential backoff
+                else:
+                    raise  # Re-raise the exception if it's not a concurrency issue
+        else:
+            raise RuntimeError(
+                "Failed to create hybrid search function after multiple attempts"
+            )
+
+    def copy(self, entry: VectorEntry, commit=True) -> None:
+        if self.collection is None:
+            raise ValueError(
+                "Please call `initialize_collection` before attempting to run `copy`."
+            )
+
+        serializeable_entry = entry.to_serializable()
+
+        self.collection.copy(
+            records=[
+                (
+                    serializeable_entry["id"],
+                    serializeable_entry["vector"],
+                    serializeable_entry["metadata"],
+                )
+            ]
+        )
+
+    def copy_entries(
+        self, entries: list[VectorEntry], commit: bool = True
+    ) -> None:
+        if self.collection is None:
+            raise ValueError(
+                "Please call `initialize_collection` before attempting to run `copy_entries`."
+            )
+
+        self.collection.copy(
+            records=[
+                (
+                    str(entry.id),
+                    entry.vector.data,
+                    entry.to_serializable()["metadata"],
+                )
+                for entry in entries
+            ]
+        )
+
+    def upsert(self, entry: VectorEntry, commit=True) -> None:
+        if self.collection is None:
+            raise ValueError(
+                "Please call `initialize_collection` before attempting to run `upsert`."
+            )
+
+        self.collection.upsert(
+            records=[
+                (
+                    str(entry.id),
+                    entry.vector.data,
+                    entry.to_serializable()["metadata"],
+                )
+            ]
+        )
+
+    def upsert_entries(
+        self, entries: list[VectorEntry], commit: bool = True
+    ) -> None:
+        if self.collection is None:
+            raise ValueError(
+                "Please call `initialize_collection` before attempting to run `upsert_entries`."
+            )
+
+        self.collection.upsert(
+            records=[
+                (
+                    str(entry.id),
+                    entry.vector.data,
+                    entry.to_serializable()["metadata"],
+                )
+                for entry in entries
+            ]
+        )
+
+    def search(
+        self,
+        query_vector: list[float],
+        filters: dict[str, Union[bool, int, str]] = {},
+        limit: int = 10,
+        *args,
+        **kwargs,
+    ) -> list[VectorSearchResult]:
+        if self.collection is None:
+            raise ValueError(
+                "Please call `initialize_collection` before attempting to run `search`."
+            )
+        measure = kwargs.get("measure", "cosine_distance")
+        mapped_filters = {
+            key: {"$eq": value} for key, value in filters.items()
+        }
+
+        return [
+            VectorSearchResult(id=ele[0], score=float(1 - ele[1]), metadata=ele[2])  # type: ignore
+            for ele in self.collection.query(
+                data=query_vector,
+                limit=limit,
+                filters=mapped_filters,
+                measure=measure,
+                include_value=True,
+                include_metadata=True,
+            )
+        ]
+
+    def hybrid_search(
+        self,
+        query_text: str,
+        query_vector: list[float],
+        limit: int = 10,
+        filters: Optional[dict[str, Union[bool, int, str]]] = None,
+        # Hybrid search parameters
+        full_text_weight: float = 1.0,
+        semantic_weight: float = 1.0,
+        rrf_k: int = 20,  # typical value is ~2x the number of results you want
+        *args,
+        **kwargs,
+    ) -> list[VectorSearchResult]:
+        if self.collection is None:
+            raise ValueError(
+                "Please call `initialize_collection` before attempting to run `hybrid_search`."
+            )
+
+        # Convert filters to a JSON-compatible format
+        filter_condition = None
+        if filters:
+            filter_condition = json.dumps(filters)
+
+        query = text(
+            f"""
+            SELECT * FROM hybrid_search_{self.collection_name}(
+                cast(:query_text as TEXT), cast(:query_embedding as VECTOR), cast(:match_limit as INT),
+                cast(:full_text_weight as FLOAT), cast(:semantic_weight as FLOAT), cast(:rrf_k as INT),
+                cast(:filter_condition as JSONB)
+            )
+        """
+        )
+
+        params = {
+            "query_text": str(query_text),
+            "query_embedding": list(query_vector),
+            "match_limit": limit,
+            "full_text_weight": full_text_weight,
+            "semantic_weight": semantic_weight,
+            "rrf_k": rrf_k,
+            "filter_condition": filter_condition,
+        }
+
+        with self.vx.Session() as session:
+            result = session.execute(query, params).fetchall()
+        return [
+            VectorSearchResult(id=row[0], score=1.0, metadata=row[-1])
+            for row in result
+        ]
+
+    def create_index(self, index_type, column_name, index_options):
+        pass
+
+    def delete_by_metadata(
+        self,
+        metadata_fields: list[str],
+        metadata_values: list[Union[bool, int, str]],
+        logic: Literal["AND", "OR"] = "AND",
+    ) -> list[str]:
+        if logic == "OR":
+            raise ValueError(
+                "OR logic is still being tested before official support for `delete_by_metadata` in pgvector."
+            )
+        if self.collection is None:
+            raise ValueError(
+                "Please call `initialize_collection` before attempting to run `delete_by_metadata`."
+            )
+
+        if len(metadata_fields) != len(metadata_values):
+            raise ValueError(
+                "The number of metadata fields must match the number of metadata values."
+            )
+
+        # Construct the filter
+        if logic == "AND":
+            filters = {
+                k: {"$eq": v} for k, v in zip(metadata_fields, metadata_values)
+            }
+        else:  # OR logic
+            # TODO - Test 'or' logic and remove check above
+            filters = {
+                "$or": [
+                    {k: {"$eq": v}}
+                    for k, v in zip(metadata_fields, metadata_values)
+                ]
+            }
+        return self.collection.delete(filters=filters)
+
+    def get_metadatas(
+        self,
+        metadata_fields: list[str],
+        filter_field: Optional[str] = None,
+        filter_value: Optional[Union[bool, int, str]] = None,
+    ) -> list[dict]:
+        if self.collection is None:
+            raise ValueError(
+                "Please call `initialize_collection` before attempting to run `get_metadatas`."
+            )
+
+        results = {tuple(metadata_fields): {}}
+        for field in metadata_fields:
+            unique_values = self.collection.get_unique_metadata_values(
+                field=field,
+                filter_field=filter_field,
+                filter_value=filter_value,
+            )
+            for value in unique_values:
+                if value not in results:
+                    results[value] = {}
+                results[value][field] = value
+
+        return [
+            results[key] for key in results if key != tuple(metadata_fields)
+        ]
+
+    def upsert_documents_overview(
+        self, documents_overview: list[DocumentInfo]
+    ) -> None:
+        for document_info in documents_overview:
+            db_entry = document_info.convert_to_db_entry()
+
+            # Convert 'None' string to None type for user_id
+            if db_entry["user_id"] == "None":
+                db_entry["user_id"] = None
+
+            query = text(
+                f"""
+                INSERT INTO "document_info_{self.collection_name}" (document_id, title, user_id, version, created_at, updated_at, size_in_bytes, metadata, status)
+                VALUES (:document_id, :title, :user_id, :version, :created_at, :updated_at, :size_in_bytes, :metadata, :status)
+                ON CONFLICT (document_id) DO UPDATE SET
+                    title = EXCLUDED.title,
+                    user_id = EXCLUDED.user_id,
+                    version = EXCLUDED.version,
+                    updated_at = EXCLUDED.updated_at,
+                    size_in_bytes = EXCLUDED.size_in_bytes,
+                    metadata = EXCLUDED.metadata,
+                    status = EXCLUDED.status;
+            """
+            )
+            with self.vx.Session() as sess:
+                sess.execute(query, db_entry)
+                sess.commit()
+
+    def delete_from_documents_overview(
+        self, document_id: str, version: Optional[str] = None
+    ) -> None:
+        query = f"""
+            DELETE FROM "document_info_{self.collection_name}"
+            WHERE document_id = :document_id
+        """
+        params = {"document_id": document_id}
+
+        if version is not None:
+            query += " AND version = :version"
+            params["version"] = version
+
+        with self.vx.Session() as sess:
+            with sess.begin():
+                sess.execute(text(query), params)
+            sess.commit()
+
+    def get_documents_overview(
+        self,
+        filter_document_ids: Optional[list[str]] = None,
+        filter_user_ids: Optional[list[str]] = None,
+    ):
+        conditions = []
+        params = {}
+
+        if filter_document_ids:
+            placeholders = ", ".join(
+                f":doc_id_{i}" for i in range(len(filter_document_ids))
+            )
+            conditions.append(f"document_id IN ({placeholders})")
+            params.update(
+                {
+                    f"doc_id_{i}": str(document_id)
+                    for i, document_id in enumerate(filter_document_ids)
+                }
+            )
+        if filter_user_ids:
+            placeholders = ", ".join(
+                f":user_id_{i}" for i in range(len(filter_user_ids))
+            )
+            conditions.append(f"user_id IN ({placeholders})")
+            params.update(
+                {
+                    f"user_id_{i}": str(user_id)
+                    for i, user_id in enumerate(filter_user_ids)
+                }
+            )
+
+        query = f"""
+            SELECT document_id, title, user_id, version, size_in_bytes, created_at, updated_at, metadata, status
+            FROM "document_info_{self.collection_name}"
+        """
+        if conditions:
+            query += " WHERE " + " AND ".join(conditions)
+
+        with self.vx.Session() as sess:
+            results = sess.execute(text(query), params).fetchall()
+            return [
+                DocumentInfo(
+                    document_id=row[0],
+                    title=row[1],
+                    user_id=row[2],
+                    version=row[3],
+                    size_in_bytes=row[4],
+                    created_at=row[5],
+                    updated_at=row[6],
+                    metadata=row[7],
+                    status=row[8],
+                )
+                for row in results
+            ]
+
+    def get_document_chunks(self, document_id: str) -> list[dict]:
+        if not self.collection:
+            raise ValueError("Collection is not initialized.")
+
+        table_name = self.collection.table.name
+        query = text(
+            f"""
+            SELECT metadata
+            FROM vecs."{table_name}"
+            WHERE metadata->>'document_id' = :document_id
+            ORDER BY CAST(metadata->>'chunk_order' AS INTEGER)
+        """
+        )
+
+        params = {"document_id": document_id}
+
+        with self.vx.Session() as sess:
+            results = sess.execute(query, params).fetchall()
+            return [result[0] for result in results]
+
+    def get_users_overview(self, user_ids: Optional[list[str]] = None):
+        user_ids_condition = ""
+        params = {}
+        if user_ids:
+            user_ids_condition = "WHERE user_id IN :user_ids"
+            params["user_ids"] = tuple(
+                map(str, user_ids)
+            )  # Convert UUIDs to strings
+
+        query = f"""
+            SELECT user_id, COUNT(document_id) AS num_files, SUM(size_in_bytes) AS total_size_in_bytes, ARRAY_AGG(document_id) AS document_ids
+            FROM "document_info_{self.collection_name}"
+            {user_ids_condition}
+            GROUP BY user_id
+        """
+
+        with self.vx.Session() as sess:
+            results = sess.execute(text(query), params).fetchall()
+        return [
+            UserStats(
+                user_id=row[0],
+                num_files=row[1],
+                total_size_in_bytes=row[2],
+                document_ids=row[3],
+            )
+            for row in results
+            if row[0] is not None
+        ]