about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--scripts/insert_data.py120
1 files changed, 77 insertions, 43 deletions
diff --git a/scripts/insert_data.py b/scripts/insert_data.py
index b3e9eea..56d880b 100644
--- a/scripts/insert_data.py
+++ b/scripts/insert_data.py
@@ -3,8 +3,8 @@ import sys
 import string
 import random
 import argparse
-from typing import Tuple
 from functools import reduce
+from typing import Tuple, Iterator
 
 import MySQLdb as mdb
 from redis import Redis
@@ -59,22 +59,22 @@ def strains_info(
         return {strain["Name"]: strain for strain in cursor.fetchall()}
 
 def read_datavalues(filepath, headings, strain_info):
-    "Read data values from file"
-    for row in (
-            dict(zip(headings, line))
-            for line in read_file_contents(filepath)):
-        for sname in headings[1:]:
-            yield {
-                "ProbeSetId": int(row["ProbeSetID"]),
-                "StrainId": strain_info[sname]["Id"],
-                "DataValue": float(row[sname])
-            }
+    from quality_control.debug import __pk__
+    return {
+        str(row["ProbeSetID"]): tuple({
+            "ProbeSetName": str(row["ProbeSetID"]),
+            "StrainId": strain_info[sname]["Id"],
+            "DataValue": float(row[sname])
+        } for sname in headings[1:])
+        for row in
+        (dict(zip(headings, line)) for line in read_file_contents(filepath))
+    }
 
 def read_probesets(filepath, headings):
     """Read the ProbeSet names."""
     for row in (dict(zip(headings, line))
                 for line in read_file_contents(filepath)):
-        yield {"Name": int(row["ProbeSetID"])}
+        yield {"Name": str(row["ProbeSetID"])}
 
 def last_data_id(dbconn: mdb.Connection) -> int:
     "Get the last id from the database"
@@ -138,21 +138,50 @@ def __format_query__(query, params):
         "INSERT INTO ", "INSERT INTO\n\t")
     return f"{insert_str}\nVALUES\n\t{values_str};"
 
-def __xref_params__(
-        dbconn: mdb.Connection, means: tuple[dict, ...]) -> tuple[dict, ...]:
-    """Process params for cross-reference table."""
-    xref_names = tuple({mean["ProbeSetId"] for mean in means})
+def insert_probesets(filepath: str,
+                     dbconn: mdb.Connection,
+                     platform_id: int,
+                     headings: tuple[str, ...],
+                     session_rand_str: str) -> tuple[str, ...]:
+    probeset_query = (
+        "INSERT INTO ProbeSet(ChipId, Name) "
+        "VALUES (%(ChipId)s, %(Name)s) ")
+    the_probesets = ({
+        **row,
+        "Name": f"{row['Name']}{session_rand_str}",
+        "ChipId": platform_id
+    } for row in read_probesets(filepath, headings))
+    probeset_names = tuple()
     with dbconn.cursor(cursorclass=DictCursor) as cursor:
