Skip to content

pytest_adbc_replay.plugin

pytest_adbc_replay.plugin

pytest-adbc-replay plugin: hooks, CLI option, fixture registration.

pytest_addoption

pytest_addoption(parser: Parser) -> None

Register --adbc-record CLI option and ini configuration keys.

Source code in src/pytest_adbc_replay/plugin.py
def pytest_addoption(parser: pytest.Parser) -> None:
    """Register --adbc-record CLI option and ini configuration keys."""
    group = parser.getgroup("adbc-replay", "ADBC cassette record/replay")
    group.addoption(
        "--adbc-record",
        action="store",
        default=None,
        choices=list(_RECORD_MODES),
        help=(
            "ADBC cassette record mode. "
            "none (default): replay only, fail on miss. "
            "once: record if cassette absent, replay if present. "
            "new_episodes: replay existing, record new. "
            "all: re-record everything."
        ),
    )
    parser.addini(
        "adbc_cassette_dir",
        help="Directory for ADBC cassette files (default: tests/cassettes).",
        type="string",
        default="tests/cassettes",
    )
    parser.addini(
        "adbc_record_mode",
        help="Default ADBC record mode when --adbc-record is not supplied (default: none).",
        type="string",
        default="none",
    )
    parser.addini(
        "adbc_dialect",
        help=(
            "SQL dialect for sqlglot normalisation. Bare value = global fallback "
            "(e.g. 'snowflake'). Per-driver form: 'driver_name: dialect' "
            "(e.g. 'adbc_driver_duckdb: duckdb'). Empty = auto-detect."
        ),
        type="linelist",
        default=[],
    )
    parser.addini(
        "adbc_auto_patch",
        help=(
            "List of ADBC driver module names whose connect() is intercepted "
            "automatically for tests with @pytest.mark.adbc_cassette. "
            "One driver per line (pytest.ini) or a TOML array (pyproject.toml). "
            "Example: adbc_driver_snowflake"
        ),
        type="linelist",
        default=[],
    )
    parser.addini(
        "adbc_scrub_keys",
        help=(
            "Lines of param key names to auto-redact from recorded cassette .json files. "
            "Global form: space-separated key names (e.g. 'token password api_key'). "
            "Per-driver form: 'driver_module_name: key1 key2' (e.g. "
            "'adbc_driver_snowflake: account_id warehouse'). "
            "Both forms can coexist. Matched dict param values become 'REDACTED'."
        ),
        type="linelist",
        default=[],
    )
    parser.addini(
        "adbc_cassette_differentiator_keys",
        help=(
            "db_kwargs key names whose values are appended as extra cassette path "
            "segments. Used to disambiguate drivers sharing a single Python module "
            "(e.g. ADBC Foundry drivers via adbc_driver_manager.dbapi). "
            "Space-separated key names (e.g. 'driver'). Default: 'driver'."
        ),
        type="linelist",
        default=["driver"],
    )

pytest_configure

pytest_configure(config: Config) -> None

Register adbc_cassette marker to suppress PytestUnknownMarkWarning.

Source code in src/pytest_adbc_replay/plugin.py
def pytest_configure(config: pytest.Config) -> None:
    """Register adbc_cassette marker to suppress PytestUnknownMarkWarning."""
    config.addinivalue_line(
        "markers",
        (
            "adbc_cassette(name, *, dialect=None): "
            "Set cassette name and SQL dialect for this test. "
            "name: cassette directory name (default: derived from node ID). "
            "dialect: sqlglot dialect string for SQL normalisation (e.g. 'snowflake')."
        ),
    )

pytest_sessionstart

pytest_sessionstart(session: Session) -> None

Monkeypatch ADBC driver connect() for each driver in adbc_auto_patch.

