about summary refs log tree commit diff
path: root/gn/packages/machine-learning.scm
blob: 3e511328723619dd2399c13ef3d2ba6aa63eaa08 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
(define-module (gn packages machine-learning)
  #:use-module ((guix licenses) #:prefix license:)
  #:use-module (guix packages)
  #:use-module (guix utils)
  #:use-module (gnu packages machine-learning)
  #:use-module (guix download)
  #:use-module (guix build-system python))

(define-public tensorflow-native
  (package
    (inherit tensorflow)
    (name "tensorflow-native")
    (arguments
     (substitute-keyword-arguments (package-arguments tensorflow)
       ((#:substitutable? _ #f) #f)
       ((#:configure-flags flags)
        `(cons
           "-Dtensorflow_OPTIMIZE_FOR_NATIVE_ARCH=ON"
           (delete "-Dtensorflow_OPTIMIZE_FOR_NATIVE_ARCH=OFF"
                   ,flags)))))))

(define-public tensowflow-native-instead-of-tensorflow
  (package-input-rewriting/spec `(("tensorflow" . ,(const tensorflow-native)))))


(define-public python-keras-no-tests
  (package
    (name "python-keras-no-tests")
    (version "2.3.1")
    (source
     (origin
       (method url-fetch)
       (uri (pypi-uri "Keras" version))
       (sha256
        (base32
         "1k68xd8n2y9ldijggjc8nn4d6d1axw0p98gfb0fmm8h641vl679j"))
       (modules '((guix build utils)))
       (snippet
        '(substitute* '("keras/callbacks/callbacks.py"
                        "keras/engine/training_utils.py"
                        "keras/engine/training.py"
                        "keras/engine/training_generator.py"
                        "keras/utils/generic_utils.py")
           (("from collections import Iterable")
            "from collections.abc import Iterable")
           (("collections.Container")
            "collections.abc.Container")
           (("collections.Mapping")
            "collections.abc.Mapping")
           (("collections.Sequence")
            "collections.abc.Sequence")))))
    (build-system python-build-system)
    (arguments
     `(#:phases
       (modify-phases %standard-phases
         (add-after 'unpack 'tf-compatibility
           (lambda _
             (substitute* "keras/backend/tensorflow_backend.py"
               (("^get_graph = .*")
                "get_graph = tf.get_default_graph")
               (("tf.compat.v1.nn.fused_batch_norm")
                "tf.nn.fused_batch_norm")
               ;; categorical_crossentropy does not support axis
               (("from_logits=from_logits, axis=axis")
                "from_logits=from_logits")
               ;; dropout accepts a level number, not a named rate argument.
               (("dropout\\(x, rate=level,")
                "dropout(x, level,")
               (("return x.shape.rank")
                "return len(x.shape)"))))
         (add-after 'unpack 'hdf5-compatibility
           (lambda _
             ;; The truth value of an array with more than one element is ambiguous.
             (substitute* "tests/keras/utils/io_utils_test.py"
               ((" *assert .* == \\[b'(asd|efg).*") ""))
             (substitute* "tests/test_model_saving.py"
               (("h5py.File\\('does not matter',")
                "h5py.File('does not matter', 'w',"))
             (substitute* "keras/utils/io_utils.py"
               (("h5py.File\\('in-memory-h5py', driver='core', backing_store=False\\)")
                "h5py.File('in-memory-h5py', 'w', driver='core', backing_store=False)")
               (("h5file.fid.get_file_image")
                "h5file.id.get_file_image"))
             (substitute* "keras/engine/saving.py"
               (("\\.decode\\('utf-?8'\\)") ""))))
         (add-after 'unpack 'delete-unavailable-backends
           (lambda _
             (delete-file "keras/backend/theano_backend.py")
             (delete-file "keras/backend/cntk_backend.py")))
         (delete 'check))))
    (propagated-inputs
     (list python-h5py
           python-keras-applications
           python-keras-preprocessing
           python-numpy
           python-pydot
           python-pyyaml
           python-scipy
           python-six
           tensorflow
           graphviz))
    (native-inputs
     (list python-flaky
           python-markdown
           python-pandas
           python-pytest
           python-pytest-cov
           python-pytest-timeout
           python-pytest-xdist
           python-pyux
           python-sphinx
           python-requests))
    (home-page "https://keras.io/")
    (synopsis "High-level deep learning framework")
    (description "Keras is a high-level neural networks API, written in Python
and capable of running on top of TensorFlow.  It was developed with a focus on
enabling fast experimentation.  Use Keras if you need a deep learning library
that:
@itemize
@item Allows for easy and fast prototyping (through user friendliness,
  modularity, and extensibility).
@item Supports both convolutional networks and recurrent networks, as well as
  combinations of the two.
@item Runs seamlessly on CPU and GPU.
@end itemize\n")
    (license license:expat)))