aboutsummaryrefslogtreecommitdiff
from peewee import Expression, Field
from ..utils import Vector


class VectorField(Field):
    field_type = 'vector'

    def __init__(self, dimensions=None, *args, **kwargs):
        self.dimensions = dimensions
        super(VectorField, self).__init__(*args, **kwargs)

    def get_modifiers(self):
        return self.dimensions and [self.dimensions] or None

    def db_value(self, value):
        return Vector._to_db(value)

    def python_value(self, value):
        return Vector._from_db(value)

    def _distance(self, op, vector):
        return Expression(lhs=self, op=op, rhs=self.to_value(vector))

    def l2_distance(self, vector):
        return self._distance('<->', vector)

    def max_inner_product(self, vector):
        return self._distance('<#>', vector)

    def cosine_distance(self, vector):
        return self._distance('<=>', vector)

    def l1_distance(self, vector):
        return self._distance('<+>', vector)