about summary refs log tree commit diff
path: root/gn3/db/traits.py
diff options
context:
space:
mode:
authorzsloan2021-10-12 20:56:31 +0000
committerzsloan2021-10-12 20:56:31 +0000
commit6e211182354fb4d6941e3a44ec1ec9d378b0e4ef (patch)
tree60d9aaf382eefbb47cdbab9c74d98481cf0983de /gn3/db/traits.py
parentb815236123ff8e144bd84f349357a1852df95651 (diff)
parent77c274b79c3ec01de60e90db3299763cb58f715b (diff)
downloadgenenetwork3-6e211182354fb4d6941e3a44ec1ec9d378b0e4ef.tar.gz
Merge branch 'main' of https://github.com/genenetwork/genenetwork3 into bug/fix_rqtl_covariates
Diffstat (limited to 'gn3/db/traits.py')
-rw-r--r--gn3/db/traits.py78
1 files changed, 43 insertions, 35 deletions
diff --git a/gn3/db/traits.py b/gn3/db/traits.py
index 1031e44..f2673c8 100644
--- a/gn3/db/traits.py
+++ b/gn3/db/traits.py
@@ -1,5 +1,8 @@
 """This class contains functions relating to trait data manipulation"""
+import os
 from typing import Any, Dict, Union, Sequence
+from gn3.settings import TMPDIR
+from gn3.random import random_string
 from gn3.function_helpers import compose
 from gn3.db.datasets import retrieve_trait_dataset
 
@@ -43,7 +46,7 @@ def update_sample_data(conn: Any,
                        count: Union[int, str]):
     """Given the right parameters, update sample-data from the relevant
     table."""
-    # pylint: disable=[R0913, R0914]
+    # pylint: disable=[R0913, R0914, C0103]
     STRAIN_ID_SQL: str = "UPDATE Strain SET Name = %s WHERE Id = %s"
     PUBLISH_DATA_SQL: str = ("UPDATE PublishData SET value = %s "
                              "WHERE StrainId = %s AND Id = %s")
