diff options
Diffstat (limited to 'gn3/computations')
| -rw-r--r-- | gn3/computations/correlations.py | 9 | ||||
| -rw-r--r-- | gn3/computations/ctl.py | 8 | ||||
| -rw-r--r-- | gn3/computations/gemma.py | 13 | ||||
| -rw-r--r-- | gn3/computations/partial_correlations.py | 41 | ||||
| -rw-r--r-- | gn3/computations/pca.py | 4 | ||||
| -rw-r--r-- | gn3/computations/qtlreaper.py | 11 | ||||
| -rw-r--r-- | gn3/computations/rqtl.py | 43 | ||||
| -rw-r--r-- | gn3/computations/rqtl2.py | 228 | ||||
| -rw-r--r-- | gn3/computations/rust_correlation.py | 37 | ||||
| -rw-r--r-- | gn3/computations/streaming.py | 62 | ||||
| -rw-r--r-- | gn3/computations/wgcna.py | 11 |
11 files changed, 392 insertions, 75 deletions
diff --git a/gn3/computations/correlations.py b/gn3/computations/correlations.py index d805af7..95bd957 100644 --- a/gn3/computations/correlations.py +++ b/gn3/computations/correlations.py @@ -6,6 +6,7 @@ from multiprocessing import Pool, cpu_count from typing import List from typing import Tuple +from typing import Sequence from typing import Optional from typing import Callable from typing import Generator @@ -52,8 +53,10 @@ def normalize_values(a_values: List, b_values: List) -> Generator: yield a_val, b_val -def compute_corr_coeff_p_value(primary_values: List, target_values: List, - corr_method: str) -> Tuple[float, float]: +def compute_corr_coeff_p_value( + primary_values: Sequence, + target_values: Sequence, + corr_method: str) -> Tuple[float, float]: """Given array like inputs calculate the primary and target_value methods -> pearson,spearman and biweight mid correlation return value is rho and p_value @@ -196,7 +199,7 @@ def compute_all_sample_correlation(this_trait, """ this_trait_samples = this_trait["trait_sample_data"] - with Pool(processes=(cpu_count() - 1)) as pool: + with Pool(processes=cpu_count() - 1) as pool: return sorted( ( corr for corr in diff --git a/gn3/computations/ctl.py b/gn3/computations/ctl.py index f881410..5c004ea 100644 --- a/gn3/computations/ctl.py +++ b/gn3/computations/ctl.py @@ -6,13 +6,11 @@ from gn3.computations.wgcna import dump_wgcna_data from gn3.computations.wgcna import compose_wgcna_cmd from gn3.computations.wgcna import process_image -from gn3.settings import TMPDIR - -def call_ctl_script(data): +def call_ctl_script(data, tmpdir): """function to call ctl script""" - data["imgDir"] = TMPDIR - temp_file_name = dump_wgcna_data(data) + data["imgDir"] = tmpdir + temp_file_name = dump_wgcna_data(data, tmpdir) cmd = compose_wgcna_cmd("ctl_analysis.R", temp_file_name) cmd_results = run_cmd(cmd) diff --git a/gn3/computations/gemma.py b/gn3/computations/gemma.py index 6c53ecc..f07628f 100644 --- a/gn3/computations/gemma.py +++ b/gn3/computations/gemma.py @@ -41,12 +41,13 @@ def generate_pheno_txt_file(trait_filename: str, # pylint: disable=R0913 -def generate_gemma_cmd(gemma_cmd: str, - output_dir: str, - token: str, - gemma_kwargs: Dict, - gemma_wrapper_kwargs: Optional[Dict] = None, - chromosomes: Optional[str] = None) -> Dict: +def generate_gemma_cmd(# pylint: disable=[too-many-positional-arguments] + gemma_cmd: str, + output_dir: str, + token: str, + gemma_kwargs: Dict, + gemma_wrapper_kwargs: Optional[Dict] = None, + chromosomes: Optional[str] = None) -> Dict: """Compute k values""" _hash = get_hash_of_files( [v for k, v in gemma_kwargs.items() if k in ["g", "p", "a", "c"]]) diff --git a/gn3/computations/partial_correlations.py b/gn3/computations/partial_correlations.py index 6eee299..8674910 100644 --- a/gn3/computations/partial_correlations.py +++ b/gn3/computations/partial_correlations.py @@ -16,7 +16,6 @@ import pandas import pingouin from scipy.stats import pearsonr, spearmanr -from gn3.settings import TEXTDIR from gn3.chancy import random_string from gn3.function_helpers import compose from gn3.data_helpers import parse_csv_line @@ -99,7 +98,7 @@ def fix_samples( primary_samples, tuple(primary_trait_data["data"][sample]["value"] for sample in primary_samples), - control_vals_vars[0], + (control_vals_vars[0],), tuple(primary_trait_data["data"][sample]["variance"] for sample in primary_samples), control_vals_vars[1]) @@ -209,7 +208,7 @@ def good_dataset_samples_indexes( samples_from_file.index(good) for good in set(samples).intersection(set(samples_from_file)))) -def partial_correlations_fast(# pylint: disable=[R0913, R0914] +def partial_correlations_fast(# pylint: disable=[R0913, R0914, too-many-positional-arguments] samples, primary_vals, control_vals, database_filename, fetched_correlations, method: str, correlation_type: str) -> Generator: """ @@ -334,7 +333,7 @@ def compute_partial( This implementation reworks the child function `compute_partial` which will then be used in the place of `determinPartialsByR`. """ - with Pool(processes=(cpu_count() - 1)) as pool: + with Pool(processes=cpu_count() - 1) as pool: return ( result for result in ( pool.starmap( @@ -345,7 +344,7 @@ def compute_partial( for target in targets))) if result is not None) -def partial_correlations_normal(# pylint: disable=R0913 +def partial_correlations_normal(# pylint: disable=[R0913, too-many-positional-arguments] primary_vals, control_vals, input_trait_gene_id, trait_database, data_start_pos: int, db_type: str, method: str) -> Generator: """ @@ -381,7 +380,7 @@ def partial_correlations_normal(# pylint: disable=R0913 return all_correlations -def partial_corrs(# pylint: disable=[R0913] +def partial_corrs(# pylint: disable=[R0913, too-many-positional-arguments] conn, samples, primary_vals, control_vals, return_number, species, input_trait_geneid, input_trait_symbol, tissue_probeset_freeze_id, method, dataset, database_filename): @@ -667,10 +666,15 @@ def check_for_common_errors(# pylint: disable=[R0914] return non_error_result -def partial_correlations_with_target_db(# pylint: disable=[R0913, R0914, R0911] - conn: Any, primary_trait_name: str, - control_trait_names: Tuple[str, ...], method: str, - criteria: int, target_db_name: str) -> dict: +def partial_correlations_with_target_db(# pylint: disable=[R0913, R0914, R0911 too-many-positional-arguments] + conn: Any, + primary_trait_name: str, + control_trait_names: Tuple[str, ...], + method: str, + criteria: int, + target_db_name: str, + textdir: str +) -> dict: """ This is the 'ochestration' function for the partial-correlation feature. @@ -755,7 +759,7 @@ def partial_correlations_with_target_db(# pylint: disable=[R0913, R0914, R0911] threshold, conn) - database_filename = get_filename(conn, target_db_name, TEXTDIR) + database_filename = get_filename(conn, target_db_name, textdir) all_correlations = partial_corrs( conn, check_res["common_primary_control_samples"], check_res["fixed_primary_values"], check_res["fixed_control_values"], @@ -837,7 +841,7 @@ def partial_correlations_with_target_traits( return check_res target_traits = { - trait["name"]: trait + trait["trait_name"]: trait for trait in traits_info(conn, threshold, target_trait_names)} target_traits_data = traits_data(conn, tuple(target_traits.values())) @@ -854,12 +858,13 @@ def partial_correlations_with_target_traits( __merge( target_traits[target_name], compute_trait_info( - check_res["primary_values"], check_res["fixed_control_values"], - (export_trait_data( - target_data, - samplelist=check_res["common_primary_control_samples"]), - target_name), - method)) + check_res["primary_values"], + check_res["fixed_control_values"], + (export_trait_data( + target_data, + samplelist=check_res["common_primary_control_samples"]), + target_name), + method)) for target_name, target_data in target_traits_data.items()) return { diff --git a/gn3/computations/pca.py b/gn3/computations/pca.py index 35c9f03..3b3041a 100644 --- a/gn3/computations/pca.py +++ b/gn3/computations/pca.py @@ -13,7 +13,7 @@ import redis from typing_extensions import TypeAlias -fArray: TypeAlias = list[float] +fArray: TypeAlias = list[float] # pylint: disable=[invalid-name] def compute_pca(array: list[fArray]) -> dict[str, Any]: @@ -133,7 +133,7 @@ def generate_pca_temp_traits( """ - # pylint: disable=too-many-arguments + # pylint: disable=[too-many-arguments, too-many-positional-arguments] pca_trait_dict = {} diff --git a/gn3/computations/qtlreaper.py b/gn3/computations/qtlreaper.py index 08c387f..ff83b33 100644 --- a/gn3/computations/qtlreaper.py +++ b/gn3/computations/qtlreaper.py @@ -7,7 +7,6 @@ import subprocess from typing import Union from gn3.chancy import random_string -from gn3.settings import TMPDIR def generate_traits_file(samples, trait_values, traits_filename): """ @@ -38,13 +37,15 @@ def create_output_directory(path: str): # If the directory already exists, do nothing. pass -# pylint: disable=too-many-arguments +# pylint: disable=[too-many-arguments, too-many-positional-arguments] def run_reaper( reaper_cmd: str, - genotype_filename: str, traits_filename: str, + genotype_filename: str, + traits_filename: str, + output_dir: str, other_options: tuple = ("--n_permutations", "1000"), - separate_nperm_output: bool = False, - output_dir: str = TMPDIR): + separate_nperm_output: bool = False +): """ Run the QTLReaper command to compute the QTLs. diff --git a/gn3/computations/rqtl.py b/gn3/computations/rqtl.py index 16f1398..3dd8fb2 100644 --- a/gn3/computations/rqtl.py +++ b/gn3/computations/rqtl.py @@ -1,5 +1,6 @@ """Procedures related to R/qtl computations""" import os +import csv from bisect import bisect from typing import Dict, List, Tuple, Union @@ -67,8 +68,8 @@ def process_rqtl_mapping(file_name: str) -> List: # Later I should probably redo this using csv.read to avoid the # awkwardness with removing quotes with [1:-1] outdir = os.path.join(get_tmpdir(),"gn3") - - with open( os.path.join(outdir,file_name),"r",encoding="utf-8") as the_file: + with open(os.path.join(outdir,file_name),"r",encoding="utf-8") as the_file: + column_count = len(the_file.readline().strip().split(",")) for line in the_file: line_items = line.split(",") if line_items[1][1:-1] == "chr" or not line_items: @@ -88,6 +89,16 @@ def process_rqtl_mapping(file_name: str) -> List: "Mb": float(line_items[2]), "lod_score": float(line_items[3]), } + # If 4-way, get extra effect columns + if column_count > 4: + this_marker['mean1'] = line_items[4][1:-1].split(' ± ')[0] + this_marker['se1'] = line_items[4][1:-1].split(' ± ')[1] + this_marker['mean2'] = line_items[5][1:-1].split(' ± ')[0] + this_marker['se2'] = line_items[5][1:-1].split(' ± ')[1] + this_marker['mean3'] = line_items[6][1:-1].split(' ± ')[0] + this_marker['se3'] = line_items[6][1:-1].split(' ± ')[1] + this_marker['mean4'] = line_items[7][1:-1].split(' ± ')[0] + this_marker['se4'] = line_items[7][1:-1].split(' ± ')[1] marker_obs.append(this_marker) return marker_obs @@ -111,7 +122,7 @@ def pairscan_for_figure(file_name: str) -> Dict: # Open the file with the actual results, written as a list of lists outdir = os.path.join(get_tmpdir(),"gn3") - with open( os.path.join(outdir,file_name),"r",encoding="utf-8") as the_file: + with open(os.path.join(outdir, file_name), "r",encoding="utf-8") as the_file: lod_results = [] for i, line in enumerate(the_file): if i == 0: # Skip first line @@ -134,14 +145,17 @@ def pairscan_for_figure(file_name: str) -> Dict: ) as the_file: chr_list = [] # type: List pos_list = [] # type: List + markers = [] # type: List for i, line in enumerate(the_file): if i == 0: # Skip first line continue line_items = [item.rstrip("\n") for item in line.split(",")] chr_list.append(line_items[1][1:-1]) pos_list.append(line_items[2]) + markers.append(line_items[0]) figure_data["chr"] = chr_list figure_data["pos"] = pos_list + figure_data["name"] = markers return figure_data @@ -312,18 +326,13 @@ def process_perm_output(file_name: str) -> Tuple[List, float, float]: suggestive and significant thresholds""" perm_results = [] - outdir = os.path.join(get_tmpdir(),"gn3") - - with open( os.path.join(outdir,file_name),"r",encoding="utf-8") as the_file: - for i, line in enumerate(the_file): - if i == 0: - # Skip header line - continue - - line_items = line.split(",") - perm_results.append(float(line_items[1])) - - suggestive = np.percentile(np.array(perm_results), 67) - significant = np.percentile(np.array(perm_results), 95) - + outdir = os.path.join(get_tmpdir(), "gn3") + + with open(os.path.join(outdir, file_name), + "r", encoding="utf-8") as file_handler: + reader = csv.reader(file_handler) + next(reader) + perm_results = [float(row[1]) for row in reader] # Extract LOD values + suggestive = np.percentile(np.array(perm_results), 67) + significant = np.percentile(np.array(perm_results), 95) return perm_results, suggestive, significant diff --git a/gn3/computations/rqtl2.py b/gn3/computations/rqtl2.py new file mode 100644 index 0000000..5d5f68e --- /dev/null +++ b/gn3/computations/rqtl2.py @@ -0,0 +1,228 @@ +"""Module contains functions to parse and process rqtl2 input and output""" +import os +import csv +import uuid +import json +from pathlib import Path +from typing import List +from typing import Dict +from typing import Any + +def generate_rqtl2_files(data, workspace_dir): + """Prepare data and generate necessary CSV files + required to write to control_file + """ + file_to_name_map = { + "geno_file": "geno_data", + "pheno_file": "pheno_data", + "geno_map_file": "geno_map_data", + "physical_map_file": "physical_map_data", + "phenocovar_file": "phenocovar_data", + "founder_geno_file" : "founder_geno_data", + "covar_file" : "covar_data" + } + parsed_files = {} + for file_name, data_key in file_to_name_map.items(): + if data_key in data: + file_path = write_to_csv( + workspace_dir, f"{file_name}.csv", data[data_key]) + if file_path: + parsed_files[file_name] = file_path + return {**data, **parsed_files} + + +def write_to_csv(work_dir, file_name, data: list[dict], + headers=None, delimiter=","): + """Functions to write data list to csv file + if headers is not provided use the keys for first boject. + """ + if not data: + return "" + if headers is None: + headers = data[0].keys() + file_path = os.path.join(work_dir, file_name) + with open(file_path, "w", encoding="utf-8") as file_handler: + writer = csv.DictWriter(file_handler, fieldnames=headers, + delimiter=delimiter) + writer.writeheader() + for row in data: + writer.writerow(row) + # return the relative file to the workspace see rqtl2 docs + return file_name + + +def validate_required_keys(required_keys: list, data: dict) -> tuple[bool, str]: + """Check for missing keys in data object""" + missing_keys = [key for key in required_keys if key not in data] + if missing_keys: + return False, f"Required key(s) missing: {', '.join(missing_keys)}" + return True, "" + + +def compose_rqtl2_cmd(# pylint: disable=[too-many-positional-arguments] + rqtl_path, input_file, output_file, workspace_dir, data, config): + """Compose the command for running the R/QTL2 analysis.""" + # pylint: disable=R0913 + params = { + "input_file": input_file, + "directory": workspace_dir, + "output_file": output_file, + "nperm": data.get("nperm", 0), + "method": data.get("method", "HK"), + "threshold": data.get("threshold", 1), + "cores": config.get('MULTIPROCESSOR_PROCS', 1) + } + rscript_path = config.get("RSCRIPT", os.environ.get("RSCRIPT", "Rscript")) + return f"{rscript_path} { rqtl_path } " + " ".join( + [f"--{key} {val}" for key, val in params.items()]) + + +def create_file(file_path): + """Utility function to create file given a file_path""" + try: + with open(file_path, "x", encoding="utf-8") as _file_handler: + return True, f"File created at {file_path}" + except FileExistsError: + return False, "File Already Exists" + + +def prepare_files(tmpdir): + """Prepare necessary files and workspace dir for computation.""" + workspace_dir = os.path.join(tmpdir, str(uuid.uuid4())) + Path(workspace_dir).mkdir(parents=False, exist_ok=True) + input_file = os.path.join( + workspace_dir, f"rqtl2-input-{uuid.uuid4()}.json") + output_file = os.path.join( + workspace_dir, f"rqtl2-output-{uuid.uuid4()}.json") + + # to ensure streaming api has access to file even after computation ends + # .. Create the log file outside the workspace_dir + log_file = os.path.join(tmpdir, f"rqtl2-log-{uuid.uuid4()}") + for file_path in [input_file, output_file, log_file]: + create_file(file_path) + return workspace_dir, input_file, output_file, log_file + + +def write_input_file(input_file, workspace_dir, data): + """ + Write input data to a json file to be passed + as input to the rqtl2 script + """ + with open(input_file, "w+", encoding="UTF-8") as file_handler: + # todo choose a better variable name + rqtl2_files = generate_rqtl2_files(data, workspace_dir) + json.dump(rqtl2_files, file_handler) + + +def read_output_file(output_path: str) -> dict: + """function to read output file json generated from rqtl2 + see rqtl2_wrapper.R script for the expected output + """ + with open(output_path, "r", encoding="utf-8") as file_handler: + results = json.load(file_handler) + return results + + +def process_permutation(data): + """ This function processses output data from the output results. + input: data object extracted from the output_file + returns: + dict: A dict containing + * phenotypes array + * permutations as dict with keys as permutation_id + * significance_results with keys as threshold values + """ + + perm_file = data.get("permutation_file") + with open(perm_file, "r", encoding="utf-8") as file_handler: + reader = csv.reader(file_handler) + phenotypes = next(reader)[1:] + perm_results = {_id: float(val) for _id, val, *_ in reader} + _, significance = fetch_significance_results(data.get("significance_file")) + return { + "phenotypes": phenotypes, + "perm_results": perm_results, + "significance": significance, + } + + +def fetch_significance_results(file_path: str): + """ + Processes the 'significance_file' from the given data object to extract + phenotypes and significance values. + thresholds values are: (0.05, 0.01) + Args: + file_path (str): file_Path for the significance output + + Returns: + tuple: A tuple containing + * phenotypes (list): List of phenotypes + * significances (dict): A dictionary where keys + ...are threshold values and values are lists + of significant results corresponding to each threshold. + """ + with open(file_path, "r", encoding="utf-8") as file_handler: + reader = csv.reader(file_handler) + results = {} + phenotypes = next(reader)[1:] + for line in reader: + threshold, significance = line[0], line[1:] + results[threshold] = significance + return (phenotypes, results) + + +def process_scan_results(qtl_file_path: str, map_file_path: str) -> List[Dict[str, Any]]: + """Function to process genome scanning results and obtain marker_name, Lod score, + marker_position, and chromosome. + Args: + qtl_file_path (str): Path to the QTL scan results CSV file. + map_file_path (str): Path to the map file from the script. + + Returns: + List[Dict[str, str]]: A list of dictionaries containing the marker data. + """ + map_data = {} + # read the genetic map + with open(map_file_path, "r", encoding="utf-8") as file_handler: + reader = csv.reader(file_handler) + next(reader) + for line in reader: + marker, chr_, cm_, mb_ = line + cm: float | None = float(cm_) if cm_ and cm_ != "NA" else None + mb: float | None = float(mb_) if mb_ and mb_ != "NA" else None + map_data[marker] = {"chr": chr_, "cM": cm, "Mb": mb} + + # Process QTL scan results and merge the positional data + results = [] + with open(qtl_file_path, "r", encoding="utf-8") as file_handler: + reader = csv.reader(file_handler) + next(reader) + for line in reader: + marker = line[0] + lod_score = line[1] + results.append({ + "name": marker, + "lod_score": float(lod_score), + **map_data.get(marker, {}) # Add chromosome and positions if available + }) + return results + + +def process_qtl2_results(output_file: str) -> Dict[str, Any]: + """Function provides abstraction for processing all QTL2 mapping results. + + Args: * File path to to the output generated + + Returns: + Dict[str, any]: A dictionary containing both QTL + and permutation results along with input data. + """ + results = read_output_file(output_file) + qtl_results = process_scan_results(results["scan_file"], + results["map_file"]) + permutation_results = process_permutation(results) if results["permutations"] > 0 else {} + return { + **results, + "qtl_results": qtl_results, + "permutation_results": permutation_results + } diff --git a/gn3/computations/rust_correlation.py b/gn3/computations/rust_correlation.py index 5ce097d..359b73a 100644 --- a/gn3/computations/rust_correlation.py +++ b/gn3/computations/rust_correlation.py @@ -3,27 +3,27 @@ https://github.com/Alexanderlacuna/correlation_rust """ -import subprocess -import json -import csv import os +import csv +import json +import traceback +import subprocess from flask import current_app from gn3.computations.qtlreaper import create_output_directory from gn3.chancy import random_string -from gn3.settings import TMPDIR -def generate_input_files(dataset: list[str], - output_dir: str = TMPDIR) -> tuple[str, str]: +def generate_input_files( + dataset: list[str], output_dir: str) -> tuple[str, str]: """function generates outputfiles and inputfiles""" tmp_dir = f"{output_dir}/correlation" create_output_directory(tmp_dir) tmp_file = os.path.join(tmp_dir, f"{random_string(10)}.txt") with open(tmp_file, "w", encoding="utf-8") as op_file: writer = csv.writer( - op_file, delimiter=",", dialect="unix", quotechar="", + op_file, delimiter=",", dialect="unix", quoting=csv.QUOTE_NONE, escapechar="\\") writer.writerows(dataset) @@ -49,17 +49,23 @@ def generate_json_file( def run_correlation( - dataset, trait_vals: str, method: str, delimiter: str, - corr_type: str = "sample", top_n: int = 500): + dataset, + trait_vals: str, + method: str, + delimiter: str, + tmpdir: str, + corr_type: str = "sample", + top_n: int = 500 +): """entry function to call rust correlation""" - # pylint: disable=too-many-arguments + # pylint: disable=[too-many-arguments, too-many-positional-arguments] correlation_command = current_app.config["CORRELATION_COMMAND"] # make arg? - (tmp_dir, tmp_file) = generate_input_files(dataset) + (tmp_dir, tmp_file) = generate_input_files(dataset, tmpdir) (output_file, json_file) = generate_json_file( tmp_dir=tmp_dir, tmp_file=tmp_file, method=method, delimiter=delimiter, x_vals=trait_vals) - command_list = [correlation_command, json_file, TMPDIR] + command_list = [correlation_command, json_file, tmpdir] try: subprocess.run(command_list, check=True, capture_output=True) except subprocess.CalledProcessError as cpe: @@ -67,7 +73,12 @@ def run_correlation( os.readlink(correlation_command) if os.path.islink(correlation_command) else correlation_command) - raise Exception(command_list, actual_command, cpe.stdout) from cpe + raise Exception(# pylint: disable=[broad-exception-raised] + command_list, + actual_command, + cpe.stdout, + traceback.format_exc().split() + ) from cpe return parse_correlation_output(output_file, corr_type, top_n) diff --git a/gn3/computations/streaming.py b/gn3/computations/streaming.py new file mode 100644 index 0000000..6e02694 --- /dev/null +++ b/gn3/computations/streaming.py @@ -0,0 +1,62 @@ +"""Module contains streaming procedures for genenetwork. """ +import os +import subprocess +from functools import wraps +from flask import current_app, request + + +def read_file(file_path): + """Add utility function to read files""" + with open(file_path, "r", encoding="UTF-8") as file_handler: + return file_handler.read() + +def run_process(cmd, log_file, run_id): + """Function to execute an external process and + capture the stdout in a file + input: + cmd: the command to execute as a list of args. + log_file: abs file path to write the stdout. + run_id: unique id to identify the process + + output: + Dict with the results for either success or failure. + """ + try: + # phase: execute the rscript cmd + with subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) as process: + for line in iter(process.stdout.readline, b""): + # phase: capture the stdout for each line allowing read and write + with open(log_file, "a+", encoding="utf-8") as file_handler: + file_handler.write(line.decode("utf-8")) + process.wait() + return {"msg": "success" if process.returncode == 0 else "Process failed", + "run_id": run_id, + "log" : read_file(log_file), + "code": process.returncode} + except subprocess.CalledProcessError as error: + return {"msg": "error occurred", + "code": error.returncode, + "error": str(error), + "run_id": run_id, + "log" : read_file(log_file)} + + +def enable_streaming(func): + """Decorator function to enable streaming for an endpoint + Note: should only be used in an app context + """ + @wraps(func) + def decorated_function(*args, **kwargs): + run_id = request.args.get("id") + stream_output_file = os.path.join(current_app.config.get("TMPDIR"), + f"{run_id}.txt") + with open(stream_output_file, "w+", encoding="utf-8", + ) as file_handler: + file_handler.write("File created for streaming\n" + ) + return func(stream_output_file, *args, **kwargs) + return decorated_function diff --git a/gn3/computations/wgcna.py b/gn3/computations/wgcna.py index d1f7b32..3229a0e 100644 --- a/gn3/computations/wgcna.py +++ b/gn3/computations/wgcna.py @@ -7,17 +7,16 @@ import subprocess from pathlib import Path -from gn3.settings import TMPDIR from gn3.commands import run_cmd -def dump_wgcna_data(request_data: dict): +def dump_wgcna_data(request_data: dict, tmpdir: str): """function to dump request data to json file""" filename = f"{str(uuid.uuid4())}.json" - temp_file_path = os.path.join(TMPDIR, filename) + temp_file_path = os.path.join(tmpdir, filename) - request_data["TMPDIR"] = TMPDIR + request_data["TMPDIR"] = tmpdir with open(temp_file_path, "w", encoding="utf-8") as output_file: json.dump(request_data, output_file) @@ -65,9 +64,9 @@ def compose_wgcna_cmd(rscript_path: str, temp_file_path: str): return cmd -def call_wgcna_script(rscript_path: str, request_data: dict): +def call_wgcna_script(rscript_path: str, request_data: dict, tmpdir: str): """function to call wgcna script""" - generated_file = dump_wgcna_data(request_data) + generated_file = dump_wgcna_data(request_data, tmpdir) cmd = compose_wgcna_cmd(rscript_path, generated_file) # stream_cmd_output(request_data, cmd) disable streaming of data |
