about summary refs log tree commit diff
path: root/gn/packages/machine-learning.scm
diff options
context:
space:
mode:
Diffstat (limited to 'gn/packages/machine-learning.scm')
-rw-r--r--gn/packages/machine-learning.scm107
1 files changed, 106 insertions, 1 deletions
diff --git a/gn/packages/machine-learning.scm b/gn/packages/machine-learning.scm
index 8f9f1f0..3e51132 100644
--- a/gn/packages/machine-learning.scm
+++ b/gn/packages/machine-learning.scm
@@ -2,7 +2,9 @@
   #:use-module ((guix licenses) #:prefix license:)
   #:use-module (guix packages)
   #:use-module (guix utils)
-  #:use-module (gnu packages machine-learning))
+  #:use-module (gnu packages machine-learning)
+  #:use-module (guix download)
+  #:use-module (guix build-system python))
 
 (define-public tensorflow-native
   (package
@@ -19,3 +21,106 @@
 
 (define-public tensowflow-native-instead-of-tensorflow
   (package-input-rewriting/spec `(("tensorflow" . ,(const tensorflow-native)))))
+
+
+(define-public python-keras-no-tests
+  (package
+    (name "python-keras-no-tests")
+    (version "2.3.1")
+    (source
+     (origin
+       (method url-fetch)
+       (uri (pypi-uri "Keras" version))
+       (sha256
+        (base32
+         "1k68xd8n2y9ldijggjc8nn4d6d1axw0p98gfb0fmm8h641vl679j"))
+       (modules '((guix build utils)))
+       (snippet
+        '(substitute* '("keras/callbacks/callbacks.py"
+                        "keras/engine/training_utils.py"
+                        "keras/engine/training.py"
+                        "keras/engine/training_generator.py"
+                        "keras/utils/generic_utils.py")
+           (("from collections import Iterable")
+            "from collections.abc import Iterable")
+           (("collections.Container")
+            "collections.abc.Container")
+           (("collections.Mapping")
+            "collections.abc.Mapping")
+           (("collections.Sequence")
+            "collections.abc.Sequence")))))
+    (build-system python-build-system)
+    (arguments
+     `(#:phases
+       (modify-phases %standard-phases
+         (add-after 'unpack 'tf-compatibility
+           (lambda _
+             (substitute* "keras/backend/tensorflow_backend.py"
+               (("^get_graph = .*")
+                "get_graph = tf.get_default_graph")
+               (("tf.compat.v1.nn.fused_batch_norm")
+                "tf.nn.fused_batch_norm")
+               ;; categorical_crossentropy does not support axis
+               (("from_logits=from_logits, axis=axis")
+                "from_logits=from_logits")
+               ;; dropout accepts a level number, not a named rate argument.
+               (("dropout\\(x, rate=level,")
+                "dropout(x, level,")
+               (("return x.shape.rank")
+                "return len(x.shape)"))))
+         (add-after 'unpack 'hdf5-compatibility
+           (lambda _
+             ;; The truth value of an array with more than one element is ambiguous.
+             (substitute* "tests/keras/utils/io_utils_test.py"
+               ((" *assert .* == \\[b'(asd|efg).*") ""))
+             (substitute* "tests/test_model_saving.py"
+               (("h5py.File\\('does not matter',")
+                "h5py.File('does not matter', 'w',"))
+             (substitute* "keras/utils/io_utils.py"
+               (("h5py.File\\('in-memory-h5py', driver='core', backing_store=False\\)")
+                "h5py.File('in-memory-h5py', 'w', driver='core', backing_store=False)")
+               (("h5file.fid.get_file_image")
+                "h5file.id.get_file_image"))
+             (substitute* "keras/engine/saving.py"
+               (("\\.decode\\('utf-?8'\\)") ""))))
+         (add-after 'unpack 'delete-unavailable-backends
+           (lambda _
+             (delete-file "keras/backend/theano_backend.py")
+             (delete-file "keras/backend/cntk_backend.py")))
+         (delete 'check))))
+    (propagated-inputs
+     (list python-h5py
+           python-keras-applications
+           python-keras-preprocessing
+           python-numpy
+           python-pydot
+           python-pyyaml
+           python-scipy
+           python-six
+           tensorflow
+           graphviz))
+    (native-inputs
+     (list python-flaky
+           python-markdown
+           python-pandas
+           python-pytest
+           python-pytest-cov
+           python-pytest-timeout
+           python-pytest-xdist
+           python-pyux
+           python-sphinx
+           python-requests))
+    (home-page "https://keras.io/")
+    (synopsis "High-level deep learning framework")
+    (description "Keras is a high-level neural networks API, written in Python
+and capable of running on top of TensorFlow.  It was developed with a focus on
+enabling fast experimentation.  Use Keras if you need a deep learning library
+that:
+@itemize
+@item Allows for easy and fast prototyping (through user friendliness,
+  modularity, and extensibility).
+@item Supports both convolutional networks and recurrent networks, as well as
+  combinations of the two.
+@item Runs seamlessly on CPU and GPU.
+@end itemize\n")
+    (license license:expat)))