about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--guix.scm6
-rwxr-xr-xmore_functions.py50
-rw-r--r--tests/test_network_gemini_ontology.py58
3 files changed, 109 insertions, 5 deletions
diff --git a/guix.scm b/guix.scm
index 77551cd..5c8519e 100644
--- a/guix.scm
+++ b/guix.scm
@@ -7,12 +7,14 @@
 ;;
 ;; Development shell:
 ;;
-;;   guix shell -L . -C -N -F edirect-25 genecup-gemini coreutils -- genecup --port 4201
+;;   guix shell -L . -C -N -F --expose=$HOME/.config/gemini --share=/export3/PubMed edirect-25 genecup-gemini coreutils -- genecup --port 4201
 ;;
 ;; In a shell you can run
 ;;
-;;   python3 -m unittest tests.test_network_esearch
+;;   guix shell -C -N -F -L . --expose=$HOME/.config/gemini --share=/export3/PubMed edirect-25 genecup-gemini
+;;   env EDIRECT_LOCAL_ARCHIVE=/export3/PubMed/Source python3 -m unittest tests.test_network_esearch
 ;;   env EDIRECT_LOCAL_ARCHIVE=/export3/PubMed/Source python3 -m unittest tests.test_local_xfetch -v
+;;   env EDIRECT_LOCAL_ARCHIVE=/export3/PubMed/Source python3 -m unittest tests.test_network_gemini_ontology
 ;;
 ;; Note: API key is read from ~/.config/gemini/credentials
 ;;
diff --git a/more_functions.py b/more_functions.py
index a115899..35e3646 100755
--- a/more_functions.py
+++ b/more_functions.py
@@ -3,6 +3,7 @@ from nltk.tokenize import sent_tokenize
 import hashlib
 import os
 import re
+import time
 
 from addiction_keywords import *
 from gene_synonyms import *
@@ -10,8 +11,51 @@ import ast
 
 global pubmed_path
 
-# In-memory cache for esearch results: hash(query) -> list of PMIDs
-_esearch_cache = {}
+# In-memory caches
+_esearch_cache = {}  # hash(query) -> list of PMIDs
+_gemini_query_cache = {}  # hash(prompt) -> response text
+
+def gemini_query(prompt, model='gemini-2.5-flash'):
+    """Send a prompt to the Gemini API with caching and retry.
+
+    Returns the response text, or raises on failure.
+    """
+    from google import genai
+
+    cache_key = hashlib.sha256(prompt.encode()).hexdigest()
+    if cache_key in _gemini_query_cache:
+        print(f"  Gemini query cache hit")
+        return _gemini_query_cache[cache_key]
+
+    api_key = os.environ.get("GEMINI_API_KEY", "")
+    if not api_key:
+        cred_file = os.path.expanduser("~/.config/gemini/credentials")
+        if os.path.isfile(cred_file):
+            with open(cred_file) as f:
+                api_key = f.read().strip()
+    if not api_key:
+        raise RuntimeError("No Gemini API key found")
+
+    client = genai.Client(api_key=api_key)
+    last_error = None
+    for attempt in range(3):
+        try:
+            if attempt > 0:
+                time.sleep(2 * attempt)
+                print(f"  Gemini retry {attempt + 1}/3")
+            print(f"  Gemini API call ({model}): {prompt[:80]}...")
+            response = client.models.generate_content(
+                model=model,
+                contents=prompt
+            )
+            result = response.text.strip()
+            print(f"  Gemini response: {result[:200]}")
+            _gemini_query_cache[cache_key] = result
+            return result
+        except Exception as e:
+            last_error = e
+            print(f"  Gemini attempt {attempt + 1}/3 failed: {e}")
+    raise RuntimeError(f"Gemini API failed after 3 attempts: {last_error}")
 
 def esearch_pmids(query):
     """Search PubMed for PMIDs matching query. Results are cached in memory.
@@ -246,7 +290,7 @@ pubmed_path=os.environ.get("EDIRECT_LOCAL_ARCHIVE", "./minipubmed")
 print(f"  pubmed_path={pubmed_path}")
 
 if not os.path.isdir(pubmed_path):
-    print(f"ERROR: EDIRECT_LOCAL_ARCHIVE directory not found: {pubmed_path} - note this is a recent env variable that replaces the others")
+    print(f"ERROR: EDIRECT_LOCAL_ARCHIVE directory not found: {pubmed_path} - note this is a recent env variable that replaces the others (ignore the minipub reference)")
     raise SystemExit(1)
 testdir = os.path.join(pubmed_path, "pubmed", "Archive", "00")
 if not os.path.isdir(testdir):
diff --git a/tests/test_network_gemini_ontology.py b/tests/test_network_gemini_ontology.py
new file mode 100644
index 0000000..4a9db34
--- /dev/null
+++ b/tests/test_network_gemini_ontology.py
@@ -0,0 +1,58 @@
+"""Test Gemini API for generating SUD ontology terms.
+
+Requires a Gemini API key in ~/.config/gemini/credentials and internet access.
+
+Run with: python3 -m unittest tests.test_network_gemini_ontology -v
+"""
+
+import os
+import sys
+import time
+import unittest
+
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
+from more_functions import gemini_query
+
+PROMPT = (
+    """
+    Give me a list of terms on substance abuse disorder (SUD) that act
+    as traits and classifiers in scientific literature with a focus on
+    behaviour and brain attributes related to the hippocampus. Avoid
+    aliases and synonyms as well as gene names. Each term should be
+    1-3 words (max).  Give me a list of at least 20, but no more than
+    80, most used terms.  Return only the terms, one per line, no
+    numbering."""
+)
+
+class TestGeminiOntology(unittest.TestCase):
+    def test_1_sud_ontology_terms(self):
+        """Gemini should return 20-50 SUD ontology terms."""
+        t0 = time.time()
+        response = gemini_query(PROMPT)
+        elapsed = time.time() - t0
+        terms = [t.strip() for t in response.strip().split("\n") if t.strip()]
+        print(f"  Got {len(terms)} terms ({elapsed:.2f}s)")
+        for t in terms:
+            print(f"    - {t}")
+        self.assertGreaterEqual(len(terms), 20,
+                                f"Expected at least 20 terms, got {len(terms)}")
+        self.assertLessEqual(len(terms), 80,
+                             f"Expected at most 80 terms, got {len(terms)}")
+        # Each term should be short (1-3 words, allow some slack)
+        long_terms = [t for t in terms if len(t.split()) > 5]
+        self.assertEqual(len(long_terms), 0,
+                         f"Terms too long: {long_terms}")
+
+    def test_2_cached_ontology(self):
+        """Second call should use cache and be fast."""
+        # Ensure cache is populated from test_1
+        gemini_query(PROMPT)
+        t0 = time.time()
+        response = gemini_query(PROMPT)
+        elapsed = time.time() - t0
+        terms = [t.strip() for t in response.strip().split("\n") if t.strip()]
+        print(f"  Cached: {len(terms)} terms ({elapsed:.4f}s)")
+        self.assertLess(elapsed, 0.01, f"Cache lookup too slow: {elapsed:.4f}s")
+
+if __name__ == "__main__":
+    unittest.main()