"""module contains code integration correlation implemented in rust here
https://github.com/Alexanderlacuna/correlation_rust
"""
import subprocess
import json
import csv
import os
from gn3.computations.qtlreaper import create_output_directory
from gn3.chancy import random_string
from gn3.settings import CORRELATION_COMMAND
from gn3.settings import TMPDIR
def generate_input_files(dataset: list[str],
output_dir: str = TMPDIR) -> tuple[str, str]:
"""function generates outputfiles and inputfiles"""
tmp_dir = f"{output_dir}/correlation"
create_output_directory(tmp_dir)
tmp_file = os.path.join(tmp_dir, f"{random_string(10)}.txt")
with open(tmp_file, "w", encoding="utf-8") as op_file:
writer = csv.writer(
op_file, delimiter=",", dialect="unix", quotechar="",
quoting=csv.QUOTE_NONE, escapechar="\\")
writer.writerows(dataset)
return (tmp_dir, tmp_file)
def generate_json_file(
tmp_dir, tmp_file, method, delimiter, x_vals) -> tuple[str, str]:
"""generating json input file required by cargo"""
tmp_json_file = os.path.join(tmp_dir, f"{random_string(10)}.json")
output_file = os.path.join(tmp_dir, f"{random_string(10)}.txt")
with open(tmp_json_file, "w", encoding="utf-8") as outputfile:
json.dump({
"method": method,
"file_path": tmp_file,
"x_vals": x_vals,
"sample_values": "bxd1",
"output_file": output_file,
"file_delimiter": delimiter
}, outputfile)
return (output_file, tmp_json_file)
def run_correlation(
dataset, trait_vals: str, method: str, delimiter: str,
corr_type: str = "sample", top_n: int = 500):
"""entry function to call rust correlation"""
# pylint: disable=too-many-arguments
(tmp_dir, tmp_file) = generate_input_files(dataset)
(output_file, json_file) = generate_json_file(
tmp_dir=tmp_dir, tmp_file=tmp_file, method=method, delimiter=delimiter,
x_vals=trait_vals)
command_list = [CORRELATION_COMMAND, json_file, TMPDIR]
try:
subprocess.run(command_list, check=True, capture_output=True)
except subprocess.CalledProcessError as cpe:
actual_command = (
os.readlink(CORRELATION_COMMAND)
if os.path.islink(CORRELATION_COMMAND)
else CORRELATION_COMMAND)
raise Exception(command_list, actual_command, cpe.stdout) from cpe
return parse_correlation_output(output_file, corr_type, top_n)
def parse_correlation_output(result_file: str,
corr_type: str, top_n: int = 500) -> dict:
"""parse file output """
# current types are sample and tissue
def __parse_line__(line):
(trait_name, corr_coeff, p_val, num_overlap) = line.rstrip().split(",")
if corr_type == "sample":
return (trait_name, {
"num_overlap": num_overlap,
"corr_coefficient": corr_coeff,
"p_value": p_val
})
if corr_type == "tissue":
return (
trait_name,
{
"tissue_corr": corr_coeff,
"tissue_number": num_overlap,
"tissue_p_val": p_val
})
return tuple(trait_name, {})
with open(result_file, "r", encoding="utf-8") as file_reader:
return dict([
__parse_line__(line)
for idx, line in enumerate(file_reader) if idx < top_n])
return {}
def get_samples(all_samples: dict[str, str],
base_samples: list[str],
excluded: list[str]):
"""filter null samples and excluded samples"""
data = {}
if base_samples:
fls = [
sm for sm in base_samples if sm not in excluded]
for sample in fls:
if sample in all_samples:
smp_val = all_samples[sample].strip()
if smp_val.lower() != "x":
data[sample] = float(smp_val)
return data
return({key: float(val.strip()) for (key, val) in all_samples.items()
if key not in excluded and val.lower().strip() != "x"})
def get_sample_corr_data(sample_type: str,
sample_data: dict[str, str],
dataset_samples: list[str]) -> dict[str, str]:
"""dependeing on the sample_type fetch the correct sample data """
if sample_type == "samples_primary":
data = get_samples(all_samples=sample_data,
base_samples=dataset_samples, excluded=[])
elif sample_type == "samples_other":
data = get_samples(
all_samples=sample_data,
base_samples=[],
excluded=dataset_samples)
else:
data = get_samples(
all_samples=sample_data, base_samples=[], excluded=[])
return data
def parse_tissue_corr_data(symbol_name: str,
symbol_dict: dict,
dataset_symbols: dict,
dataset_vals: dict):
"""parset tissue data input"""
results = None
if symbol_name and symbol_name.lower() in symbol_dict:
x_vals = [float(val) for val in symbol_dict[symbol_name.lower()]]
data = []
for (trait, symbol) in dataset_symbols.items():
try:
corr_vals = dataset_vals.get(symbol.lower())
if corr_vals:
data.append([str(trait)] + corr_vals)
except AttributeError:
pass
results = (x_vals, data)
return results
def run_lmdb_correlation(lmdb_info:dict):
"""
how is this correlation different from the one above:
1) all the preparsing is done in rust which is consinderally fast
2) file are read directly format of the file is lmdb
can also compute correlation or unprocessed csv files
"""
tmp_json_file = os.path.join(tmp_dir, f"{random_string(10)}.json")
output_file = os.path.join(tmp_dir, f"{random_string(10)}.txt")
with open(tmp_json_file, "w", encoding="utf-8") as outputfile:
json.dump({"outputfile": output_file, **lmdb_info})
command_list = [CORRELATION_COMMAND, json_file, TMPDIR]
try:
subprocess.run(command_list, check=True, capture_output=True)
except subprocess.CalledProcessError as cpe:
actual_command = (
os.readlink(CORRELATION_COMMAND)
if os.path.islink(CORRELATION_COMMAND)
else CORRELATION_COMMAND)
raise Exception(command_list, actual_command, cpe.stdout) from cpe
return parse_correlation_output(output_file, corr_type)