1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
|
from sqlalchemy.dialects.postgresql.base import ischema_names
from sqlalchemy.types import UserDefinedType, Float, String
from ..utils import Vector
class VECTOR(UserDefinedType):
cache_ok = True
_string = String()
def __init__(self, dim=None):
super(UserDefinedType, self).__init__()
self.dim = dim
def get_col_spec(self, **kw):
if self.dim is None:
return 'VECTOR'
return 'VECTOR(%d)' % self.dim
def bind_processor(self, dialect):
def process(value):
return Vector._to_db(value, self.dim)
return process
def literal_processor(self, dialect):
string_literal_processor = self._string._cached_literal_processor(dialect)
def process(value):
return string_literal_processor(Vector._to_db(value, self.dim))
return process
def result_processor(self, dialect, coltype):
def process(value):
return Vector._from_db(value)
return process
class comparator_factory(UserDefinedType.Comparator):
def l2_distance(self, other):
return self.op('<->', return_type=Float)(other)
def max_inner_product(self, other):
return self.op('<#>', return_type=Float)(other)
def cosine_distance(self, other):
return self.op('<=>', return_type=Float)(other)
def l1_distance(self, other):
return self.op('<+>', return_type=Float)(other)
# for reflection
ischema_names['vector'] = VECTOR
|