diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /R2R/r2r/examples/scripts/advanced_kg_cookbook.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-4a52a71956a8d46fcb7294ac71734504bb09bcc2.tar.gz |
Diffstat (limited to 'R2R/r2r/examples/scripts/advanced_kg_cookbook.py')
-rwxr-xr-x | R2R/r2r/examples/scripts/advanced_kg_cookbook.py | 194 |
1 files changed, 194 insertions, 0 deletions
diff --git a/R2R/r2r/examples/scripts/advanced_kg_cookbook.py b/R2R/r2r/examples/scripts/advanced_kg_cookbook.py new file mode 100755 index 00000000..a4d59a79 --- /dev/null +++ b/R2R/r2r/examples/scripts/advanced_kg_cookbook.py @@ -0,0 +1,194 @@ +import json +import os + +import fire +import requests +from bs4 import BeautifulSoup, Comment + +from r2r import ( + EntityType, + R2RClient, + R2RPromptProvider, + Relation, + update_kg_prompt, +) + + +def escape_braces(text): + return text.replace("{", "{{").replace("}", "}}") + + +def get_all_yc_co_directory_urls(): + this_file_path = os.path.abspath(os.path.dirname(__file__)) + yc_company_dump_path = os.path.join( + this_file_path, "..", "data", "yc_companies.txt" + ) + + with open(yc_company_dump_path, "r") as f: + urls = f.readlines() + urls = [url.strip() for url in urls] + return {url.split("/")[-1]: url for url in urls} + + +# Function to fetch and clean HTML content +def fetch_and_clean_yc_co_data(url): + # Fetch the HTML content from the URL + response = requests.get(url) + response.raise_for_status() # Raise an error for bad status codes + html_content = response.text + + # Parse the HTML content with BeautifulSoup + soup = BeautifulSoup(html_content, "html.parser") + + # Remove all <script>, <style>, <meta>, <link>, <header>, <nav>, and <footer> elements + for element in soup( + ["script", "style", "meta", "link", "header", "nav", "footer"] + ): + element.decompose() + + # Remove comments + for comment in soup.findAll(text=lambda text: isinstance(text, Comment)): + comment.extract() + + # Select the main content (you can adjust the selector based on the structure of your target pages) + main_content = soup.select_one("main") or soup.body + + if main_content: + spans = main_content.find_all(["span", "a"]) + + proc_spans = [] + for span in spans: + proc_spans.append(span.get_text(separator=" ", strip=True)) + span_text = "\n".join(proc_spans) + + # Extract the text content from the main content + paragraphs = main_content.find_all( + ["p", "h1", "h2", "h3", "h4", "h5", "h6", "li"] + ) + cleaned_text = ( + "### Bulk:\n\n" + + "\n\n".join( + paragraph.get_text(separator=" ", strip=True) + for paragraph in paragraphs + ) + + "\n\n### Metadata:\n\n" + + span_text + ) + + return cleaned_text + else: + return "Main content not found" + + +def execute_query(provider, query, params={}): + print(f"Executing query: {query}") + with provider.client.session(database=provider._database) as session: + result = session.run(query, params) + return [record.data() for record in result] + + +def main( + max_entries=50, + local_mode=True, + base_url="http://localhost:8000", +): + + # Specify the entity types for the KG extraction prompt + entity_types = [ + EntityType("COMPANY"), + EntityType("SCHOOL"), + EntityType("LOCATION"), + EntityType("PERSON"), + EntityType("DATE"), + EntityType("OTHER"), + EntityType("QUANTITY"), + EntityType("EVENT"), + EntityType("INDUSTRY"), + EntityType("MEDIA"), + ] + + # Specify the relations for the KG construction + relations = [ + # Founder Relations + Relation("EDUCATED_AT"), + Relation("WORKED_AT"), + Relation("FOUNDED"), + # Company relations + Relation("RAISED"), + Relation("REVENUE"), + Relation("TEAM_SIZE"), + Relation("LOCATION"), + Relation("ACQUIRED_BY"), + Relation("ANNOUNCED"), + Relation("INDUSTRY"), + # Product relations + Relation("PRODUCT"), + Relation("FEATURES"), + Relation("TECHNOLOGY"), + # Additional relations + Relation("HAS"), + Relation("AS_OF"), + Relation("PARTICIPATED"), + Relation("ASSOCIATED"), + ] + + client = R2RClient(base_url=base_url) + r2r_prompts = R2RPromptProvider() + + prompt_base = ( + "zero_shot_ner_kg_extraction" + if local_mode + else "few_shot_ner_kg_extraction" + ) + + update_kg_prompt(client, r2r_prompts, prompt_base, entity_types, relations) + + url_map = get_all_yc_co_directory_urls() + + i = 0 + # Ingest and clean the data for each company + for company, url in url_map.items(): + company_data = fetch_and_clean_yc_co_data(url) + if i >= max_entries: + break + i += 1 + + try: + # Ingest as a text document + file_name = f"{company}.txt" + with open(file_name, "w") as f: + f.write(company_data) + + client.ingest_files( + [file_name], + metadatas=[{"title": company}], + ) + os.remove(file_name) + except: + continue + + print(client.inspect_knowledge_graph(1_000)["results"]) + + if not local_mode: + + update_kg_prompt( + client, r2r_prompts, "kg_agent", entity_types, relations + ) + + result = client.search( + query="Find up to 10 founders that worked at Google", + use_kg_search=True, + )["results"] + + print("result:\n", result) + print("Search Result:\n", result["kg_search_results"]) + + result = client.rag( + query="Find up to 10 founders that worked at Google", + use_kg_search=True, + ) + print("RAG Result:\n", result) + + +if __name__ == "__main__": + fire.Fire(main) |