about summary refs log tree commit diff
diff options
context:
space:
mode:
authorPjotr Prins2017-12-06 08:49:48 +0000
committerPjotr Prins2017-12-06 08:50:00 +0000
commit6042155a4b8d8fad3e5bacdb8fa3fe2f072b80fe (patch)
tree2fe66d57a2965331ca0dd126742705f6a5cfcee4
parent4aa3540f0be29f7db73d2c18dd1f91e514ead490 (diff)
downloadpangemma-6042155a4b8d8fad3e5bacdb8fa3fe2f072b80fe.tar.gz
Integer overflow checking for matrix dgemm and adding path for OpenBlas include files
-rw-r--r--Makefile6
-rw-r--r--src/fastblas.cpp15
2 files changed, 19 insertions, 2 deletions
diff --git a/Makefile b/Makefile
index cb82e46..72e43f8 100644
--- a/Makefile
+++ b/Makefile
@@ -53,6 +53,7 @@ FORCE_STATIC           =                  # Static linking of libraries
 GCC_FLAGS              = -Wall -O3 -std=gnu++11 # extra flags -Wl,--allow-multiple-definition
 TRAVIS_CI              =                  # used by TRAVIS for testing
 EIGEN_INCLUDE_PATH     = /usr/include/eigen3
+OPENBLAS_INCLUDE_PATH  = /usr/local/opt/openblas/include
 
 # --------------------------------------------------------------------
 # Edit below this line with caution
@@ -71,14 +72,15 @@ endif
 
 ifeq ($(CPP), clang++)
   # macOS Homebrew settings (as used on Travis-CI)
-  GCC_FLAGS=-O3 -std=c++11 -stdlib=libc++ -isystem//usr/local/opt/openblas/include -isystem//usr/local/include/eigen3 -Wl,-L/usr/local/opt/openblas/lib
+  GCC_FLAGS=-O3 -std=c++11 -stdlib=libc++ -isystem/$(OPENBLAS_INCLUDE_PATH) -isystem//usr/local/include/eigen3 -Wl,-L/usr/local/opt/openblas/lib
 endif
 
 ifdef WITH_OPENBLAS
   OPENBLAS=1
   # WITH_LAPACK =  # OPENBLAS usually includes LAPACK
-  CPPFLAGS += -DOPENBLAS
+  CPPFLAGS += -DOPENBLAS -isystem/$(OPENBLAS_INCLUDE_PATH)
   ifdef OPENBLAS_LEGACY
+    # Legacy version (mostly for Travis-CI)
     CPPFLAGS += -DOPENBLAS_LEGACY
   endif
 endif
diff --git a/src/fastblas.cpp b/src/fastblas.cpp
index 20456ef..7b10852 100644
--- a/src/fastblas.cpp
+++ b/src/fastblas.cpp
@@ -124,6 +124,21 @@ void fast_cblas_dgemm(const enum CBLAS_ORDER Order,
   }
 #endif // NDEBUG
 
+  // Check for (integer) overflows
+  enforce(M>0);
+  enforce(N>0);
+  enforce(K>0);
+
+  blasint mi = M;
+  blasint ni = N;
+  blasint ki = K;
+  auto x1 = mi * ki;
+  enforce_msg(x1 / mi == ki, "matrix integer overflow - please use long int");
+  auto x2 = ni * ki;
+  enforce_msg(x2 / ni == ki, "matrix integer overflow - please use long int");
+  auto x3 = mi * ni;
+  enforce_msg(x3 / mi == ni, "matrix integer overflow - please use long int");
+
   cblas_dgemm(Order,TransA,TransB,M,N,K,alpha,A,lda,B,ldb,beta,C,ldc);
 
 #ifndef NDEBUG