Source code in src/pytest_adbc_replay/plugin.py
def pytest_sessionstart(session: pytest.Session) -> None:
    """Monkeypatch ADBC driver connect() for each driver in adbc_auto_patch."""
    driver_names: list[str] = cast("list[str]", session.config.getini("adbc_auto_patch"))

    if not driver_names:
        return

    # Initialize the session state eagerly from config so it's available before
    # the adbc_replay fixture is first requested. The adbc_replay fixture will
    # overwrite this with an instance that includes param_serialisers/scrubber.
    _auto_patch_state["session_state"] = _build_session_from_config(session.config)

    for driver_name in driver_names:
        try:
            driver_mod = importlib.import_module(driver_name)
        except ImportError:
            # Driver not installed — skip silently (supports replay-only environments)
            continue

        original_connect = driver_mod.connect
        _ORIGINAL_CONNECTS[driver_name] = original_connect

        def _make_patched(dn: str, orig: Any) -> Any:
            def _patched_connect(**kwargs: Any) -> Any:
                with _ITEM_LOCK:
                    item = _auto_patch_state["current_item"]

                if item is None:
                    # Called outside a test — pass through to real driver
                    return orig(**kwargs)

                marker = item.get_closest_marker("adbc_cassette")
                if marker is None:
                    # No cassette marker — pass through to real driver
                    return orig(**kwargs)

                # Retrieve the session-scoped ReplaySession (always set above)
                session_obj: ReplaySession = _auto_patch_state["session_state"]

                conn = session_obj.wrap_from_item(dn, item, db_kwargs=dict(kwargs), connect_fn=orig)
                with _ITEM_LOCK:
                    _OPEN_CONNECTIONS.setdefault(id(item), []).append(conn)
                return conn

            return _patched_connect

        setattr(driver_mod, "connect", _make_patched(driver_name, original_connect))  # noqa: B010

pytest_runtest_setup

pytest_runtest_setup(item: Item) -> None

Track the currently-running test item for monkeypatched connect() resolution.

Source code in src/pytest_adbc_replay/plugin.py
def pytest_runtest_setup(item: pytest.Item) -> None:
    """Track the currently-running test item for monkeypatched connect() resolution."""
    with _ITEM_LOCK:
        _auto_patch_state["current_item"] = item

pytest_runtest_teardown

pytest_runtest_teardown(
    item: Item, nextitem: Item | None
) -> None

Clear current item and close all connections opened during this test.

Source code in src/pytest_adbc_replay/plugin.py
def pytest_runtest_teardown(item: pytest.Item, nextitem: pytest.Item | None) -> None:  # noqa: ARG001
    """Clear current item and close all connections opened during this test."""
    with _ITEM_LOCK:
        _auto_patch_state["current_item"] = None
    connections = _OPEN_CONNECTIONS.pop(id(item), [])
    for conn in connections:
        with contextlib.suppress(Exception):
            conn.close()

pytest_report_header

pytest_report_header(config: Config) -> str

Display active record mode in the pytest session header (DX-01).

Source code in src/pytest_adbc_replay/plugin.py
def pytest_report_header(config: pytest.Config) -> str:
    """Display active record mode in the pytest session header (DX-01)."""
    cli_mode = cast("str | None", config.getoption("--adbc-record"))
    ini_mode: str = cast("str", config.getini("adbc_record_mode")) or "none"
    mode: str = cli_mode if cli_mode is not None else ini_mode
    return f"adbc-replay: record mode = {mode}"

adbc_param_serialisers

adbc_param_serialisers() -> (
    dict[Any, dict[str, Any]] | None
)

Session-scoped fixture providing custom parameter serialisers for ADBC replay.

Override this fixture in your conftest.py to register custom serialisers for non-JSON-native parameter types (e.g. numpy arrays, custom date wrappers).

Returns:

Type Description
dict[Any, dict[str, Any]] | None

A dict mapping Python types to serialiser dicts, or None to use defaults.

dict[Any, dict[str, Any]] | None

Each serialiser dict must have "serialise" and "type_tag" keys, and

dict[Any, dict[str, Any]] | None

optionally a "deserialise" key.

Example::

import pytest
import numpy as np
from pytest_adbc_replay import NO_DEFAULT_SERIALISERS

@pytest.fixture(scope="session")
def adbc_param_serialisers():
    return {
        np.int64: {
            "type_tag": "numpy.int64",
            "serialise": lambda v: {"value": int(v)},
        },
    }
Source code in src/pytest_adbc_replay/plugin.py
@pytest.fixture(scope="session")
def adbc_param_serialisers() -> dict[Any, dict[str, Any]] | None:
    """
    Session-scoped fixture providing custom parameter serialisers for ADBC replay.

    Override this fixture in your conftest.py to register custom serialisers
    for non-JSON-native parameter types (e.g. numpy arrays, custom date wrappers).

    Returns:
        A dict mapping Python types to serialiser dicts, or None to use defaults.
        Each serialiser dict must have "serialise" and "type_tag" keys, and
        optionally a "deserialise" key.

    Example::

        import pytest
        import numpy as np
        from pytest_adbc_replay import NO_DEFAULT_SERIALISERS

        @pytest.fixture(scope="session")
        def adbc_param_serialisers():
            return {
                np.int64: {
                    "type_tag": "numpy.int64",
                    "serialise": lambda v: {"value": int(v)},
                },
            }
    """
    return None

adbc_scrubber

adbc_scrubber() -> object

Session-scoped fixture providing a scrubbing callback for recorded data.

Override this fixture in your conftest.py to register a callback that scrubs sensitive values before they are written to cassette files.

The callback receives (params, driver_name) where params is the already config-scrubbed parameter dict (after adbc_scrub_keys is applied) and driver_name is the ADBC driver module name string.

If the callback returns None, the config-scrubbed params are used unchanged. If it returns a dict, that dict replaces the params.

Returns:

Type Description
object

A callable scrub(params, driver_name) -> dict | None, or None

object

to use no fixture-level scrubbing.

Example::

@pytest.fixture(scope="session")
def adbc_scrubber():
    def scrub(params, driver_name):
        if isinstance(params, dict):
            return {k: "REDACTED" if k == "secret" else v
                    for k, v in params.items()}
        return params
    return scrub
Source code in src/pytest_adbc_replay/plugin.py
@pytest.fixture(scope="session")
def adbc_scrubber() -> object:
    """
    Session-scoped fixture providing a scrubbing callback for recorded data.

    Override this fixture in your conftest.py to register a callback that
    scrubs sensitive values before they are written to cassette files.

    The callback receives ``(params, driver_name)`` where ``params`` is the
    already config-scrubbed parameter dict (after ``adbc_scrub_keys`` is
    applied) and ``driver_name`` is the ADBC driver module name string.

    If the callback returns ``None``, the config-scrubbed params are used
    unchanged. If it returns a dict, that dict replaces the params.

    Returns:
        A callable ``scrub(params, driver_name) -> dict | None``, or ``None``
        to use no fixture-level scrubbing.

    Example::

        @pytest.fixture(scope="session")
        def adbc_scrubber():
            def scrub(params, driver_name):
                if isinstance(params, dict):
                    return {k: "REDACTED" if k == "secret" else v
                            for k, v in params.items()}
                return params
            return scrub
    """
    return None

adbc_replay

adbc_replay(
    request: FixtureRequest,
    adbc_param_serialisers: dict[Any, dict[str, Any]]
    | None,
    adbc_scrubber: object,
) -> ReplaySession

Session-scoped fixture providing ADBC record/replay state.

Returns a ReplaySession whose .wrap() method creates per-test ReplayConnection instances. Call .wrap() from your function-scoped fixture -- it reads the adbc_cassette marker from request.node.

Example::

@pytest.fixture
def my_connection(adbc_replay, request):
    return adbc_replay.wrap(
        "adbc_driver_snowflake",
        db_kwargs={"uri": os.environ["SNOWFLAKE_URI"]},
        request=request,
    )
Source code in src/pytest_adbc_replay/plugin.py
@pytest.fixture(scope="session")
def adbc_replay(
    request: pytest.FixtureRequest,
    adbc_param_serialisers: dict[Any, dict[str, Any]] | None,
    adbc_scrubber: object,
) -> ReplaySession:
    """
    Session-scoped fixture providing ADBC record/replay state.

    Returns a ReplaySession whose .wrap() method creates per-test
    ReplayConnection instances. Call .wrap() from your function-scoped
    fixture -- it reads the adbc_cassette marker from request.node.

    Example::

        @pytest.fixture
        def my_connection(adbc_replay, request):
            return adbc_replay.wrap(
                "adbc_driver_snowflake",
                db_kwargs={"uri": os.environ["SNOWFLAKE_URI"]},
                request=request,
            )
    """
    cli_mode = cast("str | None", request.config.getoption("--adbc-record"))
    ini_mode: str = cast("str", request.config.getini("adbc_record_mode")) or "none"
    mode: str = cli_mode if cli_mode is not None else ini_mode

    raw_cassette_dir: str = (
        cast("str", request.config.getini("adbc_cassette_dir")) or "tests/cassettes"
    )
    cassette_dir = Path(raw_cassette_dir)

    raw_dialect_lines: list[str] = cast("list[str]", request.config.getini("adbc_dialect")) or []
    dialect_global, dialect_per_driver = _parse_dialect(raw_dialect_lines)

    raw_scrub_keys: list[str] = cast("list[str]", request.config.getini("adbc_scrub_keys")) or []
    global_keys, per_driver_keys = _parse_scrub_keys(raw_scrub_keys)

    raw_diff_keys: list[str] = (
        cast("list[str]", request.config.getini("adbc_cassette_differentiator_keys")) or []
    )
    differentiator_keys = _parse_differentiator_keys(raw_diff_keys)

    session = ReplaySession(
        mode=mode,
        cassette_dir=cassette_dir,
        param_serialisers=adbc_param_serialisers,
        scrubber=adbc_scrubber,
        dialect_global=dialect_global,
        dialect_per_driver=dialect_per_driver,
        scrub_keys_global=global_keys,
        scrub_keys_per_driver=per_driver_keys,
        differentiator_keys_default=differentiator_keys,
    )
    # Overwrite the eagerly-initialized session_state (set in pytest_sessionstart)
    # with this fully-configured instance that includes param_serialisers and scrubber.
    _auto_patch_state["session_state"] = session
    return session

adbc_connect

adbc_connect(
    adbc_replay: ReplaySession, request: FixtureRequest
) -> Generator[Any, None, None]

Function-scoped factory fixture for creating ADBC replay connections explicitly.

Use this as the escape hatch when adbc_auto_patch is not appropriate -- for example, when you need a session-scoped or module-scoped connection, or when you prefer explicit control over connection creation.

Returns a callable: (driver_module_name: str, **db_kwargs) -> ReplayConnection

The fixture closes all opened connections when the test finishes. Cassette paths follow the per-driver subdirectory layout used by auto-patch.

Example::

@pytest.mark.adbc_cassette("my_test")
def test_my_query(adbc_connect):
    conn = adbc_connect("adbc_driver_snowflake.dbapi", uri=os.environ["SF_URI"])
    cursor = conn.cursor()
    cursor.execute("SELECT 1")
Source code in src/pytest_adbc_replay/plugin.py
@pytest.fixture
def adbc_connect(
    adbc_replay: ReplaySession,
    request: pytest.FixtureRequest,
) -> Generator[Any, None, None]:
    """
    Function-scoped factory fixture for creating ADBC replay connections explicitly.

    Use this as the escape hatch when ``adbc_auto_patch`` is not appropriate --
    for example, when you need a session-scoped or module-scoped connection, or
    when you prefer explicit control over connection creation.

    Returns a callable: ``(driver_module_name: str, **db_kwargs) -> ReplayConnection``

    The fixture closes all opened connections when the test finishes. Cassette
    paths follow the per-driver subdirectory layout used by auto-patch.

    Example::

        @pytest.mark.adbc_cassette("my_test")
        def test_my_query(adbc_connect):
            conn = adbc_connect("adbc_driver_snowflake.dbapi", uri=os.environ["SF_URI"])
            cursor = conn.cursor()
            cursor.execute("SELECT 1")
    """
    opened: list[Any] = []

    def _factory(driver_module_name: str, **db_kwargs: Any) -> Any:
        conn = adbc_replay.wrap_from_item(
            driver_module_name,
            request.node,
            db_kwargs=db_kwargs,
        )
        opened.append(conn)
        return conn

    yield _factory

    for conn in opened:
        with contextlib.suppress(Exception):
            conn.close()