diff options
Diffstat (limited to 'gn/packages/machine-learning.scm')
-rw-r--r-- | gn/packages/machine-learning.scm | 107 |
1 files changed, 106 insertions, 1 deletions
diff --git a/gn/packages/machine-learning.scm b/gn/packages/machine-learning.scm index 8f9f1f0..3e51132 100644 --- a/gn/packages/machine-learning.scm +++ b/gn/packages/machine-learning.scm @@ -2,7 +2,9 @@ #:use-module ((guix licenses) #:prefix license:) #:use-module (guix packages) #:use-module (guix utils) - #:use-module (gnu packages machine-learning)) + #:use-module (gnu packages machine-learning) + #:use-module (guix download) + #:use-module (guix build-system python)) (define-public tensorflow-native (package @@ -19,3 +21,106 @@ (define-public tensowflow-native-instead-of-tensorflow (package-input-rewriting/spec `(("tensorflow" . ,(const tensorflow-native))))) + + +(define-public python-keras-no-tests + (package + (name "python-keras-no-tests") + (version "2.3.1") + (source + (origin + (method url-fetch) + (uri (pypi-uri "Keras" version)) + (sha256 + (base32 + "1k68xd8n2y9ldijggjc8nn4d6d1axw0p98gfb0fmm8h641vl679j")) + (modules '((guix build utils))) + (snippet + '(substitute* '("keras/callbacks/callbacks.py" + "keras/engine/training_utils.py" + "keras/engine/training.py" + "keras/engine/training_generator.py" + "keras/utils/generic_utils.py") + (("from collections import Iterable") + "from collections.abc import Iterable") + (("collections.Container") + "collections.abc.Container") + (("collections.Mapping") + "collections.abc.Mapping") + (("collections.Sequence") + "collections.abc.Sequence"))))) + (build-system python-build-system) + (arguments + `(#:phases + (modify-phases %standard-phases + (add-after 'unpack 'tf-compatibility + (lambda _ + (substitute* "keras/backend/tensorflow_backend.py" + (("^get_graph = .*") + "get_graph = tf.get_default_graph") + (("tf.compat.v1.nn.fused_batch_norm") + "tf.nn.fused_batch_norm") + ;; categorical_crossentropy does not support axis + (("from_logits=from_logits, axis=axis") + "from_logits=from_logits") + ;; dropout accepts a level number, not a named rate argument. + (("dropout\\(x, rate=level,") + "dropout(x, level,") + (("return x.shape.rank") + "return len(x.shape)")))) + (add-after 'unpack 'hdf5-compatibility + (lambda _ + ;; The truth value of an array with more than one element is ambiguous. + (substitute* "tests/keras/utils/io_utils_test.py" + ((" *assert .* == \\[b'(asd|efg).*") "")) + (substitute* "tests/test_model_saving.py" + (("h5py.File\\('does not matter',") + "h5py.File('does not matter', 'w',")) + (substitute* "keras/utils/io_utils.py" + (("h5py.File\\('in-memory-h5py', driver='core', backing_store=False\\)") + "h5py.File('in-memory-h5py', 'w', driver='core', backing_store=False)") + (("h5file.fid.get_file_image") + "h5file.id.get_file_image")) + (substitute* "keras/engine/saving.py" + (("\\.decode\\('utf-?8'\\)") "")))) + (add-after 'unpack 'delete-unavailable-backends + (lambda _ + (delete-file "keras/backend/theano_backend.py") + (delete-file "keras/backend/cntk_backend.py"))) + (delete 'check)))) + (propagated-inputs + (list python-h5py + python-keras-applications + python-keras-preprocessing + python-numpy + python-pydot + python-pyyaml + python-scipy + python-six + tensorflow + graphviz)) + (native-inputs + (list python-flaky + python-markdown + python-pandas + python-pytest + python-pytest-cov + python-pytest-timeout + python-pytest-xdist + python-pyux + python-sphinx + python-requests)) + (home-page "https://keras.io/") + (synopsis "High-level deep learning framework") + (description "Keras is a high-level neural networks API, written in Python +and capable of running on top of TensorFlow. It was developed with a focus on +enabling fast experimentation. Use Keras if you need a deep learning library +that: +@itemize +@item Allows for easy and fast prototyping (through user friendliness, + modularity, and extensibility). +@item Supports both convolutional networks and recurrent networks, as well as + combinations of the two. +@item Runs seamlessly on CPU and GPU. +@end itemize\n") + (license license:expat))) |