aboutsummaryrefslogtreecommitdiff
from sqlalchemy.dialects.postgresql.base import ischema_names
from sqlalchemy.types import UserDefinedType, Float, String
from ..utils import SparseVector


class SPARSEVEC(UserDefinedType):
    cache_ok = True
    _string = String()

    def __init__(self, dim=None):
        super(UserDefinedType, self).__init__()
        self.dim = dim

    def get_col_spec(self, **kw):
        if self.dim is None:
            return 'SPARSEVEC'
        return 'SPARSEVEC(%d)' % self.dim

    def bind_processor(self, dialect):
        def process(value):
            return SparseVector._to_db(value, self.dim)
        return process

    def literal_processor(self, dialect):
        string_literal_processor = self._string._cached_literal_processor(dialect)

        def process(value):
            return string_literal_processor(SparseVector._to_db(value, self.dim))
        return process

    def result_processor(self, dialect, coltype):
        def process(value):
            return SparseVector._from_db(value)
        return process

    class comparator_factory(UserDefinedType.Comparator):
        def l2_distance(self, other):
            return self.op('<->', return_type=Float)(other)

        def max_inner_product(self, other):
            return self.op('<#>', return_type=Float)(other)

        def cosine_distance(self, other):
            return self.op('<=>', return_type=Float)(other)

        def l1_distance(self, other):
            return self.op('<+>', return_type=Float)(other)


# for reflection
ischema_names['sparsevec'] = SPARSEVEC