diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/pgvector/django/functions.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/pgvector/django/functions.py | 55 |
1 files changed, 55 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/pgvector/django/functions.py b/.venv/lib/python3.12/site-packages/pgvector/django/functions.py new file mode 100644 index 00000000..da9fbf83 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/pgvector/django/functions.py @@ -0,0 +1,55 @@ +from django.db.models import FloatField, Func, Value +from ..utils import Vector, HalfVector, SparseVector + + +class DistanceBase(Func): + output_field = FloatField() + + def __init__(self, expression, vector, **extra): + if not hasattr(vector, 'resolve_expression'): + if isinstance(vector, HalfVector): + vector = Value(HalfVector._to_db(vector)) + elif isinstance(vector, SparseVector): + vector = Value(SparseVector._to_db(vector)) + else: + vector = Value(Vector._to_db(vector)) + super().__init__(expression, vector, **extra) + + +class BitDistanceBase(Func): + output_field = FloatField() + + def __init__(self, expression, vector, **extra): + if not hasattr(vector, 'resolve_expression'): + vector = Value(vector) + super().__init__(expression, vector, **extra) + + +class L2Distance(DistanceBase): + function = '' + arg_joiner = ' <-> ' + + +class MaxInnerProduct(DistanceBase): + function = '' + arg_joiner = ' <#> ' + + +class CosineDistance(DistanceBase): + function = '' + arg_joiner = ' <=> ' + + +class L1Distance(DistanceBase): + function = '' + arg_joiner = ' <+> ' + + +class HammingDistance(BitDistanceBase): + function = '' + arg_joiner = ' <~> ' + + +class JaccardDistance(BitDistanceBase): + function = '' + arg_joiner = ' <%%> ' |