from __future__ import annotations
from collections.abc import Sequence
from typing import Any
from typing_extensions import Literal, TypedDict
from openfisca_core.types import TaxBenefitSystem
import dataclasses
import os
import pathlib
import sys
import textwrap
import traceback
import warnings
import pytest
from openfisca_core.errors import SituationParsingError, VariableNotFound
from openfisca_core.simulations import SimulationBuilder
from openfisca_core.tools import assert_near
from openfisca_core.warnings import LibYAMLWarning
[docs]
class Options(TypedDict, total=False):
    aggregate: bool
    ignore_variables: Sequence[str] | None
    max_depth: int
    name_filter: str | None
    only_variables: Sequence[str] | None
    pdb: bool
    performance_graph: bool
    performance_tables: bool
    verbose: bool 
[docs]
@dataclasses.dataclass(frozen=True)
class ErrorMargin:
    __root__: dict[str | Literal["default"], float | None]
    def __getitem__(self, key: str) -> float | None:
        if key in self.__root__:
            return self.__root__[key]
        return self.__root__["default"] 
[docs]
@dataclasses.dataclass
class Test:
    absolute_error_margin: ErrorMargin
    relative_error_margin: ErrorMargin
    name: str = ""
    input: dict[str, float | dict[str, float]] = dataclasses.field(default_factory=dict)
    output: dict[str, float | dict[str, float]] | None = None
    period: str | None = None
    reforms: Sequence[str] = dataclasses.field(default_factory=list)
    keywords: Sequence[str] | None = None
    extensions: Sequence[str] = dataclasses.field(default_factory=list)
    description: str | None = None
    max_spiral_loops: int | None = None 
def build_test(params: dict[str, Any]) -> Test:
    for key in ["absolute_error_margin", "relative_error_margin"]:
        value = params.get(key)
        if value is None:
            value = {"default": None}
        elif isinstance(value, (float, int, str)):
            value = {"default": float(value)}
        params[key] = ErrorMargin(value)
    return Test(**params)
def import_yaml():
    import yaml
    try:
        from yaml import CLoader as Loader
    except ImportError:
        message = [
            "libyaml is not installed in your environment.",
            "This can make your test suite slower to run. Once you have installed libyaml, ",
            "run 'pip uninstall pyyaml && pip install pyyaml --no-cache-dir'",
            "so that it is used in your Python environment.",
        ]
        warnings.warn(" ".join(message), LibYAMLWarning, stacklevel=2)
        from yaml import SafeLoader as Loader
    return yaml, Loader
TEST_KEYWORDS = {
    "absolute_error_margin",
    "description",
    "extensions",
    "ignore_variables",
    "input",
    "keywords",
    "max_spiral_loops",
    "name",
    "only_variables",
    "output",
    "period",
    "reforms",
    "relative_error_margin",
}
yaml, Loader = import_yaml()
_tax_benefit_system_cache: dict = {}
options: Options = Options()
[docs]
def run_tests(
    tax_benefit_system: TaxBenefitSystem,
    paths: str | Sequence[str],
    options: Options = options,
) -> int:
    """Runs all the YAML tests contained in a file or a directory.
    If ``path`` is a directory, subdirectories will be recursively explored.
    Args:
        tax_benefit_system: the tax-benefit system to use to run the tests.
        paths: A path, or a list of paths, towards the files or directories containing the tests to run. If a path is a directory, subdirectories will be recursively explored.
        options: See more details below.
    Returns:
        The number of successful tests executed.
    Raises:
        :exc:`AssertionError`: if a test does not pass.
    **Testing options**:
    +-------------------------------+-----------+-------------------------------------------+
    | Key                           | Type      | Role                                      |
    +===============================+===========+===========================================+
    | verbose                       | ``bool``  |                                           |
    +-------------------------------+-----------+ See :any:`openfisca_test` options doc     |
    | name_filter                   | ``str``   |                                           |
    +-------------------------------+-----------+-------------------------------------------+
    """
    argv = []
    plugins = [OpenFiscaPlugin(tax_benefit_system, options)]
    if options.get("pdb"):
        argv.append("--pdb")
    if options.get("verbose"):
        argv.append("--verbose")
    if isinstance(paths, str):
        paths = [paths]
    return pytest.main([*argv, *paths], plugins=plugins) 
