diff options
-rw-r--r-- | gn3/computations/datasets.py | 90 | ||||
-rw-r--r-- | tests/unit/computations/test_datasets.py | 28 |
2 files changed, 118 insertions, 0 deletions
diff --git a/gn3/computations/datasets.py b/gn3/computations/datasets.py index 28d40a1..533ebdd 100644 --- a/gn3/computations/datasets.py +++ b/gn3/computations/datasets.py @@ -1,11 +1,15 @@ """module contains the code all related to datasets""" import json from unittest import mock +from math import ceil +from collections import defaultdict from typing import Optional from typing import List from dataclasses import dataclass +from MySQLdb import escape_string # type: ignore + import requests from gn3.experimental_db import database_connector @@ -224,3 +228,89 @@ def fetch_dataset_sample_id(samplelist: List, database, species: str) -> dict: results = database_cursor.fetchall() return dict(results) + + +def divide_into_chunks(the_list, number_chunks): + """Divides a list into approximately number_chunks + >>> divide_into_chunks([1, 2, 7, 3, 22, 8, 5, 22, 333], 3) + [[1, 2, 7], [3, 22, 8], [5, 22, 333]]""" + + length = len(the_list) + if length == 0: + return [[]] + + if length <= number_chunks: + number_chunks = length + chunk_size = int(ceil(length/number_chunks)) + chunks = [] + + for counter in range(0, length, chunk_size): + chunks.append(the_list[counter:counter+chunk_size]) + return chunks + + +def mescape(*items) -> List: + """multiple escape for query values""" + + return [escape_string(str(item)).decode('utf8') for item in items] + + +def get_traits_data(sample_ids, database_instance, dataset_name, dataset_type): + """function to fetch trait data""" + # MySQL limits the number of tables that can be used in a join to 61, + # so we break the sample ids into smaller chunks + # Postgres doesn't have that limit, so we can get rid of this after we transition + + trait_data = defaultdict(list) + chunk_size = 50 + number_chunks = int(ceil(len(sample_ids) / chunk_size)) + for sample_ids_step in divide_into_chunks(sample_ids, number_chunks): + if dataset_type == "Publish": + full_dataset_type = "Phenotype" + else: + full_dataset_type = dataset_type + temp = ['T%s.value' % item for item in sample_ids_step] + + if dataset_type: + query = "SELECT {}XRef.Id,".format(escape_string(dataset_type)) + + else: + query = "SELECT {}.Name,".format(escape_string(full_dataset_type)) + + query += ', '.join(temp) + query += ' FROM ({}, {}XRef, {}Freeze) '.format(*mescape(full_dataset_type, + dataset_type, + dataset_type)) + for item in sample_ids_step: + query += """ + left join {}Data as T{} on T{}.Id = {}XRef.DataId + and T{}.StrainId={}\n + """.format(*mescape(dataset_type, item, item, dataset_type, item, item)) + + if dataset_type == "Publish": + query += """ + WHERE {}XRef.InbredSetId = {}Freeze.InbredSetId + and {}Freeze.Name = '{}' + and {}.Id = {}XRef.{}Id + order by {}.Id + """.format(*mescape(dataset_type, dataset_type, dataset_type, dataset_name, + full_dataset_type, dataset_type, dataset_type, dataset_type)) + else: + + query += """ + WHERE {}XRef.{}FreezeId = {}Freeze.Id + and {}Freeze.Name = '{}' + and {}.Id = {}XRef.{}Id + order by {}.Id + """.format(*mescape(dataset_type, dataset_type, dataset_type, dataset_type, + dataset_name, full_dataset_type, dataset_type, + dataset_type, full_dataset_type)) + + results = fetch_from_db_sample_data(query, database_instance) + + trait_name = results[0] + + sample_value_results = results[1:] + + trait_data[trait_name] += (sample_value_results) + return trait_data diff --git a/tests/unit/computations/test_datasets.py b/tests/unit/computations/test_datasets.py index b169ba3..1b37d26 100644 --- a/tests/unit/computations/test_datasets.py +++ b/tests/unit/computations/test_datasets.py @@ -14,6 +14,8 @@ from gn3.computations.datasets import dataset_creator_store from gn3.computations.datasets import dataset_type_getter from gn3.computations.datasets import fetch_dataset_type_from_gn2_api from gn3.computations.datasets import fetch_dataset_sample_id +from gn3.computations.datasets import divide_into_chunks +from gn3.computations.datasets import get_traits_data class TestDatasets(TestCase): @@ -179,3 +181,29 @@ class TestDatasets(TestCase): samplelist=strain_list, database=database_instance, species="mouse") self.assertEqual(results, expected_results) + + @mock.patch("gn3.computations.datasets.fetch_from_db_sample_data") + @mock.patch("gn3.computations.datasets.divide_into_chunks") + def test_get_traits_data(self, mock_divide_into_chunks, mock_fetch_samples): + """test for for function to get data\ + of traits in dataset""" + + expected_results = {'AT_DSAFDS': [ + 12, 14, 13, 23, 12, 14, 13, 23, 12, 14, 13, 23]} + database = mock.Mock() + sample_id = [1, 2, 7, 3, 22, 8] + mock_divide_into_chunks.return_value = [ + [1, 2, 7], [3, 22, 8], [5, 22, 333]] + mock_fetch_samples.return_value = ("AT_DSAFDS", 12, 14, 13, 23) + results = get_traits_data(sample_id, database, "HC_M2", "Publish") + + self.assertEqual(expected_results, dict(results)) + + def test_divide_into_chunks(self): + """test for dividing a list into given number of\ + chunks for example""" + results = divide_into_chunks([1, 2, 7, 3, 22, 8, 5, 22, 333], 3) + + expected_results = [[1, 2, 7], [3, 22, 8], [5, 22, 333]] + + self.assertEqual(results, expected_results) |