about summary refs log tree commit diff
path: root/src/fastblas.cpp
diff options
context:
space:
mode:
authorPjotr Prins2017-10-13 15:23:00 +0000
committerPjotr Prins2017-10-13 15:27:24 +0000
commitfdb48997ee3ed2b326a92f8e0cc7f72a4b38d8c8 (patch)
tree5f62e06dbacdec2d4ee60da9112508615d42fc1f /src/fastblas.cpp
parent7eca3c49b7790007a4190b73209cab9ffb2bb117 (diff)
downloadpangemma-fdb48997ee3ed2b326a92f8e0cc7f72a4b38d8c8.tar.gz
Refactored debug settings
Replaced eigenlib_dgemm with fast_dgemm - 10-30% speed gain for GEMMA
Diffstat (limited to 'src/fastblas.cpp')
-rw-r--r--src/fastblas.cpp64
1 files changed, 34 insertions, 30 deletions
diff --git a/src/fastblas.cpp b/src/fastblas.cpp
index b25d964..0a4eba3 100644
--- a/src/fastblas.cpp
+++ b/src/fastblas.cpp
@@ -44,14 +44,16 @@ gsl_matrix *fast_copy(gsl_matrix *m, const double *mem) {
       }
     }
   } else { // faster goes by row
+    auto v = gsl_vector_calloc(cols);
+    enforce(v); // just to be sure
     for (auto r=0; r<rows; r++) {
-      auto v = gsl_vector_calloc(cols);
       assert(v->size == cols);
       assert(v->block->size == cols);
       assert(v->stride == 1);
       memcpy(v->block->data,&mem[r*cols],cols*sizeof(double));
       gsl_matrix_set_row(m,r,v);
     }
+    gsl_vector_free(v);
   }
   return m;
 }
@@ -160,28 +162,25 @@ static void fast_cblas_dgemm(const char *TransA, const char *TransB, const doubl
   // C++ is row-major
   auto transA = (*TransA == 'N' || *TransA == 'n' ? CblasNoTrans : CblasTrans);
   auto transB = (*TransB == 'N' || *TransB == 'n' ? CblasNoTrans : CblasTrans);
-  // A(m x k) * B(k x n) = C(m x n))
-  auto rowsA = A->size1;
-  auto colsA = A->size2;
-  blasint M = A->size1;
-  blasint K = B->size1;
-  assert(K == colsA);
-  blasint N = B->size2;
-  // cout << M << "," << N "," << K << endl;
-  // Layout = CblasRowMajor: Trans: K , NoTrans M
-  blasint lda = (transA==CblasNoTrans ? K : M );
-  blasint ldb = (transB==CblasNoTrans ? N : K );
-  blasint ldc = N;
-
-  fast_cblas_dgemm(CblasRowMajor, transA, transB, M, N, K, alpha,
-              /* A */ A->data,
-              /* lda */ lda,
-              /* B */ B->data,
-              /* ldb */ ldb,
-              /* beta */ beta,
-              /* C */ C->data, ldc);
+  const size_t M = C->size1;
+  const size_t N = C->size2;
+  const size_t MA = (transA == CblasNoTrans) ? A->size1 : A->size2;
+  const size_t NA = (transA == CblasNoTrans) ? A->size2 : A->size1;
+  const size_t MB = (transB == CblasNoTrans) ? B->size1 : B->size2;
+  const size_t NB = (transB == CblasNoTrans) ? B->size2 : B->size1;
+
+  if (M == MA && N == NB && NA == MB) {  /* [MxN] = [MAxNA][MBxNB] */
+
+    cblas_dgemm (CblasRowMajor, transA, transB, M, N,NA,
+                 alpha, A->data, A->tda, B->data, B->tda, beta,
+                 C->data, C->tda);
+
+  } else {
+    throw invalid_argument("Range error in dgemm");
+  }
 }
 
+
 /*
    Use the fasted/supported way to call BLAS dgemm
 */
@@ -189,14 +188,19 @@ static void fast_cblas_dgemm(const char *TransA, const char *TransB, const doubl
 void fast_dgemm(const char *TransA, const char *TransB, const double alpha,
                 const gsl_matrix *A, const gsl_matrix *B, const double beta,
                 gsl_matrix *C) {
-  fast_cblas_dgemm(TransA,TransB,alpha,A,B,beta,C);
-
-  #ifndef NDEBUG
-  if (is_strict_mode() && !is_no_check_mode()) {
-    // ---- validate with original implementation
-    gsl_matrix *C1 = gsl_matrix_alloc(C->size1,C->size2);
-    eigenlib_dgemm(TransA,TransB,alpha,A,B,beta,C1);
-    enforce_msg(gsl_matrix_equal(C,C1),"dgemm outcomes are not equal for fast & eigenlib");
+  if (is_legacy_mode()) {
+    eigenlib_dgemm(TransA,TransB,alpha,A,B,beta,C);
+  } else {
+    fast_cblas_dgemm(TransA,TransB,alpha,A,B,beta,C);
+
+    #ifndef NDEBUG
+    if (is_check_mode()) {
+      // ---- validate with original implementation
+      gsl_matrix *C1 = gsl_matrix_alloc(C->size1,C->size2);
+      eigenlib_dgemm(TransA,TransB,alpha,A,B,beta,C1);
+      enforce_msg(gsl_matrix_equal(C,C1),"dgemm outcomes are not equal for fast & eigenlib");
+      gsl_matrix_free(C1);
+    }
+    #endif
   }
-  #endif
 }