aboutsummaryrefslogtreecommitdiff
path: root/gn3/auth/authorisation/groups/data.py
blob: 0c821d386fdf3a31317a40653d3ee0661690cc05 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
"""Handles the resource objects' data."""
from typing import Any, Sequence

from MySQLdb.cursors import DictCursor

from gn3 import db_utils as gn3db
from gn3.auth import db as authdb
from gn3.auth.authorisation.groups import Group
from gn3.auth.authorisation.checks import authorised_p
from gn3.auth.authorisation.errors import InvalidData, NotFoundError

def __fetch_grouped_data__(
        conn: authdb.DbConnection, dataset_type: str) -> Sequence[dict[str, Any]]:
    """Retrieve ids for all data that are linked to groups in the auth db."""
    with authdb.cursor(conn) as cursor:
        cursor.execute(
            "SELECT dataset_type, dataset_or_trait_id FROM linked_group_data "
            "WHERE LOWER(dataset_type)=?",
            (dataset_type,))
        return tuple(dict(row) for row in cursor.fetchall())

def __fetch_ungrouped_mrna_data__(
        conn: gn3db.Connection, grouped_data, offset: int) -> Sequence[dict]:
    """Fetch ungrouped mRNA Assay data."""
    query = ("SELECT psf.Id, psf.Name AS dataset_name, "
             "psf.FullName AS dataset_fullname, "
             "ifiles.GN_AccesionId AS accession_id FROM ProbeSetFreeze AS psf "
             "INNER JOIN InfoFiles AS ifiles ON psf.Name=ifiles.InfoPageName")
    params: tuple[Any, ...] = tuple()
    if bool(grouped_data):
        clause = ", ".join(["%s"] * len(grouped_data))
        query = f"{query} WHERE psf.Id NOT IN ({clause})"
        params = tuple(item["dataset_or_trait_id"] for item in grouped_data)

    query = f"{query} LIMIT 100 OFFSET %s"
    with conn.cursor(DictCursor) as cursor:
        cursor.execute(query, (params + (offset,)))
        return tuple(dict(row) for row in cursor.fetchall())

def __fetch_ungrouped_geno_data__(
        conn: gn3db.Connection, grouped_data, offset: int) -> Sequence[dict]:
    """Fetch ungrouped Genotype data."""
    query = ("SELECT gf.Id, gf.Name AS dataset_name, "
             "gf.FullName AS dataset_fullname, "
             "ifiles.GN_AccesionId AS accession_id FROM GenoFreeze AS gf "
             "INNER JOIN InfoFiles AS ifiles ON gf.Name=ifiles.InfoPageName")
    params: tuple[Any, ...] = tuple()
    if bool(grouped_data):
        clause = ", ".join(["%s"] * len(grouped_data))
        query = f"{query} WHERE gf.Id NOT IN ({clause})"
        params = tuple(item["dataset_or_trait_id"] for item in grouped_data)

    query = f"{query} LIMIT 100 OFFSET %s"
    with conn.cursor(DictCursor) as cursor:
        cursor.execute(query, (params + (offset,)))
        return tuple(dict(row) for row in cursor.fetchall())

def __fetch_ungrouped_pheno_data__(
        conn: gn3db.Connection, grouped_data, offset: int) -> Sequence[dict]:
    """Fetch ungrouped Phenotype data."""
    query = ("SELECT "
              "pxf.Id, iset.InbredSetName, pf.Name AS dataset_name, "
              "pf.FullName AS dataset_fullname, "
              "pf.ShortName AS dataset_shortname "
              "FROM PublishXRef AS pxf "
              "INNER JOIN InbredSet AS iset "
              "ON pxf.InbredSetId=iset.InbredSetId "
              "LEFT JOIN PublishFreeze AS pf "
              "ON iset.InbredSetId=pf.InbredSetId")
    params: tuple[Any, ...] = tuple()
    if bool(grouped_data):
        clause = ", ".join(["%s"] * len(grouped_data))
        query = f"{query} WHERE pxf.Id NOT IN ({clause})"
        params = tuple(item["dataset_or_trait_id"] for item in grouped_data)

    query = f"{query} LIMIT 100 OFFSET %s"
    with conn.cursor(DictCursor) as cursor:
        cursor.execute(query, (params + (offset,)))
        return tuple(dict(row) for row in cursor.fetchall())

def __fetch_ungrouped_data__(
        conn: gn3db.Connection, dataset_type: str,
        ungrouped: Sequence[dict[str, Any]],
        offset) -> Sequence[dict[str, Any]]:
    """Fetch any ungrouped data."""
    fetch_fns = {
        "mrna": __fetch_ungrouped_mrna_data__,
        "genotype": __fetch_ungrouped_geno_data__,
        "phenotype": __fetch_ungrouped_pheno_data__
    }
    return fetch_fns[dataset_type](conn, ungrouped, offset)

@authorised_p(("system:data:link-to-group",),
              error_description=(
                  "You do not have sufficient privileges to link data to (a) "
                  "group(s)."),
              oauth2_scope="profile group resource")
