diff options
| -rw-r--r-- | src/lmm.cpp | 211 | ||||
| -rw-r--r-- | src/lmm.h | 5 |
2 files changed, 213 insertions, 3 deletions
diff --git a/src/lmm.cpp b/src/lmm.cpp index 2ac9835..ee7ba42 100644 --- a/src/lmm.cpp +++ b/src/lmm.cpp @@ -1867,14 +1867,219 @@ void LMM::Analyze(std::function< SnpNameValues(size_t) >& fetch_snp, This is the mirror function of below AnalyzeBimbam, but uses mdb input instead. */ + +void LMM::mdb_analyze(std::function< SnpNameValues(size_t) >& fetch_snp, + const gsl_matrix *U, const gsl_vector *eval, + const gsl_matrix *UtW, const gsl_vector *Uty, + const gsl_matrix *W, const gsl_vector *y, + const set<string> gwasnps) { + clock_t time_start = clock(); + + checkpoint_nofile("start-lmm-analyze"); + // Subset/LOCO support + bool process_gwasnps = gwasnps.size(); + + // Calculate basic quantities. + size_t n_index = (n_cvt + 2 + 1) * (n_cvt + 2) / 2; + + const size_t inds = U->size1; + enforce(inds == ni_test); + gsl_vector *x = gsl_vector_safe_alloc(inds); // #inds + gsl_vector *x_miss = gsl_vector_safe_alloc(inds); + gsl_vector *Utx = gsl_vector_safe_alloc(U->size2); + gsl_matrix *Uab = gsl_matrix_safe_alloc(U->size2, n_index); + gsl_vector *ab = gsl_vector_safe_alloc(n_index); + + // Create a large matrix with LMM_BATCH_SIZE columns for batched processing + // const size_t msize=(process_gwasnps ? 1 : LMM_BATCH_SIZE); + const size_t msize = LMM_BATCH_SIZE; + gsl_matrix *Xlarge = gsl_matrix_safe_alloc(inds, msize); + gsl_matrix *UtXlarge = gsl_matrix_safe_alloc(inds, msize); + enforce_msg(Xlarge && UtXlarge, "Xlarge memory check"); // just to be sure + enforce(Xlarge->size1 == inds); + gsl_matrix_set_zero(Xlarge); + gsl_matrix_set_zero(Uab); + CalcUab(UtW, Uty, Uab); + + // start reading genotypes and analyze + size_t c = 0; + + /* + batch_compute(l) takes l SNPs that have been loaded into Xlarge, + transforms them all at once using the eigenvector matrix U, then + loops through each transformed SNP to compute association + statistics (beta, standard errors, p-values) and stores results in + sumStat. + */ + auto batch_compute = [&](size_t l) { // using a C++ closure + // Compute SNPs in batch, note the computations are independent per SNP + // debug_msg("enter batch_compute"); + checkpoint_nofile("\nstart-batch_compute"); + gsl_matrix_view Xlarge_sub = gsl_matrix_submatrix(Xlarge, 0, 0, inds, l); + gsl_matrix_view UtXlarge_sub = + gsl_matrix_submatrix(UtXlarge, 0, 0, inds, l); + + time_start = clock(); + // Transforms all l SNPs in the batch at once: UtXlarge = U^T × Xlarge + // This is much faster than doing l separate matrix-vector products + // U is the eigenvector matrix from the spectral decomposition of the kinship matrix + fast_dgemm("T", "N", 1.0, U, &Xlarge_sub.matrix, 0.0, + &UtXlarge_sub.matrix); + time_UtX += (clock() - time_start) / (double(CLOCKS_PER_SEC) * 60.0); + + gsl_matrix_set_zero(Xlarge); + for (size_t i = 0; i < l; i++) { + // for each snp batch item extract transformed genotype: + gsl_vector_view UtXlarge_col = gsl_matrix_column(UtXlarge, i); + gsl_vector_safe_memcpy(Utx, &UtXlarge_col.vector); + + // Calculate design matrix components and compute sufficient statistics for the regression model + CalcUab(UtW, Uty, Utx, Uab); + + time_start = clock(); + FUNC_PARAM param1 = {false, ni_test, n_cvt, eval, Uab, ab, 0}; + + double lambda_mle = 0.0, lambda_remle = 0.0, beta = 0.0, se = 0.0, p_wald = 0.0; + double p_lrt = 0.0, p_score = 0.0; + double logl_H1 = 0.0; + + // Run statistical tests based on analysis mode + // 3 is before 1. + if (a_mode == M_LMM3 || a_mode == M_LMM4 || a_mode == M_LMM9 ) { + CalcRLScore(l_mle_null, param1, beta, se, p_score); + } + + // Computes Wald statistic for testing association + if (a_mode == M_LMM1 || a_mode == M_LMM4) { + // for univariate a_mode is 1 + // Estimates variance component (lambda) via REML + CalcLambda('R', param1, l_min, l_max, n_region, lambda_remle, logl_H1); + CalcRLWald(lambda_remle, param1, beta, se, p_wald); + } + + // Estimates variance component (lambda) via REML + // Likelihood Ratio Test (modes 2, 4, 9): + // Estimates variance component via MLE + // Compares log-likelihood under alternative vs null hypothesis + if (a_mode == M_LMM2 || a_mode == M_LMM4 || a_mode == M_LMM9) { + CalcLambda('L', param1, l_min, l_max, n_region, lambda_mle, logl_H1); + p_lrt = gsl_cdf_chisq_Q(2.0 * (logl_H1 - logl_mle_H0), 1); + } + + time_opt += (clock() - time_start) / (double(CLOCKS_PER_SEC) * 60.0); + + // Store summary data. + SUMSTAT SNPs = {beta, se, lambda_remle, lambda_mle, + p_wald, p_lrt, p_score, logl_H1}; + sumStat.push_back(SNPs); + } + // debug_msg("exit batch_compute"); + checkpoint_nofile("end-batch_compute"); + }; + + const auto num_snps = indicator_snp.size(); + enforce_msg(num_snps > 0,"Zero SNPs to process - data corrupt?"); + if (num_snps < 50) { + cerr << num_snps << " SNPs" << endl; + warning_msg("very few SNPs processed"); + } + const size_t progress_step = (num_snps/50>d_pace ? num_snps/50 : d_pace); + + for (size_t t = 0; t < num_snps; ++t) { + if (t % progress_step == 0 || t == (num_snps - 1)) { + ProgressBar("Reading SNPs", t, num_snps - 1); + } + if (indicator_snp[t] == 0) + continue; + + auto tup = fetch_snp(t); + auto snp = get<0>(tup); + auto gs = get<1>(tup); + + // check whether SNP is included in gwasnps (used by LOCO) + if (process_gwasnps && gwasnps.count(snp) == 0) + continue; + + // drop missing idv and plug mean values for missing geno + double x_total = 0.0; // sum genotype values to compute x_mean + uint pos = 0; // position in target vector + uint n_miss = 0; + gsl_vector_set_zero(x_miss); + for (size_t i = 0; i < ni_total; ++i) { + // get the genotypes per individual and compute stats per SNP + if (indicator_idv[i] == 0) // skip individual + continue; + + double geno = gs[i]; + if (isnan(geno)) { + gsl_vector_set(x_miss, pos, 1.0); + n_miss++; + } else { + gsl_vector_set(x, pos, geno); + x_total += geno; + } + pos++; + } + enforce(pos == ni_test); + + const double x_mean = x_total/(double)(ni_test - n_miss); + + // plug x_mean back into missing values + for (size_t i = 0; i < ni_test; ++i) { + if (gsl_vector_get(x_miss, i) == 1.0) { + gsl_vector_set(x, i, x_mean); + } + } + + /* this is what below GxE does + for (size_t i = 0; i < ni_test; ++i) { + auto geno = gsl_vector_get(x, i); + if (std::isnan(geno)) { + gsl_vector_set(x, i, x_mean); + geno = x_mean; + } + if (x_mean > 1.0) { + gsl_vector_set(x, i, 2 - geno); + } + } + */ + enforce(x->size == ni_test); + + // copy genotype values for SNP into Xlarge cache + gsl_vector_view Xlarge_col = gsl_matrix_column(Xlarge, c % msize); + gsl_vector_safe_memcpy(&Xlarge_col.vector, x); + c++; // count SNPs going in + + if (c % msize == 0) { + batch_compute(msize); + } + } + + batch_compute(c % msize); + ProgressBar("Reading SNPS", num_snps - 1, num_snps - 1); + // cout << "Counted SNPs " << c << " sumStat " << sumStat.size() << endl; + cout << endl; + checkpoint_nofile("end-lmm-analyze"); + + gsl_vector_safe_free(x); + gsl_vector_safe_free(x_miss); + gsl_vector_safe_free(Utx); + gsl_matrix_safe_free(Uab); + gsl_vector_free(ab); // unused + + gsl_matrix_safe_free(Xlarge); + gsl_matrix_safe_free(UtXlarge); + +} + void LMM::mdb_calc_gwa(const gsl_matrix *U, const gsl_vector *eval, const gsl_matrix *UtW, const gsl_vector *Uty, const gsl_matrix *W, const gsl_vector *y, const set<string> gwasnps) { checkpoint("mdb-calc-gwa",file_geno); - const auto num_snps = indicator_snp.size(); - enforce_msg(num_snps > 0,"Zero SNPs to process - data corrupt?"); + // const auto num_snps = indicator_snp.size(); + // enforce_msg(num_snps > 0,"Zero SNPs to process - data corrupt?"); auto env = lmdb::env::create(); @@ -1912,7 +2117,7 @@ void LMM::mdb_calc_gwa(const gsl_matrix *U, const gsl_vector *eval, return std::make_tuple(snp,gs); }; - LMM::Analyze(fetch_snp,U,eval,UtW,Uty,W,y,gwasnps); + LMM::mdb_analyze(fetch_snp,U,eval,UtW,Uty,W,y,gwasnps); } diff --git a/src/lmm.h b/src/lmm.h index cedbf38..c628baa 100644 --- a/src/lmm.h +++ b/src/lmm.h @@ -100,6 +100,11 @@ public: const gsl_matrix *UtW, const gsl_vector *Uty, const gsl_matrix *W, const gsl_vector *y, const set<string> gwasnps); + void mdb_analyze(std::function< SnpNameValues(size_t) >& fetch_snp, + const gsl_matrix *U, const gsl_vector *eval, + const gsl_matrix *UtW, const gsl_vector *Uty, + const gsl_matrix *W, const gsl_vector *y, + const set<string> gwasnps); void mdb_calc_gwa(const gsl_matrix *U, const gsl_vector *eval, const gsl_matrix *UtW, const gsl_vector *Uty, const gsl_matrix *W, const gsl_vector *y, |
