Source code for openfisca_core.tools.test_runner

from __future__ import annotations

from collections.abc import Sequence
from typing import Any
from typing_extensions import Literal, TypedDict

from openfisca_core.types import TaxBenefitSystem

import dataclasses
import os
import pathlib
import sys
import textwrap
import traceback
import warnings

import pytest

from openfisca_core.errors import SituationParsingError, VariableNotFound
from openfisca_core.simulation_builder import SimulationBuilder
from openfisca_core.tools import assert_near
from openfisca_core.warnings import LibYAMLWarning


[docs] class Options(TypedDict, total=False): aggregate: bool ignore_variables: Sequence[str] | None max_depth: int | None name_filter: str | None only_variables: Sequence[str] | None pdb: bool performance_graph: bool performance_tables: bool verbose: bool
[docs] @dataclasses.dataclass(frozen=True) class ErrorMargin: __root__: dict[str | Literal["default"], float | None] def __getitem__(self, key: str) -> float | None: if key in self.__root__: return self.__root__[key] return self.__root__["default"]
[docs] @dataclasses.dataclass class Test: absolute_error_margin: ErrorMargin relative_error_margin: ErrorMargin name: str = "" input: dict[str, float | dict[str, float]] = dataclasses.field(default_factory=dict) output: dict[str, float | dict[str, float]] | None = None period: str | None = None reforms: Sequence[str] = dataclasses.field(default_factory=list) keywords: Sequence[str] | None = None extensions: Sequence[str] = dataclasses.field(default_factory=list) description: str | None = None max_spiral_loops: int | None = None
def build_test(params: dict[str, Any]) -> Test: for key in ["absolute_error_margin", "relative_error_margin"]: value = params.get(key) if value is None: value = {"default": None} elif isinstance(value, (float, int, str)): value = {"default": float(value)} params[key] = ErrorMargin(value) return Test(**params) def import_yaml(): import yaml try: from yaml import CLoader as Loader except ImportError: message = [ "libyaml is not installed in your environment.", "This can make your test suite slower to run. Once you have installed libyaml, ", "run 'pip uninstall pyyaml && pip install pyyaml --no-cache-dir'", "so that it is used in your Python environment.", ] warnings.warn(" ".join(message), LibYAMLWarning, stacklevel=2) from yaml import SafeLoader as Loader return yaml, Loader TEST_KEYWORDS = { "absolute_error_margin", "description", "extensions", "ignore_variables", "input", "keywords", "max_spiral_loops", "name", "only_variables", "output", "period", "reforms", "relative_error_margin", } yaml, Loader = import_yaml() _tax_benefit_system_cache: dict = {} options: Options = Options()
[docs] def run_tests( tax_benefit_system: TaxBenefitSystem, paths: str | Sequence[str], options: Options = options, ) -> int: """Runs all the YAML tests contained in a file or a directory. If ``path`` is a directory, subdirectories will be recursively explored. Args: tax_benefit_system: the tax-benefit system to use to run the tests. paths: A path, or a list of paths, towards the files or directories containing the tests to run. If a path is a directory, subdirectories will be recursively explored. options: See more details below. Returns: The number of successful tests executed. Raises: :exc:`AssertionError`: if a test does not pass. **Testing options**: +-------------------------------+-----------+-------------------------------------------+ | Key | Type | Role | +===============================+===========+===========================================+ | verbose | ``bool`` | | +-------------------------------+-----------+ See :any:`openfisca_test` options doc | | name_filter | ``str`` | | +-------------------------------+-----------+-------------------------------------------+ """ argv = [] plugins = [OpenFiscaPlugin(tax_benefit_system, options)] if options.get("pdb"): argv.append("--pdb") if options.get("verbose"): argv.append("--verbose") if isinstance(paths, str): paths = [paths] return pytest.main([*argv, *paths], plugins=plugins)
[docs] class YamlFile(pytest.File): def __init__(self, *, tax_benefit_system, options, **kwargs) -> None: super().__init__(**kwargs) self.tax_benefit_system = tax_benefit_system self.options = options
[docs] def collect(self): try: tests = yaml.load(open(self.path), Loader=Loader) except (yaml.scanner.ScannerError, yaml.parser.ParserError, TypeError): message = os.linesep.join( [ traceback.format_exc(), f"'{self.path}' is not a valid YAML file. Check the stack trace above for more details.", ], ) raise ValueError(message) if not isinstance(tests, list): tests: Sequence[dict] = [tests] for test in tests: if not self.should_ignore(test): yield YamlItem.from_parent( self, name="", baseline_tax_benefit_system=self.tax_benefit_system, test=test, options=self.options, )
def should_ignore(self, test): name_filter = self.options.get("name_filter") return ( name_filter is not None and name_filter not in os.path.splitext(os.path.basename(self.path))[0] and name_filter not in test.get("name", "") and name_filter not in test.get("keywords", []) )
class YamlItem(pytest.Item): """Terminal nodes of the test collection tree.""" def __init__(self, *, baseline_tax_benefit_system, test, options, **kwargs) -> None: super().__init__(**kwargs) self.baseline_tax_benefit_system = baseline_tax_benefit_system self.options = options self.test = build_test(test) self.simulation = None self.tax_benefit_system = None def runtest(self) -> None: self.name = self.test.name if self.test.output is None: msg = f"Missing key 'output' in test '{self.name}' in file '{self.path}'" raise ValueError(msg) self.tax_benefit_system = _get_tax_benefit_system( self.baseline_tax_benefit_system, self.test.reforms, self.test.extensions, ) builder = SimulationBuilder() input = self.test.input period = self.test.period max_spiral_loops = self.test.max_spiral_loops verbose = self.options.get("verbose") aggregate = self.options.get("aggregate") max_depth = self.options.get("max_depth") performance_graph = self.options.get("performance_graph") performance_tables = self.options.get("performance_tables") try: builder.set_default_period(period) self.simulation = builder.build_from_dict(self.tax_benefit_system, input) except (VariableNotFound, SituationParsingError): raise except Exception as e: error_message = os.linesep.join( [str(e), "", f"Unexpected error raised while parsing '{self.path}'"], ) raise ValueError(error_message).with_traceback( sys.exc_info()[2], ) from e # Keep the stack trace from the root error if max_spiral_loops: self.simulation.max_spiral_loops = max_spiral_loops try: self.simulation.trace = verbose or performance_graph or performance_tables self.check_output() finally: tracer = self.simulation.tracer if verbose: self.print_computation_log(tracer, aggregate, max_depth) if performance_graph: self.generate_performance_graph(tracer) if performance_tables: self.generate_performance_tables(tracer) def print_computation_log(self, tracer, aggregate, max_depth) -> None: tracer.print_computation_log(aggregate, max_depth) def generate_performance_graph(self, tracer) -> None: tracer.generate_performance_graph(".") def generate_performance_tables(self, tracer) -> None: tracer.generate_performance_tables(".") def check_output(self) -> None: output = self.test.output if output is None: return for key, expected_value in output.items(): if self.tax_benefit_system.get_variable(key): # If key is a variable self.check_variable(key, expected_value, self.test.period) elif self.simulation.populations.get(key): # If key is an entity singular for variable_name, value in expected_value.items(): self.check_variable(variable_name, value, self.test.period) else: population = self.simulation.get_population(plural=key) if population is not None: # If key is an entity plural for instance_id, instance_values in expected_value.items(): for variable_name, value in instance_values.items(): entity_index = population.get_index(instance_id) self.check_variable( variable_name, value, self.test.period, entity_index, ) else: raise VariableNotFound(key, self.tax_benefit_system) def check_variable( self, variable_name: str, expected_value, period, entity_index=None, ): if self.should_ignore_variable(variable_name): return None if isinstance(expected_value, dict): for requested_period, expected_value_at_period in expected_value.items(): self.check_variable( variable_name, expected_value_at_period, requested_period, entity_index, ) return None actual_value = self.simulation.calculate(variable_name, period) if entity_index is not None: actual_value = actual_value[entity_index] return assert_near( actual_value, expected_value, self.test.absolute_error_margin[variable_name], f"{variable_name}@{period}: ", self.test.relative_error_margin[variable_name], ) def should_ignore_variable(self, variable_name: str): only_variables = self.options.get("only_variables") ignore_variables = self.options.get("ignore_variables") variable_ignored = ( ignore_variables is not None and variable_name in ignore_variables ) variable_not_tested = ( only_variables is not None and variable_name not in only_variables ) return variable_ignored or variable_not_tested def repr_failure(self, excinfo): if not isinstance( excinfo.value, (AssertionError, VariableNotFound, SituationParsingError), ): return super().repr_failure(excinfo) message = excinfo.value.args[0] if isinstance(excinfo.value, SituationParsingError): message = f"Could not parse situation described: {message}" return os.linesep.join( [ f"{self.path!s}:", f" Test '{self.name!s}':", textwrap.indent(message, " "), ], ) class OpenFiscaPlugin: def __init__(self, tax_benefit_system, options) -> None: self.tax_benefit_system = tax_benefit_system self.options = options def pytest_collect_file(self, parent, path): """Called by pytest for all plugins. :return: The collector for test methods. """ if path.ext in [".yaml", ".yml"]: return YamlFile.from_parent( parent, path=pathlib.Path(path), tax_benefit_system=self.tax_benefit_system, options=self.options, ) return None def _get_tax_benefit_system(baseline, reforms, extensions): if not isinstance(reforms, list): reforms = [reforms] if not isinstance(extensions, list): extensions = [extensions] # keep reforms order in cache, ignore extensions order key = hash((id(baseline), ":".join(reforms), frozenset(extensions))) if _tax_benefit_system_cache.get(key): return _tax_benefit_system_cache.get(key) current_tax_benefit_system = baseline.clone() for reform_path in reforms: current_tax_benefit_system = current_tax_benefit_system.apply_reform( reform_path, ) for extension in extensions: current_tax_benefit_system.load_extension(extension) _tax_benefit_system_cache[key] = current_tax_benefit_system return current_tax_benefit_system