From 5f4cef3640f84092e5692e16865002a832b7838c Mon Sep 17 00:00:00 2001 From: Pjotr Prins Date: Mon, 6 Apr 2026 09:51:26 +0200 Subject: Added a test that creates an ontology --- guix.scm | 6 ++-- more_functions.py | 50 ++++++++++++++++++++++++++++-- tests/test_network_gemini_ontology.py | 58 +++++++++++++++++++++++++++++++++++ 3 files changed, 109 insertions(+), 5 deletions(-) create mode 100644 tests/test_network_gemini_ontology.py diff --git a/guix.scm b/guix.scm index 77551cd..5c8519e 100644 --- a/guix.scm +++ b/guix.scm @@ -7,12 +7,14 @@ ;; ;; Development shell: ;; -;; guix shell -L . -C -N -F edirect-25 genecup-gemini coreutils -- genecup --port 4201 +;; guix shell -L . -C -N -F --expose=$HOME/.config/gemini --share=/export3/PubMed edirect-25 genecup-gemini coreutils -- genecup --port 4201 ;; ;; In a shell you can run ;; -;; python3 -m unittest tests.test_network_esearch +;; guix shell -C -N -F -L . --expose=$HOME/.config/gemini --share=/export3/PubMed edirect-25 genecup-gemini +;; env EDIRECT_LOCAL_ARCHIVE=/export3/PubMed/Source python3 -m unittest tests.test_network_esearch ;; env EDIRECT_LOCAL_ARCHIVE=/export3/PubMed/Source python3 -m unittest tests.test_local_xfetch -v +;; env EDIRECT_LOCAL_ARCHIVE=/export3/PubMed/Source python3 -m unittest tests.test_network_gemini_ontology ;; ;; Note: API key is read from ~/.config/gemini/credentials ;; diff --git a/more_functions.py b/more_functions.py index a115899..35e3646 100755 --- a/more_functions.py +++ b/more_functions.py @@ -3,6 +3,7 @@ from nltk.tokenize import sent_tokenize import hashlib import os import re +import time from addiction_keywords import * from gene_synonyms import * @@ -10,8 +11,51 @@ import ast global pubmed_path -# In-memory cache for esearch results: hash(query) -> list of PMIDs -_esearch_cache = {} +# In-memory caches +_esearch_cache = {} # hash(query) -> list of PMIDs +_gemini_query_cache = {} # hash(prompt) -> response text + +def gemini_query(prompt, model='gemini-2.5-flash'): + """Send a prompt to the Gemini API with caching and retry. + + Returns the response text, or raises on failure. + """ + from google import genai + + cache_key = hashlib.sha256(prompt.encode()).hexdigest() + if cache_key in _gemini_query_cache: + print(f" Gemini query cache hit") + return _gemini_query_cache[cache_key] + + api_key = os.environ.get("GEMINI_API_KEY", "") + if not api_key: + cred_file = os.path.expanduser("~/.config/gemini/credentials") + if os.path.isfile(cred_file): + with open(cred_file) as f: + api_key = f.read().strip() + if not api_key: + raise RuntimeError("No Gemini API key found") + + client = genai.Client(api_key=api_key) + last_error = None + for attempt in range(3): + try: + if attempt > 0: + time.sleep(2 * attempt) + print(f" Gemini retry {attempt + 1}/3") + print(f" Gemini API call ({model}): {prompt[:80]}...") + response = client.models.generate_content( + model=model, + contents=prompt + ) + result = response.text.strip() + print(f" Gemini response: {result[:200]}") + _gemini_query_cache[cache_key] = result + return result + except Exception as e: + last_error = e + print(f" Gemini attempt {attempt + 1}/3 failed: {e}") + raise RuntimeError(f"Gemini API failed after 3 attempts: {last_error}") def esearch_pmids(query): """Search PubMed for PMIDs matching query. Results are cached in memory. @@ -246,7 +290,7 @@ pubmed_path=os.environ.get("EDIRECT_LOCAL_ARCHIVE", "./minipubmed") print(f" pubmed_path={pubmed_path}") if not os.path.isdir(pubmed_path): - print(f"ERROR: EDIRECT_LOCAL_ARCHIVE directory not found: {pubmed_path} - note this is a recent env variable that replaces the others") + print(f"ERROR: EDIRECT_LOCAL_ARCHIVE directory not found: {pubmed_path} - note this is a recent env variable that replaces the others (ignore the minipub reference)") raise SystemExit(1) testdir = os.path.join(pubmed_path, "pubmed", "Archive", "00") if not os.path.isdir(testdir): diff --git a/tests/test_network_gemini_ontology.py b/tests/test_network_gemini_ontology.py new file mode 100644 index 0000000..4a9db34 --- /dev/null +++ b/tests/test_network_gemini_ontology.py @@ -0,0 +1,58 @@ +"""Test Gemini API for generating SUD ontology terms. + +Requires a Gemini API key in ~/.config/gemini/credentials and internet access. + +Run with: python3 -m unittest tests.test_network_gemini_ontology -v +""" + +import os +import sys +import time +import unittest + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) +from more_functions import gemini_query + +PROMPT = ( + """ + Give me a list of terms on substance abuse disorder (SUD) that act + as traits and classifiers in scientific literature with a focus on + behaviour and brain attributes related to the hippocampus. Avoid + aliases and synonyms as well as gene names. Each term should be + 1-3 words (max). Give me a list of at least 20, but no more than + 80, most used terms. Return only the terms, one per line, no + numbering.""" +) + +class TestGeminiOntology(unittest.TestCase): + def test_1_sud_ontology_terms(self): + """Gemini should return 20-50 SUD ontology terms.""" + t0 = time.time() + response = gemini_query(PROMPT) + elapsed = time.time() - t0 + terms = [t.strip() for t in response.strip().split("\n") if t.strip()] + print(f" Got {len(terms)} terms ({elapsed:.2f}s)") + for t in terms: + print(f" - {t}") + self.assertGreaterEqual(len(terms), 20, + f"Expected at least 20 terms, got {len(terms)}") + self.assertLessEqual(len(terms), 80, + f"Expected at most 80 terms, got {len(terms)}") + # Each term should be short (1-3 words, allow some slack) + long_terms = [t for t in terms if len(t.split()) > 5] + self.assertEqual(len(long_terms), 0, + f"Terms too long: {long_terms}") + + def test_2_cached_ontology(self): + """Second call should use cache and be fast.""" + # Ensure cache is populated from test_1 + gemini_query(PROMPT) + t0 = time.time() + response = gemini_query(PROMPT) + elapsed = time.time() - t0 + terms = [t.strip() for t in response.strip().split("\n") if t.strip()] + print(f" Cached: {len(terms)} terms ({elapsed:.4f}s)") + self.assertLess(elapsed, 0.01, f"Cache lookup too slow: {elapsed:.4f}s") + +if __name__ == "__main__": + unittest.main() -- cgit 1.4.1