from __future__ import annotations
from collections.abc import Mapping
from typing import NamedTuple
from openfisca_core.types import (
    CorePopulation as Population,
    TaxBenefitSystem,
    Variable,
)
import tempfile
import warnings
import numpy
from openfisca_core import (
    commons,
    errors,
    indexed_enums,
    periods,
    tracers,
    warnings as core_warnings,
)
[docs]
class Simulation:
    """Represents a simulation, and handles the calculation logic."""
    tax_benefit_system: TaxBenefitSystem
    populations: dict[str, Population]
    invalidated_caches: set[Cache]
    def __init__(
        self,
        tax_benefit_system: TaxBenefitSystem,
        populations: Mapping[str, Population],
    ) -> None:
        """This constructor is reserved for internal use; see :any:`SimulationBuilder`,
        which is the preferred way to obtain a Simulation initialized with a consistent
        set of Entities.
        """
        self.tax_benefit_system = tax_benefit_system
        assert tax_benefit_system is not None
        self.populations = populations
        self.persons = self.populations[tax_benefit_system.person_entity.key]
        self.link_to_entities_instances()
        self.create_shortcuts()
        self.invalidated_caches = set()
        self.debug = False
        self.trace = False
        self.tracer = tracers.SimpleTracer()
        self.opt_out_cache = False
        # controls the spirals detection; check for performance impact if > 1
        self.max_spiral_loops: int = 1
        self.memory_config = None
        self._data_storage_dir = None
        self.start_computation_period = None
    @property
    def trace(self):
        return self._trace
    @trace.setter
    def trace(self, trace) -> None:
        self._trace = trace
        if trace:
            self.tracer = tracers.FullTracer()
        else:
            self.tracer = tracers.SimpleTracer()
    def link_to_entities_instances(self) -> None:
        for entity_instance in self.populations.values():
            entity_instance.simulation = self
    def create_shortcuts(self) -> None:
        for population in self.populations.values():
            # create shortcut simulation.person and simulation.household (for instance)
            setattr(self, population.entity.key, population)
    @property
    def data_storage_dir(self):
        """Temporary folder used to store intermediate calculation data in case the memory is saturated."""
        if self._data_storage_dir is None:
            self._data_storage_dir = tempfile.mkdtemp(prefix="openfisca_")
            message = [
                (
                    f"Intermediate results will be stored on disk in {self._data_storage_dir} in case of memory overflow."
                ),
                "You should remove this directory once you're done with your simulation.",
            ]
            warnings.warn(
                " ".join(message),
                core_warnings.TempfileWarning,
                stacklevel=2,
            )
        return self._data_storage_dir
    # ----- Calculation methods ----- #
