about summary refs log tree commit diff
path: root/gn3
diff options
context:
space:
mode:
authorArun Isaac2022-06-09 16:46:08 +0530
committerArun Isaac2022-06-09 16:46:08 +0530
commit603a86c60869ff2017003f4d46b3de932e879c93 (patch)
tree1247e36e68456e515994704c75b813978c001446 /gn3
parent2edd7e65a373e3037c554af4b196016ace154d5a (diff)
downloadgenenetwork3-603a86c60869ff2017003f4d46b3de932e879c93.tar.gz
gn3: genodb: Rewrite without classes.
We rewrite genodb using only functions. This makes for much more readable
code.

* gn3/genodb.py: Rewrite without classes.
Diffstat (limited to 'gn3')
-rw-r--r--gn3/genodb.py107
1 files changed, 71 insertions, 36 deletions
diff --git a/gn3/genodb.py b/gn3/genodb.py
index c2decb8..2a29098 100644
--- a/gn3/genodb.py
+++ b/gn3/genodb.py
@@ -1,39 +1,74 @@
+'''Genotype database reader
+
+This module is a tiny Python library to read a GeneNetwork genotype
+database. It exports the following functions.
+
+* open - Open a genotype database
+* matrix - Get current matrix
+* row - Get row of matrix
+* column - Get column of matrix
+
+Here is a typical invocation to read row 17 and column 13 from a genotype
+database at `/tmp/bxd`.
+
+from gn3 import genodb
+
+with genodb.open('/tmp/bxd') as db:
+    matrix = genodb.matrix(db)
+    print(genodb.row(matrix, 17))
+    print(genodb.column(matrix, 13))
+
+'''
+
+from collections import namedtuple
+from contextlib import contextmanager
 import lmdb
 import numpy as np
 
-class GenotypeDatabase:
-    def __init__(self, path):
-        self.env = lmdb.open(path, readonly=True, create=False)
-        self.txn = self.env.begin()
-        # 32 bytes in a SHA256 hash
-        self.hash_length = 32
-    def __enter__(self):
-        return self
-    def __exit__(self, type, value, traceback):
-        self.txn.abort()
-        self.env.close()
-    def get(self, hash):
-        return self.txn.get(hash)
-    def get_metadata(self, hash, metadata):
-        return self.txn.get(hash + b':' + metadata.encode())
-    def matrix(self):
-        hash = self.get(b'current')[0:self.hash_length]
-        return Matrix(self, hash)
-
-class Matrix():
-    def __init__(self, db, hash):
-        self.nrows = int.from_bytes(db.get_metadata(hash, 'nrows'), byteorder='little')
-        self.ncols = int.from_bytes(db.get_metadata(hash, 'ncols'), byteorder='little')
-        row_column_pointers = db.get(hash)
-        self.row_pointers = row_column_pointers[0 : self.nrows*db.hash_length]
-        self.column_pointers = row_column_pointers[self.nrows*db.hash_length :]
-        self.db = db
-    def __vector(self, index, pointers):
-        start = index * self.db.hash_length
-        end = start + self.db.hash_length
-        return np.frombuffer(self.db.get(pointers[start:end]),
-                             dtype=np.uint8)
-    def row(self, index):
-        return self.__vector(index, self.row_pointers)
-    def column(self, index):
-        return self.__vector(index, self.column_pointers)
+# pylint: disable=invalid-name,redefined-builtin
+
+GenotypeDatabase = namedtuple('GenotypeDatabase', 'txn hash_length')
+Matrix = namedtuple('Matrix', 'db nrows ncols row_pointers column_pointers')
+
+@contextmanager
+def open(path):
+    '''Open genotype database.'''
+    env = lmdb.open(path, readonly=True, create=False)
+    txn = env.begin()
+    yield GenotypeDatabase(txn, 32) # 32 bytes in a SHA256 hash
+    txn.abort()
+    env.close()
+
+def get(db, key):
+    '''Get value associated with key in genotype database.'''
+    return db.txn.get(key)
+
+def get_metadata(db, hash, metadata):
+    '''Get metadata associated with hash in genotype database.'''
+    return db.txn.get(hash + b':' + metadata.encode())
+
+def matrix(db):
+    '''Get current matrix from genotype database.'''
+    hash = get(db, b'current')[0:db.hash_length]
+    nrows = int.from_bytes(get_metadata(db, hash, 'nrows'), byteorder='little')
+    ncols = int.from_bytes(get_metadata(db, hash, 'ncols'), byteorder='little')
+    row_column_pointers = get(db, hash)
+    return Matrix(db, nrows, ncols,
+                  row_column_pointers[0 : nrows*db.hash_length],
+                  row_column_pointers[nrows*db.hash_length :])
+
+def vector_ref(db, index, pointers):
+    '''Get vector from byte array of pointers.'''
+    start = index * db.hash_length
+    end = start + db.hash_length
+    return np.frombuffer(get(db, pointers[start:end]), dtype=np.uint8)
+
+def row(matrix, index):
+    '''Get row of matrix.'''
+    # pylint: disable=redefined-outer-name
+    return vector_ref(matrix.db, index, matrix.row_pointers)
+
+def column(matrix, index):
+    '''Get column of matrix.'''
+    # pylint: disable=redefined-outer-name
+    return vector_ref(matrix.db, index, matrix.column_pointers)