about summary refs log tree commit diff
diff options
context:
space:
mode:
authorPjotr Prins2017-10-13 13:07:40 +0000
committerPjotr Prins2017-10-13 15:27:24 +0000
commita5cee35c058b48f725a9eceabc205eb0cd1cbb07 (patch)
tree47c3d1f0b332490c9919033c7219f44010b2b886
parenta610dd723a233aed1abe31aa32e3137b23b5f983 (diff)
downloadpangemma-a5cee35c058b48f725a9eceabc205eb0cd1cbb07.tar.gz
Replacing first dgemm - tests fail
-rw-r--r--src/fastblas.cpp34
-rw-r--r--src/fastblas.h15
-rw-r--r--src/lmm.cpp3
-rwxr-xr-xtest/dev_test_suite.sh7
4 files changed, 18 insertions, 41 deletions
diff --git a/src/fastblas.cpp b/src/fastblas.cpp
index 38ca326..b25d964 100644
--- a/src/fastblas.cpp
+++ b/src/fastblas.cpp
@@ -24,6 +24,9 @@
 #include "debug.h"
 #include "fastblas.h"
 #include "mathfunc.h"
+#ifndef NDEBUG
+#include "eigenlib.h"
+#endif
 
 using namespace std;
 
@@ -169,28 +172,6 @@ static void fast_cblas_dgemm(const char *TransA, const char *TransB, const doubl
   blasint lda = (transA==CblasNoTrans ? K : M );
   blasint ldb = (transB==CblasNoTrans ? N : K );
   blasint ldc = N;
-  cout << rowsA << endl;
-  assert(transA == CblasNoTrans);
-  assert(transB == CblasNoTrans);
-  assert(rowsA == 2000);
-  assert(colsA == 200);
-  assert(lda == K);
-  assert(ldb == N);
-  assert(ldc == N);
-  assert(A->size2 == A->tda);
-  assert(B->size2 == B->tda);
-  assert(C->size2 == C->tda);
-
-  // cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans,
-  //             m, n, k, alpha, A, k, B, n, beta, C, n);
-  // m = 2000, k = 200, n = 1000;
-  assert(M==2000);
-  assert(K==200);
-  assert(N==1000);
-
-  auto k = K;
-  auto m = M;
-  auto n = N;
 
   fast_cblas_dgemm(CblasRowMajor, transA, transB, M, N, K, alpha,
               /* A */ A->data,
@@ -209,4 +190,13 @@ void fast_dgemm(const char *TransA, const char *TransB, const double alpha,
                 const gsl_matrix *A, const gsl_matrix *B, const double beta,
                 gsl_matrix *C) {
   fast_cblas_dgemm(TransA,TransB,alpha,A,B,beta,C);
+
+  #ifndef NDEBUG
+  if (is_strict_mode() && !is_no_check_mode()) {
+    // ---- validate with original implementation
+    gsl_matrix *C1 = gsl_matrix_alloc(C->size1,C->size2);
+    eigenlib_dgemm(TransA,TransB,alpha,A,B,beta,C1);
+    enforce_msg(gsl_matrix_equal(C,C1),"dgemm outcomes are not equal for fast & eigenlib");
+  }
+  #endif
 }
diff --git a/src/fastblas.h b/src/fastblas.h
index 3c28729..d0f5c14 100644
--- a/src/fastblas.h
+++ b/src/fastblas.h
@@ -7,21 +7,6 @@
 
 gsl_matrix *fast_copy(gsl_matrix *m, const double *mem);
 
-void fast_cblas_dgemm(OPENBLAS_CONST enum CBLAS_ORDER Order,
-                      OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA,
-                      OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB,
-                      OPENBLAS_CONST blasint M,
-                      OPENBLAS_CONST blasint N,
-                      OPENBLAS_CONST blasint K,
-                      OPENBLAS_CONST double alpha,
-                      OPENBLAS_CONST double *A,
-                      OPENBLAS_CONST blasint lda,
-                      OPENBLAS_CONST double *B,
-                      OPENBLAS_CONST blasint ldb,
-                      OPENBLAS_CONST double beta,
-                      double *C,
-                      OPENBLAS_CONST blasint ldc);
-
 void fast_dgemm(const char *TransA, const char *TransB, const double alpha,
                 const gsl_matrix *A, const gsl_matrix *B, const double beta,
                 gsl_matrix *C);
diff --git a/src/lmm.cpp b/src/lmm.cpp
index 71aa184..7b32330 100644
--- a/src/lmm.cpp
+++ b/src/lmm.cpp
@@ -42,6 +42,7 @@
 
 #include "gzstream.h"
 #include "io.h"
+#include "fastblas.h"
 #include "lapack.h"
 #include "lmm.h"
 
@@ -1315,7 +1316,7 @@ void LMM::Analyze(std::function< SnpNameValues(size_t) >& fetch_snp,
         gsl_matrix_submatrix(UtXlarge, 0, 0, inds, l);
 
     time_start = clock();
-    eigenlib_dgemm("T", "N", 1.0, U, &Xlarge_sub.matrix, 0.0,
+    fast_dgemm("T", "N", 1.0, U, &Xlarge_sub.matrix, 0.0,
                    &UtXlarge_sub.matrix);
     time_UtX += (clock() - time_start) / (double(CLOCKS_PER_SEC) * 60.0);
 
diff --git a/test/dev_test_suite.sh b/test/dev_test_suite.sh
index 136ba76..0ccc24c 100755
--- a/test/dev_test_suite.sh
+++ b/test/dev_test_suite.sh
@@ -11,7 +11,8 @@ testBXDStandardRelatednessMatrixKSingularError() {
            -c ../example/BXD_covariates.txt \
            -a ../example/BXD_snps.txt \
            -gk \
-           -debug -o $outn
+           -debug -strict \
+           -o $outn
     assertEquals 22 $? # should show singular error
 }
 
@@ -23,7 +24,7 @@ testBXDStandardRelatednessMatrixK() {
            -c ../example/BXD_covariates2.txt \
            -a ../example/BXD_snps.txt \
            -gk \
-           -debug \
+           -debug -strict \
            -o $outn
     assertEquals 0 $?
     outfn=output/$outn.cXX.txt
@@ -39,7 +40,7 @@ testBXDLMMLikelihoodRatio() {
            -a ../example/BXD_snps.txt \
            -k ./output/BXD.cXX.txt \
            -lmm 2 -maf 0.1 \
-           -debug \
+           -debug -strict \
            -o $outn
     assertEquals 0 $?