aboutsummaryrefslogtreecommitdiff
path: root/R2R/r2r/providers/kg/neo4j
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /R2R/r2r/providers/kg/neo4j
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to 'R2R/r2r/providers/kg/neo4j')
-rwxr-xr-xR2R/r2r/providers/kg/neo4j/base_neo4j.py983
1 files changed, 983 insertions, 0 deletions
diff --git a/R2R/r2r/providers/kg/neo4j/base_neo4j.py b/R2R/r2r/providers/kg/neo4j/base_neo4j.py
new file mode 100755
index 00000000..9ede2b85
--- /dev/null
+++ b/R2R/r2r/providers/kg/neo4j/base_neo4j.py
@@ -0,0 +1,983 @@
+# abstractions are taken from LlamaIndex
+# Neo4jKGProvider is almost entirely taken from LlamaIndex Neo4jPropertyGraphStore
+# https://github.com/run-llama/llama_index
+import json
+import os
+from typing import Any, Dict, List, Optional, Tuple
+
+from r2r.base import (
+ EntityType,
+ KGConfig,
+ KGProvider,
+ PromptProvider,
+ format_entity_types,
+ format_relations,
+)
+from r2r.base.abstractions.llama_abstractions import (
+ LIST_LIMIT,
+ ChunkNode,
+ EntityNode,
+ LabelledNode,
+ PropertyGraphStore,
+ Relation,
+ Triplet,
+ VectorStoreQuery,
+ clean_string_values,
+ value_sanitize,
+)
+
+
+def remove_empty_values(input_dict):
+ """
+ Remove entries with empty values from the dictionary.
+
+ Parameters:
+ input_dict (dict): The dictionary from which empty values need to be removed.
+
+ Returns:
+ dict: A new dictionary with all empty values removed.
+ """
+ # Create a new dictionary excluding empty values
+ return {key: value for key, value in input_dict.items() if value}
+
+
+BASE_ENTITY_LABEL = "__Entity__"
+EXCLUDED_LABELS = ["_Bloom_Perspective_", "_Bloom_Scene_"]
+EXCLUDED_RELS = ["_Bloom_HAS_SCENE_"]
+EXHAUSTIVE_SEARCH_LIMIT = 10000
+# Threshold for returning all available prop values in graph schema
+DISTINCT_VALUE_LIMIT = 10
+
+node_properties_query = """
+CALL apoc.meta.data()
+YIELD label, other, elementType, type, property
+WHERE NOT type = "RELATIONSHIP" AND elementType = "node"
+ AND NOT label IN $EXCLUDED_LABELS
+WITH label AS nodeLabels, collect({property:property, type:type}) AS properties
+RETURN {labels: nodeLabels, properties: properties} AS output
+
+"""
+
+rel_properties_query = """
+CALL apoc.meta.data()
+YIELD label, other, elementType, type, property
+WHERE NOT type = "RELATIONSHIP" AND elementType = "relationship"
+ AND NOT label in $EXCLUDED_LABELS
+WITH label AS nodeLabels, collect({property:property, type:type}) AS properties
+RETURN {type: nodeLabels, properties: properties} AS output
+"""
+
+rel_query = """
+CALL apoc.meta.data()
+YIELD label, other, elementType, type, property
+WHERE type = "RELATIONSHIP" AND elementType = "node"
+UNWIND other AS other_node
+WITH * WHERE NOT label IN $EXCLUDED_LABELS
+ AND NOT other_node IN $EXCLUDED_LABELS
+RETURN {start: label, type: property, end: toString(other_node)} AS output
+"""
+
+
+class Neo4jKGProvider(PropertyGraphStore, KGProvider):
+ r"""
+ Neo4j Property Graph Store.
+
+ This class implements a Neo4j property graph store.
+
+ If you are using local Neo4j instead of aura, here's a helpful
+ command for launching the docker container:
+
+ ```bash
+ docker run \
+ -p 7474:7474 -p 7687:7687 \
+ -v $PWD/data:/data -v $PWD/plugins:/plugins \
+ --name neo4j-apoc \
+ -e NEO4J_apoc_export_file_enabled=true \
+ -e NEO4J_apoc_import_file_enabled=true \
+ -e NEO4J_apoc_import_file_use__neo4j__config=true \
+ -e NEO4JLABS_PLUGINS=\\[\"apoc\"\\] \
+ neo4j:latest
+ ```
+
+ Args:
+ username (str): The username for the Neo4j database.
+ password (str): The password for the Neo4j database.
+ url (str): The URL for the Neo4j database.
+ database (Optional[str]): The name of the database to connect to. Defaults to "neo4j".
+
+ Examples:
+ `pip install llama-index-graph-stores-neo4j`
+
+ ```python
+ from llama_index.core.indices.property_graph import PropertyGraphIndex
+ from llama_index.graph_stores.neo4j import Neo4jKGProvider
+
+ # Create a Neo4jKGProvider instance
+ graph_store = Neo4jKGProvider(
+ username="neo4j",
+ password="neo4j",
+ url="bolt://localhost:7687",
+ database="neo4j"
+ )
+
+ # create the index
+ index = PropertyGraphIndex.from_documents(
+ documents,
+ property_graph_store=graph_store,
+ )
+ ```
+ """
+
+ supports_structured_queries: bool = True
+ supports_vector_queries: bool = True
+
+ def __init__(
+ self,
+ config: KGConfig,
+ refresh_schema: bool = True,
+ sanitize_query_output: bool = True,
+ enhanced_schema: bool = False,
+ *args: Any,
+ **kwargs: Any,
+ ) -> None:
+ if config.provider != "neo4j":
+ raise ValueError(
+ "Neo4jKGProvider must be initialized with config with `neo4j` provider."
+ )
+
+ try:
+ import neo4j
+ except ImportError:
+ raise ImportError("Please install neo4j: pip install neo4j")
+
+ username = os.getenv("NEO4J_USER")
+ password = os.getenv("NEO4J_PASSWORD")
+ url = os.getenv("NEO4J_URL")
+ database = os.getenv("NEO4J_DATABASE", "neo4j")
+
+ if not username or not password or not url:
+ raise ValueError(
+ "Neo4j configuration values are missing. Please set NEO4J_USER, NEO4J_PASSWORD, and NEO4J_URL environment variables."
+ )
+
+ self.sanitize_query_output = sanitize_query_output
+ self.enhcnaced_schema = enhanced_schema
+ self._driver = neo4j.GraphDatabase.driver(
+ url, auth=(username, password), **kwargs
+ )
+ self._async_driver = neo4j.AsyncGraphDatabase.driver(
+ url,
+ auth=(username, password),
+ **kwargs,
+ )
+ self._database = database
+ self.structured_schema = {}
+ if refresh_schema:
+ self.refresh_schema()
+ self.neo4j = neo4j
+ self.config = config
+
+ @property
+ def client(self):
+ return self._driver
+
+ def refresh_schema(self) -> None:
+ """Refresh the schema."""
+ node_query_results = self.structured_query(
+ node_properties_query,
+ param_map={
+ "EXCLUDED_LABELS": [*EXCLUDED_LABELS, BASE_ENTITY_LABEL]
+ },
+ )
+ node_properties = (
+ [el["output"] for el in node_query_results]
+ if node_query_results
+ else []
+ )
+
+ rels_query_result = self.structured_query(
+ rel_properties_query, param_map={"EXCLUDED_LABELS": EXCLUDED_RELS}
+ )
+ rel_properties = (
+ [el["output"] for el in rels_query_result]
+ if rels_query_result
+ else []
+ )
+
+ rel_objs_query_result = self.structured_query(
+ rel_query,
+ param_map={
+ "EXCLUDED_LABELS": [*EXCLUDED_LABELS, BASE_ENTITY_LABEL]
+ },
+ )
+ relationships = (
+ [el["output"] for el in rel_objs_query_result]
+ if rel_objs_query_result
+ else []
+ )
+
+ # Get constraints & indexes
+ try:
+ constraint = self.structured_query("SHOW CONSTRAINTS")
+ index = self.structured_query(
+ "CALL apoc.schema.nodes() YIELD label, properties, type, size, "
+ "valuesSelectivity WHERE type = 'RANGE' RETURN *, "
+ "size * valuesSelectivity as distinctValues"
+ )
+ except (
+ self.neo4j.exceptions.ClientError
+ ): # Read-only user might not have access to schema information
+ constraint = []
+ index = []
+
+ self.structured_schema = {
+ "node_props": {
+ el["labels"]: el["properties"] for el in node_properties
+ },
+ "rel_props": {
+ el["type"]: el["properties"] for el in rel_properties
+ },
+ "relationships": relationships,
+ "metadata": {"constraint": constraint, "index": index},
+ }
+ schema_counts = self.structured_query(
+ "CALL apoc.meta.graphSample() YIELD nodes, relationships "
+ "RETURN nodes, [rel in relationships | {name:apoc.any.property"
+ "(rel, 'type'), count: apoc.any.property(rel, 'count')}]"
+ " AS relationships"
+ )
+ # Update node info
+ for node in schema_counts[0].get("nodes", []):
+ # Skip bloom labels
+ if node["name"] in EXCLUDED_LABELS:
+ continue
+ node_props = self.structured_schema["node_props"].get(node["name"])
+ if not node_props: # The node has no properties
+ continue
+ enhanced_cypher = self._enhanced_schema_cypher(
+ node["name"],
+ node_props,
+ node["count"] < EXHAUSTIVE_SEARCH_LIMIT,
+ )
+ enhanced_info = self.structured_query(enhanced_cypher)[0]["output"]
+ for prop in node_props:
+ if prop["property"] in enhanced_info:
+ prop.update(enhanced_info[prop["property"]])
+ # Update rel info
+ for rel in schema_counts[0].get("relationships", []):
+ # Skip bloom labels
+ if rel["name"] in EXCLUDED_RELS:
+ continue
+ rel_props = self.structured_schema["rel_props"].get(rel["name"])
+ if not rel_props: # The rel has no properties
+ continue
+ enhanced_cypher = self._enhanced_schema_cypher(
+ rel["name"],
+ rel_props,
+ rel["count"] < EXHAUSTIVE_SEARCH_LIMIT,
+ is_relationship=True,
+ )
+ try:
+ enhanced_info = self.structured_query(enhanced_cypher)[0][
+ "output"
+ ]
+ for prop in rel_props:
+ if prop["property"] in enhanced_info:
+ prop.update(enhanced_info[prop["property"]])
+ except self.neo4j.exceptions.ClientError:
+ # Sometimes the types are not consistent in the db
+ pass
+
+ def upsert_nodes(self, nodes: List[LabelledNode]) -> None:
+ # Lists to hold separated types
+ entity_dicts: List[dict] = []
+ chunk_dicts: List[dict] = []
+
+ # Sort by type
+ for item in nodes:
+ if isinstance(item, EntityNode):
+ entity_dicts.append({**item.dict(), "id": item.id})
+ elif isinstance(item, ChunkNode):
+ chunk_dicts.append({**item.dict(), "id": item.id})
+ else:
+ # Log that we do not support these types of nodes
+ # Or raise an error?
+ pass
+
+ if chunk_dicts:
+ self.structured_query(
+ """
+ UNWIND $data AS row
+ MERGE (c:Chunk {id: row.id})
+ SET c.text = row.text
+ WITH c, row
+ SET c += row.properties
+ WITH c, row.embedding AS embedding
+ WHERE embedding IS NOT NULL
+ CALL db.create.setNodeVectorProperty(c, 'embedding', embedding)
+ RETURN count(*)
+ """,
+ param_map={"data": chunk_dicts},
+ )
+
+ if entity_dicts:
+ self.structured_query(
+ """
+ UNWIND $data AS row
+ MERGE (e:`__Entity__` {id: row.id})
+ SET e += apoc.map.clean(row.properties, [], [])
+ SET e.name = row.name
+ WITH e, row
+ CALL apoc.create.addLabels(e, [row.label])
+ YIELD node
+ WITH e, row
+ CALL {
+ WITH e, row
+ WITH e, row
+ WHERE row.embedding IS NOT NULL
+ CALL db.create.setNodeVectorProperty(e, 'embedding', row.embedding)
+ RETURN count(*) AS count
+ }
+ WITH e, row WHERE row.properties.triplet_source_id IS NOT NULL
+ MERGE (c:Chunk {id: row.properties.triplet_source_id})
+ MERGE (e)<-[:MENTIONS]-(c)
+ """,
+ param_map={"data": entity_dicts},
+ )
+
+ def upsert_relations(self, relations: List[Relation]) -> None:
+ """Add relations."""
+ params = [r.dict() for r in relations]
+
+ self.structured_query(
+ """
+ UNWIND $data AS row
+ MERGE (source {id: row.source_id})
+ MERGE (target {id: row.target_id})
+ WITH source, target, row
+ CALL apoc.merge.relationship(source, row.label, {}, row.properties, target) YIELD rel
+ RETURN count(*)
+ """,
+ param_map={"data": params},
+ )
+
+ def get(
+ self,
+ properties: Optional[dict] = None,
+ ids: Optional[List[str]] = None,
+ ) -> List[LabelledNode]:
+ """Get nodes."""
+ cypher_statement = "MATCH (e) "
+
+ params = {}
+ if properties or ids:
+ cypher_statement += "WHERE "
+
+ if ids:
+ cypher_statement += "e.id in $ids "
+ params["ids"] = ids
+
+ if properties:
+ prop_list = []
+ for i, prop in enumerate(properties):
+ prop_list.append(f"e.`{prop}` = $property_{i}")
+ params[f"property_{i}"] = properties[prop]
+ cypher_statement += " AND ".join(prop_list)
+
+ return_statement = """
+ WITH e
+ RETURN e.id AS name,
+ [l in labels(e) WHERE l <> '__Entity__' | l][0] AS type,
+ e{.* , embedding: Null, id: Null} AS properties
+ """
+ cypher_statement += return_statement
+
+ response = self.structured_query(cypher_statement, param_map=params)
+ response = response if response else []
+
+ nodes = []
+ for record in response:
+ # text indicates a chunk node
+ # none on the type indicates an implicit node, likely a chunk node
+ if "text" in record["properties"] or record["type"] is None:
+ text = record["properties"].pop("text", "")
+ nodes.append(
+ ChunkNode(
+ id_=record["name"],
+ text=text,
+ properties=remove_empty_values(record["properties"]),
+ )
+ )
+ else:
+ nodes.append(
+ EntityNode(
+ name=record["name"],
+ label=record["type"],
+ properties=remove_empty_values(record["properties"]),
+ )
+ )
+
+ return nodes
+
+ def get_triplets(
+ self,
+ entity_names: Optional[List[str]] = None,
+ relation_names: Optional[List[str]] = None,
+ properties: Optional[dict] = None,
+ ids: Optional[List[str]] = None,
+ ) -> List[Triplet]:
+ # TODO: handle ids of chunk nodes
+ cypher_statement = "MATCH (e:`__Entity__`) "
+
+ params = {}
+ if entity_names or properties or ids:
+ cypher_statement += "WHERE "
+
+ if entity_names:
+ cypher_statement += "e.name in $entity_names "
+ params["entity_names"] = entity_names
+
+ if ids:
+ cypher_statement += "e.id in $ids "
+ params["ids"] = ids
+
+ if properties:
+ prop_list = []
+ for i, prop in enumerate(properties):
+ prop_list.append(f"e.`{prop}` = $property_{i}")
+ params[f"property_{i}"] = properties[prop]
+ cypher_statement += " AND ".join(prop_list)
+
+ return_statement = f"""
+ WITH e
+ CALL {{
+ WITH e
+ MATCH (e)-[r{':`' + '`|`'.join(relation_names) + '`' if relation_names else ''}]->(t)
+ RETURN e.name AS source_id, [l in labels(e) WHERE l <> '__Entity__' | l][0] AS source_type,
+ e{{.* , embedding: Null, name: Null}} AS source_properties,
+ type(r) AS type,
+ t.name AS target_id, [l in labels(t) WHERE l <> '__Entity__' | l][0] AS target_type,
+ t{{.* , embedding: Null, name: Null}} AS target_properties
+ UNION ALL
+ WITH e
+ MATCH (e)<-[r{':`' + '`|`'.join(relation_names) + '`' if relation_names else ''}]-(t)
+ RETURN t.name AS source_id, [l in labels(t) WHERE l <> '__Entity__' | l][0] AS source_type,
+ e{{.* , embedding: Null, name: Null}} AS source_properties,
+ type(r) AS type,
+ e.name AS target_id, [l in labels(e) WHERE l <> '__Entity__' | l][0] AS target_type,
+ t{{.* , embedding: Null, name: Null}} AS target_properties
+ }}
+ RETURN source_id, source_type, type, target_id, target_type, source_properties, target_properties"""
+ cypher_statement += return_statement
+
+ data = self.structured_query(cypher_statement, param_map=params)
+ data = data if data else []
+
+ triples = []
+ for record in data:
+ source = EntityNode(
+ name=record["source_id"],
+ label=record["source_type"],
+ properties=remove_empty_values(record["source_properties"]),
+ )
+ target = EntityNode(
+ name=record["target_id"],
+ label=record["target_type"],
+ properties=remove_empty_values(record["target_properties"]),
+ )
+ rel = Relation(
+ source_id=record["source_id"],
+ target_id=record["target_id"],
+ label=record["type"],
+ )
+ triples.append([source, rel, target])
+ return triples
+
+ def get_rel_map(
+ self,
+ graph_nodes: List[LabelledNode],
+ depth: int = 2,
+ limit: int = 30,
+ ignore_rels: Optional[List[str]] = None,
+ ) -> List[Triplet]:
+ """Get depth-aware rel map."""
+ triples = []
+
+ ids = [node.id for node in graph_nodes]
+ # Needs some optimization
+ response = self.structured_query(
+ f"""
+ MATCH (e:`__Entity__`)
+ WHERE e.id in $ids
+ MATCH p=(e)-[r*1..{depth}]-(other)
+ WHERE ALL(rel in relationships(p) WHERE type(rel) <> 'MENTIONS')
+ UNWIND relationships(p) AS rel
+ WITH distinct rel
+ WITH startNode(rel) AS source,
+ type(rel) AS type,
+ endNode(rel) AS endNode
+ RETURN source.id AS source_id, [l in labels(source) WHERE l <> '__Entity__' | l][0] AS source_type,
+ source{{.* , embedding: Null, id: Null}} AS source_properties,
+ type,
+ endNode.id AS target_id, [l in labels(endNode) WHERE l <> '__Entity__' | l][0] AS target_type,
+ endNode{{.* , embedding: Null, id: Null}} AS target_properties
+ LIMIT toInteger($limit)
+ """,
+ param_map={"ids": ids, "limit": limit},
+ )
+ response = response if response else []
+
+ ignore_rels = ignore_rels or []
+ for record in response:
+ if record["type"] in ignore_rels:
+ continue
+
+ source = EntityNode(
+ name=record["source_id"],
+ label=record["source_type"],
+ properties=remove_empty_values(record["source_properties"]),
+ )
+ target = EntityNode(
+ name=record["target_id"],
+ label=record["target_type"],
+ properties=remove_empty_values(record["target_properties"]),
+ )
+ rel = Relation(
+ source_id=record["source_id"],
+ target_id=record["target_id"],
+ label=record["type"],
+ )
+ triples.append([source, rel, target])
+
+ return triples
+
+ def structured_query(
+ self, query: str, param_map: Optional[Dict[str, Any]] = None
+ ) -> Any:
+ param_map = param_map or {}
+
+ with self._driver.session(database=self._database) as session:
+ result = session.run(query, param_map)
+ full_result = [d.data() for d in result]
+
+ if self.sanitize_query_output:
+ return value_sanitize(full_result)
+
+ return full_result
+
+ def vector_query(
+ self, query: VectorStoreQuery, **kwargs: Any
+ ) -> Tuple[List[LabelledNode], List[float]]:
+ """Query the graph store with a vector store query."""
+ data = self.structured_query(
+ """MATCH (e:`__Entity__`)
+ WHERE e.embedding IS NOT NULL AND size(e.embedding) = $dimension
+ WITH e, vector.similarity.cosine(e.embedding, $embedding) AS score
+ ORDER BY score DESC LIMIT toInteger($limit)
+ RETURN e.id AS name,
+ [l in labels(e) WHERE l <> '__Entity__' | l][0] AS type,
+ e{.* , embedding: Null, name: Null, id: Null} AS properties,
+ score""",
+ param_map={
+ "embedding": query.query_embedding,
+ "dimension": len(query.query_embedding),
+ "limit": query.similarity_top_k,
+ },
+ )
+ data = data if data else []
+
+ nodes = []
+ scores = []
+ for record in data:
+ node = EntityNode(
+ name=record["name"],
+ label=record["type"],
+ properties=remove_empty_values(record["properties"]),
+ )
+ nodes.append(node)
+ scores.append(record["score"])
+
+ return (nodes, scores)
+
+ def delete(
+ self,
+ entity_names: Optional[List[str]] = None,
+ relation_names: Optional[List[str]] = None,
+ properties: Optional[dict] = None,
+ ids: Optional[List[str]] = None,
+ ) -> None:
+ """Delete matching data."""
+ if entity_names:
+ self.structured_query(
+ "MATCH (n) WHERE n.name IN $entity_names DETACH DELETE n",
+ param_map={"entity_names": entity_names},
+ )
+
+ if ids:
+ self.structured_query(
+ "MATCH (n) WHERE n.id IN $ids DETACH DELETE n",
+ param_map={"ids": ids},
+ )
+
+ if relation_names:
+ for rel in relation_names:
+ self.structured_query(f"MATCH ()-[r:`{rel}`]->() DELETE r")
+
+ if properties:
+ cypher = "MATCH (e) WHERE "
+ prop_list = []
+ params = {}
+ for i, prop in enumerate(properties):
+ prop_list.append(f"e.`{prop}` = $property_{i}")
+ params[f"property_{i}"] = properties[prop]
+ cypher += " AND ".join(prop_list)
+ self.structured_query(
+ cypher + " DETACH DELETE e", param_map=params
+ )
+
+ def _enhanced_schema_cypher(
+ self,
+ label_or_type: str,
+ properties: List[Dict[str, Any]],
+ exhaustive: bool,
+ is_relationship: bool = False,
+ ) -> str:
+ if is_relationship:
+ match_clause = f"MATCH ()-[n:`{label_or_type}`]->()"
+ else:
+ match_clause = f"MATCH (n:`{label_or_type}`)"
+
+ with_clauses = []
+ return_clauses = []
+ output_dict = {}
+ if exhaustive:
+ for prop in properties:
+ prop_name = prop["property"]
+ prop_type = prop["type"]
+ if prop_type == "STRING":
+ with_clauses.append(
+ f"collect(distinct substring(toString(n.`{prop_name}`), 0, 50)) "
+ f"AS `{prop_name}_values`"
+ )
+ return_clauses.append(
+ f"values:`{prop_name}_values`[..{DISTINCT_VALUE_LIMIT}],"
+ f" distinct_count: size(`{prop_name}_values`)"
+ )
+ elif prop_type in [
+ "INTEGER",
+ "FLOAT",
+ "DATE",
+ "DATE_TIME",
+ "LOCAL_DATE_TIME",
+ ]:
+ with_clauses.append(
+ f"min(n.`{prop_name}`) AS `{prop_name}_min`"
+ )
+ with_clauses.append(
+ f"max(n.`{prop_name}`) AS `{prop_name}_max`"
+ )
+ with_clauses.append(
+ f"count(distinct n.`{prop_name}`) AS `{prop_name}_distinct`"
+ )
+ return_clauses.append(
+ f"min: toString(`{prop_name}_min`), "
+ f"max: toString(`{prop_name}_max`), "
+ f"distinct_count: `{prop_name}_distinct`"
+ )
+ elif prop_type == "LIST":
+ with_clauses.append(
+ f"min(size(n.`{prop_name}`)) AS `{prop_name}_size_min`, "
+ f"max(size(n.`{prop_name}`)) AS `{prop_name}_size_max`"
+ )
+ return_clauses.append(
+ f"min_size: `{prop_name}_size_min`, "
+ f"max_size: `{prop_name}_size_max`"
+ )
+ elif prop_type in ["BOOLEAN", "POINT", "DURATION"]:
+ continue
+ output_dict[prop_name] = "{" + return_clauses.pop() + "}"
+ else:
+ # Just sample 5 random nodes
+ match_clause += " WITH n LIMIT 5"
+ for prop in properties:
+ prop_name = prop["property"]
+ prop_type = prop["type"]
+
+ # Check if indexed property, we can still do exhaustive
+ prop_index = [
+ el
+ for el in self.structured_schema["metadata"]["index"]
+ if el["label"] == label_or_type
+ and el["properties"] == [prop_name]
+ and el["type"] == "RANGE"
+ ]
+ if prop_type == "STRING":
+ if (
+ prop_index
+ and prop_index[0].get("size") > 0
+ and prop_index[0].get("distinctValues")
+ <= DISTINCT_VALUE_LIMIT
+ ):
+ distinct_values = self.query(
+ f"CALL apoc.schema.properties.distinct("
+ f"'{label_or_type}', '{prop_name}') YIELD value"
+ )[0]["value"]
+ return_clauses.append(
+ f"values: {distinct_values},"
+ f" distinct_count: {len(distinct_values)}"
+ )
+ else:
+ with_clauses.append(
+ f"collect(distinct substring(n.`{prop_name}`, 0, 50)) "
+ f"AS `{prop_name}_values`"
+ )
+ return_clauses.append(f"values: `{prop_name}_values`")
+ elif prop_type in [
+ "INTEGER",
+ "FLOAT",
+ "DATE",
+ "DATE_TIME",
+ "LOCAL_DATE_TIME",
+ ]:
+ if not prop_index:
+ with_clauses.append(
+ f"collect(distinct toString(n.`{prop_name}`)) "
+ f"AS `{prop_name}_values`"
+ )
+ return_clauses.append(f"values: `{prop_name}_values`")
+ else:
+ with_clauses.append(
+ f"min(n.`{prop_name}`) AS `{prop_name}_min`"
+ )
+ with_clauses.append(
+ f"max(n.`{prop_name}`) AS `{prop_name}_max`"
+ )
+ with_clauses.append(
+ f"count(distinct n.`{prop_name}`) AS `{prop_name}_distinct`"
+ )
+ return_clauses.append(
+ f"min: toString(`{prop_name}_min`), "
+ f"max: toString(`{prop_name}_max`), "
+ f"distinct_count: `{prop_name}_distinct`"
+ )
+
+ elif prop_type == "LIST":
+ with_clauses.append(
+ f"min(size(n.`{prop_name}`)) AS `{prop_name}_size_min`, "
+ f"max(size(n.`{prop_name}`)) AS `{prop_name}_size_max`"
+ )
+ return_clauses.append(
+ f"min_size: `{prop_name}_size_min`, "
+ f"max_size: `{prop_name}_size_max`"
+ )
+ elif prop_type in ["BOOLEAN", "POINT", "DURATION"]:
+ continue
+
+ output_dict[prop_name] = "{" + return_clauses.pop() + "}"
+
+ with_clause = "WITH " + ",\n ".join(with_clauses)
+ return_clause = (
+ "RETURN {"
+ + ", ".join(f"`{k}`: {v}" for k, v in output_dict.items())
+ + "} AS output"
+ )
+
+ # Combine all parts of the Cypher query
+ return f"{match_clause}\n{with_clause}\n{return_clause}"
+
+ def get_schema(self, refresh: bool = False) -> Any:
+ if refresh:
+ self.refresh_schema()
+
+ return self.structured_schema
+
+ def get_schema_str(self, refresh: bool = False) -> str:
+ schema = self.get_schema(refresh=refresh)
+
+ formatted_node_props = []
+ formatted_rel_props = []
+
+ if self.enhcnaced_schema:
+ # Enhanced formatting for nodes
+ for node_type, properties in schema["node_props"].items():
+ formatted_node_props.append(f"- **{node_type}**")
+ for prop in properties:
+ example = ""
+ if prop["type"] == "STRING" and prop.get("values"):
+ if (
+ prop.get("distinct_count", 11)
+ > DISTINCT_VALUE_LIMIT
+ ):
+ example = (
+ f'Example: "{clean_string_values(prop["values"][0])}"'
+ if prop["values"]
+ else ""
+ )
+ else: # If less than 10 possible values return all
+ example = (
+ (
+ "Available options: "
+ f'{[clean_string_values(el) for el in prop["values"]]}'
+ )
+ if prop["values"]
+ else ""
+ )
+
+ elif prop["type"] in [
+ "INTEGER",
+ "FLOAT",
+ "DATE",
+ "DATE_TIME",
+ "LOCAL_DATE_TIME",
+ ]:
+ if prop.get("min") is not None:
+ example = f'Min: {prop["min"]}, Max: {prop["max"]}'
+ else:
+ example = (
+ f'Example: "{prop["values"][0]}"'
+ if prop.get("values")
+ else ""
+ )
+ elif prop["type"] == "LIST":
+ # Skip embeddings
+ if (
+ not prop.get("min_size")
+ or prop["min_size"] > LIST_LIMIT
+ ):
+ continue
+ example = f'Min Size: {prop["min_size"]}, Max Size: {prop["max_size"]}'
+ formatted_node_props.append(
+ f" - `{prop['property']}`: {prop['type']} {example}"
+ )
+
+ # Enhanced formatting for relationships
+ for rel_type, properties in schema["rel_props"].items():
+ formatted_rel_props.append(f"- **{rel_type}**")
+ for prop in properties:
+ example = ""
+ if prop["type"] == "STRING":
+ if (
+ prop.get("distinct_count", 11)
+ > DISTINCT_VALUE_LIMIT
+ ):
+ example = (
+ f'Example: "{clean_string_values(prop["values"][0])}"'
+ if prop.get("values")
+ else ""
+ )
+ else: # If less than 10 possible values return all
+ example = (
+ (
+ "Available options: "
+ f'{[clean_string_values(el) for el in prop["values"]]}'
+ )
+ if prop.get("values")
+ else ""
+ )
+ elif prop["type"] in [
+ "INTEGER",
+ "FLOAT",
+ "DATE",
+ "DATE_TIME",
+ "LOCAL_DATE_TIME",
+ ]:
+ if prop.get("min"): # If we have min/max
+ example = (
+ f'Min: {prop["min"]}, Max: {prop["max"]}'
+ )
+ else: # return a single value
+ example = (
+ f'Example: "{prop["values"][0]}"'
+ if prop.get("values")
+ else ""
+ )
+ elif prop["type"] == "LIST":
+ # Skip embeddings
+ if prop["min_size"] > LIST_LIMIT:
+ continue
+ example = f'Min Size: {prop["min_size"]}, Max Size: {prop["max_size"]}'
+ formatted_rel_props.append(
+ f" - `{prop['property']}: {prop['type']}` {example}"
+ )
+ else:
+ # Format node properties
+ for label, props in schema["node_props"].items():
+ props_str = ", ".join(
+ [f"{prop['property']}: {prop['type']}" for prop in props]
+ )
+ formatted_node_props.append(f"{label} {{{props_str}}}")
+
+ # Format relationship properties using structured_schema
+ for type, props in schema["rel_props"].items():
+ props_str = ", ".join(
+ [f"{prop['property']}: {prop['type']}" for prop in props]
+ )
+ formatted_rel_props.append(f"{type} {{{props_str}}}")
+
+ # Format relationships
+ formatted_rels = [
+ f"(:{el['start']})-[:{el['type']}]->(:{el['end']})"
+ for el in schema["relationships"]
+ ]
+
+ return "\n".join(
+ [
+ "Node properties:",
+ "\n".join(formatted_node_props),
+ "Relationship properties:",
+ "\n".join(formatted_rel_props),
+ "The relationships:",
+ "\n".join(formatted_rels),
+ ]
+ )
+
+ def update_extraction_prompt(
+ self,
+ prompt_provider: PromptProvider,
+ entity_types: list[EntityType],
+ relations: list[Relation],
+ ):
+ # Fetch the kg extraction prompt with blank entity types and relations
+ # Note - Assumes that for given prompt there is a `_with_spec` that can have entities + relations specified
+ few_shot_ner_kg_extraction_with_spec = prompt_provider.get_prompt(
+ f"{self.config.kg_extraction_prompt}_with_spec"
+ )
+
+ # Format the prompt to include the desired entity types and relations
+ few_shot_ner_kg_extraction = (
+ few_shot_ner_kg_extraction_with_spec.replace(
+ "{entity_types}", format_entity_types(entity_types)
+ ).replace("{relations}", format_relations(relations))
+ )
+
+ # Update the "few_shot_ner_kg_extraction" prompt used in downstream KG construction
+ prompt_provider.update_prompt(
+ self.config.kg_extraction_prompt,
+ json.dumps(few_shot_ner_kg_extraction, ensure_ascii=False),
+ )
+
+ def update_kg_agent_prompt(
+ self,
+ prompt_provider: PromptProvider,
+ entity_types: list[EntityType],
+ relations: list[Relation],
+ ):
+ # Fetch the kg extraction prompt with blank entity types and relations
+ # Note - Assumes that for given prompt there is a `_with_spec` that can have entities + relations specified
+ few_shot_ner_kg_extraction_with_spec = prompt_provider.get_prompt(
+ f"{self.config.kg_agent_prompt}_with_spec"
+ )
+
+ # Format the prompt to include the desired entity types and relations
+ few_shot_ner_kg_extraction = (
+ few_shot_ner_kg_extraction_with_spec.replace(
+ "{entity_types}",
+ format_entity_types(entity_types, ignore_subcats=True),
+ ).replace("{relations}", format_relations(relations))
+ )
+
+ # Update the "few_shot_ner_kg_extraction" prompt used in downstream KG construction
+ prompt_provider.update_prompt(
+ self.config.kg_agent_prompt,
+ json.dumps(few_shot_ner_kg_extraction, ensure_ascii=False),
+ )