about summary refs log tree commit diff
diff options
context:
space:
mode:
authorFrederick Muriuki Muriithi2022-09-09 06:44:59 +0300
committerFrederick Muriuki Muriithi2022-09-09 06:44:59 +0300
commitf9ce0bc32db3bdd8a5947a18039c557c750f2957 (patch)
tree4d59941d11700ad9a3e4480dd14152c6ecb8815d
parentff160e5c3dc05e28a98050200b089470c3233e11 (diff)
downloadgenenetwork2-f9ce0bc32db3bdd8a5947a18039c557c750f2957.tar.gz
Refactor: Add tests and handle edge case
Remove mutation of state, and handle the edge case where the
sub-sequence could be an empty sequence.
-rw-r--r--wqflask/tests/unit/wqflask/show_trait/test_get_max_digits.py13
-rw-r--r--wqflask/wqflask/show_trait/show_trait.py11
2 files changed, 18 insertions, 6 deletions
diff --git a/wqflask/tests/unit/wqflask/show_trait/test_get_max_digits.py b/wqflask/tests/unit/wqflask/show_trait/test_get_max_digits.py
new file mode 100644
index 00000000..509f6c3a
--- /dev/null
+++ b/wqflask/tests/unit/wqflask/show_trait/test_get_max_digits.py
@@ -0,0 +1,13 @@
+import pytest
+
+from wqflask.show_trait.show_trait import get_max_digits
+
+@pytest.mark.parametrize(
+    "trait_vals,expected",
+    (((
+        (0, 1345, 92, 734),
+        (234253, 33, 153, 5352),
+        (3542, 24, 135)),
+      [3, 5, 3]),))
+def test_get_max_digits(trait_vals, expected):
+    assert get_max_digits(trait_vals) == expected
diff --git a/wqflask/wqflask/show_trait/show_trait.py b/wqflask/wqflask/show_trait/show_trait.py
index f7dbf8df..ae6cf0cf 100644
--- a/wqflask/wqflask/show_trait/show_trait.py
+++ b/wqflask/wqflask/show_trait/show_trait.py
@@ -555,13 +555,12 @@ def get_trait_vals(sample_list):
     return trait_vals
 
 def get_max_digits(trait_vals):
-    max_digits = []
-    for these_vals in trait_vals:
-        max_val = max(these_vals)
-        digits = len(str(max_val))
-        max_digits.append(digits - 1)
+    def __max_digits__(these_vals):
+        if not bool(these_vals):
+            return None
+        return len(str(max(these_vals))) - 1
 
-    return max_digits
+    return [__max_digits__(val) for val in trait_vals]
 
 def normf(trait_vals):
     ranked_vals = ss.rankdata(trait_vals)