aboutsummaryrefslogtreecommitdiff
path: root/scripts/insert_data.py
diff options
context:
space:
mode:
authorFrederick Muriuki Muriithi2022-07-19 15:55:05 +0300
committerFrederick Muriuki Muriithi2022-07-19 15:55:05 +0300
commit3c47afe8092ff9f5bc152320e724c07973c1941a (patch)
tree30f0f54a2ecac03a0bb67dda6b1ea7f0580af005 /scripts/insert_data.py
parentc52570a4069abb6b8953e486adb326392ce6714c (diff)
downloadgn-uploader-3c47afe8092ff9f5bc152320e724c07973c1941a.tar.gz
Save standard error data. Fix linting and typing errors.
Diffstat (limited to 'scripts/insert_data.py')
-rw-r--r--scripts/insert_data.py136
1 files changed, 112 insertions, 24 deletions
diff --git a/scripts/insert_data.py b/scripts/insert_data.py
index 5e596ff..53bf4bd 100644
--- a/scripts/insert_data.py
+++ b/scripts/insert_data.py
@@ -2,28 +2,32 @@
import sys
import argparse
from typing import Tuple
+from functools import reduce
import MySQLdb as mdb
from redis import Redis
from MySQLdb.cursors import DictCursor
from quality_control.parsing import take
-from qc_app.db_utils import database_connection
from quality_control.file_utils import open_file
+from qc_app.db_utils import database_connection
from qc_app.check_connections import check_db, check_redis
def translate_alias(heading):
+ "Translate strain aliases into canonical names"
translations = {"B6": "C57BL/6J", "D2": "DBA/2J"}
return translations.get(heading, heading)
-def read_file_headings(filepath):
+def read_file_headings(filepath) -> Tuple[str, ...]:
"Get the file headings"
with open_file(filepath) as input_file:
- for line_number, line_contents in enumerate(input_file):
- if line_number == 0:
- return tuple(
- translate_alias(heading.strip())
- for heading in line_contents.split("\t"))
+ for line_contents in input_file:
+ headings = tuple(
+ translate_alias(heading.strip())
+ for heading in line_contents.split("\t"))
+ break
+
+ return headings
def read_file_contents(filepath):
"Get the file contents"
@@ -35,16 +39,20 @@ def read_file_contents(filepath):
yield tuple(
field.strip() for field in line_contents.split("\t"))
-def strains_info(dbconn: mdb.Connection, strain_names: Tuple[str, ...]) -> dict:
+def strains_info(
+ dbconn: mdb.Connection, strain_names: Tuple[str, ...],
+ speciesid: int) -> dict:
"Retrieve information for the strains"
with dbconn.cursor(cursorclass=DictCursor) as cursor:
query = (
"SELECT * FROM Strain WHERE Name IN "
- f"({', '.join(['%s']*len(strain_names))})")
- cursor.execute(query, tuple(strain_names))
+ f"({', '.join(['%s']*len(strain_names))}) "
+ "AND SpeciesId = %s")
+ cursor.execute(query, tuple(strain_names) + (speciesid,))
return {strain["Name"]: strain for strain in cursor.fetchall()}
-def read_means(filepath, headings, strain_info):
+def read_datavalues(filepath, headings, strain_info):
+ "Read data values from file"
for row in (
dict(zip(headings, line))
for line in read_file_contents(filepath)):
@@ -52,7 +60,7 @@ def read_means(filepath, headings, strain_info):
yield {
"ProbeSetId": int(row["ProbeSetID"]),
"StrainId": strain_info[sname]["Id"],
- "ProbeSetDataValue": float(row[sname])
+ "DataValue": float(row[sname])
}
def last_data_id(dbconn: mdb.Connection) -> int:
@@ -61,24 +69,65 @@ def last_data_id(dbconn: mdb.Connection) -> int:
cursor.execute("SELECT MAX(Id) FROM ProbeSetData")
return int(cursor.fetchone()[0])
+def check_strains(headings_strains, db_strains):
+ "Check strains in headings exist in database"
+ from_db = tuple(db_strains.keys())
+ not_in_db = tuple(
+ strain for strain in headings_strains if strain not in from_db)
+ if len(not_in_db) == 0:
+ return True
+
+ str_not_in_db = "', '".join(not_in_db)
+ print(
+ (f"ERROR: The strain(s) '{str_not_in_db}' w(as|ere) not found in the "
+ "database."),
+ file=sys.stderr)
+ sys.exit(1)
+
+def annotationinfo(
+ dbconn: mdb.Connection, platformid: int, datasetid: int) -> dict:
+ "Get annotation information from the database."
+ # This is somewhat slow. Look into optimising the behaviour
+ def __organise_annotations__(accm, item):
+ names_dict = (
+ {**accm[0], item["Name"]: item} if bool(item["Name"]) else accm[0])
+ targs_dict = (
+ {**accm[1], item["TargetId"]: item}
+ if bool(item["TargetId"]) else accm[1])
+ return (names_dict, targs_dict)
+
+ query = (
+ "SELECT ProbeSet.Name, ProbeSet.ChipId, ProbeSet.TargetId, "
+ "ProbeSetXRef.DataId, ProbeSetXRef.ProbeSetFreezeId "
+ "FROM ProbeSet INNER JOIN ProbeSetXRef "
+ "ON ProbeSet.Id=ProbeSetXRef.ProbeSetId "
+ "WHERE ProbeSet.ChipId=%s AND ProbeSetXRef.ProbeSetFreezeId=%s")
+ with dbconn.cursor(cursorclass=DictCursor) as cursor:
+ cursor.execute(query, (platformid, datasetid))
+ annot_dicts = reduce(# type: ignore[var-annotated]
+ __organise_annotations__, cursor.fetchall(), ({}, {}))
+ return {**annot_dicts[0], **annot_dicts[1]}
+ return {}
+
def insert_means(
- filepath: str, dataset_id: int, dbconn: mdb.Connection,
- rconn: Redis) -> int:
+ filepath: str, speciesid: int, datasetid: int, dbconn: mdb.Connection,
+ rconn: Redis) -> int: # pylint: disable=[unused-argument]
"Insert the means/averages data into the database"
print("INSERTING MEANS/AVERAGES DATA.")
headings = read_file_headings(filepath)
- strains = strains_info(dbconn, headings[1:])
+ strains = strains_info(dbconn, headings[1:], speciesid)
+ check_strains(headings[1:], strains)
means_query = (
"INSERT INTO ProbeSetData "
- "VALUES(%(ProbeSetDataId)s, %(StrainId)s, %(ProbeSetDataValue)s)")
+ "VALUES(%(ProbeSetDataId)s, %(StrainId)s, %(DataValue)s)")
xref_query = (
"INSERT INTO ProbeSetXRef(ProbeSetFreezeId, ProbeSetId, DataId) "
"VALUES (%(ProbeSetFreezeId)s, %(ProbeSetId)s, %(ProbeSetDataId)s)")
the_means = (
- {"ProbeSetFreezeId": dataset_id, "ProbeSetDataId": data_id, **mean}
+ {"ProbeSetFreezeId": datasetid, "ProbeSetDataId": data_id, **mean}
for data_id, mean in
enumerate(
- read_means(filepath, headings, strains),
+ read_datavalues(filepath, headings, strains),
start=(last_data_id(dbconn)+1)))
with dbconn.cursor(cursorclass=DictCursor) as cursor:
while True:
@@ -92,15 +141,42 @@ def insert_means(
cursor.executemany(xref_query, means)
return 0
-def insert_se(
- filepath: str, dataset_id: int, dbconn: mdb.Connection,
- rconn: Redis) -> int:
+def insert_se(# pylint: disable = [too-many-arguments]
+ filepath: str, speciesid: int, platformid: int, datasetid: int,
+ dbconn: mdb.Connection, rconn: Redis) -> int: # pylint: disable=[unused-argument]
"Insert the standard-error data into the database"
print("INSERTING STANDARD ERROR DATA...")
+ headings = read_file_headings(filepath)
+ strains = strains_info(dbconn, headings[1:], speciesid)
+ check_strains(headings[1:], strains)
+ se_query = (
+ "INSERT INTO ProbeSetSE "
+ "VALUES(%(DataId)s, %(StrainId)s, %(DataValue)s)")
+ annotations = annotationinfo(dbconn, platformid, datasetid)
+ if not bool(annotations):
+ print(
+ (f"ERROR: No annotations found for platform {platformid} and "
+ f"dataset {datasetid}. Quiting!"),
+ file=sys.stderr)
+ return 1
+
+ se_values = (
+ {"DataId": annotations[str(item["ProbeSetId"])]["DataId"], **item}
+ for item in read_datavalues(filepath, headings, strains))
+ with dbconn.cursor(cursorclass=DictCursor) as cursor:
+ while True:
+ serrors = tuple(take(se_values, 1000))
+ if not bool(serrors):
+ break
+ print(
+ f"\nEXECUTING QUERY:\n\t* {se_query}\n\tWITH PARAMETERS:\n\t"
+ f"{serrors}")
+ cursor.executemany(se_query, serrors)
return 0
if __name__ == "__main__":
def cli_args():
+ "Compute the CLI arguments"
parser = argparse.ArgumentParser(
prog="InsertData", description=(
"Script to insert data from an 'averages' file into the "
@@ -111,10 +187,13 @@ if __name__ == "__main__":
parser.add_argument(
"filepath", help="path to the file with the 'averages' data.")
parser.add_argument(
- "species_id", help="Identifier for the species in the database.",
+ "speciesid", help="Identifier for the species in the database.",
type=int)
parser.add_argument(
- "dataset_id", help="Identifier for the dataset in the database.",
+ "platformid", help="Identifier for the platform in the database.",
+ type=int)
+ parser.add_argument(
+ "datasetid", help="Identifier for the dataset in the database.",
type=int)
parser.add_argument(
"database_uri",
@@ -134,12 +213,21 @@ if __name__ == "__main__":
"standard-error": insert_se
}
+ extract_args = {
+ "average": lambda args, dbconn, rconn: (
+ args.filepath, args.speciesid, args.datasetid, dbconn, rconn),
+ "standard-error": lambda args, dbconn, rconn: (
+ args.filepath, args.speciesid, args.platformid, args.datasetid,
+ dbconn, rconn),
+ }
+
def main():
+ "Main entry point"
args = cli_args()
with Redis.from_url(args.redisuri, decode_responses=True) as rconn:
with database_connection(args.database_uri) as dbconn:
return insert_fns[args.filetype](
- args.filepath, args.dataset_id, dbconn, rconn)
+ *extract_args[args.filetype](args, dbconn, rconn))
return 2