aboutsummaryrefslogtreecommitdiff
path: root/src/fastblas.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/fastblas.cpp')
-rw-r--r--src/fastblas.cpp28
1 files changed, 17 insertions, 11 deletions
diff --git a/src/fastblas.cpp b/src/fastblas.cpp
index a7adc44..c15e34f 100644
--- a/src/fastblas.cpp
+++ b/src/fastblas.cpp
@@ -129,16 +129,10 @@ void fast_cblas_dgemm(const enum CBLAS_ORDER Order,
enforce(N>0);
enforce(K>0);
- // cout << sizeof(blasint) << endl;
- blasint mi = M;
- blasint ni = N;
- blasint ki = K;
- auto x1 = mi * ki;
- enforce_msg(x1 / mi == ki, "matrix integer overflow - please use long int");
- auto x2 = ni * ki;
- enforce_msg(x2 / ni == ki, "matrix integer overflow - please use long int");
- auto x3 = mi * ni;
- enforce_msg(x3 / mi == ni, "matrix integer overflow - please use long int");
+ // check_int_mult_overflow(560000,8000); // fails on default int (32-bits)
+ check_int_mult_overflow(M,K);
+ check_int_mult_overflow(N,K);
+ check_int_mult_overflow(M,N);
cblas_dgemm(Order,TransA,TransB,M,N,K,alpha,A,lda,B,ldb,beta,C,ldc);
@@ -191,7 +185,19 @@ static void fast_cblas_dgemm(const char *TransA, const char *TransB, const doubl
if (M == MA && N == NB && NA == MBx) { /* [MxN] = [MAxNA][MBxNB] */
- cblas_dgemm (CblasRowMajor, transA, transB, M, N,NA,
+ auto K = NA;
+
+ // Check for (integer) overflows
+ enforce(M>0);
+ enforce(N>0);
+ enforce(K>0);
+
+ // check_int_mult_overflow(560000,8000);
+ check_int_mult_overflow(M,K);
+ check_int_mult_overflow(N,K);
+ check_int_mult_overflow(M,N);
+
+ cblas_dgemm (CblasRowMajor, transA, transB, M, N, NA,
alpha, A->data, A->tda, B->data, B->tda, beta,
C->data, C->tda);