diff options
author | Frederick Muriuki Muriithi | 2021-11-12 04:07:42 +0300 |
---|---|---|
committer | Frederick Muriuki Muriithi | 2021-11-12 04:07:42 +0300 |
commit | d1617bd8af25bf7c7777be7a634559fd31b491ad (patch) | |
tree | 9565d4fcca4fa553dc21a9543353a8b29357ab4a /gn3 | |
parent | d895eea22ab908c11f4ebb77f99518367879b1f6 (diff) | |
parent | 85405fe6875358d3bb98b03621271d5909dd393f (diff) | |
download | genenetwork3-d1617bd8af25bf7c7777be7a634559fd31b491ad.tar.gz |
Merge branch 'main' of github.com:genenetwork/genenetwork3 into partial-correlations
Diffstat (limited to 'gn3')
-rw-r--r-- | gn3/authentication.py | 67 | ||||
-rw-r--r-- | gn3/computations/correlations.py | 30 | ||||
-rw-r--r-- | gn3/computations/correlations2.py | 36 | ||||
-rw-r--r-- | gn3/heatmaps.py | 6 |
4 files changed, 86 insertions, 53 deletions
diff --git a/gn3/authentication.py b/gn3/authentication.py index 7bc7b77..6719631 100644 --- a/gn3/authentication.py +++ b/gn3/authentication.py @@ -1,9 +1,12 @@ """Methods for interacting with gn-proxy.""" import functools import json +import uuid +import datetime + from urllib.parse import urljoin from enum import Enum, unique -from typing import Dict, Union +from typing import Dict, List, Optional, Union from redis import Redis import requests @@ -95,3 +98,65 @@ def get_highest_user_access_role( for key, value in json.loads(response.content).items(): access_role[key] = max(map(lambda role: role_mapping[role], value)) return access_role + + +def get_groups_by_user_uid(user_uid: str, conn: Redis) -> Dict: + """Given a user uid, get the groups in which they are a member or admin of. + + Args: + - user_uid: A user's unique id + - conn: A redis connection + + Returns: + - A dictionary containing the list of groups the user is part of e.g.: + {"admin": [], "member": ["ce0dddd1-6c50-4587-9eec-6c687a54ad86"]} + """ + admin = [] + member = [] + for uuid, group_info in conn.hgetall("groups").items(): + group_info = json.loads(group_info) + group_info["uuid"] = uuid + if user_uid in group_info.get('admins'): + admin.append(group_info) + if user_uid in group_info.get('members'): + member.append(group_info) + return { + "admin": admin, + "member": member, + } + + +def get_user_info_by_key(key: str, value: str, + conn: Redis) -> Optional[Dict]: + """Given a key, get a user's information if value is matched""" + if key != "user_id": + for uuid, user_info in conn.hgetall("users").items(): + user_info = json.loads(user_info) + if (key in user_info and + user_info.get(key) == value): + user_info["user_id"] = uuid + return user_info + elif key == "user_id": + if user_info := conn.hget("users", value): + user_info = json.loads(user_info) + user_info["user_id"] = value + return user_info + return None + + +def create_group(conn: Redis, group_name: Optional[str], + admin_user_uids: List = [], + member_user_uids: List = []) -> Optional[Dict]: + """Create a group given the group name, members and admins of that group.""" + if group_name and bool(admin_user_uids + member_user_uids): + timestamp = datetime.datetime.utcnow().strftime('%b %d %Y %I:%M%p') + group = { + "id": (group_id := str(uuid.uuid4())), + "admins": admin_user_uids, + "members": member_user_uids, + "name": group_name, + "created_timestamp": timestamp, + "changed_timestamp": timestamp, + } + conn.hset("groups", group_id, json.dumps(group)) + return group diff --git a/gn3/computations/correlations.py b/gn3/computations/correlations.py index c930df0..c5c56db 100644 --- a/gn3/computations/correlations.py +++ b/gn3/computations/correlations.py @@ -1,6 +1,7 @@ """module contains code for correlations""" import math import multiprocessing +from contextlib import closing from typing import List from typing import Tuple @@ -49,13 +50,9 @@ def normalize_values(a_values: List, ([2.3, 4.1, 5], [3.4, 6.2, 4.1], 3) """ - a_new = [] - b_new = [] for a_val, b_val in zip(a_values, b_values): if (a_val and b_val is not None): - a_new.append(a_val) - b_new.append(b_val) - return a_new, b_new, len(a_new) + yield a_val, b_val def compute_corr_coeff_p_value(primary_values: List, target_values: List, @@ -81,8 +78,10 @@ def compute_sample_r_correlation(trait_name, corr_method, trait_vals, correlation coeff and p value """ - (sanitized_traits_vals, sanitized_target_vals, - num_overlap) = normalize_values(trait_vals, target_samples_vals) + + sanitized_traits_vals, sanitized_target_vals = list( + zip(*list(normalize_values(trait_vals, target_samples_vals)))) + num_overlap = len(sanitized_traits_vals) if num_overlap > 5: @@ -114,13 +113,9 @@ def filter_shared_sample_keys(this_samplelist, filter the values using the shared keys """ - this_vals = [] - target_vals = [] for key, value in target_samplelist.items(): if key in this_samplelist: - target_vals.append(value) - this_vals.append(this_samplelist[key]) - return (this_vals, target_vals) + yield this_samplelist[key], value def fast_compute_all_sample_correlation(this_trait, @@ -139,9 +134,10 @@ def fast_compute_all_sample_correlation(this_trait, for target_trait in target_dataset: trait_name = target_trait.get("trait_id") target_trait_data = target_trait["trait_sample_data"] - processed_values.append((trait_name, corr_method, *filter_shared_sample_keys( - this_trait_samples, target_trait_data))) - with multiprocessing.Pool(4) as pool: + processed_values.append((trait_name, corr_method, + list(zip(*list(filter_shared_sample_keys( + this_trait_samples, target_trait_data)))))) + with closing(multiprocessing.Pool()) as pool: results = pool.starmap(compute_sample_r_correlation, processed_values) for sample_correlation in results: @@ -172,8 +168,8 @@ def compute_all_sample_correlation(this_trait, for target_trait in target_dataset: trait_name = target_trait.get("trait_id") target_trait_data = target_trait["trait_sample_data"] - this_vals, target_vals = filter_shared_sample_keys( - this_trait_samples, target_trait_data) + this_vals, target_vals = list(zip(*list(filter_shared_sample_keys( + this_trait_samples, target_trait_data)))) sample_correlation = compute_sample_r_correlation( trait_name=trait_name, diff --git a/gn3/computations/correlations2.py b/gn3/computations/correlations2.py index 93db3fa..d0222ae 100644 --- a/gn3/computations/correlations2.py +++ b/gn3/computations/correlations2.py @@ -6,45 +6,21 @@ FUNCTIONS: compute_correlation: TODO: Describe what the function does...""" -from math import sqrt -from functools import reduce +from scipy import stats ## From GN1: mostly for clustering and heatmap generation def __items_with_values(dbdata, userdata): """Retains only corresponding items in the data items that are not `None` values. This should probably be renamed to something sensible""" - def both_not_none(item1, item2): - """Check that both items are not the value `None`.""" - if (item1 is not None) and (item2 is not None): - return (item1, item2) - return None - def split_lists(accumulator, item): - """Separate the 'x' and 'y' items.""" - return [accumulator[0] + [item[0]], accumulator[1] + [item[1]]] - return reduce( - split_lists, - filter(lambda x: x is not None, map(both_not_none, dbdata, userdata)), - [[], []]) + filtered = [x for x in zip(dbdata, userdata) if x[0] is not None and x[1] is not None] + return tuple(zip(*filtered)) if filtered else ([], []) def compute_correlation(dbdata, userdata): - """Compute some form of correlation. + """Compute the Pearson correlation coefficient. This is extracted from https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/utility/webqtlUtil.py#L622-L647 """ x_items, y_items = __items_with_values(dbdata, userdata) - if len(x_items) < 6: - return (0.0, len(x_items)) - meanx = sum(x_items)/len(x_items) - meany = sum(y_items)/len(y_items) - def cal_corr_vals(acc, item): - xitem, yitem = item - return [ - acc[0] + ((xitem - meanx) * (yitem - meany)), - acc[1] + ((xitem - meanx) * (xitem - meanx)), - acc[2] + ((yitem - meany) * (yitem - meany))] - xyd, sxd, syd = reduce(cal_corr_vals, zip(x_items, y_items), [0.0, 0.0, 0.0]) - try: - return ((xyd/(sqrt(sxd)*sqrt(syd))), len(x_items)) - except ZeroDivisionError: - return(0, len(x_items)) + correlation = stats.pearsonr(x_items, y_items)[0] if len(x_items) >= 6 else 0 + return (correlation, len(x_items)) diff --git a/gn3/heatmaps.py b/gn3/heatmaps.py index bf9dfd1..f0af409 100644 --- a/gn3/heatmaps.py +++ b/gn3/heatmaps.py @@ -64,11 +64,7 @@ def cluster_traits(traits_data_list: Sequence[Dict]): def __compute_corr(tdata_i, tdata_j): if tdata_i[0] == tdata_j[0]: return 0.0 - corr_vals = compute_correlation(tdata_i[1], tdata_j[1]) - corr = corr_vals[0] - if (1 - corr) < 0: - return 0.0 - return 1 - corr + return 1 - compute_correlation(tdata_i[1], tdata_j[1])[0] def __cluster(tdata_i): return tuple( |