From 87223540e63f6da7ea085c1379deacc23b6ba5dc Mon Sep 17 00:00:00 2001
From: Pjotr Prins
Date: Thu, 7 Dec 2017 09:16:25 +0000
Subject: Some more overflow checks

---
 src/debug.cpp    |  1 +
 src/debug.h      |  4 ++++
 src/fastblas.cpp | 28 +++++++++++++++++-----------
 src/lapack.cpp   |  4 ++++
 4 files changed, 26 insertions(+), 11 deletions(-)

(limited to 'src')

diff --git a/src/debug.cpp b/src/debug.cpp
index 4e57e3e..a4d2562 100644
--- a/src/debug.cpp
+++ b/src/debug.cpp
@@ -40,6 +40,7 @@ bool is_quiet_mode() { return debug_quiet; };
 bool is_issue(uint issue) { return issue == debug_issue; };
 bool is_legacy_mode() { return debug_legacy; };
 
+
 /*
   Helper function to make sure gsl allocations do their job because
   gsl_matrix_alloc does not initiatize values (behaviour that changed
diff --git a/src/debug.h b/src/debug.h
index 69b0a7c..e58c1d5 100644
--- a/src/debug.h
+++ b/src/debug.h
@@ -25,6 +25,10 @@ bool is_quiet_mode();
 bool is_issue(uint issue);
 bool is_legacy_mode();
 
+#define check_int_mult_overflow(m,n) \
+  { auto x = m * n;                                      \
+    enforce_msg(x / m == n, "multiply integer overflow"); }
+
 gsl_matrix *gsl_matrix_safe_alloc(size_t rows,size_t cols);
 int gsl_matrix_safe_memcpy (gsl_matrix *dest, const gsl_matrix *src);
 void gsl_matrix_safe_free (gsl_matrix *v);
diff --git a/src/fastblas.cpp b/src/fastblas.cpp
index a7adc44..c15e34f 100644
--- a/src/fastblas.cpp
+++ b/src/fastblas.cpp
@@ -129,16 +129,10 @@ void fast_cblas_dgemm(const enum CBLAS_ORDER Order,
   enforce(N>0);
   enforce(K>0);
 
-  // cout << sizeof(blasint) << endl;
-  blasint mi = M;
-  blasint ni = N;
-  blasint ki = K;
-  auto x1 = mi * ki;
-  enforce_msg(x1 / mi == ki, "matrix integer overflow - please use long int");
-  auto x2 = ni * ki;
-  enforce_msg(x2 / ni == ki, "matrix integer overflow - please use long int");
-  auto x3 = mi * ni;
-  enforce_msg(x3 / mi == ni, "matrix integer overflow - please use long int");
+  // check_int_mult_overflow(560000,8000); // fails on default int (32-bits)
+  check_int_mult_overflow(M,K);
+  check_int_mult_overflow(N,K);
+  check_int_mult_overflow(M,N);
 
   cblas_dgemm(Order,TransA,TransB,M,N,K,alpha,A,lda,B,ldb,beta,C,ldc);
 
@@ -191,7 +185,19 @@ static void fast_cblas_dgemm(const char *TransA, const char *TransB, const doubl
 
   if (M == MA && N == NB && NA == MBx) {  /* [MxN] = [MAxNA][MBxNB] */
 
-    cblas_dgemm (CblasRowMajor, transA, transB, M, N,NA,
+    auto K = NA;
+
+    // Check for (integer) overflows
+    enforce(M>0);
+    enforce(N>0);
+    enforce(K>0);
+
+    // check_int_mult_overflow(560000,8000);
+    check_int_mult_overflow(M,K);
+    check_int_mult_overflow(N,K);
+    check_int_mult_overflow(M,N);
+
+    cblas_dgemm (CblasRowMajor, transA, transB, M, N, NA,
                  alpha, A->data, A->tda, B->data, B->tda, beta,
                  C->data, C->tda);
 
diff --git a/src/lapack.cpp b/src/lapack.cpp
index ee0a497..26333b6 100644
--- a/src/lapack.cpp
+++ b/src/lapack.cpp
@@ -128,6 +128,10 @@ void lapack_dgemm(char *TransA, char *TransB, double alpha, const gsl_matrix *A,
   gsl_matrix *C_t = gsl_matrix_alloc(C->size2, C->size1);
   gsl_matrix_transpose_memcpy(C_t, C);
 
+  check_int_mult_overflow(M,K1);
+  check_int_mult_overflow(N,K1);
+  check_int_mult_overflow(M,N);
+
   dgemm_(TransA, TransB, &M, &N, &K1, &alpha, A_t->data, &LDA, B_t->data, &LDB,
          &beta, C_t->data, &LDC);
 
-- 
cgit v1.2.3