aboutsummaryrefslogtreecommitdiff
path: root/src/fastblas.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/fastblas.cpp')
-rw-r--r--src/fastblas.cpp64
1 files changed, 34 insertions, 30 deletions
diff --git a/src/fastblas.cpp b/src/fastblas.cpp
index b25d964..0a4eba3 100644
--- a/src/fastblas.cpp
+++ b/src/fastblas.cpp
@@ -44,14 +44,16 @@ gsl_matrix *fast_copy(gsl_matrix *m, const double *mem) {
}
}
} else { // faster goes by row
+ auto v = gsl_vector_calloc(cols);
+ enforce(v); // just to be sure
for (auto r=0; r<rows; r++) {
- auto v = gsl_vector_calloc(cols);
assert(v->size == cols);
assert(v->block->size == cols);
assert(v->stride == 1);
memcpy(v->block->data,&mem[r*cols],cols*sizeof(double));
gsl_matrix_set_row(m,r,v);
}
+ gsl_vector_free(v);
}
return m;
}
@@ -160,28 +162,25 @@ static void fast_cblas_dgemm(const char *TransA, const char *TransB, const doubl
// C++ is row-major
auto transA = (*TransA == 'N' || *TransA == 'n' ? CblasNoTrans : CblasTrans);
auto transB = (*TransB == 'N' || *TransB == 'n' ? CblasNoTrans : CblasTrans);
- // A(m x k) * B(k x n) = C(m x n))
- auto rowsA = A->size1;
- auto colsA = A->size2;
- blasint M = A->size1;
- blasint K = B->size1;
- assert(K == colsA);
- blasint N = B->size2;
- // cout << M << "," << N "," << K << endl;
- // Layout = CblasRowMajor: Trans: K , NoTrans M
- blasint lda = (transA==CblasNoTrans ? K : M );
- blasint ldb = (transB==CblasNoTrans ? N : K );
- blasint ldc = N;
-
- fast_cblas_dgemm(CblasRowMajor, transA, transB, M, N, K, alpha,
- /* A */ A->data,
- /* lda */ lda,
- /* B */ B->data,
- /* ldb */ ldb,
- /* beta */ beta,
- /* C */ C->data, ldc);
+ const size_t M = C->size1;
+ const size_t N = C->size2;
+ const size_t MA = (transA == CblasNoTrans) ? A->size1 : A->size2;
+ const size_t NA = (transA == CblasNoTrans) ? A->size2 : A->size1;
+ const size_t MB = (transB == CblasNoTrans) ? B->size1 : B->size2;
+ const size_t NB = (transB == CblasNoTrans) ? B->size2 : B->size1;
+
+ if (M == MA && N == NB && NA == MB) { /* [MxN] = [MAxNA][MBxNB] */
+
+ cblas_dgemm (CblasRowMajor, transA, transB, M, N,NA,
+ alpha, A->data, A->tda, B->data, B->tda, beta,
+ C->data, C->tda);
+
+ } else {
+ throw invalid_argument("Range error in dgemm");
+ }
}
+
/*
Use the fasted/supported way to call BLAS dgemm
*/
@@ -189,14 +188,19 @@ static void fast_cblas_dgemm(const char *TransA, const char *TransB, const doubl
void fast_dgemm(const char *TransA, const char *TransB, const double alpha,
const gsl_matrix *A, const gsl_matrix *B, const double beta,
gsl_matrix *C) {
- fast_cblas_dgemm(TransA,TransB,alpha,A,B,beta,C);
-
- #ifndef NDEBUG
- if (is_strict_mode() && !is_no_check_mode()) {
- // ---- validate with original implementation
- gsl_matrix *C1 = gsl_matrix_alloc(C->size1,C->size2);
- eigenlib_dgemm(TransA,TransB,alpha,A,B,beta,C1);
- enforce_msg(gsl_matrix_equal(C,C1),"dgemm outcomes are not equal for fast & eigenlib");
+ if (is_legacy_mode()) {
+ eigenlib_dgemm(TransA,TransB,alpha,A,B,beta,C);
+ } else {
+ fast_cblas_dgemm(TransA,TransB,alpha,A,B,beta,C);
+
+ #ifndef NDEBUG
+ if (is_check_mode()) {
+ // ---- validate with original implementation
+ gsl_matrix *C1 = gsl_matrix_alloc(C->size1,C->size2);
+ eigenlib_dgemm(TransA,TransB,alpha,A,B,beta,C1);
+ enforce_msg(gsl_matrix_equal(C,C1),"dgemm outcomes are not equal for fast & eigenlib");
+ gsl_matrix_free(C1);
+ }
+ #endif
}
- #endif
}