aboutsummaryrefslogtreecommitdiff
path: root/scripts/update_rif_table.py
blob: 24edf3d4a45affcaea94c4363a7965c2d48b9b20 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
#!/usr/bin/env python3

"""
Script responsible for updating the GeneRIF_BASIC table
"""

import argparse
import csv
import datetime
import gzip
import logging
import pathlib
import os
from tempfile import TemporaryDirectory
from typing import Dict, Generator

import requests
from MySQLdb.cursors import DictCursor

from gn3.db_utils import database_connection

TAX_IDS = {"10090": 1, "9606": 4, "10116": 2, "3702": 3}

GENE_INFO_URL = "https://ftp.ncbi.nlm.nih.gov/gene/DATA/gene_info.gz"
GENERIFS_BASIC_URL = "https://ftp.ncbi.nih.gov/gene/GeneRIF/generifs_basic.gz"

VERSION_ID = 5


INSERT_QUERY = """ INSERT INTO GeneRIF_BASIC
(SpeciesId, GeneId, symbol, PubMed_Id, createtime, comment, TaxID, VersionId)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
"""


def download_file(url: str, dest: pathlib.Path):
    """Saves the contents of url in dest"""
    with requests.get(url, stream=True) as resp:
        resp.raise_for_status()
        with open(dest, "wb") as downloaded_file:
            for chunk in resp.iter_content(chunk_size=8192):
                downloaded_file.write(chunk)


def read_tsv_file(fname: pathlib.Path) -> Generator:
    """Load tsv file from NCBI"""
    with gzip.open(fname, mode="rt") as gz_file:
        reader = csv.DictReader(gz_file, delimiter="\t", quoting=csv.QUOTE_NONE)
        yield from reader


def parse_gene_info_from_ncbi(fname: pathlib.Path) -> Dict[str, str]:
    """Parse gene_info into geneid: symbol pairs"""
    genedict: Dict[str, str] = {}
    for row in read_tsv_file(fname):
        if row["#tax_id"] not in TAX_IDS:
            continue
        gene_id, symbol = row["GeneID"], row["Symbol"]
        genedict[gene_id] = symbol
    return genedict


def build_already_exists_cache(conn) -> dict:
    """
    Build cache for all GeneId, SpeciesID, createtime, PubMed_ID combinations.
    Helps prevent duplicate inserts.
    """
    cache = {}
    query = """SELECT
        COUNT(*) as cnt, GeneId, SpeciesId, createtime, PubMed_ID
        from GeneRIF_BASIC
        GROUP BY GeneId, SpeciesId, createtime, PubMed_Id """

    with conn.cursor(DictCursor) as cursor:
        cursor.execute(query)
        while row := cursor.fetchone():
            key = (
                str(row["GeneId"]),
                str(row["SpeciesId"]),
                row["createtime"],
                str(row["PubMed_ID"]),
            )
            cache[key] = row["cnt"]
    return cache


def should_add_rif_row(row: dict, exists_cache: dict) -> bool:
    """Checks if we can add a rif_row, prevent duplicate errors from Mysql"""
    species_id = str(TAX_IDS[row["#Tax ID"]])
    insert_date = datetime.datetime.fromisoformat(row["last update timestamp"])
    search_key = (
        row["Gene ID"],
        species_id,
        insert_date,
        row["PubMed ID (PMID) list"],
    )
    if search_key not in exists_cache:
        exists_cache[search_key] = 1
        return True
    return False


def update_rif(sqluri: str):
    """Update GeneRIF_BASIC table"""
    with TemporaryDirectory() as _tmpdir:
        tmpdir = pathlib.Path(_tmpdir)
        gene_info_path = tmpdir / "gene_info.gz"
        logging.debug("Fetching gene_info data from: %s", GENE_INFO_URL)
        download_file(GENE_INFO_URL, gene_info_path)

        logging.debug("Fetching gene_rif_basics data from: %s", GENERIFS_BASIC_URL)
        generif_basics_path = tmpdir / "generif_basics.gz"
        download_file(
            GENERIFS_BASIC_URL,
            generif_basics_path,
        )

        logging.debug("Parsing gene_info data")
        genedict = parse_gene_info_from_ncbi(gene_info_path)
        with database_connection(sql_uri=sqluri) as con:
            exists_cache = build_already_exists_cache(con)
            cursor = con.cursor()
            skipped_if_exists, added = 0, 0
            for row in read_tsv_file(generif_basics_path):
                if row["#Tax ID"] not in TAX_IDS:
                    continue
                if not should_add_rif_row(row, exists_cache):
                    skipped_if_exists += 1
                    continue
                species_id = TAX_IDS[row["#Tax ID"]]
                symbol = genedict.get(row["Gene ID"], "")
                insert_values = (
                    species_id,  # SpeciesId
                    row["Gene ID"],  # GeneId
                    symbol,  # symbol
                    row["PubMed ID (PMID) list"],  # PubMed_ID
                    row["last update timestamp"],  # createtime
                    row["GeneRIF text"],  # comment
                    row["#Tax ID"],  # TaxID
                    VERSION_ID,  # VersionId
                )
                cursor.execute(INSERT_QUERY, insert_values)
                added += 1
                if added % 40_000 == 0:
                    logging.debug("Added 40,000 rows to database")
        logging.info(
            "Generif_BASIC table updated. Added %s. Skipped %s because they "
            "already exists. In case of error, you can use VersionID=%s to find "
            "rows inserted with this script", added, skipped_if_exists,
            VERSION_ID
        )


if __name__ == "__main__":
    logging.basicConfig(
        level=os.environ.get("LOGLEVEL", "DEBUG"),
        format="%(asctime)s %(levelname)s: %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S %Z",
    )
    parser = argparse.ArgumentParser("Update Generif_BASIC table")
    parser.add_argument(
        "--sql-uri",
        required=True,
        help="MYSQL uri path in the form mysql://user:password@localhost/gn2",
    )
    args = parser.parse_args()
    update_rif(args.sql_uri)