about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--src/lmm.cpp211
-rw-r--r--src/lmm.h5
2 files changed, 213 insertions, 3 deletions
diff --git a/src/lmm.cpp b/src/lmm.cpp
index 2ac9835..ee7ba42 100644
--- a/src/lmm.cpp
+++ b/src/lmm.cpp
@@ -1867,14 +1867,219 @@ void LMM::Analyze(std::function< SnpNameValues(size_t) >& fetch_snp,
   This is the mirror function of below AnalyzeBimbam, but uses mdb input instead.
  */
 
+
+void LMM::mdb_analyze(std::function< SnpNameValues(size_t) >& fetch_snp,
+                  const gsl_matrix *U, const gsl_vector *eval,
+                  const gsl_matrix *UtW, const gsl_vector *Uty,
+                  const gsl_matrix *W, const gsl_vector *y,
+                  const set<string> gwasnps) {
+  clock_t time_start = clock();
+
+  checkpoint_nofile("start-lmm-analyze");
+  // Subset/LOCO support
+  bool process_gwasnps = gwasnps.size();
+
+  // Calculate basic quantities.
+  size_t n_index = (n_cvt + 2 + 1) * (n_cvt + 2) / 2;
+
+  const size_t inds = U->size1;
+  enforce(inds == ni_test);
+  gsl_vector *x = gsl_vector_safe_alloc(inds); // #inds
+  gsl_vector *x_miss = gsl_vector_safe_alloc(inds);
+  gsl_vector *Utx = gsl_vector_safe_alloc(U->size2);
+  gsl_matrix *Uab = gsl_matrix_safe_alloc(U->size2, n_index);
+  gsl_vector *ab = gsl_vector_safe_alloc(n_index);
+
+  // Create a large matrix with LMM_BATCH_SIZE columns for batched processing
+  // const size_t msize=(process_gwasnps ? 1 : LMM_BATCH_SIZE);
+  const size_t msize = LMM_BATCH_SIZE;
+  gsl_matrix *Xlarge = gsl_matrix_safe_alloc(inds, msize);
+  gsl_matrix *UtXlarge = gsl_matrix_safe_alloc(inds, msize);
+  enforce_msg(Xlarge && UtXlarge, "Xlarge memory check"); // just to be sure
+  enforce(Xlarge->size1 == inds);
+  gsl_matrix_set_zero(Xlarge);
+  gsl_matrix_set_zero(Uab);
+  CalcUab(UtW, Uty, Uab);
+
+  // start reading genotypes and analyze
+  size_t c = 0;
+
+  /*
+    batch_compute(l) takes l SNPs that have been loaded into Xlarge,
+    transforms them all at once using the eigenvector matrix U, then
+    loops through each transformed SNP to compute association
+    statistics (beta, standard errors, p-values) and stores results in
+    sumStat.
+  */
+  auto batch_compute = [&](size_t l) { // using a C++ closure
+    // Compute SNPs in batch, note the computations are independent per SNP
+    // debug_msg("enter batch_compute");
+    checkpoint_nofile("\nstart-batch_compute");
+    gsl_matrix_view Xlarge_sub = gsl_matrix_submatrix(Xlarge, 0, 0, inds, l);
+    gsl_matrix_view UtXlarge_sub =
+        gsl_matrix_submatrix(UtXlarge, 0, 0, inds, l);
+
+    time_start = clock();
+    // Transforms all l SNPs in the batch at once: UtXlarge = U^T × Xlarge
+    // This is much faster than doing l separate matrix-vector products
+    // U is the eigenvector matrix from the spectral decomposition of the kinship matrix
+    fast_dgemm("T", "N", 1.0, U, &Xlarge_sub.matrix, 0.0,
+                   &UtXlarge_sub.matrix);
+    time_UtX += (clock() - time_start) / (double(CLOCKS_PER_SEC) * 60.0);
+
+    gsl_matrix_set_zero(Xlarge);
+    for (size_t i = 0; i < l; i++) {
+      // for each snp batch item extract transformed genotype:
+      gsl_vector_view UtXlarge_col = gsl_matrix_column(UtXlarge, i);
+      gsl_vector_safe_memcpy(Utx, &UtXlarge_col.vector);
+
+      // Calculate design matrix components and compute sufficient statistics for the regression model
+      CalcUab(UtW, Uty, Utx, Uab);
+
+      time_start = clock();
+      FUNC_PARAM param1 = {false, ni_test, n_cvt, eval, Uab, ab, 0};
+
+      double lambda_mle = 0.0, lambda_remle = 0.0, beta = 0.0, se = 0.0, p_wald = 0.0;
+      double p_lrt = 0.0, p_score = 0.0;
+      double logl_H1 = 0.0;
+
+      // Run statistical tests based on analysis mode
+      // 3 is before 1.
+      if (a_mode == M_LMM3 || a_mode == M_LMM4 || a_mode == M_LMM9 ) {
+        CalcRLScore(l_mle_null, param1, beta, se, p_score);
+      }
+
+      // Computes Wald statistic for testing association
+      if (a_mode == M_LMM1 || a_mode == M_LMM4) {
+        // for univariate a_mode is 1
+        // Estimates variance component (lambda) via REML
+        CalcLambda('R', param1, l_min, l_max, n_region, lambda_remle, logl_H1);
+        CalcRLWald(lambda_remle, param1, beta, se, p_wald);
+      }
+
+      // Estimates variance component (lambda) via REML
+      // Likelihood Ratio Test (modes 2, 4, 9):
+      // Estimates variance component via MLE
+      // Compares log-likelihood under alternative vs null hypothesis
+      if (a_mode == M_LMM2 || a_mode == M_LMM4 || a_mode == M_LMM9) {
+        CalcLambda('L', param1, l_min, l_max, n_region, lambda_mle, logl_H1);
+        p_lrt = gsl_cdf_chisq_Q(2.0 * (logl_H1 - logl_mle_H0), 1);
+      }
+
+      time_opt += (clock() - time_start) / (double(CLOCKS_PER_SEC) * 60.0);
+
+      // Store summary data.
+      SUMSTAT SNPs = {beta,   se,    lambda_remle, lambda_mle,
+                      p_wald, p_lrt, p_score, logl_H1};
+      sumStat.push_back(SNPs);
+    }
+    // debug_msg("exit batch_compute");
+    checkpoint_nofile("end-batch_compute");
+  };
+
+  const auto num_snps = indicator_snp.size();
+  enforce_msg(num_snps > 0,"Zero SNPs to process - data corrupt?");
+  if (num_snps < 50) {
+    cerr << num_snps << " SNPs" << endl;
+    warning_msg("very few SNPs processed");
+  }
+  const size_t progress_step = (num_snps/50>d_pace ? num_snps/50 : d_pace);
+
+  for (size_t t = 0; t < num_snps; ++t) {
+    if (t % progress_step == 0 || t == (num_snps - 1)) {
+      ProgressBar("Reading SNPs", t, num_snps - 1);
+    }
+    if (indicator_snp[t] == 0)
+      continue;
+
+    auto tup = fetch_snp(t);
+    auto snp = get<0>(tup);
+    auto gs = get<1>(tup);
+
+    // check whether SNP is included in gwasnps (used by LOCO)
+    if (process_gwasnps && gwasnps.count(snp) == 0)
+      continue;
+
+    // drop missing idv and plug mean values for missing geno
+    double x_total = 0.0; // sum genotype values to compute x_mean
+    uint pos = 0;         // position in target vector
+    uint n_miss = 0;
+    gsl_vector_set_zero(x_miss);
+    for (size_t i = 0; i < ni_total; ++i) {
+      // get the genotypes per individual and compute stats per SNP
+      if (indicator_idv[i] == 0) // skip individual
+        continue;
+
+      double geno = gs[i];
+      if (isnan(geno)) {
+        gsl_vector_set(x_miss, pos, 1.0);
+        n_miss++;
+      } else {
+        gsl_vector_set(x, pos, geno);
+        x_total += geno;
+      }
+      pos++;
+    }
+    enforce(pos == ni_test);
+
+    const double x_mean = x_total/(double)(ni_test - n_miss);
+
+    // plug x_mean back into missing values
+    for (size_t i = 0; i < ni_test; ++i) {
+      if (gsl_vector_get(x_miss, i) == 1.0) {
+        gsl_vector_set(x, i, x_mean);
+      }
+    }
+
+    /* this is what below GxE does
+    for (size_t i = 0; i < ni_test; ++i) {
+      auto geno = gsl_vector_get(x, i);
+      if (std::isnan(geno)) {
+        gsl_vector_set(x, i, x_mean);
+        geno = x_mean;
+      }
+      if (x_mean > 1.0) {
+        gsl_vector_set(x, i, 2 - geno);
+      }
+    }
+    */
+    enforce(x->size == ni_test);
+
+    // copy genotype values for SNP into Xlarge cache
+    gsl_vector_view Xlarge_col = gsl_matrix_column(Xlarge, c % msize);
+    gsl_vector_safe_memcpy(&Xlarge_col.vector, x);
+    c++; // count SNPs going in
+
+    if (c % msize == 0) {
+      batch_compute(msize);
+    }
+  }
+
+  batch_compute(c % msize);
+  ProgressBar("Reading SNPS", num_snps - 1, num_snps - 1);
+  // cout << "Counted SNPs " << c << " sumStat " << sumStat.size() << endl;
+  cout << endl;
+  checkpoint_nofile("end-lmm-analyze");
+
+  gsl_vector_safe_free(x);
+  gsl_vector_safe_free(x_miss);
+  gsl_vector_safe_free(Utx);
+  gsl_matrix_safe_free(Uab);
+  gsl_vector_free(ab); // unused
+
+  gsl_matrix_safe_free(Xlarge);
+  gsl_matrix_safe_free(UtXlarge);
+
+}
+
 void LMM::mdb_calc_gwa(const gsl_matrix *U, const gsl_vector *eval,
                           const gsl_matrix *UtW, const gsl_vector *Uty,
                           const gsl_matrix *W, const gsl_vector *y,
                           const set<string> gwasnps) {
   checkpoint("mdb-calc-gwa",file_geno);
 
-  const auto num_snps = indicator_snp.size();
-  enforce_msg(num_snps > 0,"Zero SNPs to process - data corrupt?");
+  // const auto num_snps = indicator_snp.size();
+  // enforce_msg(num_snps > 0,"Zero SNPs to process - data corrupt?");
 
   auto env = lmdb::env::create();
 
@@ -1912,7 +2117,7 @@ void LMM::mdb_calc_gwa(const gsl_matrix *U, const gsl_vector *eval,
     return std::make_tuple(snp,gs);
   };
 
-  LMM::Analyze(fetch_snp,U,eval,UtW,Uty,W,y,gwasnps);
+  LMM::mdb_analyze(fetch_snp,U,eval,UtW,Uty,W,y,gwasnps);
 }
 
 
diff --git a/src/lmm.h b/src/lmm.h
index cedbf38..c628baa 100644
--- a/src/lmm.h
+++ b/src/lmm.h
@@ -100,6 +100,11 @@ public:
                const gsl_matrix *UtW, const gsl_vector *Uty,
                const gsl_matrix *W, const gsl_vector *y,
                const set<string> gwasnps);
+  void mdb_analyze(std::function< SnpNameValues(size_t) >& fetch_snp,
+               const gsl_matrix *U, const gsl_vector *eval,
+               const gsl_matrix *UtW, const gsl_vector *Uty,
+               const gsl_matrix *W, const gsl_vector *y,
+               const set<string> gwasnps);
   void mdb_calc_gwa(const gsl_matrix *U, const gsl_vector *eval,
                      const gsl_matrix *UtW, const gsl_vector *Uty,
                      const gsl_matrix *W, const gsl_vector *y,