about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/pgvector/sqlalchemy/halfvec.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/pgvector/sqlalchemy/halfvec.py')
-rw-r--r--.venv/lib/python3.12/site-packages/pgvector/sqlalchemy/halfvec.py51
1 files changed, 51 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/pgvector/sqlalchemy/halfvec.py b/.venv/lib/python3.12/site-packages/pgvector/sqlalchemy/halfvec.py
new file mode 100644
index 00000000..639f77bd
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/pgvector/sqlalchemy/halfvec.py
@@ -0,0 +1,51 @@
+from sqlalchemy.dialects.postgresql.base import ischema_names
+from sqlalchemy.types import UserDefinedType, Float, String
+from ..utils import HalfVector
+
+
+class HALFVEC(UserDefinedType):
+    cache_ok = True
+    _string = String()
+
+    def __init__(self, dim=None):
+        super(UserDefinedType, self).__init__()
+        self.dim = dim
+
+    def get_col_spec(self, **kw):
+        if self.dim is None:
+            return 'HALFVEC'
+        return 'HALFVEC(%d)' % self.dim
+
+    def bind_processor(self, dialect):
+        def process(value):
+            return HalfVector._to_db(value, self.dim)
+        return process
+
+    def literal_processor(self, dialect):
+        string_literal_processor = self._string._cached_literal_processor(dialect)
+
+        def process(value):
+            return string_literal_processor(HalfVector._to_db(value, self.dim))
+        return process
+
+    def result_processor(self, dialect, coltype):
+        def process(value):
+            return HalfVector._from_db(value)
+        return process
+
+    class comparator_factory(UserDefinedType.Comparator):
+        def l2_distance(self, other):
+            return self.op('<->', return_type=Float)(other)
+
+        def max_inner_product(self, other):
+            return self.op('<#>', return_type=Float)(other)
+
+        def cosine_distance(self, other):
+            return self.op('<=>', return_type=Float)(other)
+
+        def l1_distance(self, other):
+            return self.op('<+>', return_type=Float)(other)
+
+
+# for reflection
+ischema_names['halfvec'] = HALFVEC