aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/pgvector/psycopg2/register.py
blob: 7752852090cf8e5f771748f1eb76971371ff7015 (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import psycopg2
from psycopg2.extensions import cursor
from .halfvec import register_halfvec_info
from .sparsevec import register_sparsevec_info
from .vector import register_vector_info


# TODO make globally False by default in 0.4.0
# note: register_adapter is always global
# TODO make arrays True by defalt in 0.4.0
def register_vector(conn_or_curs=None, globally=True, arrays=False):
    conn = conn_or_curs if hasattr(conn_or_curs, 'cursor') else conn_or_curs.connection
    cur = conn.cursor(cursor_factory=cursor)
    scope = None if globally else conn_or_curs

    # use to_regtype to get first matching type in search path
    cur.execute("SELECT typname, oid FROM pg_type WHERE oid IN (to_regtype('vector'), to_regtype('_vector'), to_regtype('halfvec'), to_regtype('_halfvec'), to_regtype('sparsevec'), to_regtype('_sparsevec'))")
    type_info = dict(cur.fetchall())

    if 'vector' not in type_info:
        raise psycopg2.ProgrammingError('vector type not found in the database')

    register_vector_info(type_info['vector'], type_info['_vector'] if arrays else None, scope)

    if 'halfvec' in type_info:
        register_halfvec_info(type_info['halfvec'], type_info['_halfvec'] if arrays else None, scope)

    if 'sparsevec' in type_info:
        register_sparsevec_info(type_info['sparsevec'], type_info['_sparsevec'] if arrays else None, scope)