From cdd4dc456e56bb4eb055e1cb7f2518d45fb3bfb9 Mon Sep 17 00:00:00 2001 From: Frederick Muriuki Muriithi Date: Sat, 20 Jan 2024 09:57:23 +0300 Subject: Fetch sample/case names from database Fetch the sample/case names from the database rather than from a static file in the repository. Issue: https://issues.genenetwork.org/issues/quality-control/read-samples-from-database-by-species --- quality_control/parsing.py | 27 ++++++++++++--------------- 1 file changed, 12 insertions(+), 15 deletions(-) (limited to 'quality_control/parsing.py') diff --git a/quality_control/parsing.py b/quality_control/parsing.py index c545937..f7a664f 100644 --- a/quality_control/parsing.py +++ b/quality_control/parsing.py @@ -4,6 +4,9 @@ from enum import Enum from functools import partial from typing import Tuple, Union, Generator, Callable, Optional +import MySQLdb as mdb +from MySQLdb.cursors import DictCursor + import quality_control.average as avg from quality_control.file_utils import open_file import quality_control.standard_error as se @@ -17,21 +20,15 @@ class FileType(Enum): AVERAGE = 1 STANDARD_ERROR = 2 -def strain_names(filepath): - """Retrieve the strains names from given file""" - strains = set() - with open(filepath, encoding="utf8") as strains_file: - for idx, line in enumerate(strains_file.readlines()): - if idx > 0: - parts = line.split() - for name in (parts[1], parts[2]): - strains.add(name.strip()) - if len(parts) >= 6: - alias = parts[5].strip() - if alias != "" and alias not in ("P", "\\N"): - strains.add(alias) - - return strains +def strain_names(dbconn: mdb.Connection, speciesid: int) -> tuple[str, ...]: + """Retrieve samples/cases from database.""" + with dbconn.cursor(cursorclass=DictCursor) as cursor: + cursor.execute("SELECT * FROM Strain WHERE SpeciesId=%s", + (speciesid,)) + samplenames = ((row["Name"], row["Name2"]) for row in cursor.fetchall()) + return tuple(set(filter( + lambda item: bool(item.strip() if item is not None else item), + (name for names in samplenames for name in names)))) def header_errors(line_number, fields, strains): """Gather all header row errors.""" -- cgit v1.2.3