about summary refs log tree commit diff
path: root/src/lmm.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/lmm.cpp')
-rw-r--r--src/lmm.cpp484
1 files changed, 459 insertions, 25 deletions
diff --git a/src/lmm.cpp b/src/lmm.cpp
index 87b9e1e..2b730f3 100644
--- a/src/lmm.cpp
+++ b/src/lmm.cpp
@@ -2,7 +2,7 @@
     Genome-wide Efficient Mixed Model Association (GEMMA)
     Copyright © 2011-2017, Xiang Zhou
     Copyright © 2017, Peter Carbonetto
-    Copyright © 2017-2022 Pjotr Prins
+    Copyright © 2017-2025 Pjotr Prins
 
     This program is free software: you can redistribute it and/or modify
     it under the terms of the GNU General Public License as published by
@@ -40,6 +40,9 @@
 #include "gsl/gsl_min.h"
 #include "gsl/gsl_roots.h"
 #include "gsl/gsl_vector.h"
+#include <lmdb++.h>
+// #include <lmdb.h>
+#include <sys/mman.h>
 
 #include "gzstream.h"
 #include "gemma.h"
@@ -55,6 +58,7 @@
 using namespace std;
 
 void LMM::CopyFromParam(PARAM &cPar) {
+  checkpoint_nofile("lmm-copy-from-param");
   a_mode = cPar.a_mode;
   d_pace = cPar.d_pace;
 
@@ -91,10 +95,12 @@ void LMM::CopyFromParam(PARAM &cPar) {
 }
 
 void LMM::CopyToParam(PARAM &cPar) {
+  checkpoint_nofile("lmm-copy-to-param");
   cPar.time_UtX = time_UtX;
   cPar.time_opt = time_opt;
-
-  cPar.ng_test = ng_test;
+  cPar.ns_total = ns_total;
+  cPar.ns_test  = ns_test;
+  cPar.ng_test  = ng_test; // number of markers tested
 
   return;
 }
@@ -105,10 +111,11 @@ void LMM::WriteFiles() {
 
   file_str = path_out + "/" + file_out;
   file_str += ".assoc.txt";
+  checkpoint("lmm-write-files",file_str);
 
   ofstream outfile(file_str.c_str(), ofstream::out);
   if (!outfile) {
-    cout << "error writing file: " << file_str.c_str() << endl;
+    cout << "error writing file: " << file_str << endl;
     return;
   }
 
@@ -148,7 +155,7 @@ void LMM::WriteFiles() {
     outfile << scientific << setprecision(6);
 
     if (a_mode != M_LMM2) {
-      outfile << st.beta << "\t";
+      outfile << scientific << st.beta << "\t";
       outfile << st.se << "\t";
     }
 
@@ -202,21 +209,29 @@ void LMM::WriteFiles() {
 
     common_header();
 
-    size_t t = 0;
-    for (size_t i = 0; i < snpInfo.size(); ++i) {
-      if (indicator_snp[i] == 0)
-        continue;
-      auto snp = snpInfo[i].rs_number;
-      if (process_gwasnps && setGWASnps.count(snp) == 0)
-        continue;
-      // cout << t << endl;
-      outfile << snpInfo[i].chr << "\t" << snpInfo[i].rs_number << "\t"
-              << snpInfo[i].base_position << "\t" << snpInfo[i].n_miss << "\t"
-              << snpInfo[i].a_minor << "\t" << snpInfo[i].a_major << "\t"
-              << fixed << setprecision(3) << snpInfo[i].maf << "\t";
-
-      sumstats(sumStat[t]);
-      t++;
+    if (snpInfo.size()) {
+      size_t t = 0;
+      for (size_t i = 0; i < snpInfo.size(); ++i) {
+        if (indicator_snp[i] == 0)
+          continue;
+        auto snp = snpInfo[i].rs_number;
+        if (process_gwasnps && setGWASnps.count(snp) == 0)
+          continue;
+        // cout << t << endl;
+        outfile << snpInfo[i].chr << "\t" << snpInfo[i].rs_number << "\t"
+                << snpInfo[i].base_position << "\t" << snpInfo[i].n_miss << "\t"
+                << snpInfo[i].a_minor << "\t" << snpInfo[i].a_major << "\t"
+                << fixed << setprecision(3) << snpInfo[i].maf << "\t";
+
+        sumstats(sumStat[t]);
+        t++;
+      }
+    }
+    else
+    {
+      for (auto &s : sumStat) {
+        sumstats(s);
+      }
     }
   }
 
@@ -1662,12 +1677,9 @@ void LMM::Analyze(std::function< SnpNameValues(size_t) >& fetch_snp,
                   const set<string> gwasnps) {
   clock_t time_start = clock();
 
-  write(W, "W");
-  write(y, "y");
+  checkpoint_nofile("start-lmm-analyze");
   // Subset/LOCO support
   bool process_gwasnps = gwasnps.size();
-  if (process_gwasnps)
-    debug_msg("Analyze subset of SNPs (LOCO)");
 
   // Calculate basic quantities.
   size_t n_index = (n_cvt + 2 + 1) * (n_cvt + 2) / 2;
@@ -1704,6 +1716,7 @@ void LMM::Analyze(std::function< SnpNameValues(size_t) >& fetch_snp,
   auto batch_compute = [&](size_t l) { // using a C++ closure
     // Compute SNPs in batch, note the computations are independent per SNP
     // debug_msg("enter batch_compute");
+    checkpoint_nofile("\nstart-batch_compute");
     gsl_matrix_view Xlarge_sub = gsl_matrix_submatrix(Xlarge, 0, 0, inds, l);
     gsl_matrix_view UtXlarge_sub =
         gsl_matrix_submatrix(UtXlarge, 0, 0, inds, l);
@@ -1763,6 +1776,7 @@ void LMM::Analyze(std::function< SnpNameValues(size_t) >& fetch_snp,
       sumStat.push_back(SNPs);
     }
     // debug_msg("exit batch_compute");
+    checkpoint_nofile("end-batch_compute");
   };
 
   const auto num_snps = indicator_snp.size();
@@ -1844,9 +1858,10 @@ void LMM::Analyze(std::function< SnpNameValues(size_t) >& fetch_snp,
   }
 
   batch_compute(c % msize);
-  ProgressBar("Reading SNPs", num_snps - 1, num_snps - 1);
+  ProgressBar("Reading SNPS", num_snps - 1, num_snps - 1);
   // cout << "Counted SNPs " << c << " sumStat " << sumStat.size() << endl;
   cout << endl;
+  checkpoint_nofile("end-lmm-analyze");
 
   gsl_vector_safe_free(x);
   gsl_vector_safe_free(x_miss);
@@ -1860,6 +1875,425 @@ void LMM::Analyze(std::function< SnpNameValues(size_t) >& fetch_snp,
 }
 
 /*
+  This is the mirror function of LMM::Analyze and AnalyzeBimbam, but uses mdb input instead.
+ */
+
+
+
+void LMM::mdb_analyze(std::function< SnpNameValues2(size_t) >& fetch_snp,
+                      const gsl_matrix *U, const gsl_vector *eval,
+                      const gsl_matrix *UtW, const gsl_vector *Uty,
+                      const gsl_matrix *W, const gsl_vector *y,
+                      size_t num_markers) {
+  vector<SUMSTAT2> sumstat2;
+  clock_t time_start = clock();
+
+  checkpoint_nofile("start-lmm-mdb-analyze");
+
+  // Calculate basic quantities.
+  size_t n_index = (n_cvt + 2 + 1) * (n_cvt + 2) / 2;
+
+  const size_t inds = U->size1;
+  enforce(inds == ni_test);
+  assert(inds > 0);
+
+  gsl_vector *x = gsl_vector_safe_alloc(inds); // #inds
+  gsl_vector *x_miss = gsl_vector_safe_alloc(inds);
+  assert(ni_test == U->size2);
+  assert(ni_test > 0);
+  assert(ni_total > 0);
+  assert(n_index > 0);
+  gsl_vector *Utx = gsl_vector_safe_alloc(ni_test);
+  gsl_matrix *Uab = gsl_matrix_safe_alloc(ni_test, n_index);
+  gsl_vector *ab = gsl_vector_safe_alloc(n_index);
+
+  const size_t msize = LMM_BATCH_SIZE;
+  gsl_matrix *Xlarge = gsl_matrix_safe_alloc(inds, msize);
+  gsl_matrix *UtXlarge = gsl_matrix_safe_alloc(inds, msize);
+  enforce_msg(Xlarge && UtXlarge, "Xlarge memory check"); // just to be sure
+  enforce(Xlarge->size1 == inds);
+  gsl_matrix_set_zero(Xlarge);
+  gsl_matrix_set_zero(Uab);
+  CalcUab(UtW, Uty, Uab);
+
+  // start reading genotypes and analyze
+  size_t c = 0;
+
+  /*
+    batch_compute(l) takes l x SNPs that have been loaded into Xlarge,
+    transforms them all at once using the eigenvector matrix U, then
+    loops through each transformed SNP to compute association
+    statistics (beta, standard errors, p-values) and stores results in
+    sumStat.
+  */
+  auto batch_compute = [&](size_t l, const Markers &markers) { // using a C++ closure
+    // Compute SNPs in batch, note the computations are independent per SNP
+    debug_msg("enter batch_compute");
+    assert(l>0);
+    assert(inds>0);
+    gsl_matrix_view Xlarge_sub = gsl_matrix_submatrix(Xlarge, 0, 0, inds, l);
+    gsl_matrix_view UtXlarge_sub =
+        gsl_matrix_submatrix(UtXlarge, 0, 0, inds, l);
+
+    time_start = clock();
+    // Transforms all l SNPs in the batch at once: UtXlarge = U^T × Xlarge
+    // This is much faster than doing l separate matrix-vector products
+    // U is the eigenvector matrix from the spectral decomposition of the kinship matrix
+    fast_dgemm("T", "N", 1.0, U, &Xlarge_sub.matrix, 0.0,
+                   &UtXlarge_sub.matrix);
+    time_UtX += (clock() - time_start) / (double(CLOCKS_PER_SEC) * 60.0);
+
+    gsl_matrix_set_zero(Xlarge);
+    for (size_t i = 0; i < l; i++) {
+      // for each snp batch item extract transformed genotype:
+      gsl_vector_view UtXlarge_col = gsl_matrix_column(UtXlarge, i);
+      gsl_vector_safe_memcpy(Utx, &UtXlarge_col.vector);
+
+      // Calculate design matrix components and compute sufficient statistics for the regression model
+      CalcUab(UtW, Uty, Utx, Uab);
+
+      time_start = clock();
+      FUNC_PARAM param1 = {false, ni_test, n_cvt, eval, Uab, ab, 0};
+
+      double lambda_mle = 0.0, lambda_remle = 0.0, beta = 0.0, se = 0.0, p_wald = 0.0;
+      double p_lrt = 0.0, p_score = 0.0;
+      double logl_H1 = 0.0;
+
+      // Run statistical tests based on analysis mode
+      // 3 is before 1.
+      if (a_mode == M_LMM3 || a_mode == M_LMM4 || a_mode == M_LMM9 ) {
+        CalcRLScore(l_mle_null, param1, beta, se, p_score);
+      }
+
+      // Computes Wald statistic for testing association
+      if (a_mode == M_LMM1 || a_mode == M_LMM4) {
+        // for univariate a_mode is 1
+        // Estimates variance component (lambda) via REML
+        CalcLambda('R', param1, l_min, l_max, n_region, lambda_remle, logl_H1);
+        CalcRLWald(lambda_remle, param1, beta, se, p_wald);
+      }
+
+      // Estimates variance component (lambda) via REML
+      // Likelihood Ratio Test (modes 2, 4, 9):
+      // Estimates variance component via MLE
+      // Compares log-likelihood under alternative vs null hypothesis
+      if (a_mode == M_LMM2 || a_mode == M_LMM4 || a_mode == M_LMM9) {
+        CalcLambda('L', param1, l_min, l_max, n_region, lambda_mle, logl_H1);
+        p_lrt = gsl_cdf_chisq_Q(2.0 * (logl_H1 - logl_mle_H0), 1);
+      }
+
+      time_opt += (clock() - time_start) / (double(CLOCKS_PER_SEC) * 60.0);
+
+      auto markerinfo = markers[i];
+      // Store summary data.
+      SUMSTAT2 st = {markerinfo, beta, se, lambda_remle, lambda_mle, p_wald, p_lrt, p_score, logl_H1};
+      sumstat2.push_back(st);
+    }
+  };
+
+  /*
+  const auto num_markers = indicator_snp.size();
+  enforce_msg(num_markers > 0,"Zero SNPs to process - data corrupt?");
+  if (num_markers < 50) {
+    cerr << num_markers << " SNPs" << endl;
+    warning_msg("very few SNPs processed");
+  }
+  */
+  const size_t progress_step = (num_markers/50>d_pace ? num_markers/50 : d_pace);
+  Markers markers;
+
+  assert(num_markers > 0);
+  for (size_t t = 0; t < num_markers; ++t) {
+  // for (size_t t = 0; t < 2; ++t) {
+    if (t % progress_step == 0 || t == (num_markers - 1)) {
+      ProgressBar("Reading markers", t, num_markers - 1);
+    }
+    // if (indicator_snp[t] == 0)
+    // continue;
+
+    auto tup = fetch_snp(t); // use the callback
+    auto state = get<0>(tup);
+    if (state == SKIP)
+      continue;
+    if (state == LAST)
+      break; // marker loop because of LOCO
+
+    auto markerinfo = get<1>(tup);
+    auto gs = get<2>(tup);
+
+    markers.push_back(markerinfo);
+
+    // drop missing idv and plug mean values for missing geno
+    double x_total = 0.0; // sum genotype values to compute x_mean
+    uint vpos = 0;        // position in target vector
+    uint n_miss = 0;      // count NA genotypes
+    gsl_vector_set_zero(x_miss);
+    for (size_t i = 0; i < ni_total; ++i) {
+      // get the genotypes per individual and compute stats per SNP
+      if (indicator_idv[i] == 0) // skip individual
+        continue;
+
+      double geno = gs[i];
+      if (isnan(geno)) {
+        gsl_vector_set(x_miss, vpos, 1.0);
+        n_miss++;
+      } else {
+        gsl_vector_set(x, vpos, geno);
+        x_total += geno;
+      }
+      vpos++;
+    }
+    enforce(vpos == ni_test);
+
+    const double x_mean = x_total/(double)(ni_test - n_miss);
+
+    // plug x_mean back into missing values
+    for (size_t i = 0; i < ni_test; ++i) {
+      if (gsl_vector_get(x_miss, i) == 1.0) {
+        gsl_vector_set(x, i, x_mean);
+      }
+    }
+
+    enforce(x->size == ni_test);
+
+    // copy genotype values for SNP into Xlarge cache
+    gsl_vector_view Xlarge_col = gsl_matrix_column(Xlarge, c % msize);
+    gsl_vector_safe_memcpy(&Xlarge_col.vector, x);
+    c++; // count markers going in
+
+    if (c % msize == 0) {
+      batch_compute(msize,markers);
+      markers.clear();
+      markers.reserve(msize);
+    }
+  }
+
+  if (c % msize)
+    batch_compute(c % msize,markers);
+  ProgressBar("Reading markers", num_markers - 1, num_markers - 1);
+  cout << endl;
+  cout << "Counted markers " << c << " sumStat " << sumstat2.size() << endl;
+  checkpoint_nofile("end-lmm-mdb-analyze");
+
+  string file_str;
+  debug_msg("LMM::WriteFiles");
+  file_str = path_out + "/" + file_out;
+  file_str += ".assoc.txt";
+  checkpoint("lmm-write-files",file_str);
+
+  ofstream outfile(file_str.c_str(), ofstream::out);
+  if (!outfile) {
+    cout << "error writing file: " << file_str << endl;
+    return;
+  }
+
+  auto sumstats = [&] (SUMSTAT2 st) {
+    outfile << scientific << setprecision(6);
+    auto m = st.markerinfo;
+    auto name = m.name;
+    auto chr  = m.chr;
+    string chr_s;
+    if (chr == CHR_X)
+      chr_s = "X";
+    else if (chr == CHR_Y)
+      chr_s = "Y";
+    else if (chr == CHR_M)
+      chr_s = "M";
+    else
+      chr_s = to_string(chr);
+
+    outfile << chr_s << "\t";
+    outfile << name << "\t";
+    outfile << m.pos << "\t";
+    outfile << m.n_miss << "\t-\t-\t"; // n_miss column + allele1 + allele0
+    outfile << fixed << setprecision(3) << m.maf << "\t";
+    outfile << scientific << setprecision(6);
+    if (a_mode != M_LMM2) {
+      outfile << st.beta << "\t";
+      outfile << st.se << "\t";
+    }
+
+    if (a_mode != M_LMM3 && a_mode != M_LMM9)
+      outfile << st.logl_H1 << "\t";
+
+    switch(a_mode) {
+    case M_LMM1:
+      outfile << st.lambda_remle << "\t"
+              << st.p_wald << endl;
+      break;
+    case M_LMM2:
+    case M_LMM9:
+      outfile << st.lambda_mle << "\t"
+              << st.p_lrt << endl;
+      break;
+    case M_LMM3:
+      outfile << st.p_score << endl;
+      break;
+    case M_LMM4:
+      outfile << st.lambda_remle << "\t"
+              << st.lambda_mle << "\t"
+              << st.p_wald << "\t"
+              << st.p_lrt << "\t"
+              << st.p_score << endl;
+      break;
+    }
+  };
+
+  for (auto &s : sumstat2) {
+    sumstats(s);
+  }
+
+  gsl_vector_safe_free(x);
+  gsl_vector_safe_free(x_miss);
+  gsl_vector_safe_free(Utx);
+  gsl_matrix_safe_free(Uab);
+  gsl_vector_free(ab); // unused
+
+  gsl_matrix_safe_free(Xlarge);
+  gsl_matrix_safe_free(UtXlarge);
+
+}
+
+
+void LMM::mdb_calc_gwa(const gsl_matrix *U, const gsl_vector *eval,
+                          const gsl_matrix *UtW, const gsl_vector *Uty,
+                       const gsl_matrix *W, const gsl_vector *y, const string loco) {
+  checkpoint("mdb-calc-gwa",file_geno);
+  bool is_loco = !loco.empty();
+
+  // Convert loco string to what we use in the chrpos index
+  uint8_t loco_chr;
+  if (is_loco) {
+    if (loco == "X") {
+      loco_chr = CHR_X;
+    } else if (loco == "Y") {
+      loco_chr = CHR_Y;
+    } else if (loco == "M") {
+      loco_chr = CHR_M;
+    } else {
+      try {
+        loco_chr = static_cast<uint8_t>(stoi(loco));
+      } catch (...) {
+        loco_chr = 0;
+      }
+    }
+  }
+
+  auto env = lmdb::env::create();
+  env.set_mapsize(1UL * 1024UL * 1024UL * 1024UL * 1024UL);
+  env.set_max_dbs(10);
+  env.open(file_geno.c_str(), MDB_RDONLY | MDB_NOSUBDIR, 0664);
+  // Get mmap info using lmdb++ wrapper
+  MDB_envinfo info;
+  mdb_env_info(env.handle(), &info);
+  // Linux kernel aggressive readahead hints
+  madvise(info.me_mapaddr, info.me_mapsize, MADV_SEQUENTIAL);
+  madvise(info.me_mapaddr, info.me_mapsize, MADV_WILLNEED);
+
+  std::cout << "## LMDB opened with optimized readahead; map size = " << (info.me_mapsize / 1024 / 1024) << " MB" << std::endl;
+
+  auto rtxn = lmdb::txn::begin(env, nullptr, MDB_RDONLY);
+  auto geno_mdb = lmdb::dbi::open(rtxn, "geno");
+
+  auto marker_mdb = lmdb::dbi::open(rtxn, "marker");
+  auto info_mdb = lmdb::dbi::open(rtxn, "info");
+  string_view info_key,info_value;
+  info_mdb.get(rtxn,"format",info_value);
+  auto format = string(info_value);
+
+  MDB_stat stat;
+  mdb_stat(rtxn, geno_mdb, &stat);
+  auto num_markers = stat.ms_entries;
+
+  auto mdb_fetch = MDB_FIRST;
+
+  auto cursor = lmdb::cursor::open(rtxn, geno_mdb);
+  cout << "## number of total individuals = " << ni_total << endl;
+  cout << "## number of analyzed individuals = " << ni_test << endl;
+  cout << "## number of analyzed SNPs/var = " << num_markers << endl;
+
+  std::function<SnpNameValues2(size_t)>  fetch_snp = [&](size_t num) {
+
+    string_view key,value;
+
+    auto mdb_success = cursor.get(key, value, mdb_fetch);
+    mdb_fetch = MDB_NEXT;
+
+    // uint8_t chr;
+    vector<double> gs;
+    MarkerInfo markerinfo;
+
+    if (mdb_success) {
+      size_t size = 0;
+      // ---- Depending on the format we get different buffers - currently float and byte are supported:
+      if (format == "Gb") {
+        size_t num_bytes = value.size() / sizeof(uint8_t);
+        assert(num_bytes == ni_total);
+        size = num_bytes;
+        const uint8_t* gs_bbuf = reinterpret_cast<const uint8_t*>(value.data());
+        gs.reserve(size);
+        for (size_t i = 0; i < size; ++i) {
+          double g = static_cast<double>(gs_bbuf[i])/127.0;
+          gs.push_back(g);
+        }
+      } else {
+        size_t num_floats = value.size() / sizeof(float);
+        assert(num_floats == ni_total);
+        size = num_floats;
+        const float* gs_fbuf = reinterpret_cast<const float*>(value.data());
+        gs.reserve(size);
+        for (size_t i = 0; i < size; ++i) {
+          gs.push_back(static_cast<double>(gs_fbuf[i]));
+        }
+      }
+      // Start unpacking key chr+pos
+      if (key.size() != 10) {
+        cerr << "key.size=" << key.size() << endl;
+        throw std::runtime_error("Invalid key size");
+      }
+      // "S>L>L>"
+      const uint8_t* data = reinterpret_cast<const uint8_t*>(key.data());
+      auto chr = static_cast<uint8_t>(data[1]);
+      // Extract big-endian uint32
+      // uint32_t rest = static_cast<uint32_t>(data[2]);
+      uint32_t pos =  (data[2] << 24) | (data[3] << 16) |
+        (data[4] << 8) | data[5];
+
+      uint32_t num = (data[6] << 24) | (data[7] << 16) |
+        (data[8] << 8) | data[9];
+
+      // printf("%#02x %#02x\n", chr, loco_chr);
+
+      if (is_loco && loco_chr != chr) {
+        if (chr > loco_chr)
+            return make_tuple(LAST, MarkerInfo { .name="", .chr=chr, .pos=pos } , gs);
+          else
+            return make_tuple(SKIP, MarkerInfo { .name="", .chr=chr, .pos=pos } , gs);
+      }
+
+      string_view value2;
+      marker_mdb.get(rtxn,key,value2);
+      auto marker = string(value2);
+      // 1       rs13476251      174792257
+      // cout << static_cast<int>(chr) << ":" << pos2 << " line " << rest2 << ":" << marker << endl ;
+
+      // compute maf and n_miss (NAs)
+      size_t n_miss = 0; // count NAs: FIXME
+      double maf = compute_maf(ni_total, ni_test, n_miss, gs.data(), indicator_idv);
+
+      markerinfo = MarkerInfo { .name=marker,.chr=chr,.pos=pos,.line_no=num,.n_miss=n_miss,.maf=maf };
+
+      // cout << "!!!!" << size << marker << ": af" << maf << " " << gs[0] << "," << gs[1] << "," << gs[2] << "," << gs[3] << endl;
+    }
+    return make_tuple(COMPUTE, markerinfo, gs);
+  };
+  LMM::mdb_analyze(fetch_snp,U,eval,UtW,Uty,W,y,num_markers);
+
+  ns_total = ns_test = num_markers; // track global number of snps in original gemma (goes to cPar)
+}
+
+
+/*
 
 Looking at the `LMM::AnalyzeBimbam` function, here are the parameters that get modified: