about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/sqlalchemy/testing/provision.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/sqlalchemy/testing/provision.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/sqlalchemy/testing/provision.py')
-rw-r--r--.venv/lib/python3.12/site-packages/sqlalchemy/testing/provision.py502
1 files changed, 502 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/sqlalchemy/testing/provision.py b/.venv/lib/python3.12/site-packages/sqlalchemy/testing/provision.py
new file mode 100644
index 00000000..3afcf119
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/sqlalchemy/testing/provision.py
@@ -0,0 +1,502 @@
+# testing/provision.py
+# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+# mypy: ignore-errors
+
+from __future__ import annotations
+
+import collections
+import logging
+
+from . import config
+from . import engines
+from . import util
+from .. import exc
+from .. import inspect
+from ..engine import url as sa_url
+from ..sql import ddl
+from ..sql import schema
+
+
+log = logging.getLogger(__name__)
+
+FOLLOWER_IDENT = None
+
+
+class register:
+    def __init__(self, decorator=None):
+        self.fns = {}
+        self.decorator = decorator
+
+    @classmethod
+    def init(cls, fn):
+        return register().for_db("*")(fn)
+
+    @classmethod
+    def init_decorator(cls, decorator):
+        return register(decorator).for_db("*")
+
+    def for_db(self, *dbnames):
+        def decorate(fn):
+            if self.decorator:
+                fn = self.decorator(fn)
+            for dbname in dbnames:
+                self.fns[dbname] = fn
+            return self
+
+        return decorate
+
+    def __call__(self, cfg, *arg, **kw):
+        if isinstance(cfg, str):
+            url = sa_url.make_url(cfg)
+        elif isinstance(cfg, sa_url.URL):
+            url = cfg
+        else:
+            url = cfg.db.url
+        backend = url.get_backend_name()
+        if backend in self.fns:
+            return self.fns[backend](cfg, *arg, **kw)
+        else:
+            return self.fns["*"](cfg, *arg, **kw)
+
+
+def create_follower_db(follower_ident):
+    for cfg in _configs_for_db_operation():
+        log.info("CREATE database %s, URI %r", follower_ident, cfg.db.url)
+        create_db(cfg, cfg.db, follower_ident)
+
+
+def setup_config(db_url, options, file_config, follower_ident):
+    # load the dialect, which should also have it set up its provision
+    # hooks
+
+    dialect = sa_url.make_url(db_url).get_dialect()
+
+    dialect.load_provisioning()
+
+    if follower_ident:
+        db_url = follower_url_from_main(db_url, follower_ident)
+    db_opts = {}
+    update_db_opts(db_url, db_opts, options)
+    db_opts["scope"] = "global"
+    eng = engines.testing_engine(db_url, db_opts)
+    post_configure_engine(db_url, eng, follower_ident)
+    eng.connect().close()
+
+    cfg = config.Config.register(eng, db_opts, options, file_config)
+
+    # a symbolic name that tests can use if they need to disambiguate
+    # names across databases
+    if follower_ident:
+        config.ident = follower_ident
+
+    if follower_ident:
+        configure_follower(cfg, follower_ident)
+    return cfg
+
+
+def drop_follower_db(follower_ident):
+    for cfg in _configs_for_db_operation():
+        log.info("DROP database %s, URI %r", follower_ident, cfg.db.url)
+        drop_db(cfg, cfg.db, follower_ident)
+
+
+def generate_db_urls(db_urls, extra_drivers):
+    """Generate a set of URLs to test given configured URLs plus additional
+    driver names.
+
+    Given:
+
+    .. sourcecode:: text
+
+        --dburi postgresql://db1  \
+        --dburi postgresql://db2  \
+        --dburi postgresql://db2  \
+        --dbdriver=psycopg2 --dbdriver=asyncpg?async_fallback=true
+
+    Noting that the default postgresql driver is psycopg2,  the output
+    would be:
+
+    .. sourcecode:: text
+
+        postgresql+psycopg2://db1
+        postgresql+asyncpg://db1
+        postgresql+psycopg2://db2
+        postgresql+psycopg2://db3
+
+    That is, for the driver in a --dburi, we want to keep that and use that
+    driver for each URL it's part of .   For a driver that is only
+    in --dbdrivers, we want to use it just once for one of the URLs.
+    for a driver that is both coming from --dburi as well as --dbdrivers,
+    we want to keep it in that dburi.
+
+    Driver specific query options can be specified by added them to the
+    driver name. For example, to enable the async fallback option for
+    asyncpg::
+
+    .. sourcecode:: text
+
+        --dburi postgresql://db1  \
+        --dbdriver=asyncpg?async_fallback=true
+
+    """
+    urls = set()
+
+    backend_to_driver_we_already_have = collections.defaultdict(set)
+
+    urls_plus_dialects = [
+        (url_obj, url_obj.get_dialect())
+        for url_obj in [sa_url.make_url(db_url) for db_url in db_urls]
+    ]
+
+    for url_obj, dialect in urls_plus_dialects:
+        # use get_driver_name instead of dialect.driver to account for
+        # "_async" virtual drivers like oracledb and psycopg
+        driver_name = url_obj.get_driver_name()
+        backend_to_driver_we_already_have[dialect.name].add(driver_name)
+
+    backend_to_driver_we_need = {}
+
+    for url_obj, dialect in urls_plus_dialects:
+        backend = dialect.name
+        dialect.load_provisioning()
+
+        if backend not in backend_to_driver_we_need:
+            backend_to_driver_we_need[backend] = extra_per_backend = set(
+                extra_drivers
+            ).difference(backend_to_driver_we_already_have[backend])
+        else:
+            extra_per_backend = backend_to_driver_we_need[backend]
+
+        for driver_url in _generate_driver_urls(url_obj, extra_per_backend):
+            if driver_url in urls:
+                continue
+            urls.add(driver_url)
+            yield driver_url
+
+
+def _generate_driver_urls(url, extra_drivers):
+    main_driver = url.get_driver_name()
+    extra_drivers.discard(main_driver)
+
+    url = generate_driver_url(url, main_driver, "")
+    yield url
+
+    for drv in list(extra_drivers):
+        if "?" in drv:
+            driver_only, query_str = drv.split("?", 1)
+
+        else:
+            driver_only = drv
+            query_str = None
+
+        new_url = generate_driver_url(url, driver_only, query_str)
+        if new_url:
+            extra_drivers.remove(drv)
+
+            yield new_url
+
+
+@register.init
+def generate_driver_url(url, driver, query_str):
+    backend = url.get_backend_name()
+
+    new_url = url.set(
+        drivername="%s+%s" % (backend, driver),
+    )
+    if query_str:
+        new_url = new_url.update_query_string(query_str)
+
+    try:
+        new_url.get_dialect()
+    except exc.NoSuchModuleError:
+        return None
+    else:
+        return new_url
+
+
+def _configs_for_db_operation():
+    hosts = set()
+
+    for cfg in config.Config.all_configs():
+        cfg.db.dispose()
+
+    for cfg in config.Config.all_configs():
+        url = cfg.db.url
+        backend = url.get_backend_name()
+        host_conf = (backend, url.username, url.host, url.database)
+
+        if host_conf not in hosts:
+            yield cfg
+            hosts.add(host_conf)
+
+    for cfg in config.Config.all_configs():
+        cfg.db.dispose()
+
+
+@register.init
+def drop_all_schema_objects_pre_tables(cfg, eng):
+    pass
+
+
+@register.init
+def drop_all_schema_objects_post_tables(cfg, eng):
+    pass
+
+
+def drop_all_schema_objects(cfg, eng):
+    drop_all_schema_objects_pre_tables(cfg, eng)
+
+    drop_views(cfg, eng)
+
+    if config.requirements.materialized_views.enabled:
+        drop_materialized_views(cfg, eng)
+
+    inspector = inspect(eng)
+
+    consider_schemas = (None,)
+    if config.requirements.schemas.enabled_for_config(cfg):
+        consider_schemas += (cfg.test_schema, cfg.test_schema_2)
+    util.drop_all_tables(eng, inspector, consider_schemas=consider_schemas)
+
+    drop_all_schema_objects_post_tables(cfg, eng)
+
+    if config.requirements.sequences.enabled_for_config(cfg):
+        with eng.begin() as conn:
+            for seq in inspector.get_sequence_names():
+                conn.execute(ddl.DropSequence(schema.Sequence(seq)))
+            if config.requirements.schemas.enabled_for_config(cfg):
+                for schema_name in [cfg.test_schema, cfg.test_schema_2]:
+                    for seq in inspector.get_sequence_names(
+                        schema=schema_name
+                    ):
+                        conn.execute(
+                            ddl.DropSequence(
+                                schema.Sequence(seq, schema=schema_name)
+                            )
+                        )
+
+
+def drop_views(cfg, eng):
+    inspector = inspect(eng)
+
+    try:
+        view_names = inspector.get_view_names()
+    except NotImplementedError:
+        pass
+    else:
+        with eng.begin() as conn:
+            for vname in view_names:
+                conn.execute(
+                    ddl._DropView(schema.Table(vname, schema.MetaData()))
+                )
+
+    if config.requirements.schemas.enabled_for_config(cfg):
+        try:
+            view_names = inspector.get_view_names(schema=cfg.test_schema)
+        except NotImplementedError:
+            pass
+        else:
+            with eng.begin() as conn:
+                for vname in view_names:
+                    conn.execute(
+                        ddl._DropView(
+                            schema.Table(
+                                vname,
+                                schema.MetaData(),
+                                schema=cfg.test_schema,
+                            )
+                        )
+                    )
+
+
+def drop_materialized_views(cfg, eng):
+    inspector = inspect(eng)
+
+    mview_names = inspector.get_materialized_view_names()
+
+    with eng.begin() as conn:
+        for vname in mview_names:
+            conn.exec_driver_sql(f"DROP MATERIALIZED VIEW {vname}")
+
+    if config.requirements.schemas.enabled_for_config(cfg):
+        mview_names = inspector.get_materialized_view_names(
+            schema=cfg.test_schema
+        )
+        with eng.begin() as conn:
+            for vname in mview_names:
+                conn.exec_driver_sql(
+                    f"DROP MATERIALIZED VIEW {cfg.test_schema}.{vname}"
+                )
+
+
+@register.init
+def create_db(cfg, eng, ident):
+    """Dynamically create a database for testing.
+
+    Used when a test run will employ multiple processes, e.g., when run
+    via `tox` or `pytest -n4`.
+    """
+    raise NotImplementedError(
+        "no DB creation routine for cfg: %s" % (eng.url,)
+    )
+
+
+@register.init
+def drop_db(cfg, eng, ident):
+    """Drop a database that we dynamically created for testing."""
+    raise NotImplementedError("no DB drop routine for cfg: %s" % (eng.url,))
+
+
+def _adapt_update_db_opts(fn):
+    insp = util.inspect_getfullargspec(fn)
+    if len(insp.args) == 3:
+        return fn
+    else:
+        return lambda db_url, db_opts, _options: fn(db_url, db_opts)
+
+
+@register.init_decorator(_adapt_update_db_opts)
+def update_db_opts(db_url, db_opts, options):
+    """Set database options (db_opts) for a test database that we created."""
+
+
+@register.init
+def post_configure_engine(url, engine, follower_ident):
+    """Perform extra steps after configuring an engine for testing.
+
+    (For the internal dialects, currently only used by sqlite, oracle, mssql)
+    """
+
+
+@register.init
+def follower_url_from_main(url, ident):
+    """Create a connection URL for a dynamically-created test database.
+
+    :param url: the connection URL specified when the test run was invoked
+    :param ident: the pytest-xdist "worker identifier" to be used as the
+                  database name
+    """
+    url = sa_url.make_url(url)
+    return url.set(database=ident)
+
+
+@register.init
+def configure_follower(cfg, ident):
+    """Create dialect-specific config settings for a follower database."""
+    pass
+
+
+@register.init
+def run_reap_dbs(url, ident):
+    """Remove databases that were created during the test process, after the
+    process has ended.
+
+    This is an optional step that is invoked for certain backends that do not
+    reliably release locks on the database as long as a process is still in
+    use. For the internal dialects, this is currently only necessary for
+    mssql and oracle.
+    """
+
+
+def reap_dbs(idents_file):
+    log.info("Reaping databases...")
+
+    urls = collections.defaultdict(set)
+    idents = collections.defaultdict(set)
+    dialects = {}
+
+    with open(idents_file) as file_:
+        for line in file_:
+            line = line.strip()
+            db_name, db_url = line.split(" ")
+            url_obj = sa_url.make_url(db_url)
+            if db_name not in dialects:
+                dialects[db_name] = url_obj.get_dialect()
+                dialects[db_name].load_provisioning()
+            url_key = (url_obj.get_backend_name(), url_obj.host)
+            urls[url_key].add(db_url)
+            idents[url_key].add(db_name)
+
+    for url_key in urls:
+        url = list(urls[url_key])[0]
+        ident = idents[url_key]
+        run_reap_dbs(url, ident)
+
+
+@register.init
+def temp_table_keyword_args(cfg, eng):
+    """Specify keyword arguments for creating a temporary Table.
+
+    Dialect-specific implementations of this method will return the
+    kwargs that are passed to the Table method when creating a temporary
+    table for testing, e.g., in the define_temp_tables method of the
+    ComponentReflectionTest class in suite/test_reflection.py
+    """
+    raise NotImplementedError(
+        "no temp table keyword args routine for cfg: %s" % (eng.url,)
+    )
+
+
+@register.init
+def prepare_for_drop_tables(config, connection):
+    pass
+
+
+@register.init
+def stop_test_class_outside_fixtures(config, db, testcls):
+    pass
+
+
+@register.init
+def get_temp_table_name(cfg, eng, base_name):
+    """Specify table name for creating a temporary Table.
+
+    Dialect-specific implementations of this method will return the
+    name to use when creating a temporary table for testing,
+    e.g., in the define_temp_tables method of the
+    ComponentReflectionTest class in suite/test_reflection.py
+
+    Default to just the base name since that's what most dialects will
+    use. The mssql dialect's implementation will need a "#" prepended.
+    """
+    return base_name
+
+
+@register.init
+def set_default_schema_on_connection(cfg, dbapi_connection, schema_name):
+    raise NotImplementedError(
+        "backend does not implement a schema name set function: %s"
+        % (cfg.db.url,)
+    )
+
+
+@register.init
+def upsert(
+    cfg, table, returning, *, set_lambda=None, sort_by_parameter_order=False
+):
+    """return the backends insert..on conflict / on dupe etc. construct.
+
+    while we should add a backend-neutral upsert construct as well, such as
+    insert().upsert(), it's important that we continue to test the
+    backend-specific insert() constructs since if we do implement
+    insert().upsert(), that would be using a different codepath for the things
+    we need to test like insertmanyvalues, etc.
+
+    """
+    raise NotImplementedError(
+        f"backend does not include an upsert implementation: {cfg.db.url}"
+    )
+
+
+@register.init
+def normalize_sequence(cfg, sequence):
+    """Normalize sequence parameters for dialect that don't start with 1
+    by default.
+
+    The default implementation does nothing
+    """
+    return sequence