Source code for openfisca_core.tools.test_runner

# Avoid T201 warning for print statements in this test runner
# flake8: noqa: T201
from __future__ import annotations

from collections.abc import Sequence
from typing import Any
from typing_extensions import Literal, TypedDict

from openfisca_core.types import TaxBenefitSystem

import dataclasses
import json
import os
import pathlib
import shutil
import subprocess
import sys
import textwrap
import time
import traceback
import warnings

# Unix-specific modules for parallel testing (not available on Windows)
try:
    import pty
    import select

    PARALLEL_AVAILABLE = True
except ImportError:
    PARALLEL_AVAILABLE = False

import pytest

from openfisca_core.errors import SituationParsingError, VariableNotFound
from openfisca_core.simulations import SimulationBuilder
from openfisca_core.tools import assert_near
from openfisca_core.warnings import LibYAMLWarning


[docs] class Options(TypedDict, total=False): aggregate: bool ignore_variables: Sequence[str] | None max_depth: int name_filter: str | None only_variables: Sequence[str] | None pdb: bool performance_graph: bool performance_tables: bool verbose: bool
[docs] @dataclasses.dataclass(frozen=True) class ErrorMargin: __root__: dict[str | Literal["default"], float | None] def __getitem__(self, key: str) -> float | None: if key in self.__root__: return self.__root__[key] return self.__root__["default"]
[docs] @dataclasses.dataclass class Test: absolute_error_margin: ErrorMargin relative_error_margin: ErrorMargin name: str = "" input: dict[str, float | dict[str, float]] = dataclasses.field(default_factory=dict) output: dict[str, float | dict[str, float]] | None = None period: str | None = None reforms: Sequence[str] = dataclasses.field(default_factory=list) keywords: Sequence[str] | None = None extensions: Sequence[str] = dataclasses.field(default_factory=list) description: str | None = None max_spiral_loops: int | None = None
def build_test(params: dict[str, Any]) -> Test: for key in ["absolute_error_margin", "relative_error_margin"]: value = params.get(key) if value is None: value = {"default": None} elif isinstance(value, (float, int, str)): value = {"default": float(value)} params[key] = ErrorMargin(value) return Test(**params) def import_yaml(): import yaml try: from yaml import CLoader as Loader except ImportError: message = [ "libyaml is not installed in your environment.", "This can make your test suite slower to run. Once you have installed libyaml, ", "run 'pip uninstall pyyaml && pip install pyyaml --no-cache-dir'", "so that it is used in your Python environment.", ] warnings.warn(" ".join(message), LibYAMLWarning, stacklevel=2) from yaml import SafeLoader as Loader return yaml, Loader TEST_KEYWORDS = { "absolute_error_margin", "description", "extensions", "ignore_variables", "input", "keywords", "max_spiral_loops", "name", "only_variables", "output", "period", "reforms", "relative_error_margin", } yaml, Loader = import_yaml() _tax_benefit_system_cache: dict = {} options: Options = Options() def _create_worker_environment(options: Options) -> dict: """Create environment variables for worker processes. Args: options: Test options to pass to workers Returns: Environment dictionary with OpenFisca configuration """ env = os.environ.copy() env["PYTHONUNBUFFERED"] = "1" # Ensure output is not buffered if options.get("country_package"): env["OPENFISCA_COUNTRY_PACKAGE"] = options.get("country_package") env["OPENFISCA_EXTENSIONS"] = json.dumps(options.get("extensions") or []) env["OPENFISCA_REFORMS"] = json.dumps(options.get("reforms") or []) env["OPENFISCA_OPTIONS"] = json.dumps(options) return env def _spawn_worker( batch: list[str], python_bin: str, env: dict, verbose: bool ) -> tuple[subprocess.Popen, int]: """Spawn a single worker process with PTY. Args: batch: List of test files for this worker python_bin: Path to Python executable env: Environment variables verbose: Whether to enable verbose output Returns: Tuple of (subprocess.Popen, master_fd) """ # Create PTY (pseudo-terminal) for real-time output capture # PTY ensures output is line-buffered and flushed immediately master_fd, slave_fd = pty.openpty() # Build pytest command for this worker cmd = [ python_bin, "-m", "pytest", "-p", "openfisca_core.tools.parallel_plugin", # Load our custom plugin "--maxfail=1", # Stop on first failure within this worker "--disable-warnings", # Reduce noise in output ] if verbose: cmd.append("-vv") cmd.extend(batch) # Add test files for this worker # Spawn worker process with PTY output p = subprocess.Popen( cmd, stdout=slave_fd, stderr=subprocess.STDOUT, env=env, close_fds=True ) os.close(slave_fd) # Close slave end in parent process return p, master_fd def _format_file_list(file_names: list[str], max_display: int = 3) -> str: """Format list of files for display. Args: file_names: List of file names max_display: Maximum number of files to show before truncating Returns: Formatted string of file names """ if len(file_names) <= max_display: return ", ".join(file_names) return ( f"{', '.join(file_names[:max_display])} (+{len(file_names) - max_display} more)" ) def _read_worker_output( readable: list, fd_to_idx: dict, outputs: list[str], verbose: bool ) -> None: """Read output from workers that have data available. Args: readable: List of file descriptors with data available fd_to_idx: Mapping from file descriptor to worker index outputs: List of accumulated outputs per worker verbose: Whether to print output in real-time """ for fd in readable: idx = fd_to_idx.get(fd) if idx is not None: try: chunk = os.read(fd, 4096).decode("utf-8", "replace") outputs[idx] += chunk if verbose: print(f"[Worker {idx}] {chunk}", end="", flush=True) except OSError: pass # FD closed or error, ignore def _terminate_workers(procs: list, running: set) -> None: """Terminate all running workers. Args: procs: List of (worker_id, subprocess.Popen) tuples running: Set of worker IDs still running """ for idx, p in procs: if idx in running: try: p.terminate() except (ProcessLookupError, PermissionError): pass # Process already dead or permission denied def discover_test_files( paths: Sequence[str], name_filter: str | None = None ) -> list[str]: """Discover all test files (YAML and Python) in the given paths. Args: paths: List of file or directory paths to search for test files name_filter: Optional string filter to match against file names Returns: Sorted list of unique absolute paths to test files Notes: - Accepts .yaml and .yml files - Only accepts Python files starting with 'test_' - Recursively explores directories """ files = [] yaml_exts = {".yaml", ".yml"} for p in paths: p = pathlib.Path(p) if p.is_file(): if p.suffix in yaml_exts: files.append(str(p.resolve())) elif p.suffix == ".py": # keep only test_*.py files if p.name.startswith("test_"): files.append(str(p.resolve())) elif p.is_dir(): # collect yaml files for ext in yaml_exts: for f in p.rglob(f"*{ext}"): files.append(str(f.resolve())) # collect only python test files named test_*.py for f in p.rglob("test_*.py"): files.append(str(f.resolve())) if name_filter: files = [f for f in files if name_filter in pathlib.Path(f).name] return sorted(set(files)) def run_tests_in_parallel(tax_benefit_system, paths, options, num_workers, verbose): """Run OpenFisca tests in parallel across multiple worker processes. This function implements parallel test execution by: 1. Discovering all test files in the given paths 2. Splitting test files into batches for each worker 3. Spawning pytest worker processes with the parallel_plugin 4. Monitoring worker progress and collecting output 5. Stopping all workers on first failure (fail-fast behavior) 6. Reporting results with colored output and timing Architecture: - Main process: orchestrates workers via subprocess.Popen - Each worker: independent pytest process with parallel_plugin loaded - Communication: PTY for real-time output capture - Configuration: passed via environment variables Args: tax_benefit_system: The tax-benefit system to test paths: List of paths containing test files options: Test options dict (name_filter, verbose, etc.) num_workers: Number of parallel workers (0 = auto-detect from CPU count) verbose: If True, print detailed output from each worker Returns: Exit code: 0 for success, 1 for failure Notes: - Uses fail-fast: stops all workers on first failure - Provides progress updates every 2 seconds - Captures and displays output from failed workers - Uses PTY to ensure proper output flushing - On Windows, falls back to single-threaded testing """ if not PARALLEL_AVAILABLE: print( "Parallel testing not available on this platform (requires Unix PTY support)." ) print("Falling back to single-threaded testing...") return run_tests(tax_benefit_system, paths, options) test_files = discover_test_files(paths, options.get("name_filter")) if not test_files: print("No test files found") return 0 # Auto-detect number of workers based on CPU count if num_workers <= 0: try: import multiprocessing # Use N-1 CPUs to leave one for the system num_workers = max(1, multiprocessing.cpu_count() - 1) except (ImportError, NotImplementedError): num_workers = 1 # Limit workers to number of test files (no point having idle workers) num_workers = min(num_workers, len(test_files)) # Split test files evenly across workers using round-robin distribution # This ensures balanced workload even with uneven file counts batches = [[] for _ in range(num_workers)] for i, f in enumerate(test_files): batches[i % num_workers].append(f) # Remove empty batches and adjust worker count batches = [b for b in batches if b] num_workers = len(batches) print(f"Running {len(test_files)} test files across {num_workers} workers...") # Prepare environment variables for pytest workers env = _create_worker_environment(options) # Get python executable with fallback python_bin = sys.executable or shutil.which("python3") or shutil.which("python") if not python_bin: print("Error: Could not find Python executable") return 1 # Initialize data structures for worker management procs = [] # List of (worker_id, subprocess.Popen) tuples fds = [] # List of (worker_id, master_fd) tuples for PTY communication outputs = ["" for _ in range(num_workers)] # Accumulated output from each worker worker_info = {} # Metadata about each worker (files, timing, status) start_time = time.time() # Launch worker processes for idx, batch in enumerate(batches): if not batch: continue # Spawn worker process p, master_fd = _spawn_worker(batch, python_bin, env, verbose) procs.append((idx, p)) fds.append((idx, master_fd)) # Store worker metadata for progress reporting file_names = [os.path.basename(f) for f in batch] file_str = _format_file_list(file_names) worker_info[idx] = { "files": file_str, "batch": batch, "start_time": time.time(), "status": "running", } print(f" Worker {idx}: {file_str}") print() running = set(i for i, _ in procs) # Set of worker IDs still running exit_codes = {} # Map of worker_id -> exit_code last_update = time.time() # For throttling progress updates # Create fd -> worker_idx mapping for O(1) lookup fd_to_idx = {fd: idx for idx, fd in fds} # Monitor workers until all complete or one fails while running: # Use select() to wait for output from any worker # Timeout of 0.1s allows periodic checking of process status rlist = [fd for (_, fd) in fds] readable, _, _ = select.select(rlist, [], [], 0.1) # Read output from any workers that have data available _read_worker_output(readable, fd_to_idx, outputs, verbose) # Print progress update every 2 seconds (avoid spamming) current_time = time.time() if current_time - last_update >= 2.0: completed = len(exit_codes) total = num_workers elapsed = current_time - start_time print( f"\rProgress: {completed}/{total} workers completed ({elapsed:.1f}s elapsed)", end="", flush=True, ) last_update = current_time # Check each worker for completion for idx2, p in procs: if idx2 in running: ret = p.poll() # Non-blocking check if process finished if ret is not None: # Worker completed - record exit code and timing exit_codes[idx2] = ret duration = time.time() - worker_info[idx2]["start_time"] worker_info[idx2]["status"] = "passed" if ret == 0 else "failed" worker_info[idx2]["duration"] = duration # Print completion with colored status indicator status_symbol = "✓" if ret == 0 else "✗" status_color = "\033[32m" if ret == 0 else "\033[31m" reset_color = "\033[0m" print( f"\r{status_color}{status_symbol}{reset_color} Worker {idx2}: {worker_info[idx2]['files']} ({duration:.1f}s)" ) running.remove(idx2) # Fail-fast: if any worker fails, terminate all others if ret != 0: _terminate_workers(procs, running) running.clear() break # Close file descriptors for _, fd in fds: try: os.close(fd) except OSError: pass total_duration = time.time() - start_time print() # Report failures for idx, code in exit_codes.items(): if code != 0: print(f"\n{'=' * 80}") print(f"Worker {idx} FAILED") print(f"{'=' * 80}") print(f"Files: {', '.join(worker_info[idx]['batch'])}") print(f"{'=' * 80}") print(outputs[idx]) return 1 # Success summary print(f"{'=' * 80}") print("✓ All tests passed!") print( f" {len(test_files)} test files across {num_workers} workers in {total_duration:.2f}s" ) print(f"{'=' * 80}") return 0
[docs] def run_tests( tax_benefit_system: TaxBenefitSystem, paths: str | Sequence[str], options: Options = options, ) -> int: """Runs all the YAML tests contained in a file or a directory. If ``path`` is a directory, subdirectories will be recursively explored. Args: tax_benefit_system: the tax-benefit system to use to run the tests. paths: A path, or a list of paths, towards the files or directories containing the tests to run. If a path is a directory, subdirectories will be recursively explored. options: See more details below. Returns: The number of successful tests executed. Raises: :exc:`AssertionError`: if a test does not pass. **Testing options**: +-------------------------------+-----------+-------------------------------------------+ | Key | Type | Role | +===============================+===========+===========================================+ | verbose | ``bool`` | | +-------------------------------+-----------+ See :any:`openfisca_test` options doc | | name_filter | ``str`` | | +-------------------------------+-----------+-------------------------------------------+ """ argv = [] plugins = [OpenFiscaPlugin(tax_benefit_system, options)] if options.get("pdb"): argv.append("--pdb") if options.get("verbose"): argv.append("--verbose") if isinstance(paths, str): paths = [paths] return pytest.main([*argv, *paths], plugins=plugins)
[docs] class YamlFile(pytest.File): def __init__(self, *, tax_benefit_system, options, **kwargs) -> None: super().__init__(**kwargs) self.tax_benefit_system = tax_benefit_system self.options = options
[docs] def collect(self): try: tests = yaml.load(open(self.path), Loader=Loader) except (yaml.scanner.ScannerError, yaml.parser.ParserError, TypeError): message = os.linesep.join( [ traceback.format_exc(), f"'{self.path}' is not a valid YAML file. Check the stack trace above for more details.", ], ) raise ValueError(message) if not isinstance(tests, list): tests: Sequence[dict] = [tests] for test in tests: if not self.should_ignore(test): yield YamlItem.from_parent( self, name="", baseline_tax_benefit_system=self.tax_benefit_system, test=test, options=self.options, )
def should_ignore(self, test): name_filter = self.options.get("name_filter") return ( name_filter is not None and name_filter not in os.path.splitext(os.path.basename(self.path))[0] and name_filter not in test.get("name", "") and name_filter not in test.get("keywords", []) )
class YamlItem(pytest.Item): """Terminal nodes of the test collection tree.""" def __init__(self, *, baseline_tax_benefit_system, test, options, **kwargs) -> None: super().__init__(**kwargs) self.baseline_tax_benefit_system = baseline_tax_benefit_system self.options = options self.test = build_test(test) self.simulation = None self.tax_benefit_system = None def runtest(self) -> None: self.name = self.test.name if self.test.output is None: msg = f"Missing key 'output' in test '{self.name}' in file '{self.path}'" raise ValueError(msg) self.tax_benefit_system = _get_tax_benefit_system( self.baseline_tax_benefit_system, self.test.reforms, self.test.extensions, ) builder = SimulationBuilder() input = self.test.input period = self.test.period max_spiral_loops = self.test.max_spiral_loops verbose = self.options.get("verbose") aggregate = self.options.get("aggregate") max_depth = self.options.get("max_depth") performance_graph = self.options.get("performance_graph") performance_tables = self.options.get("performance_tables") try: builder.set_default_period(period) self.simulation = builder.build_from_dict(self.tax_benefit_system, input) except (VariableNotFound, SituationParsingError): raise except Exception as e: error_message = os.linesep.join( [str(e), "", f"Unexpected error raised while parsing '{self.path}'"], ) raise ValueError(error_message).with_traceback( sys.exc_info()[2], ) from e # Keep the stack trace from the root error if max_spiral_loops: self.simulation.max_spiral_loops = max_spiral_loops try: self.simulation.trace = verbose or performance_graph or performance_tables self.check_output() finally: tracer = self.simulation.tracer if verbose: self.print_computation_log(tracer, aggregate, max_depth) if performance_graph: self.generate_performance_graph(tracer) if performance_tables: self.generate_performance_tables(tracer) def print_computation_log(self, tracer, aggregate, max_depth) -> None: tracer.print_computation_log(aggregate, max_depth) def generate_performance_graph(self, tracer) -> None: tracer.generate_performance_graph(".") def generate_performance_tables(self, tracer) -> None: tracer.generate_performance_tables(".") def check_output(self) -> None: output = self.test.output if output is None: return for key, expected_value in output.items(): if self.tax_benefit_system.get_variable(key): # If key is a variable self.check_variable(key, expected_value, self.test.period) elif self.simulation.populations.get(key): # If key is an entity singular for variable_name, value in expected_value.items(): self.check_variable(variable_name, value, self.test.period) else: population = self.simulation.get_population(plural=key) if population is not None: # If key is an entity plural for instance_id, instance_values in expected_value.items(): for variable_name, value in instance_values.items(): entity_index = population.get_index(instance_id) self.check_variable( variable_name, value, self.test.period, entity_index, ) else: raise VariableNotFound(key, self.tax_benefit_system) def check_variable( self, variable_name: str, expected_value, period, entity_index=None, ): if self.should_ignore_variable(variable_name): return None if isinstance(expected_value, dict): for requested_period, expected_value_at_period in expected_value.items(): self.check_variable( variable_name, expected_value_at_period, requested_period, entity_index, ) return None actual_value = self.simulation.calculate(variable_name, period) if entity_index is not None: actual_value = actual_value[entity_index] return assert_near( actual_value, expected_value, self.test.absolute_error_margin[variable_name], f"{variable_name}@{period}: ", self.test.relative_error_margin[variable_name], ) def should_ignore_variable(self, variable_name: str): only_variables = self.options.get("only_variables") ignore_variables = self.options.get("ignore_variables") variable_ignored = ( ignore_variables is not None and variable_name in ignore_variables ) variable_not_tested = ( only_variables is not None and variable_name not in only_variables ) return variable_ignored or variable_not_tested def repr_failure(self, excinfo): if not isinstance( excinfo.value, (AssertionError, VariableNotFound, SituationParsingError), ): return super().repr_failure(excinfo) message = excinfo.value.args[0] if isinstance(excinfo.value, SituationParsingError): message = f"Could not parse situation described: {message}" return os.linesep.join( [ f"{self.path!s}:", f" Test '{self.name!s}':", textwrap.indent(message, " "), ], ) class OpenFiscaPlugin: def __init__(self, tax_benefit_system, options) -> None: self.tax_benefit_system = tax_benefit_system self.options = options def pytest_collect_file(self, parent, path): """Called by pytest for all plugins. :return: The collector for test methods. """ if path.ext in [".yaml", ".yml"]: return YamlFile.from_parent( parent, path=pathlib.Path(path), tax_benefit_system=self.tax_benefit_system, options=self.options, ) return None def _get_tax_benefit_system(baseline, reforms, extensions): if not isinstance(reforms, list): reforms = [reforms] if not isinstance(extensions, list): extensions = [extensions] # keep reforms order in cache, ignore extensions order key = hash((id(baseline), ":".join(reforms), frozenset(extensions))) if _tax_benefit_system_cache.get(key): return _tax_benefit_system_cache.get(key) current_tax_benefit_system = baseline.clone() for reform_path in reforms: current_tax_benefit_system = current_tax_benefit_system.apply_reform( reform_path, ) for extension in extensions: current_tax_benefit_system.load_extension(extension) _tax_benefit_system_cache[key] = current_tax_benefit_system return current_tax_benefit_system