about summary refs log tree commit diff
path: root/scripts
diff options
context:
space:
mode:
Diffstat (limited to 'scripts')
-rw-r--r--scripts/load_phenotypes_to_db.py34
1 files changed, 33 insertions, 1 deletions
diff --git a/scripts/load_phenotypes_to_db.py b/scripts/load_phenotypes_to_db.py
index 3737f2d..9158307 100644
--- a/scripts/load_phenotypes_to_db.py
+++ b/scripts/load_phenotypes_to_db.py
@@ -29,6 +29,8 @@ from uploader.publications.models import fetch_publication_by_id
 
 from scripts.rqtl2.bundleutils import build_line_joiner, build_line_splitter
 
+from functional_tools import take
+
 logging.basicConfig(
     format="%(asctime)s — %(filename)s:%(lineno)s — %(levelname)s: %(message)s")
 logger = logging.getLogger(__name__)
@@ -406,6 +408,32 @@ def load_data(conn: mysqldb.Connection, job: dict) -> int:#pylint: disable=[too-
     return (_species, _population, _dataset, _xrefs)
 
 
+def update_means(
+        conn: mysqldb.Connection,
+        population_id: int,
+        xref_ids: tuple[int, ...]
+):
+    """Compute the means from the data and update them in the database."""
+    query = (
+        "UPDATE PublishXRef SET mean = "
+        "(SELECT AVG(value) FROM PublishData"
+        " WHERE PublishData.Id=PublishXRef.DataId) "
+        "WHERE PublishXRef.Id=%(xref_id)s "
+        "AND PublishXRef.InbredSetId=%(population_id)s")
+    _xref_iterator = (_xref_id for _xref_id in xref_ids)
+    with conn.cursor(cursorclass=DictCursor) as cursor:
+        while True:
+            batch = take(_xref_iterator, 10000)
+            if len(batch) == 0:
+                break
+            cursor.executemany(
+                query,
+                tuple({
+                    "population_id": population_id,
+                    "xref_id": _xref_id
+                } for _xref_id in batch))
+
+
 if __name__ == "__main__":
     def parse_args():
         """Setup command-line arguments."""
@@ -469,15 +497,19 @@ if __name__ == "__main__":
                     f"{_table} WRITE" for _table in _db_tables_))
 
             db_results = load_data(conn, job)
+            _xref_ids = tuple(xref["xref_id"] for xref in db_results[3])
             jobs.update_metadata(
                 jobs_conn,
                 args.job_id,
                 "xref_ids",
-                json.dumps([xref["xref_id"] for xref in db_results[3]]))
+                json.dumps(_xref_ids))
 
             logger.info("Unlocking all database tables.")
             cursor.execute("UNLOCK TABLES")
 
+            logger.info("Updating means.")
+            update_means(conn, db_results[1]["Id"], _xref_ids)
+
         # Update authorisations (break this down) — maybe loop until it works?
         logger.info("Updating authorisation.")
         _job_metadata = job["metadata"]