Source code for openfisca_core.simulations.simulation

from __future__ import annotations

from collections.abc import Mapping
from typing import NamedTuple

from openfisca_core.types import (
    CorePopulation as Population,
    TaxBenefitSystem,
    Variable,
)

import tempfile
import warnings

import numpy

from openfisca_core import (
    commons,
    errors,
    indexed_enums,
    periods,
    tracers,
    warnings as core_warnings,
)


[docs] class Simulation: """Represents a simulation, and handles the calculation logic.""" tax_benefit_system: TaxBenefitSystem populations: dict[str, Population] invalidated_caches: set[Cache] def __init__( self, tax_benefit_system: TaxBenefitSystem, populations: Mapping[str, Population], ) -> None: """This constructor is reserved for internal use; see :any:`SimulationBuilder`, which is the preferred way to obtain a Simulation initialized with a consistent set of Entities. """ self.tax_benefit_system = tax_benefit_system assert tax_benefit_system is not None self.populations = populations self.persons = self.populations[tax_benefit_system.person_entity.key] self.link_to_entities_instances() self.create_shortcuts() self.invalidated_caches = set() self.debug = False self.trace = False self.tracer = tracers.SimpleTracer() self.opt_out_cache = False # controls the spirals detection; check for performance impact if > 1 self.max_spiral_loops: int = 1 self.memory_config = None self._data_storage_dir = None @property def trace(self): return self._trace @trace.setter def trace(self, trace) -> None: self._trace = trace if trace: self.tracer = tracers.FullTracer() else: self.tracer = tracers.SimpleTracer() def link_to_entities_instances(self) -> None: for entity_instance in self.populations.values(): entity_instance.simulation = self def create_shortcuts(self) -> None: for population in self.populations.values(): # create shortcut simulation.person and simulation.household (for instance) setattr(self, population.entity.key, population) @property def data_storage_dir(self): """Temporary folder used to store intermediate calculation data in case the memory is saturated.""" if self._data_storage_dir is None: self._data_storage_dir = tempfile.mkdtemp(prefix="openfisca_") message = [ ( f"Intermediate results will be stored on disk in {self._data_storage_dir} in case of memory overflow." ), "You should remove this directory once you're done with your simulation.", ] warnings.warn( " ".join(message), core_warnings.TempfileWarning, stacklevel=2, ) return self._data_storage_dir # ----- Calculation methods ----- #
[docs] def calculate(self, variable_name: str, period): """Calculate ``variable_name`` for ``period``.""" if period is not None and not isinstance(period, periods.Period): period = periods.period(period) self.tracer.record_calculation_start(variable_name, period) try: result = self._calculate(variable_name, period) self.tracer.record_calculation_result(result) return result finally: self.tracer.record_calculation_end() self.purge_cache_of_invalid_values()
def _calculate(self, variable_name: str, period: periods.Period): """Calculate the variable ``variable_name`` for the period ``period``, using the variable formula if it exists. :returns: A numpy array containing the result of the calculation """ variable: Variable | None population = self.get_variable_population(variable_name) holder = population.get_holder(variable_name) variable = self.tax_benefit_system.get_variable( variable_name, check_existence=True, ) if variable is None: raise errors.VariableNotFoundError(variable_name, self.tax_benefit_system) self._check_period_consistency(period, variable) # First look for a value already cached cached_array = holder.get_array(period) if cached_array is not None: return cached_array array = None # First, try to run a formula try: self._check_for_cycle(variable.name, period) array = self._run_formula(variable, population, period) # If no result, use the default value and cache it if array is None: array = holder.default_array() array = self._cast_formula_result(array, variable) holder.put_in_cache(array, period) except errors.SpiralError: array = holder.default_array() return array def purge_cache_of_invalid_values(self) -> None: # We wait for the end of calculate(), signalled by an empty stack, before purging the cache if self.tracer.stack: return for _name, _period in self.invalidated_caches: holder = self.get_holder(_name) holder.delete_arrays(_period) self.invalidated_caches = set() def calculate_add(self, variable_name: str, period): variable: Variable | None variable = self.tax_benefit_system.get_variable( variable_name, check_existence=True, ) if variable is None: raise errors.VariableNotFoundError(variable_name, self.tax_benefit_system) if period is not None and not isinstance(period, periods.Period): period = periods.period(period) # Check that the requested period matches definition_period if periods.unit_weight(variable.definition_period) > periods.unit_weight( period.unit, ): msg = ( f"Unable to compute variable '{variable.name}' for period " f"{period}: '{variable.name}' can only be computed for " f"{variable.definition_period}-long periods. You can use the " f"DIVIDE option to get an estimate of {variable.name}." ) raise ValueError( msg, ) if variable.definition_period not in ( periods.DateUnit.isoformat + periods.DateUnit.isocalendar ): msg = ( f"Unable to ADD constant variable '{variable.name}' over " f"the period {period}: eternal variables can't be summed " "over time." ) raise ValueError( msg, ) return sum( self.calculate(variable_name, sub_period) for sub_period in period.get_subperiods(variable.definition_period) ) def calculate_divide(self, variable_name: str, period): variable: Variable | None variable = self.tax_benefit_system.get_variable( variable_name, check_existence=True, ) if variable is None: raise errors.VariableNotFoundError(variable_name, self.tax_benefit_system) if period is not None and not isinstance(period, periods.Period): period = periods.period(period) if ( periods.unit_weight(variable.definition_period) < periods.unit_weight(period.unit) or period.size > 1 ): msg = ( f"Can't calculate variable '{variable.name}' for period " f"{period}: '{variable.name}' can only be computed for " f"{variable.definition_period}-long periods. You can use the " f"ADD option to get an estimate of {variable.name}." ) raise ValueError( msg, ) if variable.definition_period not in ( periods.DateUnit.isoformat + periods.DateUnit.isocalendar ): msg = ( f"Unable to DIVIDE constant variable '{variable.name}' over " f"the period {period}: eternal variables can't be divided " "over time." ) raise ValueError( msg, ) if ( period.unit not in (periods.DateUnit.isoformat + periods.DateUnit.isocalendar) or period.size != 1 ): msg = ( f"Unable to DIVIDE constant variable '{variable.name}' over " f"the period {period}: eternal variables can't be used " "as a denominator to divide a variable over time." ) raise ValueError( msg, ) if variable.definition_period == periods.DateUnit.YEAR: calculation_period = period.this_year elif variable.definition_period == periods.DateUnit.MONTH: calculation_period = period.first_month elif variable.definition_period == periods.DateUnit.DAY: calculation_period = period.first_day elif variable.definition_period == periods.DateUnit.WEEK: calculation_period = period.first_week else: calculation_period = period.first_weekday if period.unit == periods.DateUnit.YEAR: denominator = calculation_period.size_in_years elif period.unit == periods.DateUnit.MONTH: denominator = calculation_period.size_in_months elif period.unit == periods.DateUnit.DAY: denominator = calculation_period.size_in_days elif period.unit == periods.DateUnit.WEEK: denominator = calculation_period.size_in_weeks else: denominator = calculation_period.size_in_weekdays return self.calculate(variable_name, calculation_period) / denominator def calculate_output(self, variable_name: str, period): """Calculate the value of a variable using the ``calculate_output`` attribute of the variable.""" variable: Variable | None variable = self.tax_benefit_system.get_variable( variable_name, check_existence=True, ) if variable is None: raise errors.VariableNotFoundError(variable_name, self.tax_benefit_system) if variable.calculate_output is None: return self.calculate(variable_name, period) return variable.calculate_output(self, variable_name, period) def trace_parameters_at_instant(self, formula_period): return tracers.TracingParameterNodeAtInstant( self.tax_benefit_system.get_parameters_at_instant(formula_period), self.tracer, ) def _run_formula(self, variable, population, period): """Find the ``variable`` formula for the given ``period`` if it exists, and apply it to ``population``.""" formula = variable.get_formula(period) if formula is None: return None if self.trace: parameters_at = self.trace_parameters_at_instant else: parameters_at = self.tax_benefit_system.get_parameters_at_instant if formula.__code__.co_argcount == 2: array = formula(population, period) else: array = formula(population, period, parameters_at) return array def _check_period_consistency(self, period, variable) -> None: """Check that a period matches the variable definition_period.""" if variable.definition_period == periods.DateUnit.ETERNITY: return # For variables which values are constant in time, all periods are accepted if ( variable.definition_period == periods.DateUnit.YEAR and period.unit != periods.DateUnit.YEAR ): msg = f"Unable to compute variable '{variable.name}' for period {period}: '{variable.name}' must be computed for a whole year. You can use the DIVIDE option to get an estimate of {variable.name} by dividing the yearly value by 12, or change the requested period to 'period.this_year'." raise ValueError( msg, ) if ( variable.definition_period == periods.DateUnit.MONTH and period.unit != periods.DateUnit.MONTH ): msg = f"Unable to compute variable '{variable.name}' for period {period}: '{variable.name}' must be computed for a whole month. You can use the ADD option to sum '{variable.name}' over the requested period, or change the requested period to 'period.first_month'." raise ValueError( msg, ) if ( variable.definition_period == periods.DateUnit.WEEK and period.unit != periods.DateUnit.WEEK ): msg = f"Unable to compute variable '{variable.name}' for period {period}: '{variable.name}' must be computed for a whole week. You can use the ADD option to sum '{variable.name}' over the requested period, or change the requested period to 'period.first_week'." raise ValueError( msg, ) if period.size != 1: msg = f"Unable to compute variable '{variable.name}' for period {period}: '{variable.name}' must be computed for a whole {variable.definition_period}. You can use the ADD option to sum '{variable.name}' over the requested period." raise ValueError( msg, ) def _cast_formula_result(self, value, variable): if variable.value_type == indexed_enums.Enum and not isinstance( value, indexed_enums.EnumArray, ): return variable.possible_values.encode(value) if not isinstance(value, numpy.ndarray): population = self.get_variable_population(variable.name) value = population.filled_array(value) if value.dtype != variable.dtype: return value.astype(variable.dtype) return value # ----- Handle circular dependencies in a calculation ----- # def _check_for_cycle(self, variable: str, period) -> None: """Raise an exception in the case of a circular definition, where evaluating a variable for a given period loops around to evaluating the same variable/period pair. Also guards, as a heuristic, against "quasicircles", where the evaluation of a variable at a period involves the same variable at a different period. """ # The last frame is the current calculation, so it should be ignored from cycle detection previous_periods = [ frame["period"] for frame in self.tracer.stack[:-1] if frame["name"] == variable ] if period in previous_periods: msg = f"Circular definition detected on formula {variable}@{period}" raise errors.CycleError( msg, ) spiral = len(previous_periods) >= self.max_spiral_loops if spiral: self.invalidate_spiral_variables(variable) message = f"Quasicircular definition detected on formula {variable}@{period} involving {self.tracer.stack}" raise errors.SpiralError(message, variable) def invalidate_cache_entry(self, variable: str, period) -> None: self.invalidated_caches.add(Cache(variable, period)) def invalidate_spiral_variables(self, variable: str) -> None: # Visit the stack, from the bottom (most recent) up; we know that we'll find # the variable implicated in the spiral (max_spiral_loops+1) times; we keep the # intermediate values computed (to avoid impacting performance) but we mark them # for deletion from the cache once the calculation ends. count = 0 for frame in reversed(self.tracer.stack): self.invalidate_cache_entry(str(frame["name"]), frame["period"]) if frame["name"] == variable: count += 1 if count > self.max_spiral_loops: break # ----- Methods to access stored values ----- #
[docs] def get_array(self, variable_name: str, period): """Return the value of ``variable_name`` for ``period``, if this value is alreay in the cache (if it has been set as an input or previously calculated). Unlike :meth:`.calculate`, this method *does not* trigger calculations and *does not* use any formula. """ if period is not None and not isinstance(period, periods.Period): period = periods.period(period) return self.get_holder(variable_name).get_array(period)
[docs] def get_holder(self, variable_name: str): """Get the holder associated with the variable.""" return self.get_variable_population(variable_name).get_holder(variable_name)
[docs] def get_memory_usage(self, variables=None): """Get data about the virtual memory usage of the simulation.""" result = {"total_nb_bytes": 0, "by_variable": {}} for entity in self.populations.values(): entity_memory_usage = entity.get_memory_usage(variables=variables) result["total_nb_bytes"] += entity_memory_usage["total_nb_bytes"] result["by_variable"].update(entity_memory_usage["by_variable"]) return result
# ----- Misc ----- # def delete_arrays(self, variable, period=None) -> None: """Delete a variable's value for a given period. :param variable: the variable to be set :param period: the period for which the value should be deleted Example: >>> from openfisca_country_template import CountryTaxBenefitSystem >>> simulation = Simulation(CountryTaxBenefitSystem()) >>> simulation.set_input("age", "2018-04", [12, 14]) >>> simulation.set_input("age", "2018-05", [13, 14]) >>> simulation.get_array("age", "2018-05") array([13, 14], dtype=int32) >>> simulation.delete_arrays("age", "2018-05") >>> simulation.get_array("age", "2018-04") array([12, 14], dtype=int32) >>> simulation.get_array("age", "2018-05") is None True >>> simulation.set_input("age", "2018-05", [13, 14]) >>> simulation.delete_arrays("age") >>> simulation.get_array("age", "2018-04") is None True >>> simulation.get_array("age", "2018-05") is None True """ self.get_holder(variable).delete_arrays(period) def get_known_periods(self, variable): """Get a list variable's known period, i.e. the periods where a value has been initialized and. :param variable: the variable to be set Example: >>> from openfisca_country_template import CountryTaxBenefitSystem >>> simulation = Simulation(CountryTaxBenefitSystem()) >>> simulation.set_input("age", "2018-04", [12, 14]) >>> simulation.set_input("age", "2018-05", [13, 14]) >>> simulation.get_known_periods("age") [Period((u'month', Instant((2018, 5, 1)), 1)), Period((u'month', Instant((2018, 4, 1)), 1))] """ return self.get_holder(variable).get_known_periods() def set_input(self, variable_name: str, period, value) -> None: """Set a variable's value for a given period. :param variable: the variable to be set :param value: the input value for the variable :param period: the period for which the value is setted Example: >>> from openfisca_country_template import CountryTaxBenefitSystem >>> simulation = Simulation(CountryTaxBenefitSystem()) >>> simulation.set_input("age", "2018-04", [12, 14]) >>> simulation.get_array("age", "2018-04") array([12, 14], dtype=int32) If a ``set_input`` property has been set for the variable, this method may accept inputs for periods not matching the ``definition_period`` of the variable. To read more about this, check the `documentation <https://openfisca.org/doc/coding-the-legislation/35_periods.html#automatically-process-variable-inputs-defined-for-periods-not-matching-the-definitionperiod>`_. """ variable: Variable | None variable = self.tax_benefit_system.get_variable( variable_name, check_existence=True, ) if variable is None: raise errors.VariableNotFoundError(variable_name, self.tax_benefit_system) period = periods.period(period) if (variable.end is not None) and (period.start.date > variable.end): return self.get_holder(variable_name).set_input(period, value) def get_variable_population(self, variable_name: str) -> Population: variable: Variable | None variable = self.tax_benefit_system.get_variable( variable_name, check_existence=True, ) if variable is None: raise errors.VariableNotFoundError(variable_name, self.tax_benefit_system) return self.populations[variable.entity.key] def get_population(self, plural: str | None = None) -> Population | None: return next( ( population for population in self.populations.values() if population.entity.plural == plural ), None, ) def get_entity( self, plural: str | None = None, ) -> Population | None: population = self.get_population(plural) return population and population.entity def describe_entities(self): return { population.entity.plural: population.ids for population in self.populations.values() } def clone(self, debug=False, trace=False): """Copy the simulation just enough to be able to run the copy without modifying the original simulation.""" new = commons.empty_clone(self) new_dict = new.__dict__ for key, value in self.__dict__.items(): if key not in ("debug", "trace", "tracer"): new_dict[key] = value new.persons = self.persons.clone(new) setattr(new, new.persons.entity.key, new.persons) new.populations = {new.persons.entity.key: new.persons} for entity in self.tax_benefit_system.group_entities: population = self.populations[entity.key].clone(new) new.populations[entity.key] = population setattr( new, entity.key, population, ) # create shortcut simulation.household (for instance) new.debug = debug new.trace = trace return new
class Cache(NamedTuple): variable: str period: periods.Period