def retrieve_ungrouped_data(
        authconn: authdb.DbConnection,
        gn3conn: gn3db.Connection,
        dataset_type: str,
        offset: int = 0) -> Sequence[dict]:
    """Retrieve any data not linked to any group."""
    if dataset_type not in ("mrna", "genotype", "phenotype"):
        raise InvalidData(
            "Requested dataset type is invalid. Expected one of "
            "'mrna', 'genotype' or 'phenotype'.")
    grouped_data = __fetch_grouped_data__(authconn, dataset_type)
    return __fetch_ungrouped_data__(gn3conn, dataset_type, grouped_data, offset)

def __fetch_mrna_data_by_ids__(
        conn: gn3db.Connection, dataset_ids: tuple[str, ...]) -> tuple[
            dict, ...]:
    """Fetch mRNA Assay data by ID."""
    with conn.cursor(DictCursor) as cursor:
        paramstr = ", ".join(["%s"] * len(dataset_ids))
        cursor.execute(
            "SELECT psf.Id, psf.Name AS dataset_name, "
            "psf.FullName AS dataset_fullname, "
            "ifiles.GN_AccesionId AS accession_id FROM ProbeSetFreeze AS psf "
            "INNER JOIN InfoFiles AS ifiles ON psf.Name=ifiles.InfoPageName "
            f"WHERE psf.Id IN ({paramstr})",
            dataset_ids)
        res = cursor.fetchall()
        if res:
            return tuple(dict(row) for row in res)
        raise NotFoundError("Could not find mRNA Assay data with the given ID.")

def __fetch_geno_data_by_ids__(
        conn: gn3db.Connection, dataset_ids: tuple[str, ...]) -> tuple[
            dict, ...]:
    """Fetch genotype data by ID."""
    with conn.cursor(DictCursor) as cursor:
        paramstr = ", ".join(["%s"] * len(dataset_ids))
        cursor.execute(
            "SELECT gf.Id, gf.Name AS dataset_name, "
            "gf.FullName AS dataset_fullname, "
            "ifiles.GN_AccesionId AS accession_id FROM GenoFreeze AS gf "
            "INNER JOIN InfoFiles AS ifiles ON gf.Name=ifiles.InfoPageName "
            f"WHERE gf.Id IN ({paramstr})",
            dataset_ids)
        res = cursor.fetchall()
        if res:
            return tuple(dict(row) for row in res)
        raise NotFoundError("Could not find Genotype data with the given ID.")

def __fetch_pheno_data_by_ids__(
        conn: gn3db.Connection, dataset_ids: tuple[str, ...]) -> tuple[
            dict, ...]:
    """Fetch phenotype data by ID."""
    with conn.cursor(DictCursor) as cursor:
        paramstr = ", ".join(["%s"] * len(dataset_ids))
        cursor.execute(
            "SELECT pxf.Id, iset.InbredSetName, pf.Id AS dataset_id, "
            "pf.Name AS dataset_name, pf.FullName AS dataset_fullname, "
            "ifiles.GN_AccesionId AS accession_id "
            "FROM PublishXRef AS pxf "
            "INNER JOIN InbredSet AS iset ON pxf.InbredSetId=iset.InbredSetId "
            "INNER JOIN PublishFreeze AS pf ON iset.InbredSetId=pf.InbredSetId "
            "INNER JOIN InfoFiles AS ifiles ON pf.Name=ifiles.InfoPageName "
            f"WHERE pxf.Id IN ({paramstr})",
            dataset_ids)
        res = cursor.fetchall()
        if res:
            return tuple(dict(row) for row in res)
        raise NotFoundError(
            "Could not find Phenotype/Publish data with the given IDs.")

def __fetch_data_by_id(
        conn: gn3db.Connection, dataset_type: str,
        dataset_ids: tuple[str, ...]) -> tuple[dict, ...]:
    """Fetch data from MySQL by IDs."""
    fetch_fns = {
        "mrna": __fetch_mrna_data_by_ids__,
        "genotype": __fetch_geno_data_by_ids__,
        "phenotype": __fetch_pheno_data_by_ids__
    }
    return fetch_fns[dataset_type](conn, dataset_ids)

@authorised_p(("system:data:link-to-group",),
              error_description=(
                  "You do not have sufficient privileges to link data to (a) "
                  "group(s)."),
              oauth2_scope="profile group resource")
def link_data_to_group(
        authconn: authdb.DbConnection, gn3conn: gn3db.Connection,
        dataset_type: str, dataset_ids: tuple[str, ...], group: Group) -> tuple[
            dict, ...]:
    """Link the given data to the specified group."""
    the_data = __fetch_data_by_id(gn3conn, dataset_type, dataset_ids)
    with authdb.cursor(authconn) as cursor:
        params = tuple({
            "group_id": str(group.group_id), "dataset_type": {
                "mrna": "mRNA", "genotype": "Genotype",
                "phenotype": "Phenotype"
            }[dataset_type],
            "dataset_or_trait_id": item["Id"],
            "dataset_name": item["dataset_name"],
            "dataset_fullname": item["dataset_fullname"],
            "accession_id": item["accession_id"]
        } for item in the_data)
        cursor.executemany(
            "INSERT INTO linked_group_data VALUES"
            "(:group_id, :dataset_type, :dataset_or_trait_id, :dataset_name, "
            ":dataset_fullname, :accession_id)",
            params)
        return params