diff options
Diffstat (limited to 'src/fastblas.cpp')
-rw-r--r-- | src/fastblas.cpp | 64 |
1 files changed, 34 insertions, 30 deletions
diff --git a/src/fastblas.cpp b/src/fastblas.cpp index b25d964..0a4eba3 100644 --- a/src/fastblas.cpp +++ b/src/fastblas.cpp @@ -44,14 +44,16 @@ gsl_matrix *fast_copy(gsl_matrix *m, const double *mem) { } } } else { // faster goes by row + auto v = gsl_vector_calloc(cols); + enforce(v); // just to be sure for (auto r=0; r<rows; r++) { - auto v = gsl_vector_calloc(cols); assert(v->size == cols); assert(v->block->size == cols); assert(v->stride == 1); memcpy(v->block->data,&mem[r*cols],cols*sizeof(double)); gsl_matrix_set_row(m,r,v); } + gsl_vector_free(v); } return m; } @@ -160,28 +162,25 @@ static void fast_cblas_dgemm(const char *TransA, const char *TransB, const doubl // C++ is row-major auto transA = (*TransA == 'N' || *TransA == 'n' ? CblasNoTrans : CblasTrans); auto transB = (*TransB == 'N' || *TransB == 'n' ? CblasNoTrans : CblasTrans); - // A(m x k) * B(k x n) = C(m x n)) - auto rowsA = A->size1; - auto colsA = A->size2; - blasint M = A->size1; - blasint K = B->size1; - assert(K == colsA); - blasint N = B->size2; - // cout << M << "," << N "," << K << endl; - // Layout = CblasRowMajor: Trans: K , NoTrans M - blasint lda = (transA==CblasNoTrans ? K : M ); - blasint ldb = (transB==CblasNoTrans ? N : K ); - blasint ldc = N; - - fast_cblas_dgemm(CblasRowMajor, transA, transB, M, N, K, alpha, - /* A */ A->data, - /* lda */ lda, - /* B */ B->data, - /* ldb */ ldb, - /* beta */ beta, - /* C */ C->data, ldc); + const size_t M = C->size1; + const size_t N = C->size2; + const size_t MA = (transA == CblasNoTrans) ? A->size1 : A->size2; + const size_t NA = (transA == CblasNoTrans) ? A->size2 : A->size1; + const size_t MB = (transB == CblasNoTrans) ? B->size1 : B->size2; + const size_t NB = (transB == CblasNoTrans) ? B->size2 : B->size1; + + if (M == MA && N == NB && NA == MB) { /* [MxN] = [MAxNA][MBxNB] */ + + cblas_dgemm (CblasRowMajor, transA, transB, M, N,NA, + alpha, A->data, A->tda, B->data, B->tda, beta, + C->data, C->tda); + + } else { + throw invalid_argument("Range error in dgemm"); + } } + /* Use the fasted/supported way to call BLAS dgemm */ @@ -189,14 +188,19 @@ static void fast_cblas_dgemm(const char *TransA, const char *TransB, const doubl void fast_dgemm(const char *TransA, const char *TransB, const double alpha, const gsl_matrix *A, const gsl_matrix *B, const double beta, gsl_matrix *C) { - fast_cblas_dgemm(TransA,TransB,alpha,A,B,beta,C); - - #ifndef NDEBUG - if (is_strict_mode() && !is_no_check_mode()) { - // ---- validate with original implementation - gsl_matrix *C1 = gsl_matrix_alloc(C->size1,C->size2); - eigenlib_dgemm(TransA,TransB,alpha,A,B,beta,C1); - enforce_msg(gsl_matrix_equal(C,C1),"dgemm outcomes are not equal for fast & eigenlib"); + if (is_legacy_mode()) { + eigenlib_dgemm(TransA,TransB,alpha,A,B,beta,C); + } else { + fast_cblas_dgemm(TransA,TransB,alpha,A,B,beta,C); + + #ifndef NDEBUG + if (is_check_mode()) { + // ---- validate with original implementation + gsl_matrix *C1 = gsl_matrix_alloc(C->size1,C->size2); + eigenlib_dgemm(TransA,TransB,alpha,A,B,beta,C1); + enforce_msg(gsl_matrix_equal(C,C1),"dgemm outcomes are not equal for fast & eigenlib"); + gsl_matrix_free(C1); + } + #endif } - #endif } |