aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFrederick Muriuki Muriithi2025-03-26 10:37:38 -0500
committerFrederick Muriuki Muriithi2025-03-26 13:11:20 -0500
commitc8efee1af29cef1b0471f898a0b0b7d5065ed7fc (patch)
tree14dd6f2cdee1beb3b633e61a2e8afb458932118b
parentf5ce3958b230f364f363ffa5a53bb25eafbc6e9a (diff)
downloadgn-uploader-c8efee1af29cef1b0471f898a0b0b7d5065ed7fc.tar.gz
Collect IDs once at the top-level call to save on iterations.
-rw-r--r--scripts/phenotypes_bulk_edit.py21
1 files changed, 14 insertions, 7 deletions
diff --git a/scripts/phenotypes_bulk_edit.py b/scripts/phenotypes_bulk_edit.py
index 1d6689e..6e5d416 100644
--- a/scripts/phenotypes_bulk_edit.py
+++ b/scripts/phenotypes_bulk_edit.py
@@ -4,6 +4,7 @@ import logging
import argparse
from pathlib import Path
from typing import Iterator
+from functools import reduce
from MySQLdb.cursors import DictCursor
@@ -14,11 +15,10 @@ logging.basicConfig(
logger = logging.getLogger(__name__)
-def check_ids(conn, contents):
+def check_ids(conn, ids: tuple[tuple[int, int], ...]) -> bool:
"""Verify that all the `UniqueIdentifier` values are valid."""
logger.info("Checking the 'UniqueIdentifier' values.")
with conn.cursor(cursorclass=DictCursor) as cursor:
- ids = tuple((row["phenotype_id"], row["xref_id"]) for row in contents)
paramstr = ",".join(["(%s, %s)"] * len(ids))
cursor.execute(
"SELECT PhenotypeId AS phenotype_id, Id AS xref_id "
@@ -26,7 +26,7 @@ def check_ids(conn, contents):
f"WHERE (PhenotypeId, Id) IN ({paramstr})",
tuple(item for row in ids for item in row))
mysqldb.debug_query(cursor, logger)
- found = tuple((str(row["phenotype_id"]), str(row["xref_id"]))
+ found = tuple((row["phenotype_id"], row["xref_id"])
for row in cursor.fetchall())
not_found = tuple(item for item in ids if item not in found)
@@ -108,16 +108,23 @@ def read_file(filepath: Path) -> Iterator[str]:
_dict = dict(zip(headers, fields))
_pheno, _xref = _dict.pop("UniqueIdentifier").split("::")
- _dict["phenotype_id"] = _pheno.split(":")[1]
- _dict["xref_id"] = _xref.split(":")[1]
+ _dict["phenotype_id"] = int(_pheno.split(":")[1])
+ _dict["xref_id"] = int(_xref.split(":")[1])
yield _dict
count = count + 1
def run(conn, job):
"""Process the data and update it."""
- file_contents = tuple(read_file(Path(job["metadata"]["edit-file"])))
- check_ids(conn, file_contents)
+ file_contents = tuple(sorted(read_file(Path(job["metadata"]["edit-file"])),
+ key=lambda item: item["phenotype_id"]))
+ pheno_ids, pheno_xref_ids = reduce(
+ lambda coll, curr: (
+ coll[0] + (curr["phenotype_id"],),
+ coll[1] + ((curr["phenotype_id"], curr["xref_id"]),)),
+ file_contents,
+ (tuple(), tuple()))
+ check_ids(conn, pheno_xref_ids)
check_for_mandatory_fields()
# stop running here if any errors are found.
compute_differences()