about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/sqlalchemy/testing/fixtures/mypy.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/sqlalchemy/testing/fixtures/mypy.py')
-rw-r--r--.venv/lib/python3.12/site-packages/sqlalchemy/testing/fixtures/mypy.py312
1 files changed, 312 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/sqlalchemy/testing/fixtures/mypy.py b/.venv/lib/python3.12/site-packages/sqlalchemy/testing/fixtures/mypy.py
new file mode 100644
index 00000000..0832d892
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/sqlalchemy/testing/fixtures/mypy.py
@@ -0,0 +1,312 @@
+# testing/fixtures/mypy.py
+# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+# mypy: ignore-errors
+
+from __future__ import annotations
+
+import inspect
+import os
+from pathlib import Path
+import re
+import shutil
+import sys
+import tempfile
+
+from .base import TestBase
+from .. import config
+from ..assertions import eq_
+from ... import util
+
+
+@config.add_to_marker.mypy
+class MypyTest(TestBase):
+    __requires__ = ("no_sqlalchemy2_stubs",)
+
+    @config.fixture(scope="function")
+    def per_func_cachedir(self):
+        yield from self._cachedir()
+
+    @config.fixture(scope="class")
+    def cachedir(self):
+        yield from self._cachedir()
+
+    def _cachedir(self):
+        # as of mypy 0.971 i think we need to keep mypy_path empty
+        mypy_path = ""
+
+        with tempfile.TemporaryDirectory() as cachedir:
+            with open(
+                Path(cachedir) / "sqla_mypy_config.cfg", "w"
+            ) as config_file:
+                config_file.write(
+                    f"""
+                    [mypy]\n
+                    plugins = sqlalchemy.ext.mypy.plugin\n
+                    show_error_codes = True\n
+                    {mypy_path}
+                    disable_error_code = no-untyped-call
+
+                    [mypy-sqlalchemy.*]
+                    ignore_errors = True
+
+                    """
+                )
+            with open(
+                Path(cachedir) / "plain_mypy_config.cfg", "w"
+            ) as config_file:
+                config_file.write(
+                    f"""
+                    [mypy]\n
+                    show_error_codes = True\n
+                    {mypy_path}
+                    disable_error_code = var-annotated,no-untyped-call
+                    [mypy-sqlalchemy.*]
+                    ignore_errors = True
+
+                    """
+                )
+            yield cachedir
+
+    @config.fixture()
+    def mypy_runner(self, cachedir):
+        from mypy import api
+
+        def run(path, use_plugin=False, use_cachedir=None):
+            if use_cachedir is None:
+                use_cachedir = cachedir
+            args = [
+                "--strict",
+                "--raise-exceptions",
+                "--cache-dir",
+                use_cachedir,
+                "--config-file",
+                os.path.join(
+                    use_cachedir,
+                    (
+                        "sqla_mypy_config.cfg"
+                        if use_plugin
+                        else "plain_mypy_config.cfg"
+                    ),
+                ),
+            ]
+
+            # mypy as of 0.990 is more aggressively blocking messaging
+            # for paths that are in sys.path, and as pytest puts currdir,
+            # test/ etc in sys.path, just copy the source file to the
+            # tempdir we are working in so that we don't have to try to
+            # manipulate sys.path and/or guess what mypy is doing
+            filename = os.path.basename(path)
+            test_program = os.path.join(use_cachedir, filename)
+            if path != test_program:
+                shutil.copyfile(path, test_program)
+            args.append(test_program)
+
+            # I set this locally but for the suite here needs to be
+            # disabled
+            os.environ.pop("MYPY_FORCE_COLOR", None)
+
+            stdout, stderr, exitcode = api.run(args)
+            return stdout, stderr, exitcode
+
+        return run
+
+    @config.fixture
+    def mypy_typecheck_file(self, mypy_runner):
+        def run(path, use_plugin=False):
+            expected_messages = self._collect_messages(path)
+            stdout, stderr, exitcode = mypy_runner(path, use_plugin=use_plugin)
+            self._check_output(
+                path, expected_messages, stdout, stderr, exitcode
+            )
+
+        return run
+
+    @staticmethod
+    def file_combinations(dirname):
+        if os.path.isabs(dirname):
+            path = dirname
+        else:
+            caller_path = inspect.stack()[1].filename
+            path = os.path.join(os.path.dirname(caller_path), dirname)
+        files = list(Path(path).glob("**/*.py"))
+
+        for extra_dir in config.options.mypy_extra_test_paths:
+            if extra_dir and os.path.isdir(extra_dir):
+                files.extend((Path(extra_dir) / dirname).glob("**/*.py"))
+        return files
+
+    def _collect_messages(self, path):
+        from sqlalchemy.ext.mypy.util import mypy_14
+
+        expected_messages = []
+        expected_re = re.compile(r"\s*# EXPECTED(_MYPY)?(_RE)?(_TYPE)?: (.+)")
+        py_ver_re = re.compile(r"^#\s*PYTHON_VERSION\s?>=\s?(\d+\.\d+)")
+        with open(path) as file_:
+            current_assert_messages = []
+            for num, line in enumerate(file_, 1):
+                m = py_ver_re.match(line)
+                if m:
+                    major, _, minor = m.group(1).partition(".")
+                    if sys.version_info < (int(major), int(minor)):
+                        config.skip_test(
+                            "Requires python >= %s" % (m.group(1))
+                        )
+                    continue
+
+                m = expected_re.match(line)
+                if m:
+                    is_mypy = bool(m.group(1))
+                    is_re = bool(m.group(2))
+                    is_type = bool(m.group(3))
+
+                    expected_msg = re.sub(r"# noqa[:]? ?.*", "", m.group(4))
+                    if is_type:
+                        if not is_re:
+                            # the goal here is that we can cut-and-paste
+                            # from vscode -> pylance into the
+                            # EXPECTED_TYPE: line, then the test suite will
+                            # validate that line against what mypy produces
+                            expected_msg = re.sub(
+                                r"([\[\]])",
+                                lambda m: rf"\{m.group(0)}",
+                                expected_msg,
+                            )
+
+                            # note making sure preceding text matches
+                            # with a dot, so that an expect for "Select"
+                            # does not match "TypedSelect"
+                            expected_msg = re.sub(
+                                r"([\w_]+)",
+                                lambda m: rf"(?:.*\.)?{m.group(1)}\*?",
+                                expected_msg,
+                            )
+
+                            expected_msg = re.sub(
+                                "List", "builtins.list", expected_msg
+                            )
+
+                            expected_msg = re.sub(
+                                r"\b(int|str|float|bool)\b",
+                                lambda m: rf"builtins.{m.group(0)}\*?",
+                                expected_msg,
+                            )
+                            # expected_msg = re.sub(
+                            #     r"(Sequence|Tuple|List|Union)",
+                            #     lambda m: fr"typing.{m.group(0)}\*?",
+                            #     expected_msg,
+                            # )
+
+                        is_mypy = is_re = True
+                        expected_msg = f'Revealed type is "{expected_msg}"'
+
+                    if mypy_14 and util.py39:
+                        # use_lowercase_names, py39 and above
+                        # https://github.com/python/mypy/blob/304997bfb85200fb521ac727ee0ce3e6085e5278/mypy/options.py#L363  # noqa: E501
+
+                        # skip first character which could be capitalized
+                        # "List item x not found" type of message
+                        expected_msg = expected_msg[0] + re.sub(
+                            (
+                                r"\b(List|Tuple|Dict|Set)\b"
+                                if is_type
+                                else r"\b(List|Tuple|Dict|Set|Type)\b"
+                            ),
+                            lambda m: m.group(1).lower(),
+                            expected_msg[1:],
+                        )
+
+                    if mypy_14 and util.py310:
+                        # use_or_syntax, py310 and above
+                        # https://github.com/python/mypy/blob/304997bfb85200fb521ac727ee0ce3e6085e5278/mypy/options.py#L368  # noqa: E501
+                        expected_msg = re.sub(
+                            r"Optional\[(.*?)\]",
+                            lambda m: f"{m.group(1)} | None",
+                            expected_msg,
+                        )
+                    current_assert_messages.append(
+                        (is_mypy, is_re, expected_msg.strip())
+                    )
+                elif current_assert_messages:
+                    expected_messages.extend(
+                        (num, is_mypy, is_re, expected_msg)
+                        for (
+                            is_mypy,
+                            is_re,
+                            expected_msg,
+                        ) in current_assert_messages
+                    )
+                    current_assert_messages[:] = []
+
+        return expected_messages
+
+    def _check_output(self, path, expected_messages, stdout, stderr, exitcode):
+        not_located = []
+        filename = os.path.basename(path)
+        if expected_messages:
+            # mypy 0.990 changed how return codes work, so don't assume a
+            # 1 or a 0 return code here, could be either depending on if
+            # errors were generated or not
+
+            output = []
+
+            raw_lines = stdout.split("\n")
+            while raw_lines:
+                e = raw_lines.pop(0)
+                if re.match(r".+\.py:\d+: error: .*", e):
+                    output.append(("error", e))
+                elif re.match(
+                    r".+\.py:\d+: note: +(?:Possible overload|def ).*", e
+                ):
+                    while raw_lines:
+                        ol = raw_lines.pop(0)
+                        if not re.match(r".+\.py:\d+: note: +def \[.*", ol):
+                            break
+                elif re.match(
+                    r".+\.py:\d+: note: .*(?:perhaps|suggestion)", e, re.I
+                ):
+                    pass
+                elif re.match(r".+\.py:\d+: note: .*", e):
+                    output.append(("note", e))
+
+            for num, is_mypy, is_re, msg in expected_messages:
+                msg = msg.replace("'", '"')
+                prefix = "[SQLAlchemy Mypy plugin] " if not is_mypy else ""
+                for idx, (typ, errmsg) in enumerate(output):
+                    if is_re:
+                        if re.match(
+                            rf".*{filename}\:{num}\: {typ}\: {prefix}{msg}",
+                            errmsg,
+                        ):
+                            break
+                    elif (
+                        f"{filename}:{num}: {typ}: {prefix}{msg}"
+                        in errmsg.replace("'", '"')
+                    ):
+                        break
+                else:
+                    not_located.append(msg)
+                    continue
+                del output[idx]
+
+            if not_located:
+                missing = "\n".join(not_located)
+                print("Couldn't locate expected messages:", missing, sep="\n")
+                if output:
+                    extra = "\n".join(msg for _, msg in output)
+                    print("Remaining messages:", extra, sep="\n")
+                assert False, "expected messages not found, see stdout"
+
+            if output:
+                print(f"{len(output)} messages from mypy were not consumed:")
+                print("\n".join(msg for _, msg in output))
+                assert False, "errors and/or notes remain, see stdout"
+
+        else:
+            if exitcode != 0:
+                print(stdout, stderr, sep="\n")
+
+            eq_(exitcode, 0, msg=stdout)