about summary refs log tree commit diff
path: root/scripts/index-genenetwork
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/index-genenetwork')
-rwxr-xr-xscripts/index-genenetwork251
1 files changed, 221 insertions, 30 deletions
diff --git a/scripts/index-genenetwork b/scripts/index-genenetwork
index 1f649cf..2779abc 100755
--- a/scripts/index-genenetwork
+++ b/scripts/index-genenetwork
@@ -8,21 +8,26 @@ xapian index. This xapian index is later used in providing search
 through the web interface.
 
 """
-
-from collections import deque, namedtuple
+from dataclasses import dataclass
+from collections import deque, namedtuple, Counter
 import contextlib
+import time
+import datetime
 from functools import partial
 import itertools
 import json
 import logging
-from multiprocessing import Lock, Process
+from multiprocessing import Lock, Manager, Process, managers
 import os
 import pathlib
 import resource
+import re
 import shutil
 import sys
+import hashlib
 import tempfile
-from typing import Callable, Generator, Iterable, List
+from typing import Callable, Dict, Generator, Hashable, Iterable, List
+from SPARQLWrapper import SPARQLWrapper, JSON
 
 import MySQLdb
 import click
@@ -33,7 +38,10 @@ import xapian
 from gn3.db_utils import database_connection
 from gn3.monads import query_sql
 
-DOCUMENTS_PER_CHUNK = 100000
+DOCUMENTS_PER_CHUNK = 100_000
+# Running the script in prod consumers ~1GB per process when handling 100_000 Documents per chunk.
+# To prevent running out of RAM, we set this as the upper bound for total concurrent processes
+PROCESS_COUNT_LIMIT = 67
 
 SQLQuery = namedtuple("SQLQuery",
                       ["fields", "tables", "where", "offset", "limit"],
@@ -122,6 +130,38 @@ phenotypes_query = SQLQuery(
      SQLTableClause("LEFT JOIN", "Geno",
                     "PublishXRef.Locus = Geno.Name AND Geno.SpeciesId = Species.Id")])
 
+WIKI_CACHE_QUERY = """
+PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
+PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
+PREFIX gnt: <http://genenetwork.org/term/>
+PREFIX gnc: <http://genenetwork.org/category/>
+
+SELECT ?symbolName ?speciesName GROUP_CONCAT(DISTINCT ?comment ; separator=\"\\n\") AS ?comment WHERE {
+    ?symbol rdfs:comment _:node ;
+            rdfs:label ?symbolName .
+_:node rdf:type gnc:GNWikiEntry ;
+       gnt:belongsToSpecies ?species ;
+       rdfs:comment ?comment .
+?species gnt:shortName ?speciesName .
+} GROUP BY ?speciesName ?symbolName
+"""
+
+RIF_CACHE_QUERY = """
+PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
+PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
+PREFIX gnt: <http://genenetwork.org/term/>
+PREFIX gnc: <http://genenetwork.org/category/>
+
+SELECT ?symbolName ?speciesName GROUP_CONCAT(DISTINCT ?comment ; separator=\"\\n\") AS ?comment WHERE {
+    ?symbol rdfs:comment _:node ;
+            rdfs:label ?symbolName .
+_:node rdf:type gnc:NCBIWikiEntry ;
+       gnt:belongsToSpecies ?species ;
+       rdfs:comment ?comment .
+?species gnt:shortName ?speciesName .
+} GROUP BY ?speciesName ?symbolName
+"""
+
 
 def serialize_sql(query: SQLQuery) -> str:
     """Serialize SQLQuery object to a string."""
@@ -168,6 +208,48 @@ def locked_xapian_writable_database(path: pathlib.Path) -> xapian.WritableDataba
         db.close()
 
 
+def build_rdf_cache(sparql_uri: str, query: str, remove_common_words: bool = False):
+    cache = {}
+    sparql = SPARQLWrapper(sparql_uri)
+    sparql.setReturnFormat(JSON)
+    sparql.setQuery(query)
+    results = sparql.queryAndConvert()
+    if not isinstance(results, dict):
+        raise TypeError(f"Expected results to be a dict but found {type(results)}")
+    bindings = results["results"]["bindings"]
+    count: Counter[str] = Counter()
+    words_regex = re.compile(r"\w+")
+    for entry in bindings :
+        x = (entry["speciesName"]["value"], entry["symbolName"]["value"],)
+        value = entry["comment"]["value"]
+        value = " ".join(words_regex.findall(value)) # remove punctuation
+        cache[x] = value
+        count.update(Counter(value.lower().strip().split()))
+
+    if not remove_common_words:
+        return cache
+
+    words_to_drop = set()
+    for word, cnt in count.most_common(1000):
+        if len(word) < 4 or cnt > 3000:
+            words_to_drop.add(word)
+    smaller_cache = {}
+    for entry, value in cache.items():
+        new_value = set(word for word in value.lower().split() if word not in words_to_drop)
+        smaller_cache[entry] = " ".join(new_value)
+    return smaller_cache
+
+
+def md5hash_ttl_dir(ttl_dir: pathlib.Path) -> str:
+    if not ttl_dir.exists():
+        return "-1"
+    ttl_hash = hashlib.new("md5")
+    for ttl_file in ttl_dir.glob("*.ttl"):
+        with open(ttl_file, encoding="utf-8") as f_:
+            ttl_hash.update(f_.read().encode())
+    return ttl_hash.hexdigest()
+
+
 # pylint: disable=invalid-name
 def write_document(db: xapian.WritableDatabase, identifier: str,
                    doctype: str, doc: xapian.Document) -> None:
@@ -181,15 +263,23 @@ def write_document(db: xapian.WritableDatabase, identifier: str,
 
 termgenerator = xapian.TermGenerator()
 termgenerator.set_stemmer(xapian.Stem("en"))
+termgenerator.set_stopper_strategy(xapian.TermGenerator.STOP_ALL)
+termgenerator.set_stopper(xapian.SimpleStopper())
 
 def index_text(text: str) -> None:
     """Index text and increase term position."""
     termgenerator.index_text(text)
     termgenerator.increase_termpos()
 
-# pylint: disable=unnecessary-lambda
-index_text_without_positions = lambda text: termgenerator.index_text_without_positions(text)
+@curry(3)
+def index_from_dictionary(keys: Hashable, prefix: str, dictionary: dict):
+    entry = dictionary.get(keys)
+    if not entry:
+        return
+    termgenerator.index_text_without_positions(entry, 0, prefix)
+
 
+index_text_without_positions = lambda text: termgenerator.index_text_without_positions(text)
 index_authors = lambda authors: termgenerator.index_text(authors, 0, "A")
 index_species = lambda species: termgenerator.index_text_without_positions(species, 0, "XS")
 index_group = lambda group: termgenerator.index_text_without_positions(group, 0, "XG")
@@ -206,10 +296,17 @@ add_peakmb = lambda doc, peakmb: doc.add_value(3, xapian.sortable_serialise(peak
 add_additive = lambda doc, additive: doc.add_value(4, xapian.sortable_serialise(additive))
 add_year = lambda doc, year: doc.add_value(5, xapian.sortable_serialise(float(year)))
 
+
+
+
 # When a child process is forked, it inherits a copy of the memory of
 # its parent. We use this to pass data retrieved from SQL from parent
 # to child. Specifically, we use this global variable.
-data: Iterable
+# This is copy-on-write so make sure child processes don't modify this data
+mysql_data: Iterable
+rif_cache: Iterable
+wiki_cache: Iterable
+
 # We use this lock to ensure that only one process writes its Xapian
 # index to disk at a time.
 xapian_lock = Lock()
@@ -217,7 +314,7 @@ xapian_lock = Lock()
 def index_genes(xapian_build_directory: pathlib.Path, chunk_index: int) -> None:
     """Index genes data into a Xapian index."""
     with locked_xapian_writable_database(xapian_build_directory / f"genes-{chunk_index:04d}") as db:
-        for trait in data:
+        for trait in mysql_data:
             # pylint: disable=cell-var-from-loop
             doc = xapian.Document()
             termgenerator.set_document(doc)
@@ -230,7 +327,7 @@ def index_genes(xapian_build_directory: pathlib.Path, chunk_index: int) -> None:
             trait["additive"].bind(partial(add_additive, doc))
 
             # Index free text.
-            for key in ["description", "tissue", "dataset_fullname"]:
+            for key in ["description", "tissue", "dataset"]:
                 trait[key].bind(index_text)
             trait.pop("probe_target_description").bind(index_text)
             for key in ["name", "symbol", "species", "group"]:
@@ -242,11 +339,23 @@ def index_genes(xapian_build_directory: pathlib.Path, chunk_index: int) -> None:
             trait["species"].bind(index_species)
             trait["group"].bind(index_group)
             trait["tissue"].bind(index_tissue)
-            trait["dataset_fullname"].bind(index_dataset)
+            trait["dataset"].bind(index_dataset)
             trait["symbol"].bind(index_symbol)
             trait["chr"].bind(index_chr)
             trait["geno_chr"].bind(index_peakchr)
 
+            Maybe.apply(index_from_dictionary).to_arguments(
+                    Just((trait["species"].value, trait["symbol"].value)),
+                    Just("XRF"),
+                    Just(rif_cache)
+                    )
+
+            Maybe.apply(index_from_dictionary).to_arguments(
+                    Just((trait["species"].value, trait["symbol"].value)),
+                    Just("XWK"),
+                    Just(wiki_cache)
+                    )
+
             doc.set_data(json.dumps(trait.data))
             (Maybe.apply(curry(2, lambda name, dataset: f"{name}:{dataset}"))
              .to_arguments(trait["name"], trait["dataset"])
@@ -257,7 +366,8 @@ def index_phenotypes(xapian_build_directory: pathlib.Path, chunk_index: int) ->
     """Index phenotypes data into a Xapian index."""
     with locked_xapian_writable_database(
             xapian_build_directory / f"phenotypes-{chunk_index:04d}") as db:
-        for trait in data:
+
+        for trait in mysql_data:
             # pylint: disable=cell-var-from-loop
             doc = xapian.Document()
             termgenerator.set_document(doc)
@@ -270,7 +380,7 @@ def index_phenotypes(xapian_build_directory: pathlib.Path, chunk_index: int) ->
             trait["year"].bind(partial(add_year, doc))
 
             # Index free text.
-            for key in ["description", "authors", "dataset_fullname"]:
+            for key in ["description", "authors", "dataset"]:
                 trait[key].bind(index_text)
             for key in ["Abstract", "Title"]:
                 trait.pop(key).bind(index_text)
@@ -284,7 +394,7 @@ def index_phenotypes(xapian_build_directory: pathlib.Path, chunk_index: int) ->
             trait["group"].bind(index_group)
             trait["authors"].bind(index_authors)
             trait["geno_chr"].bind(index_peakchr)
-            trait["dataset_fullname"].bind(index_dataset)
+            trait["dataset"].bind(index_dataset)
 
             # Convert name from integer to string.
             trait["name"] = trait["name"].map(str)
@@ -320,12 +430,16 @@ def worker_queue(number_of_workers: int = os.cpu_count() or 1) -> Generator:
         process.join()
 
 
-def index_query(index_function: Callable, query: SQLQuery,
-                xapian_build_directory: pathlib.Path, sql_uri: str, start: int = 0) -> None:
+def index_query(index_function: Callable[[pathlib.Path, int], None], query: SQLQuery,
+                xapian_build_directory: pathlib.Path, sql_uri: str,
+                sparql_uri: str, start: int = 0) -> None:
     """Run SQL query, and index its results for Xapian."""
     i = start
+    default_no_of_workers = os.cpu_count() or 1
+    no_of_workers = min(default_no_of_workers, PROCESS_COUNT_LIMIT)
+
     try:
-        with worker_queue() as spawn_worker:
+        with worker_queue(no_of_workers) as spawn_worker:
             with database_connection(sql_uri) as conn:
                 for chunk in group(query_sql(conn, serialize_sql(
                         # KLUDGE: MariaDB does not allow an offset
@@ -335,9 +449,8 @@ def index_query(index_function: Callable, query: SQLQuery,
                                        offset=start*DOCUMENTS_PER_CHUNK)),
                                                    server_side=True),
                                    DOCUMENTS_PER_CHUNK):
-                    # pylint: disable=global-statement
-                    global data
-                    data = chunk
+                    global mysql_data
+                    mysql_data = chunk
                     spawn_worker(index_function, (xapian_build_directory, i))
                     logging.debug("Spawned worker process on chunk %s", i)
                     i += 1
@@ -347,7 +460,7 @@ def index_query(index_function: Callable, query: SQLQuery,
     except MySQLdb._exceptions.OperationalError:
         logging.warning("Reopening connection to recovering from SQL operational error",
                         exc_info=True)
-        index_query(index_function, query, xapian_build_directory, sql_uri, i)
+        index_query(index_function, query, xapian_build_directory, sql_uri, sparql_uri, i)
 
 
 @contextlib.contextmanager
@@ -357,12 +470,33 @@ def temporary_directory(prefix: str, parent_directory: str) -> Generator:
         yield pathlib.Path(tmpdirname)
 
 
+def parallel_xapian_compact(combined_index: pathlib.Path, indices: List[pathlib.Path]) -> None:
+    # We found that compacting 50 files of ~600MB has decent performance
+    no_of_workers = 20
+    file_groupings = 50
+    with temporary_directory("parallel_combine", str(combined_index)) as parallel_combine:
+        parallel_combine.mkdir(parents=True, exist_ok=True)
+        with worker_queue(no_of_workers) as spawn_worker:
+            i = 0
+            while i < len(indices):
+                end_index = (i + file_groupings)
+                files = indices[i:end_index]
+                last_item_idx = i + len(files)
+                spawn_worker(xapian_compact, (parallel_combine / f"{i}_{last_item_idx}", files))
+                logging.debug("Spawned worker to compact files from %s to %s", i, last_item_idx)
+                i = end_index
+        logging.debug("Completed parallel xapian compacts")
+        xapian_compact(combined_index, list(parallel_combine.iterdir()))
+
+
 def xapian_compact(combined_index: pathlib.Path, indices: List[pathlib.Path]) -> None:
     """Compact and combine several Xapian indices."""
     # xapian-compact opens all indices simultaneously. So, raise the limit on
     # the number of open files.
     soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
     resource.setrlimit(resource.RLIMIT_NOFILE, (max(soft, min(10*len(indices), hard)), hard))
+    combined_index.mkdir(parents=True, exist_ok=True)
+    start = time.monotonic()
     db = xapian.Database()
     try:
         for index in indices:
@@ -370,32 +504,73 @@ def xapian_compact(combined_index: pathlib.Path, indices: List[pathlib.Path]) ->
         db.compact(str(combined_index), xapian.DBCOMPACT_MULTIPASS | xapian.Compactor.FULLER)
     finally:
         db.close()
+    logging.debug("Removing databases that were compacted into %s", combined_index.name)
+    for folder in indices:
+        shutil.rmtree(folder)
+    logging.debug("Completed xapian-compact %s; handled %s files in %s minutes", combined_index.name, len(indices), (time.monotonic() - start) / 60)
+
+
+@click.command(help="Verify checksums and return True when the data has been changed.")
+@click.argument("xapian_directory")
+@click.argument("sql_uri")
+@click.argument("sparql_uri")
+def is_data_modified(xapian_directory: str,
+                     sql_uri: str,
+                     sparql_uri: str) -> None:
+    dir_ = pathlib.Path(xapian_directory)
+    with locked_xapian_writable_database(dir_) as db, database_connection(sql_uri) as conn:
+        checksums = "-1"
+        if db.get_metadata('tables'):
+            checksums = " ".join([
+                str(result["Checksum"].value)
+                for result in query_sql(
+                        conn,
+                        f"CHECKSUM TABLE {', '.join(db.get_metadata('tables').decode().split())}")
+            ])
+        # Return a zero exit status code when the data has changed;
+        # otherwise exit with a 1 exit status code.
+        generif = pathlib.Path("/var/lib/data/")
+        if (db.get_metadata("generif-checksum").decode() == md5hash_ttl_dir(generif) and
+            db.get_metadata("checksums").decode() == checksums):
+            sys.exit(1)
+        sys.exit(0)
 
 
 @click.command(help="Index GeneNetwork data and build Xapian search index in XAPIAN_DIRECTORY.")
 @click.argument("xapian_directory")
 @click.argument("sql_uri")
+@click.argument("sparql_uri")
 # pylint: disable=missing-function-docstring
-def main(xapian_directory: str, sql_uri: str) -> None:
+def create_xapian_index(xapian_directory: str, sql_uri: str,
+                        sparql_uri: str) -> None:
     logging.basicConfig(level=os.environ.get("LOGLEVEL", "DEBUG"),
-                        format='%(relativeCreated)s: %(levelname)s: %(message)s')
+                        format='%(asctime)s %(levelname)s: %(message)s',
+                        datefmt='%Y-%m-%d %H:%M:%S %Z')
+    if not pathlib.Path(xapian_directory).exists():
+        pathlib.Path(xapian_directory).mkdir()
 
     # Ensure no other build process is running.
-    if pathlib.Path(xapian_directory).exists():
-        logging.error("Build directory %s already exists; "
+    if any(pathlib.Path(xapian_directory).iterdir()):
+        logging.error("Build directory %s has build files; "
                       "perhaps another build process is running.",
                       xapian_directory)
         sys.exit(1)
 
-    pathlib.Path(xapian_directory).mkdir()
+    start_time = time.perf_counter()
     with temporary_directory("combined", xapian_directory) as combined_index:
         with temporary_directory("build", xapian_directory) as xapian_build_directory:
+            global rif_cache
+            global wiki_cache
+            logging.info("Building wiki cache")
+            wiki_cache = build_rdf_cache(sparql_uri, WIKI_CACHE_QUERY, remove_common_words=True)
+            logging.info("Building rif cache")
+            rif_cache = build_rdf_cache(sparql_uri, RIF_CACHE_QUERY, remove_common_words=True)
             logging.info("Indexing genes")
-            index_query(index_genes, genes_query, xapian_build_directory, sql_uri)
+            index_query(index_genes, genes_query, xapian_build_directory, sql_uri, sparql_uri)
             logging.info("Indexing phenotypes")
-            index_query(index_phenotypes, phenotypes_query, xapian_build_directory, sql_uri)
+            index_query(index_phenotypes, phenotypes_query, xapian_build_directory, sql_uri, sparql_uri)
             logging.info("Combining and compacting indices")
-            xapian_compact(combined_index, list(xapian_build_directory.iterdir()))
+            parallel_xapian_compact(combined_index, list(xapian_build_directory.iterdir()))
             logging.info("Writing table checksums into index")
             with locked_xapian_writable_database(combined_index) as db:
                 # Build a (deduplicated) set of all tables referenced in
@@ -409,11 +584,27 @@ def main(xapian_directory: str, sql_uri: str) -> None:
                     ]
                 db.set_metadata("tables", " ".join(tables))
                 db.set_metadata("checksums", " ".join(checksums))
+                logging.info("Writing generif checksums into index")
+                db.set_metadata(
+                    "generif-checksum",
+                    md5hash_ttl_dir(pathlib.Path("/var/lib/data/")).encode())
         for child in combined_index.iterdir():
             shutil.move(child, xapian_directory)
     logging.info("Index built")
+    end_time = time.perf_counter()
+    index_time = datetime.timedelta(seconds=end_time - start_time)
+    logging.info(f"Time to Index: {index_time}")
+
+
+@click.group()
+def cli():
+    pass
+
+
+cli.add_command(is_data_modified)
+cli.add_command(create_xapian_index)
 
 
 if __name__ == "__main__":
     # pylint: disable=no-value-for-parameter
-    main()
+    cli()