@@ -60,22 +63,22 @@ def update_sample_data(conn: Any,
     with conn.cursor() as cursor:
         # Update the Strains table
         cursor.execute(STRAIN_ID_SQL, (strain_name, strain_id))
-        updated_strains: int = cursor.rowcount
+        updated_strains = cursor.rowcount
         # Update the PublishData table
         cursor.execute(PUBLISH_DATA_SQL,
                        (None if value == "x" else value,
                         strain_id, publish_data_id))
-        updated_published_data: int = cursor.rowcount
+        updated_published_data = cursor.rowcount
         # Update the PublishSE table
         cursor.execute(PUBLISH_SE_SQL,
                        (None if error == "x" else error,
                         strain_id, publish_data_id))
-        updated_se_data: int = cursor.rowcount
+        updated_se_data = cursor.rowcount
         # Update the NStrain table
         cursor.execute(N_STRAIN_SQL,
                        (None if count == "x" else count,
                         strain_id, publish_data_id))
-        updated_n_strains: int = cursor.rowcount
+        updated_n_strains = cursor.rowcount
     return (updated_strains, updated_published_data,
             updated_se_data, updated_n_strains)
 
@@ -223,7 +226,7 @@ def set_homologene_id_field_probeset(trait_info, conn):
     """
     query = (
         "SELECT HomologeneId FROM Homologene, Species, InbredSet"
-        " WHERE Homologene.GeneId = %(geneid)s AND InbredSet.Name = %(riset)s"
+        " WHERE Homologene.GeneId = %(geneid)s AND InbredSet.Name = %(group)s"
         " AND InbredSet.SpeciesId = Species.Id AND"
         " Species.TaxonomyId = Homologene.TaxonomyId")
     with conn.cursor() as cursor:
@@ -231,7 +234,7 @@ def set_homologene_id_field_probeset(trait_info, conn):
             query,
             {
                 k:v for k, v in trait_info.items()
-                if k in ["geneid", "riset"]
+                if k in ["geneid", "group"]
             })
         res = cursor.fetchone()
         if res:
@@ -419,7 +422,7 @@ def retrieve_trait_info(
     if trait_info["haveinfo"]:
         return {
             **trait_post_processing_functions_table[trait_dataset_type](
-                {**trait_info, "riset": trait_dataset["riset"]}),
+                {**trait_info, "group": trait_dataset["group"]}),
             "db": {**trait["db"], **trait_dataset}
         }
     return trait_info
@@ -442,18 +445,18 @@ def retrieve_temp_trait_data(trait_info: dict, conn: Any):
             query,
             {"trait_name": trait_info["trait_name"]})
         return [dict(zip(
-            ["strain_name", "value", "se_error", "nstrain", "id"], row))
+            ["sample_name", "value", "se_error", "nstrain", "id"], row))
                 for row in cursor.fetchall()]
     return []
 
-def retrieve_species_id(riset, conn: Any):
+def retrieve_species_id(group, conn: Any):
     """
-    Retrieve a species id given the RISet value
+    Retrieve a species id given the Group value
     """
     with conn.cursor as cursor:
         cursor.execute(
-            "SELECT SpeciesId from InbredSet WHERE Name = %(riset)s",
-            {"riset": riset})
+            "SELECT SpeciesId from InbredSet WHERE Name = %(group)s",
+            {"group": group})
         return cursor.fetchone()[0]
     return None
 
@@ -479,9 +482,9 @@ def retrieve_geno_trait_data(trait_info: Dict, conn: Any):
             {"trait_name": trait_info["trait_name"],
              "dataset_name": trait_info["db"]["dataset_name"],
              "species_id": retrieve_species_id(
-                 trait_info["db"]["riset"], conn)})
+                 trait_info["db"]["group"], conn)})
         return [dict(zip(
-            ["strain_name", "value", "se_error", "id"], row))
+            ["sample_name", "value", "se_error", "id"], row))
                 for row in cursor.fetchall()]
     return []
 
@@ -512,7 +515,7 @@ def retrieve_publish_trait_data(trait_info: Dict, conn: Any):
             {"trait_name": trait_info["trait_name"],
              "dataset_id": trait_info["db"]["dataset_id"]})
         return [dict(zip(
-            ["strain_name", "value", "se_error", "nstrain", "id"], row))
+            ["sample_name", "value", "se_error", "nstrain", "id"], row))
                 for row in cursor.fetchall()]
     return []
 
@@ -545,7 +548,7 @@ def retrieve_cellid_trait_data(trait_info: Dict, conn: Any):
              "trait_name": trait_info["trait_name"],
              "dataset_id": trait_info["db"]["dataset_id"]})
         return [dict(zip(
-            ["strain_name", "value", "se_error", "id"], row))
+            ["sample_name", "value", "se_error", "id"], row))
                 for row in cursor.fetchall()]
     return []
 
@@ -574,29 +577,29 @@ def retrieve_probeset_trait_data(trait_info: Dict, conn: Any):
             {"trait_name": trait_info["trait_name"],
              "dataset_name": trait_info["db"]["dataset_name"]})
         return [dict(zip(
-            ["strain_name", "value", "se_error", "id"], row))
+            ["sample_name", "value", "se_error", "id"], row))
                 for row in cursor.fetchall()]
     return []
 
-def with_strainlist_data_setup(strainlist: Sequence[str]):
+def with_samplelist_data_setup(samplelist: Sequence[str]):
     """
-    Build function that computes the trait data from provided list of strains.
+    Build function that computes the trait data from provided list of samples.
 
     PARAMETERS
-    strainlist: (list)
-      A list of strain names
+    samplelist: (list)
+      A list of sample names
 
     RETURNS:
       Returns a function that given some data from the database, computes the
-      strain's value, variance and ndata values, only if the strain is present
-      in the provided `strainlist` variable.
+      sample's value, variance and ndata values, only if the sample is present
+      in the provided `samplelist` variable.
     """
     def setup_fn(tdata):
-        if tdata["strain_name"] in strainlist:
+        if tdata["sample_name"] in samplelist:
             val = tdata["value"]
             if val is not None:
                 return {
-                    "strain_name": tdata["strain_name"],
+                    "sample_name": tdata["sample_name"],
                     "value": val,
                     "variance": tdata["se_error"],
                     "ndata": tdata.get("nstrain", None)
@@ -604,19 +607,19 @@ def with_strainlist_data_setup(strainlist: Sequence[str]):
         return None
     return setup_fn
 
-def without_strainlist_data_setup():
+def without_samplelist_data_setup():
     """
     Build function that computes the trait data.
 
     RETURNS:
       Returns a function that given some data from the database, computes the
-      strain's value, variance and ndata values.
+      sample's value, variance and ndata values.
     """
     def setup_fn(tdata):
         val = tdata["value"]
         if val is not None:
             return {
-                "strain_name": tdata["strain_name"],
+                "sample_name": tdata["sample_name"],
                 "value": val,
                 "variance": tdata["se_error"],
                 "ndata": tdata.get("nstrain", None)
@@ -624,7 +627,7 @@ def without_strainlist_data_setup():
         return None
     return setup_fn
 
-def retrieve_trait_data(trait: dict, conn: Any, strainlist: Sequence[str] = tuple()):
+def retrieve_trait_data(trait: dict, conn: Any, samplelist: Sequence[str] = tuple()):
     """
     Retrieve trait data
 
@@ -647,22 +650,27 @@ def retrieve_trait_data(trait: dict, conn: Any, strainlist: Sequence[str] = tupl
     if results:
         # do something with mysqlid
         mysqlid = results[0]["id"]
-        if strainlist:
+        if samplelist:
             data = [
                 item for item in
-                map(with_strainlist_data_setup(strainlist), results)
+                map(with_samplelist_data_setup(samplelist), results)
                 if item is not None]
         else:
             data = [
                 item for item in
-                map(without_strainlist_data_setup(), results)
+                map(without_samplelist_data_setup(), results)
                 if item is not None]
 
         return {
             "mysqlid": mysqlid,
             "data": dict(map(
                 lambda x: (
-                    x["strain_name"],
-                    {k:v for k, v in x.items() if x != "strain_name"}),
+                    x["sample_name"],
+                    {k:v for k, v in x.items() if x != "sample_name"}),
                 data))}
     return {}
+
+def generate_traits_filename(base_path: str = TMPDIR):
+    """Generate a unique filename for use with generated traits files."""
+    return "{}/traits_test_file_{}.txt".format(
+        os.path.abspath(base_path), random_string(10))