[docs]
    def calculate(self, variable_name: str, period):
        """Calculate ``variable_name`` for ``period``."""
        if period is not None and not isinstance(period, periods.Period):
            period = periods.period(period)
        self.tracer.record_calculation_start(variable_name, period)
        try:
            result = self._calculate(variable_name, period)
            self.tracer.record_calculation_result(result)
            return result
        finally:
            self.tracer.record_calculation_end()
            self.purge_cache_of_invalid_values() 
    def _calculate(self, variable_name: str, period: periods.Period):
        """Calculate the variable ``variable_name`` for the period ``period``, using the variable formula if it exists.
        :returns: A numpy array containing the result of the calculation
        """
        variable: Variable | None
        population = self.get_variable_population(variable_name)
        holder = population.get_holder(variable_name)
        variable = self.tax_benefit_system.get_variable(
            variable_name,
            check_existence=True,
        )
        if variable is None:
            raise errors.VariableNotFoundError(variable_name, self.tax_benefit_system)
        self._check_period_consistency(period, variable)
        # First look for a value already cached
        cached_array = holder.get_array(period)
        if cached_array is not None:
            return cached_array
        array = None
        # First, try to run a formula
        if self.start_computation_period is not None:
            if not isinstance(self.start_computation_period, periods.Period):
                self.start_computation_period = periods.period(
                    self.start_computation_period
                )
            if period < self.start_computation_period:
                return holder.default_array()
        try:
            self._check_for_cycle(variable.name, period)
            array = self._run_formula(variable, population, period)
            # If no result, use the default value and cache it
            if array is None:
                array = holder.default_array()
            array = self._cast_formula_result(array, variable)
            holder.put_in_cache(array, period)
        except errors.SpiralError:
            array = holder.default_array()
        return array
    def purge_cache_of_invalid_values(self) -> None:
        # We wait for the end of calculate(), signalled by an empty stack, before purging the cache
        if self.tracer.stack:
            return
        for _name, _period in self.invalidated_caches:
            holder = self.get_holder(_name)
            holder.delete_arrays(_period)
        self.invalidated_caches = set()
    def calculate_add(self, variable_name: str, period):
        variable: Variable | None
        variable = self.tax_benefit_system.get_variable(
            variable_name,
            check_existence=True,
        )
        if variable is None:
            raise errors.VariableNotFoundError(variable_name, self.tax_benefit_system)
        if period is not None and not isinstance(period, periods.Period):
            period = periods.period(period)
        # Check that the requested period matches definition_period
        if periods.unit_weight(variable.definition_period) > periods.unit_weight(
            period.unit,
        ):
            msg = (
                f"Unable to compute variable '{variable.name}' for period "
                f"{period}: '{variable.name}' can only be computed for "
                f"{variable.definition_period}-long periods. You can use the "
                f"DIVIDE option to get an estimate of {variable.name}."
            )
            raise ValueError(
                msg,
            )
        if variable.definition_period not in (
            periods.DateUnit.isoformat + periods.DateUnit.isocalendar
        ):
            msg = (
                f"Unable to ADD constant variable '{variable.name}' over "
                f"the period {period}: eternal variables can't be summed "
                "over time."
            )
            raise ValueError(
                msg,
            )
        return sum(
            self.calculate(variable_name, sub_period)
            for sub_period in period.get_subperiods(variable.definition_period)
        )
    def calculate_divide(self, variable_name: str, period):
        variable: Variable | None
        variable = self.tax_benefit_system.get_variable(
            variable_name,
            check_existence=True,
        )
        if variable is None:
            raise errors.VariableNotFoundError(variable_name, self.tax_benefit_system)
        if period is not None and not isinstance(period, periods.Period):
            period = periods.period(period)
        if (
            periods.unit_weight(variable.definition_period)
            < periods.unit_weight(period.unit)
            or period.size > 1
        ):
            msg = (
                f"Can't calculate variable '{variable.name}' for period "
                f"{period}: '{variable.name}' can only be computed for "
                f"{variable.definition_period}-long periods. You can use the "
                f"ADD option to get an estimate of {variable.name}."
            )
            raise ValueError(
                msg,
            )
        if variable.definition_period not in (
            periods.DateUnit.isoformat + periods.DateUnit.isocalendar
        ):
            msg = (
                f"Unable to DIVIDE constant variable '{variable.name}' over "
                f"the period {period}: eternal variables can't be divided "
                "over time."
            )
            raise ValueError(
                msg,
            )
        if (
            period.unit
            not in (periods.DateUnit.isoformat + periods.DateUnit.isocalendar)
            or period.size != 1
        ):
            msg = (
                f"Unable to DIVIDE constant variable '{variable.name}' over "
                f"the period {period}: eternal variables can't be used "
                "as a denominator to divide a variable over time."
            )
            raise ValueError(
                msg,
            )
        if variable.definition_period == periods.DateUnit.YEAR:
            calculation_period = period.this_year
        elif variable.definition_period == periods.DateUnit.MONTH:
            calculation_period = period.first_month
        elif variable.definition_period == periods.DateUnit.DAY:
            calculation_period = period.first_day
        elif variable.definition_period == periods.DateUnit.WEEK:
            calculation_period = period.first_week
        else:
            calculation_period = period.first_weekday
        if period.unit == periods.DateUnit.YEAR:
            denominator = calculation_period.size_in_years
        elif period.unit == periods.DateUnit.MONTH:
            denominator = calculation_period.size_in_months
        elif period.unit == periods.DateUnit.DAY:
            denominator = calculation_period.size_in_days
        elif period.unit == periods.DateUnit.WEEK:
            denominator = calculation_period.size_in_weeks
        else:
            denominator = calculation_period.size_in_weekdays
        return self.calculate(variable_name, calculation_period) / denominator
    def calculate_output(self, variable_name: str, period):
        """Calculate the value of a variable using the ``calculate_output`` attribute of the variable."""
        variable: Variable | None
        variable = self.tax_benefit_system.get_variable(
            variable_name,
            check_existence=True,
        )
        if variable is None:
            raise errors.VariableNotFoundError(variable_name, self.tax_benefit_system)
        if variable.calculate_output is None:
            return self.calculate(variable_name, period)
        return variable.calculate_output(self, variable_name, period)
    def trace_parameters_at_instant(self, formula_period):
        return tracers.TracingParameterNodeAtInstant(
            self.tax_benefit_system.get_parameters_at_instant(formula_period),
            self.tracer,
        )
    def _run_formula(self, variable, population, period):
        """Find the ``variable`` formula for the given ``period`` if it exists, and apply it to ``population``."""
        formula = variable.get_formula(period)
        if formula is None:
            return None
        if self.trace:
            parameters_at = self.trace_parameters_at_instant
        else:
            parameters_at = self.tax_benefit_system.get_parameters_at_instant
        if formula.__code__.co_argcount == 2:
            array = formula(population, period)
        else:
            array = formula(population, period, parameters_at)
        return array
    def _check_period_consistency(self, period, variable) -> None:
        """Check that a period matches the variable definition_period."""
        if variable.definition_period == periods.DateUnit.ETERNITY:
            return  # For variables which values are constant in time, all periods are accepted
        if (
            variable.definition_period == periods.DateUnit.YEAR
            and period.unit != periods.DateUnit.YEAR
        ):
            msg = f"Unable to compute variable '{variable.name}' for period {period}: '{variable.name}' must be computed for a whole year. You can use the DIVIDE option to get an estimate of {variable.name} by dividing the yearly value by 12, or change the requested period to 'period.this_year'."
            raise ValueError(
                msg,
            )
        if (
            variable.definition_period == periods.DateUnit.MONTH
            and period.unit != periods.DateUnit.MONTH
        ):
            msg = f"Unable to compute variable '{variable.name}' for period {period}: '{variable.name}' must be computed for a whole month. You can use the ADD option to sum '{variable.name}' over the requested period, or change the requested period to 'period.first_month'."
            raise ValueError(
                msg,
            )
        if (
            variable.definition_period == periods.DateUnit.WEEK
            and period.unit != periods.DateUnit.WEEK
        ):
            msg = f"Unable to compute variable '{variable.name}' for period {period}: '{variable.name}' must be computed for a whole week. You can use the ADD option to sum '{variable.name}' over the requested period, or change the requested period to 'period.first_week'."
            raise ValueError(
                msg,
            )
        if period.size != 1:
            msg = f"Unable to compute variable '{variable.name}' for period {period}: '{variable.name}' must be computed for a whole {variable.definition_period}. You can use the ADD option to sum '{variable.name}' over the requested period."
            raise ValueError(
                msg,
            )
    def _cast_formula_result(self, value, variable):
        if variable.value_type == indexed_enums.Enum and not isinstance(
            value,
            indexed_enums.EnumArray,
        ):
            return variable.possible_values.encode(value)
        if not isinstance(value, numpy.ndarray):
            population = self.get_variable_population(variable.name)
            value = population.filled_array(value)
        if value.dtype != variable.dtype:
            return value.astype(variable.dtype)
        return value
    # ----- Handle circular dependencies in a calculation ----- #
    def _check_for_cycle(self, variable: str, period) -> None:
        """Raise an exception in the case of a circular definition, where evaluating a variable for
        a given period loops around to evaluating the same variable/period pair. Also guards, as
        a heuristic, against "quasicircles", where the evaluation of a variable at a period involves
        the same variable at a different period.
        """
        # The last frame is the current calculation, so it should be ignored from cycle detection
        previous_periods = [
            frame["period"]
            for frame in self.tracer.stack[:-1]
            if frame["name"] == variable
        ]
        if period in previous_periods:
            msg = f"Circular definition detected on formula {variable}@{period}"
            raise errors.CycleError(
                msg,
            )
        spiral = len(previous_periods) >= self.max_spiral_loops
        if spiral:
            self.invalidate_spiral_variables(variable)
            message = f"Quasicircular definition detected on formula {variable}@{period} involving {self.tracer.stack}"
            raise errors.SpiralError(message, variable)
    def invalidate_cache_entry(self, variable: str, period) -> None:
        self.invalidated_caches.add(Cache(variable, period))
    def invalidate_spiral_variables(self, variable: str) -> None:
        # Visit the stack, from the bottom (most recent) up; we know that we'll find
        # the variable implicated in the spiral (max_spiral_loops+1) times; we keep the
        # intermediate values computed (to avoid impacting performance) but we mark them
        # for deletion from the cache once the calculation ends.
        count = 0
        for frame in reversed(self.tracer.stack):
            self.invalidate_cache_entry(str(frame["name"]), frame["period"])
            if frame["name"] == variable:
                count += 1
                if count > self.max_spiral_loops:
                    break
    # ----- Methods to access stored values ----- #