[docs]
class YamlFile(pytest.File):
    def __init__(self, *, tax_benefit_system, options, **kwargs) -> None:
        super().__init__(**kwargs)
        self.tax_benefit_system = tax_benefit_system
        self.options = options
[docs]
    def collect(self):
        try:
            tests = yaml.load(open(self.path), Loader=Loader)
        except (yaml.scanner.ScannerError, yaml.parser.ParserError, TypeError):
            message = os.linesep.join(
                [
                    traceback.format_exc(),
                    f"'{self.path}' is not a valid YAML file. Check the stack trace above for more details.",
                ],
            )
            raise ValueError(message)
        if not isinstance(tests, list):
            tests: Sequence[dict] = [tests]
        for test in tests:
            if not self.should_ignore(test):
                yield YamlItem.from_parent(
                    self,
                    name="",
                    baseline_tax_benefit_system=self.tax_benefit_system,
                    test=test,
                    options=self.options,
                ) 
    def should_ignore(self, test):
        name_filter = self.options.get("name_filter")
        return (
            name_filter is not None
            and name_filter not in os.path.splitext(os.path.basename(self.path))[0]
            and name_filter not in test.get("name", "")
            and name_filter not in test.get("keywords", [])
        ) 
class YamlItem(pytest.Item):
    """Terminal nodes of the test collection tree."""
    def __init__(self, *, baseline_tax_benefit_system, test, options, **kwargs) -> None:
        super().__init__(**kwargs)
        self.baseline_tax_benefit_system = baseline_tax_benefit_system
        self.options = options
        self.test = build_test(test)
        self.simulation = None
        self.tax_benefit_system = None
    def runtest(self) -> None:
        self.name = self.test.name
        if self.test.output is None:
            msg = f"Missing key 'output' in test '{self.name}' in file '{self.path}'"
            raise ValueError(msg)
        self.tax_benefit_system = _get_tax_benefit_system(
            self.baseline_tax_benefit_system,
            self.test.reforms,
            self.test.extensions,
        )
        builder = SimulationBuilder()
        input = self.test.input
        period = self.test.period
        max_spiral_loops = self.test.max_spiral_loops
        verbose = self.options.get("verbose")
        aggregate = self.options.get("aggregate")
        max_depth = self.options.get("max_depth")
        performance_graph = self.options.get("performance_graph")
        performance_tables = self.options.get("performance_tables")
        try:
            builder.set_default_period(period)
            self.simulation = builder.build_from_dict(self.tax_benefit_system, input)
        except (VariableNotFound, SituationParsingError):
            raise
        except Exception as e:
            error_message = os.linesep.join(
                [str(e), "", f"Unexpected error raised while parsing '{self.path}'"],
            )
            raise ValueError(error_message).with_traceback(
                sys.exc_info()[2],
            ) from e  # Keep the stack trace from the root error
        if max_spiral_loops:
            self.simulation.max_spiral_loops = max_spiral_loops
        try:
            self.simulation.trace = verbose or performance_graph or performance_tables
            self.check_output()
        finally:
            tracer = self.simulation.tracer
            if verbose:
                self.print_computation_log(tracer, aggregate, max_depth)
            if performance_graph:
                self.generate_performance_graph(tracer)
            if performance_tables:
                self.generate_performance_tables(tracer)
    def print_computation_log(self, tracer, aggregate, max_depth) -> None:
        tracer.print_computation_log(aggregate, max_depth)
    def generate_performance_graph(self, tracer) -> None:
        tracer.generate_performance_graph(".")
    def generate_performance_tables(self, tracer) -> None:
        tracer.generate_performance_tables(".")
    def check_output(self) -> None:
        output = self.test.output
        if output is None:
            return
        for key, expected_value in output.items():
            if self.tax_benefit_system.get_variable(key):  # If key is a variable
                self.check_variable(key, expected_value, self.test.period)
            elif self.simulation.populations.get(key):  # If key is an entity singular
                for variable_name, value in expected_value.items():
                    self.check_variable(variable_name, value, self.test.period)
            else:
                population = self.simulation.get_population(plural=key)
                if population is not None:  # If key is an entity plural
                    for instance_id, instance_values in expected_value.items():
                        for variable_name, value in instance_values.items():
                            entity_index = population.get_index(instance_id)
                            self.check_variable(
                                variable_name,
                                value,
                                self.test.period,
                                entity_index,
                            )
                else:
                    raise VariableNotFound(key, self.tax_benefit_system)
    def check_variable(
        self,
        variable_name: str,
        expected_value,
        period,
        entity_index=None,
    ):
        if self.should_ignore_variable(variable_name):
            return None
        if isinstance(expected_value, dict):
            for requested_period, expected_value_at_period in expected_value.items():
                self.check_variable(
                    variable_name,
                    expected_value_at_period,
                    requested_period,
                    entity_index,
                )
            return None
        actual_value = self.simulation.calculate(variable_name, period)
        if entity_index is not None:
            actual_value = actual_value[entity_index]
        return assert_near(
            actual_value,
            expected_value,
            self.test.absolute_error_margin[variable_name],
            f"{variable_name}@{period}: ",
            self.test.relative_error_margin[variable_name],
        )
    def should_ignore_variable(self, variable_name: str):
        only_variables = self.options.get("only_variables")
        ignore_variables = self.options.get("ignore_variables")
        variable_ignored = (
            ignore_variables is not None and variable_name in ignore_variables
        )
        variable_not_tested = (
            only_variables is not None and variable_name not in only_variables
        )
        return variable_ignored or variable_not_tested
    def repr_failure(self, excinfo):
        if not isinstance(
            excinfo.value,
            (AssertionError, VariableNotFound, SituationParsingError),
        ):
            return super().repr_failure(excinfo)
        message = excinfo.value.args[0]
        if isinstance(excinfo.value, SituationParsingError):
            message = f"Could not parse situation described: {message}"
        return os.linesep.join(
            [
                f"{self.path!s}:",
                f"  Test '{self.name!s}':",
                textwrap.indent(message, "    "),
            ],
        )
