From 8f022ae1a31224d0526443ad9779f30206b4a770 Mon Sep 17 00:00:00 2001
From: Muriithi Frederick Muriuki
Date: Mon, 9 Aug 2021 14:22:54 +0300
Subject: Retrieve the trait data

Issue:
https://github.com/genenetwork/gn-gemtext-threads/blob/main/topics/gn1-migration-to-gn2/clustering.gmi

* Add functions to retrieve the `value`, `variance`, and `ndata` values for
  any given trait.
---
 gn3/db/traits.py | 245 ++++++++++++++++++++++++++++++++++++++++++++++++++++++-
 1 file changed, 244 insertions(+), 1 deletion(-)

diff --git a/gn3/db/traits.py b/gn3/db/traits.py
index be46437..a740352 100644
--- a/gn3/db/traits.py
+++ b/gn3/db/traits.py
@@ -1,5 +1,5 @@
 """This class contains functions relating to trait data manipulation"""
-from typing import Any, Dict, Union
+from typing import Any, Dict, Union, Sequence
 from gn3.function_helpers import compose
 from gn3.db.datasets import retrieve_trait_dataset
 
@@ -408,3 +408,246 @@ def retrieve_trait_info(
             "riset": trait_dataset["riset"]
         }
     return trait_info
+
+def retrieve_temp_trait_data(trait_info: dict, conn: Any):
+    """
+    Retrieve trait data for `Temp` traits.
+    """
+    query = (
+        "SELECT "
+        "Strain.Name, TempData.value, TempData.SE, TempData.NStrain, "
+        "TempData.Id "
+        "FROM TempData, Temp, Strain "
+        "WHERE TempData.StrainId = Strain.Id "
+        "AND TempData.Id = Temp.DataId "
+        "AND Temp.name = %(trait_name)s "
+        "ORDER BY Strain.Name")
+    with conn.cursor() as cursor:
+        cursor.execute(
+            query,
+            {"trait_name": trait_info["trait_name"]})
+        return [dict(zip(
+            ["strain_name", "value", "se_error", "nstrain", "id"], row))
+                for row in cursor.fetchall()]
+    return []
+
+def retrieve_species_id(riset, conn: Any):
+    """
+    Retrieve a species id given the RISet value
+    """
+    with conn.cursor as cursor:
+        cursor.execute(
+            "SELECT SpeciesId from InbredSet WHERE Name = %(riset)s",
+            {"riset": riset})
+        return cursor.fetchone()[0]
+    return None
+
+def retrieve_geno_trait_data(trait_info: Dict, conn: Any):
+    """
+    Retrieve trait data for `Geno` traits.
+    """
+    query = (
+        "SELECT Strain.Name, GenoData.value, GenoSE.error, GenoData.Id "
+        "FROM (GenoData, GenoFreeze, Strain, Geno, GenoXRef) "
+        "LEFT JOIN GenoSE ON "
+        "(GenoSE.DataId = GenoData.Id AND GenoSE.StrainId = GenoData.StrainId) "
+        "WHERE Geno.SpeciesId = %(species_id)s "
+        "AND Geno.Name = %(trait_name)s AND GenoXRef.GenoId = Geno.Id "
+        "AND GenoXRef.GenoFreezeId = GenoFreeze.Id "
+        "AND GenoFreeze.Name = %(dataset_name)s "
+        "AND GenoXRef.DataId = GenoData.Id "
+        "AND GenoData.StrainId = Strain.Id "
+        "ORDER BY Strain.Name")
+    with conn.cursor() as cursor:
+        cursor.execute(
+            query,
+            {"trait_name": trait_info["trait_name"],
+             "dataset_name": trait_info["db"]["dataset_name"],
+             "species_id": retrieve_species_id(
+                 trait_info["db"]["riset"], conn)})
+        return [dict(zip(
+            ["strain_name", "value", "se_error", "id"], row))
+                for row in cursor.fetchall()]
+    return []
+
+def retrieve_publish_trait_data(trait_info: Dict, conn: Any):
+    """
+    Retrieve trait data for `Publish` traits.
+    """
+    query = (
+        "SELECT "
+        "Strain.Name, PublishData.value, PublishSE.error, NStrain.count, "
+        "PublishData.Id "
+        "FROM (PublishData, Strain, PublishXRef, PublishFreeze) "
+        "LEFT JOIN PublishSE ON "
+        "(PublishSE.DataId = PublishData.Id "
+        "AND PublishSE.StrainId = PublishData.StrainId) "
+        "LEFT JOIN NStrain ON "
+        "(NStrain.DataId = PublishData.Id "
+        "AND NStrain.StrainId = PublishData.StrainId) "
+        "WHERE PublishXRef.InbredSetId = PublishFreeze.InbredSetId "
+        "AND PublishData.Id = PublishXRef.DataId "
+        "AND PublishXRef.Id = %(trait_name)s "
+        "AND PublishFreeze.Id = %(dataset_id)s "
+        "AND PublishData.StrainId = Strain.Id "
+        "ORDER BY Strain.Name")
+    with conn.cursor() as cursor:
+        cursor.execute(
+            query,
+            {"trait_name": trait_info["trait_name"],
+             "dataset_id": trait_info["db"]["dataset_id"]})
+        return [dict(zip(
+            ["strain_name", "value", "se_error", "nstrain", "id"], row))
+                for row in cursor.fetchall()]
+    return []
+
+def retrieve_cellid_trait_data(trait_info: Dict, conn: Any):
+    """
+    Retrieve trait data for `Probe Data` types.
+    """
+    query = (
+        "SELECT "
+        "Strain.Name, ProbeData.value, ProbeSE.error, ProbeData.Id "
+        "FROM (ProbeData, ProbeFreeze, ProbeSetFreeze, ProbeXRef, Strain,"
+        " Probe, ProbeSet) "
+        "LEFT JOIN ProbeSE ON "
+        "(ProbeSE.DataId = ProbeData.Id "
+        " AND ProbeSE.StrainId = ProbeData.StrainId) "
+        "WHERE Probe.Name = %(cellid)s "
+        "AND ProbeSet.Name = %(trait_name)s "
+        "AND Probe.ProbeSetId = ProbeSet.Id "
+        "AND ProbeXRef.ProbeId = Probe.Id "
+        "AND ProbeXRef.ProbeFreezeId = ProbeFreeze.Id "
+        "AND ProbeSetFreeze.ProbeFreezeId = ProbeFreeze.Id "
+        "AND ProbeSetFreeze.Name = %(dataset_name)s "
+        "AND ProbeXRef.DataId = ProbeData.Id "
+        "AND ProbeData.StrainId = Strain.Id "
+        "ORDER BY Strain.Name")
+    with conn.cursor() as cursor:
+        cursor.execute(
+            query,
+            {"cellid": trait_info["cellid"],
+             "trait_name": trait_info["trait_name"],
+             "dataset_id": trait_info["db"]["dataset_id"]})
+        return [dict(zip(
+            ["strain_name", "value", "se_error", "id"], row))
+                for row in cursor.fetchall()]
+    return []
+
+def retrieve_probeset_trait_data(trait_info: Dict, conn: Any):
+    """
+    Retrieve trait data for `ProbeSet` traits.
+    """
+    query = (
+        "SELECT Strain.Name, ProbeSetData.value, ProbeSetSE.error, "
+        "ProbeSetData.Id "
+        "FROM (ProbeSetData, ProbeSetFreeze, Strain, ProbeSet, ProbeSetXRef) "
+        "LEFT JOIN ProbeSetSE ON "
+        "(ProbeSetSE.DataId = ProbeSetData.Id "
+        "AND ProbeSetSE.StrainId = ProbeSetData.StrainId) "
+        "WHERE ProbeSet.Name = %(trait_name)s "
+        "AND ProbeSetXRef.ProbeSetId = ProbeSet.Id "
+        "AND ProbeSetXRef.ProbeSetFreezeId = ProbeSetFreeze.Id "
+        "AND ProbeSetFreeze.Name = %(dataset_name)s "
+        "AND ProbeSetXRef.DataId = ProbeSetData.Id "
+        "AND ProbeSetData.StrainId = Strain.Id "
+        "ORDER BY Strain.Name")
+
+    with conn.cursor() as cursor:
+        cursor.execute(
+            query,
+            {"trait_name": trait_info["trait_name"],
+             "dataset_name": trait_info["db"]["dataset_name"]})
+        return [dict(zip(
+            ["strain_name", "value", "se_error", "id"], row))
+                for row in cursor.fetchall()]
+    return []
+
+def with_strainlist_data_setup(strainlist: Sequence[str]):
+    """
+    Build function that computes the trait data from provided list of strains.
+
+    PARAMETERS
+    strainlist: (list)
+      A list of strain names
+
+    RETURNS:
+      Returns a function that given some data from the database, computes the
+      strain's value, variance and ndata values, only if the strain is present
+      in the provided `strainlist` variable.
+    """
+    def setup_fn(tdata):
+        if tdata["strain_name"] in strainlist:
+            val = tdata["value"]
+            if val is not None:
+                return {
+                    "strain_name": tdata["strain_name"],
+                    "value": val,
+                    "variance": tdata["se_error"],
+                    "ndata": tdata.get("nstrain", None)
+                }
+        return None
+    return setup_fn
+
+def without_strainlist_data_setup():
+    """
+    Build function that computes the trait data.
+
+    RETURNS:
+      Returns a function that given some data from the database, computes the
+      strain's value, variance and ndata values.
+    """
+    def setup_fn(tdata):
+        val = tdata["value"]
+        if val is not None:
+            return {
+                "strain_name": tdata["strain_name"],
+                "value": val,
+                "variance": tdata["se_error"],
+                "ndata": tdata.get("nstrain", None)
+            }
+        return None
+    return setup_fn
+
+def retrieve_trait_data(trait: dict, conn: Any, strainlist: Sequence[str] = tuple()):
+    """
+    Retrieve trait data
+
+    DESCRIPTION
+    Retrieve trait data as is done in
+    https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/base/webqtlTrait.py#L258-L386
+    """
+    # I do not like this section, but it retains the flow in the old codebase
+    if trait["db"]["dataset_type"] == "Temp":
+        results = retrieve_temp_trait_data(trait, conn)
+    elif trait["db"]["dataset_type"] == "Publish":
+        results = retrieve_publish_trait_data(trait, conn)
+    elif trait["cellid"]:
+        results = retrieve_cellid_trait_data(trait, conn)
+    elif trait["db"]["dataset_type"] == "ProbeSet":
+        results = retrieve_probeset_trait_data(trait, conn)
+    else:
+        results = retrieve_geno_trait_data(trait, conn)
+
+    if results:
+        # do something with mysqlid
+        mysqlid = results[0]["id"]
+        if strainlist:
+            data = [
+                item for item in
+                map(with_strainlist_data_setup(strainlist), results)
+                if item is not None]
+        else:
+            data = [
+                item for item in
+                map(without_strainlist_data_setup(), results)
+                if item is not None]
+
+        return {
+            "mysqlid": mysqlid,
+            "data": dict(map(
+                lambda x: (
+                    x["strain_name"],
+                    {k:v for k, v in x.items() if x != "strain_name"}),
+                data))}
+    return {}
-- 
cgit v1.2.3