-        params_str = ", ".join(["%s"] * len(xref_names))
+        while True:
+            probeset_params = tuple(take(the_probesets, 10000))
+            if not bool(probeset_params):
+                break
+            print(__format_query__(probeset_query, probeset_params))
+            print()
+            cursor.executemany(probeset_query, probeset_params)
+            probeset_names = probeset_names + tuple(
+                name[0:name.index("::RAND_")] for name in (
+                    row["Name"] for row in probeset_params))
+
+    return probeset_names
+
+def probeset_ids(dbconn: mdb.Connection,
+                 chip_id: int,
+                 probeset_names: tuple[str, ...]) -> Iterator[tuple[str, int]]:
+    """Fetch the IDs of the probesets with the given names."""
+    with dbconn.cursor() as cursor:
+        params_str = ", ".join(["%s"] * len(probeset_names))
         cursor.execute(
-            f"SELECT Name, Id FROM ProbeSet WHERE Name IN ({params_str})",
-            xref_names)
-        ids = {row["Name"]: row["Id"] for row in cursor.fetchall()}
-        return tuple({
-            **mean,
-            "ProbeSetName": mean["ProbeSetId"],
-            "ProbeSetId": ids[str(mean["ProbeSetId"])]
-        } for mean in means)
+            "SELECT Name, Id FROM ProbeSet "
+            "WHERE ChipId=%s "
+            f"AND Name IN ({params_str})",
+            (chip_id,) + probeset_names)
+        while True:
+            row = cursor.fetchone()
+            if not bool(row):
+                break
+            yield row
 
 def insert_means(# pylint: disable=[too-many-locals, too-many-arguments]
         filepath: str, speciesid: int, platform_id: int, datasetid: int,
@@ -161,40 +190,45 @@ def insert_means(# pylint: disable=[too-many-locals, too-many-arguments]
     headings = read_file_headings(filepath)
     strains = strains_info(dbconn, headings[1:], speciesid)
     check_strains(headings[1:], strains)
-    probeset_query = (
-        "INSERT INTO ProbeSet(ChipId, Name) "
-        "VALUES (%(ChipId)s, %(Name)s) ")
     means_query = (
         "INSERT INTO ProbeSetData "
         "VALUES(%(ProbeSetDataId)s, %(StrainId)s, %(DataValue)s)")
     xref_query = (
         "INSERT INTO ProbeSetXRef(ProbeSetFreezeId, ProbeSetId, DataId) "
         "VALUES(%(ProbeSetFreezeId)s, %(ProbeSetId)s, %(ProbeSetDataId)s)")
+
+    # A random string to avoid over-write chances.
+    #   This is needed because the `ProbeSet` table is defined with
+    #   UNIQUE KEY `ProbeSetId` (`ChipId`,`Name`)
+    #   which means that we cannot have 2 (or more) ProbeSets which share both
+    #   the name and chip_id (platform) at the same time.
+    rand_str = f"::RAND_{random_string()}"
+    pset_ids = {
+        name: pset_id
+        for name, pset_id in probeset_ids(
+                dbconn,
+                platform_id,
+                insert_probesets(
+                    filepath, dbconn, platform_id, headings, rand_str))
+    }
     the_means = ({
-        "ProbeSetFreezeId": datasetid, "ProbeSetDataId": data_id,
-        "ChipId": platform_id, **mean
-    } for data_id, mean in enumerate(
-        read_datavalues(filepath, headings, strains),
-        start=(last_data_id(dbconn)+1)))
-    the_probesets = ({
-        **row,
-        "Name": f"{row['Name']}::RAND_{random_string()}",
-        "ChipId": platform_id
-    } for row in read_probesets(filepath, headings))
+        **mean, "ProbeSetFreezeId": datasetid, "ProbeSetDataId": data_id,
+        "ChipId": platform_id, "ProbeSetId": pset_ids[mean["ProbeSetName"]]
+    } for data_id, mean in enumerate((
+        item for sublist in
+        read_datavalues(filepath, headings, strains).values()
+        for item in sublist),
+                                     start=(last_data_id(dbconn)+1)))
     with dbconn.cursor(cursorclass=DictCursor) as cursor:
         while True:
             means = tuple(take(the_means, 10000))
-            probeset_params = tuple(take(the_probesets, 10000))
             if not bool(means):
                 break
-            print(__format_query__(probeset_query, probeset_params))
-            print()
             print(__format_query__(means_query, means))
             print()
             print(__format_query__(xref_query, means))
-            cursor.executemany(probeset_query, probeset_params)
             cursor.executemany(means_query, means)
-            cursor.executemany(xref_query, __xref_params__(dbconn, means))
+            cursor.executemany(xref_query, means)
     return 0
 
 def insert_se(# pylint: disable = [too-many-arguments]