Source code for openfisca_core.tools.test_runner

from __future__ import annotations

from typing import Any, Dict, Optional, Sequence, Union
from typing_extensions import Literal, TypedDict

import dataclasses
import os
import sys
import textwrap
import traceback
import warnings

import pytest

from openfisca_core.errors import SituationParsingError, VariableNotFound
from openfisca_core.simulation_builder import SimulationBuilder
from openfisca_core.tools import assert_near
from openfisca_core.types import TaxBenefitSystem
from openfisca_core.warnings import LibYAMLWarning


[docs]class Options(TypedDict, total = False): aggregate: bool ignore_variables: Optional[Sequence[str]] max_depth: Optional[int] name_filter: Optional[str] only_variables: Optional[Sequence[str]] pdb: bool performance_graph: bool performance_tables: bool verbose: bool
[docs]@dataclasses.dataclass(frozen = True) class ErrorMargin: __root__: Dict[Union[str, Literal["default"]], Optional[float]] def __getitem__(self, key: str) -> Optional[float]: if key in self.__root__: return self.__root__[key] return self.__root__["default"]
[docs]@dataclasses.dataclass class Test: absolute_error_margin: ErrorMargin relative_error_margin: ErrorMargin name: str = "" input: Dict[str, Union[float, Dict[str, float]]] = dataclasses.field(default_factory = dict) output: Optional[Dict[str, Union[float, Dict[str, float]]]] = None period: Optional[str] = None reforms: Sequence[str] = dataclasses.field(default_factory = list) keywords: Optional[Sequence[str]] = None extensions: Sequence[str] = dataclasses.field(default_factory = list) description: Optional[str] = None max_spiral_loops: Optional[int] = None
def build_test(params: Dict[str, Any]) -> Test: for key in ["absolute_error_margin", "relative_error_margin"]: value = params.get(key) if value is None: value = {"default": None} elif isinstance(value, (float, int, str)): value = {"default": float(value)} params[key] = ErrorMargin(value) return Test(**params) def import_yaml(): import yaml try: from yaml import CLoader as Loader except ImportError: message = [ "libyaml is not installed in your environment.", "This can make your test suite slower to run. Once you have installed libyaml, ", "run 'pip uninstall pyyaml && pip install pyyaml --no-cache-dir'", "so that it is used in your Python environment." ] warnings.warn(" ".join(message), LibYAMLWarning) from yaml import SafeLoader as Loader return yaml, Loader TEST_KEYWORDS = {'absolute_error_margin', 'description', 'extensions', 'ignore_variables', 'input', 'keywords', 'max_spiral_loops', 'name', 'only_variables', 'output', 'period', 'reforms', 'relative_error_margin'} yaml, Loader = import_yaml() _tax_benefit_system_cache: Dict = {} options: Options = Options()
[docs]def run_tests( tax_benefit_system: TaxBenefitSystem, paths: Union[str, Sequence[str]], options: Options = options, ) -> int: """Runs all the YAML tests contained in a file or a directory. If ``path`` is a directory, subdirectories will be recursively explored. Args: tax_benefit_system: the tax-benefit system to use to run the tests. paths: A path, or a list of paths, towards the files or directories containing the tests to run. If a path is a directory, subdirectories will be recursively explored. options: See more details below. Returns: The number of successful tests executed. Raises: :exc:`AssertionError`: if a test does not pass. **Testing options**: +-------------------------------+-----------+-------------------------------------------+ | Key | Type | Role | +===============================+===========+===========================================+ | verbose | ``bool`` | | +-------------------------------+-----------+ See :any:`openfisca_test` options doc | | name_filter | ``str`` | | +-------------------------------+-----------+-------------------------------------------+ """ argv = [] plugins = [OpenFiscaPlugin(tax_benefit_system, options)] if options.get('pdb'): argv.append('--pdb') if options.get('verbose'): argv.append('--verbose') if isinstance(paths, str): paths = [paths] return pytest.main([*argv, *paths], plugins = plugins)
[docs]class YamlFile(pytest.File): def __init__(self, path, fspath, parent, tax_benefit_system, options): super(YamlFile, self).__init__(path, parent) self.tax_benefit_system = tax_benefit_system self.options = options
[docs] def collect(self): try: tests = yaml.load(self.fspath.open(), Loader = Loader) except (yaml.scanner.ScannerError, yaml.parser.ParserError, TypeError): message = os.linesep.join([ traceback.format_exc(), f"'{self.fspath}' is not a valid YAML file. Check the stack trace above for more details.", ]) raise ValueError(message) if not isinstance(tests, list): tests: Sequence[Dict] = [tests] for test in tests: if not self.should_ignore(test): yield YamlItem.from_parent( self, name = '', baseline_tax_benefit_system = self.tax_benefit_system, test = test, options = self.options, )
def should_ignore(self, test): name_filter = self.options.get('name_filter') return ( name_filter is not None and name_filter not in os.path.splitext(self.fspath.basename)[0] and name_filter not in test.get('name', '') and name_filter not in test.get('keywords', []) )
[docs]class YamlItem(pytest.Item): """ Terminal nodes of the test collection tree. """ def __init__(self, name, parent, baseline_tax_benefit_system, test, options): super(YamlItem, self).__init__(name, parent) self.baseline_tax_benefit_system = baseline_tax_benefit_system self.options = options self.test = build_test(test) self.simulation = None self.tax_benefit_system = None def runtest(self): self.name = self.test.name if self.test.output is None: msg = f"Missing key 'output' in test '{self.name}' in file '{self.fspath}'" raise ValueError(msg) self.tax_benefit_system = _get_tax_benefit_system( self.baseline_tax_benefit_system, self.test.reforms, self.test.extensions, ) builder = SimulationBuilder() input = self.test.input period = self.test.period max_spiral_loops = self.test.max_spiral_loops verbose = self.options.get('verbose') aggregate = self.options.get('aggregate') max_depth = self.options.get('max_depth') performance_graph = self.options.get('performance_graph') performance_tables = self.options.get('performance_tables') try: builder.set_default_period(period) self.simulation = builder.build_from_dict(self.tax_benefit_system, input) except (VariableNotFound, SituationParsingError): raise except Exception as e: error_message = os.linesep.join([str(e), '', f"Unexpected error raised while parsing '{self.fspath}'"]) raise ValueError(error_message).with_traceback(sys.exc_info()[2]) from e # Keep the stack trace from the root error if max_spiral_loops: self.simulation.max_spiral_loops = max_spiral_loops try: self.simulation.trace = verbose or performance_graph or performance_tables self.check_output() finally: tracer = self.simulation.tracer if verbose: self.print_computation_log(tracer, aggregate, max_depth) if performance_graph: self.generate_performance_graph(tracer) if performance_tables: self.generate_performance_tables(tracer) def print_computation_log(self, tracer, aggregate, max_depth): print("Computation log:") # noqa T001 tracer.print_computation_log(aggregate, max_depth) def generate_performance_graph(self, tracer): tracer.generate_performance_graph('.') def generate_performance_tables(self, tracer): tracer.generate_performance_tables('.') def check_output(self): output = self.test.output if output is None: return for key, expected_value in output.items(): if self.tax_benefit_system.get_variable(key): # If key is a variable self.check_variable(key, expected_value, self.test.period) elif self.simulation.populations.get(key): # If key is an entity singular for variable_name, value in expected_value.items(): self.check_variable(variable_name, value, self.test.period) else: population = self.simulation.get_population(plural = key) if population is not None: # If key is an entity plural for instance_id, instance_values in expected_value.items(): for variable_name, value in instance_values.items(): entity_index = population.get_index(instance_id) self.check_variable(variable_name, value, self.test.period, entity_index) else: raise VariableNotFound(key, self.tax_benefit_system) def check_variable(self, variable_name: str, expected_value, period, entity_index = None): if self.should_ignore_variable(variable_name): return if isinstance(expected_value, Dict): for requested_period, expected_value_at_period in expected_value.items(): self.check_variable(variable_name, expected_value_at_period, requested_period, entity_index) return None actual_value = self.simulation.calculate(variable_name, period) if entity_index is not None: actual_value = actual_value[entity_index] return assert_near( actual_value, expected_value, self.test.absolute_error_margin[variable_name], f"{variable_name}@{period}: ", self.test.relative_error_margin[variable_name], ) def should_ignore_variable(self, variable_name: str): only_variables = self.options.get('only_variables') ignore_variables = self.options.get('ignore_variables') variable_ignored = ignore_variables is not None and variable_name in ignore_variables variable_not_tested = only_variables is not None and variable_name not in only_variables return variable_ignored or variable_not_tested def repr_failure(self, excinfo): if not isinstance(excinfo.value, (AssertionError, VariableNotFound, SituationParsingError)): return super(YamlItem, self).repr_failure(excinfo) message = excinfo.value.args[0] if isinstance(excinfo.value, SituationParsingError): message = f"Could not parse situation described: {message}" return os.linesep.join([ f"{str(self.fspath)}:", f" Test '{str(self.name)}':", textwrap.indent(message, ' ') ])
class OpenFiscaPlugin(object): def __init__(self, tax_benefit_system, options): self.tax_benefit_system = tax_benefit_system self.options = options def pytest_collect_file(self, parent, path): """ Called by pytest for all plugins. :return: The collector for test methods. """ if path.ext in [".yaml", ".yml"]: return YamlFile.from_parent(parent, path = path, fspath = path, tax_benefit_system = self.tax_benefit_system, options = self.options) def _get_tax_benefit_system(baseline, reforms, extensions): if not isinstance(reforms, list): reforms = [reforms] if not isinstance(extensions, list): extensions = [extensions] # keep reforms order in cache, ignore extensions order key = hash((id(baseline), ':'.join(reforms), frozenset(extensions))) if _tax_benefit_system_cache.get(key): return _tax_benefit_system_cache.get(key) current_tax_benefit_system = baseline for reform_path in reforms: current_tax_benefit_system = current_tax_benefit_system.apply_reform(reform_path) for extension in extensions: current_tax_benefit_system = current_tax_benefit_system.clone() current_tax_benefit_system.load_extension(extension) _tax_benefit_system_cache[key] = current_tax_benefit_system return current_tax_benefit_system