aboutsummaryrefslogtreecommitdiff
path: root/gn3/db/matrix.py
blob: 24825f937fc1056ca8cbeabda670f7c58695a860 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
"""Methods for fetching data from the matrix stored in LMDB"""
from typing import Optional
from dataclasses import dataclass

import struct
import json
import lmdb

BLOB_HASH_DIGEST = 32


@dataclass
class Matrix:
    """Store sample data and any other relevant metadata"""

    data: list
    metadata: dict


def get_total_versions(db_path: str) -> int:
    """Get the total number of versions in the matrix"""
    env = lmdb.open(db_path, readonly="True")
    with env.begin(write=False) as txn:
        versions_hash = txn.get(b"versions")
        if not versions_hash:
            return 0
        return int(len(versions_hash) / BLOB_HASH_DIGEST)


def get_nth_matrix(index: int, db_path: str) -> Optional[Matrix]:
    """Get the NTH matrix from the DB_PATH.  The most recent matrix is 0."""
    env = lmdb.open(db_path, readonly="True")
    with env.begin(write=False) as txn:
        versions_hash = txn.get(b"versions")
        if (index * 32) + 32 > len(versions_hash):
            return None
        _hash, nrows = versions_hash[index * 32: (index * 32) + 32], 0
        row_pointers = txn.get(_hash + b":row-pointers")
        if _hash:
            (nrows,) = struct.unpack("<Q", txn.get(_hash + b":nrows"))
        if row_pointers:
            return Matrix(
                data=[
                    json.loads(txn.get(row_pointers[i: i + 32]).decode())
                    for i in range(0, nrows * 32, 32)
                ],
                metadata=json.loads(
                    txn.get(_hash + b":metadata").rstrip(b"\x00").decode()
                ),
            )
        return None


def get_current_matrix(db_path: str) -> Optional[Matrix]:
    """Get the most recent matrix from DB_PATH.  This is functionally
    equivalent to get_nth_matrix(0, db_path)"""
    env = lmdb.open(db_path, readonly="True")
    with env.begin(write=False) as txn:
        current_hash = txn.get(b"current") or b""
        matrix_hash = txn.get(current_hash + b":matrix") or b""
        row_pointers = txn.get(matrix_hash + b":row-pointers")
        nrows = 0
        if matrix_hash:
            (nrows,) = struct.unpack("<Q", txn.get(matrix_hash + b":nrows"))
        if row_pointers:
            return Matrix(
                data=[
                    json.loads(txn.get(row_pointers[i: i + 32]).decode())
                    for i in range(0, nrows * 32, 32)
                ],
                metadata=json.loads(
                    txn.get(matrix_hash + b":metadata")
                    .rstrip(b"\x00")
                    .decode()
                ),
            )
        return None