class OpenFiscaPlugin:
    def __init__(self, tax_benefit_system, options) -> None:
        self.tax_benefit_system = tax_benefit_system
        self.options = options
    def pytest_collect_file(self, parent, path):
        """Called by pytest for all plugins.
        :return: The collector for test methods.
        """
        if path.ext in [".yaml", ".yml"]:
            return YamlFile.from_parent(
                parent,
                path=pathlib.Path(path),
                tax_benefit_system=self.tax_benefit_system,
                options=self.options,
            )
        return None
def _get_tax_benefit_system(baseline, reforms, extensions):
    if not isinstance(reforms, list):
        reforms = [reforms]
    if not isinstance(extensions, list):
        extensions = [extensions]
    # keep reforms order in cache, ignore extensions order
    key = hash((id(baseline), ":".join(reforms), frozenset(extensions)))
    if _tax_benefit_system_cache.get(key):
        return _tax_benefit_system_cache.get(key)
    current_tax_benefit_system = baseline.clone()
    for reform_path in reforms:
        current_tax_benefit_system = current_tax_benefit_system.apply_reform(
            reform_path,
        )
    for extension in extensions:
        current_tax_benefit_system.load_extension(extension)
    _tax_benefit_system_cache[key] = current_tax_benefit_system
    return current_tax_benefit_system