about summary refs log tree commit diff
diff options
context:
space:
mode:
authorAlexander Kabui2021-04-03 13:12:59 +0300
committerAlexander Kabui2021-04-03 13:12:59 +0300
commitd2e24157130ea28a8ac5e7a4511074bb82b6d634 (patch)
tree6328fc0a17e2df8d3095224d54313f69b7d01601
parente4cf6ac08a961f8d647f46a6d984d8d66d9f83ae (diff)
downloadgenenetwork3-d2e24157130ea28a8ac5e7a4511074bb82b6d634.tar.gz
add tests for getting trait data
-rw-r--r--gn3/computations/datasets.py90
-rw-r--r--tests/unit/computations/test_datasets.py28
2 files changed, 118 insertions, 0 deletions
diff --git a/gn3/computations/datasets.py b/gn3/computations/datasets.py
index 28d40a1..533ebdd 100644
--- a/gn3/computations/datasets.py
+++ b/gn3/computations/datasets.py
@@ -1,11 +1,15 @@
 """module contains the code all related to datasets"""
 import json
 from unittest import mock
+from math import ceil
+from collections import defaultdict
 
 from typing import Optional
 from typing import List
 
 from dataclasses import dataclass
+from MySQLdb import escape_string  # type: ignore
+
 import requests
 
 from gn3.experimental_db import database_connector
@@ -224,3 +228,89 @@ def fetch_dataset_sample_id(samplelist: List, database, species: str) -> dict:
     results = database_cursor.fetchall()
 
     return dict(results)
+
+
+def divide_into_chunks(the_list, number_chunks):
+    """Divides a list into approximately number_chunks
+    >>> divide_into_chunks([1, 2, 7, 3, 22, 8, 5, 22, 333], 3)
+    [[1, 2, 7], [3, 22, 8], [5, 22, 333]]"""
+
+    length = len(the_list)
+    if length == 0:
+        return [[]]
+
+    if length <= number_chunks:
+        number_chunks = length
+    chunk_size = int(ceil(length/number_chunks))
+    chunks = []
+
+    for counter in range(0, length, chunk_size):
+        chunks.append(the_list[counter:counter+chunk_size])
+    return chunks
+
+
+def mescape(*items) -> List:
+    """multiple escape for query values"""
+
+    return [escape_string(str(item)).decode('utf8') for item in items]
+
+
+def get_traits_data(sample_ids, database_instance, dataset_name, dataset_type):
+    """function to fetch trait data"""
+    # MySQL limits the number of tables that can be used in a join to 61,
+    # so we break the sample ids into smaller chunks
+    # Postgres doesn't have that limit, so we can get rid of this after we transition
+
+    trait_data = defaultdict(list)
+    chunk_size = 50
+    number_chunks = int(ceil(len(sample_ids) / chunk_size))
+    for sample_ids_step in divide_into_chunks(sample_ids, number_chunks):
+        if dataset_type == "Publish":
+            full_dataset_type = "Phenotype"
+        else:
+            full_dataset_type = dataset_type
+        temp = ['T%s.value' % item for item in sample_ids_step]
+
+        if dataset_type:
+            query = "SELECT {}XRef.Id,".format(escape_string(dataset_type))
+
+        else:
+            query = "SELECT {}.Name,".format(escape_string(full_dataset_type))
+
+        query += ', '.join(temp)
+        query += ' FROM ({}, {}XRef, {}Freeze) '.format(*mescape(full_dataset_type,
+                                                                 dataset_type,
+                                                                 dataset_type))
+        for item in sample_ids_step:
+            query += """
+                    left join {}Data as T{} on T{}.Id = {}XRef.DataId
+                    and T{}.StrainId={}\n
+                    """.format(*mescape(dataset_type, item, item, dataset_type, item, item))
+
+        if dataset_type == "Publish":
+            query += """
+            WHERE {}XRef.InbredSetId = {}Freeze.InbredSetId
+            and {}Freeze.Name = '{}'
+            and {}.Id = {}XRef.{}Id
+            order by {}.Id
+            """.format(*mescape(dataset_type, dataset_type, dataset_type, dataset_name,
+                                full_dataset_type, dataset_type, dataset_type, dataset_type))
+        else:
+
+            query += """
+            WHERE {}XRef.{}FreezeId = {}Freeze.Id
+            and {}Freeze.Name = '{}'
+            and {}.Id = {}XRef.{}Id
+            order by {}.Id
+            """.format(*mescape(dataset_type, dataset_type, dataset_type, dataset_type,
+                                dataset_name, full_dataset_type, dataset_type,
+                                dataset_type, full_dataset_type))
+
+        results = fetch_from_db_sample_data(query, database_instance)
+
+        trait_name = results[0]
+
+        sample_value_results = results[1:]
+
+        trait_data[trait_name] += (sample_value_results)
+    return trait_data
diff --git a/tests/unit/computations/test_datasets.py b/tests/unit/computations/test_datasets.py
index b169ba3..1b37d26 100644
--- a/tests/unit/computations/test_datasets.py
+++ b/tests/unit/computations/test_datasets.py
@@ -14,6 +14,8 @@ from gn3.computations.datasets import dataset_creator_store
 from gn3.computations.datasets import dataset_type_getter
 from gn3.computations.datasets import fetch_dataset_type_from_gn2_api
 from gn3.computations.datasets import fetch_dataset_sample_id
+from gn3.computations.datasets import divide_into_chunks
+from gn3.computations.datasets import get_traits_data
 
 
 class TestDatasets(TestCase):
@@ -179,3 +181,29 @@ class TestDatasets(TestCase):
             samplelist=strain_list, database=database_instance, species="mouse")
 
         self.assertEqual(results, expected_results)
+
+    @mock.patch("gn3.computations.datasets.fetch_from_db_sample_data")
+    @mock.patch("gn3.computations.datasets.divide_into_chunks")
+    def test_get_traits_data(self, mock_divide_into_chunks, mock_fetch_samples):
+        """test for for function to get data\
+        of traits in dataset"""
+
+        expected_results = {'AT_DSAFDS': [
+            12, 14, 13, 23, 12, 14, 13, 23, 12, 14, 13, 23]}
+        database = mock.Mock()
+        sample_id = [1, 2, 7, 3, 22, 8]
+        mock_divide_into_chunks.return_value = [
+            [1, 2, 7], [3, 22, 8], [5, 22, 333]]
+        mock_fetch_samples.return_value = ("AT_DSAFDS", 12, 14, 13, 23)
+        results = get_traits_data(sample_id, database, "HC_M2", "Publish")
+
+        self.assertEqual(expected_results, dict(results))
+
+    def test_divide_into_chunks(self):
+        """test for dividing a list into given number of\
+        chunks for example"""
+        results = divide_into_chunks([1, 2, 7, 3, 22, 8, 5, 22, 333], 3)
+
+        expected_results = [[1, 2, 7], [3, 22, 8], [5, 22, 333]]
+
+        self.assertEqual(results, expected_results)