[docs]
    def get_array(self, variable_name: str, period):
        """Return the value of ``variable_name`` for ``period``, if this value is already in the cache (if it has been set as an input or previously calculated).
        Unlike :meth:`.calculate`, this method *does not* trigger calculations and *does not* use any formula.
        """
        if period is not None and not isinstance(period, periods.Period):
            period = periods.period(period)
        return self.get_holder(variable_name).get_array(period) 
[docs]
    def get_holder(self, variable_name: str):
        """Get the holder associated with the variable."""
        return self.get_variable_population(variable_name).get_holder(variable_name) 
[docs]
    def get_memory_usage(self, variables=None):
        """Get data about the virtual memory usage of the simulation."""
        result = {"total_nb_bytes": 0, "by_variable": {}}
        for entity in self.populations.values():
            entity_memory_usage = entity.get_memory_usage(variables=variables)
            result["total_nb_bytes"] += entity_memory_usage["total_nb_bytes"]
            result["by_variable"].update(entity_memory_usage["by_variable"])
        return result 
    # ----- Misc ----- #
    def delete_arrays(self, variable, period=None) -> None:
        """Delete a variable's value for a given period.
        :param variable: the variable to be set
        :param period: the period for which the value should be deleted
        Example:
        >>> from openfisca_country_template import CountryTaxBenefitSystem
        >>> simulation = Simulation(CountryTaxBenefitSystem())
        >>> simulation.set_input("age", "2018-04", [12, 14])
        >>> simulation.set_input("age", "2018-05", [13, 14])
        >>> simulation.get_array("age", "2018-05")
        array([13, 14], dtype=int32)
        >>> simulation.delete_arrays("age", "2018-05")
        >>> simulation.get_array("age", "2018-04")
        array([12, 14], dtype=int32)
        >>> simulation.get_array("age", "2018-05") is None
        True
        >>> simulation.set_input("age", "2018-05", [13, 14])
        >>> simulation.delete_arrays("age")
        >>> simulation.get_array("age", "2018-04") is None
        True
        >>> simulation.get_array("age", "2018-05") is None
        True
        """
        self.get_holder(variable).delete_arrays(period)
    def get_known_periods(self, variable):
        """Get a list variable's known period, i.e. the periods where a value has been initialized and.
        :param variable: the variable to be set
        Example:
        >>> from openfisca_country_template import CountryTaxBenefitSystem
        >>> simulation = Simulation(CountryTaxBenefitSystem())
        >>> simulation.set_input("age", "2018-04", [12, 14])
        >>> simulation.set_input("age", "2018-05", [13, 14])
        >>> simulation.get_known_periods("age")
        [Period((u'month', Instant((2018, 5, 1)), 1)), Period((u'month', Instant((2018, 4, 1)), 1))]
        """
        return self.get_holder(variable).get_known_periods()
    def set_input(self, variable_name: str, period, value) -> None:
        """Set a variable's value for a given period.
        :param variable: the variable to be set
        :param value: the input value for the variable
        :param period: the period for which the value is set
        Example:
        >>> from openfisca_country_template import CountryTaxBenefitSystem
        >>> simulation = Simulation(CountryTaxBenefitSystem())
        >>> simulation.set_input("age", "2018-04", [12, 14])
        >>> simulation.get_array("age", "2018-04")
        array([12, 14], dtype=int32)
        If a ``set_input`` property has been set for the variable, this method may accept inputs for periods not matching the ``definition_period`` of the variable. To read more about this, check the `documentation <https://openfisca.org/doc/coding-the-legislation/35_periods.html#automatically-process-variable-inputs-defined-for-periods-not-matching-the-definitionperiod>`_.
        """
        variable: Variable | None
        variable = self.tax_benefit_system.get_variable(
            variable_name,
            check_existence=True,
        )
        if variable is None:
            raise errors.VariableNotFoundError(variable_name, self.tax_benefit_system)
        period = periods.period(period)
        if (variable.end is not None) and (period.start.date > variable.end):
            return
        self.get_holder(variable_name).set_input(period, value)
    def get_variable_population(self, variable_name: str) -> Population:
        variable: Variable | None
        variable = self.tax_benefit_system.get_variable(
            variable_name,
            check_existence=True,
        )
        if variable is None:
            raise errors.VariableNotFoundError(variable_name, self.tax_benefit_system)
        return self.populations[variable.entity.key]
    def get_population(self, plural: str | None = None) -> Population | None:
        return next(
            (
                population
                for population in self.populations.values()
                if population.entity.plural == plural
            ),
            None,
        )
    def get_entity(
        self,
        plural: str | None = None,
    ) -> Population | None:
        population = self.get_population(plural)
        return population and population.entity
    def describe_entities(self):
        return {
            population.entity.plural: population.ids
            for population in self.populations.values()
        }
    def clone(self, debug=False, trace=False):
        """Copy the simulation just enough to be able to run the copy without modifying the original simulation."""
        new = commons.empty_clone(self)
        new_dict = new.__dict__
        for key, value in self.__dict__.items():
            if key not in ("debug", "trace", "tracer"):
                new_dict[key] = value
        new.persons = self.persons.clone(new)
        setattr(new, new.persons.entity.key, new.persons)
        new.populations = {new.persons.entity.key: new.persons}
        for entity in self.tax_benefit_system.group_entities:
            population = self.populations[entity.key].clone(new)
            new.populations[entity.key] = population
            setattr(
                new,
                entity.key,
                population,
            )  # create shortcut simulation.household (for instance)
        new.debug = debug
        new.trace = trace
        return new 
class Cache(NamedTuple):
    variable: str
    period: periods.Period