Source code for openfisca_core.simulations.simulation

from __future__ import annotations

from typing import Dict, NamedTuple, Optional, Set

import tempfile
import warnings

import numpy

from openfisca_core import commons, periods
from openfisca_core.errors import CycleError, SpiralError, VariableNotFoundError
from openfisca_core.indexed_enums import Enum, EnumArray
from openfisca_core.periods import Period
from openfisca_core.tracers import FullTracer, SimpleTracer, TracingParameterNodeAtInstant
from openfisca_core.types import Population, TaxBenefitSystem, Variable
from openfisca_core.warnings import TempfileWarning


[docs]class Simulation: """ Represents a simulation, and handles the calculation logic """ tax_benefit_system: TaxBenefitSystem populations: Dict[str, Population] invalidated_caches: Set[Cache] def __init__( self, tax_benefit_system: TaxBenefitSystem, populations: Dict[str, Population], ): """ This constructor is reserved for internal use; see :any:`SimulationBuilder`, which is the preferred way to obtain a Simulation initialized with a consistent set of Entities. """ self.tax_benefit_system = tax_benefit_system assert tax_benefit_system is not None self.populations = populations self.persons = self.populations[tax_benefit_system.person_entity.key] self.link_to_entities_instances() self.create_shortcuts() self.invalidated_caches = set() self.debug = False self.trace = False self.tracer = SimpleTracer() self.opt_out_cache = False # controls the spirals detection; check for performance impact if > 1 self.max_spiral_loops: int = 1 self.memory_config = None self._data_storage_dir = None @property def trace(self): return self._trace @trace.setter def trace(self, trace): self._trace = trace if trace: self.tracer = FullTracer() else: self.tracer = SimpleTracer() def link_to_entities_instances(self): for _key, entity_instance in self.populations.items(): entity_instance.simulation = self def create_shortcuts(self): for _key, population in self.populations.items(): # create shortcut simulation.person and simulation.household (for instance) setattr(self, population.entity.key, population) @property def data_storage_dir(self): """ Temporary folder used to store intermediate calculation data in case the memory is saturated """ if self._data_storage_dir is None: self._data_storage_dir = tempfile.mkdtemp(prefix = "openfisca_") message = [ ("Intermediate results will be stored on disk in {} in case of memory overflow.").format(self._data_storage_dir), "You should remove this directory once you're done with your simulation." ] warnings.warn(" ".join(message), TempfileWarning) return self._data_storage_dir # ----- Calculation methods ----- #
[docs] def calculate(self, variable_name: str, period): """Calculate ``variable_name`` for ``period``.""" if period is not None and not isinstance(period, Period): period = periods.period(period) self.tracer.record_calculation_start(variable_name, period) try: result = self._calculate(variable_name, period) self.tracer.record_calculation_result(result) return result finally: self.tracer.record_calculation_end() self.purge_cache_of_invalid_values()
def _calculate(self, variable_name: str, period: Period): """ Calculate the variable ``variable_name`` for the period ``period``, using the variable formula if it exists. :returns: A numpy array containing the result of the calculation """ variable: Optional[Variable] population = self.get_variable_population(variable_name) holder = population.get_holder(variable_name) variable = self.tax_benefit_system.get_variable(variable_name, check_existence = True) if variable is None: raise VariableNotFoundError(variable_name, self.tax_benefit_system) self._check_period_consistency(period, variable) # First look for a value already cached cached_array = holder.get_array(period) if cached_array is not None: return cached_array array = None # First, try to run a formula try: self._check_for_cycle(variable.name, period) array = self._run_formula(variable, population, period) # If no result, use the default value and cache it if array is None: array = holder.default_array() array = self._cast_formula_result(array, variable) holder.put_in_cache(array, period) except SpiralError: array = holder.default_array() return array def purge_cache_of_invalid_values(self): # We wait for the end of calculate(), signalled by an empty stack, before purging the cache if self.tracer.stack: return for (_name, _period) in self.invalidated_caches: holder = self.get_holder(_name) holder.delete_arrays(_period) self.invalidated_caches = set() def calculate_add(self, variable_name: str, period): variable: Optional[Variable] variable = self.tax_benefit_system.get_variable(variable_name, check_existence = True) if variable is None: raise VariableNotFoundError(variable_name, self.tax_benefit_system) if period is not None and not isinstance(period, Period): period = periods.period(period) # Check that the requested period matches definition_period if periods.unit_weight(variable.definition_period) > periods.unit_weight(period.unit): raise ValueError("Unable to compute variable '{0}' for period {1}: '{0}' can only be computed for {2}-long periods. You can use the DIVIDE option to get an estimate of {0} by dividing the yearly value by 12, or change the requested period to 'period.this_year'.".format( variable.name, period, variable.definition_period )) if variable.definition_period not in [periods.DAY, periods.MONTH, periods.YEAR]: raise ValueError("Unable to sum constant variable '{}' over period {}: only variables defined daily, monthly, or yearly can be summed over time.".format( variable.name, period)) return sum( self.calculate(variable_name, sub_period) for sub_period in period.get_subperiods(variable.definition_period) ) def calculate_divide(self, variable_name: str, period): variable: Optional[Variable] variable = self.tax_benefit_system.get_variable(variable_name, check_existence = True) if variable is None: raise VariableNotFoundError(variable_name, self.tax_benefit_system) if period is not None and not isinstance(period, Period): period = periods.period(period) # Check that the requested period matches definition_period if variable.definition_period != periods.YEAR: raise ValueError("Unable to divide the value of '{}' over time on period {}: only variables defined yearly can be divided over time.".format( variable_name, period)) if period.size != 1: raise ValueError("DIVIDE option can only be used for a one-year or a one-month requested period") if period.unit == periods.MONTH: computation_period = period.this_year return self.calculate(variable_name, period = computation_period) / 12. elif period.unit == periods.YEAR: return self.calculate(variable_name, period) raise ValueError("Unable to divide the value of '{}' to match period {}.".format( variable_name, period)) def calculate_output(self, variable_name: str, period): """ Calculate the value of a variable using the ``calculate_output`` attribute of the variable. """ variable: Optional[Variable] variable = self.tax_benefit_system.get_variable(variable_name, check_existence = True) if variable is None: raise VariableNotFoundError(variable_name, self.tax_benefit_system) if variable.calculate_output is None: return self.calculate(variable_name, period) return variable.calculate_output(self, variable_name, period) def trace_parameters_at_instant(self, formula_period): return TracingParameterNodeAtInstant( self.tax_benefit_system.get_parameters_at_instant(formula_period), self.tracer ) def _run_formula(self, variable, population, period): """ Find the ``variable`` formula for the given ``period`` if it exists, and apply it to ``population``. """ formula = variable.get_formula(period) if formula is None: return None if self.trace: parameters_at = self.trace_parameters_at_instant else: parameters_at = self.tax_benefit_system.get_parameters_at_instant if formula.__code__.co_argcount == 2: array = formula(population, period) else: array = formula(population, period, parameters_at) return array def _check_period_consistency(self, period, variable): """ Check that a period matches the variable definition_period """ if variable.definition_period == periods.ETERNITY: return # For variables which values are constant in time, all periods are accepted if variable.definition_period == periods.MONTH and period.unit != periods.MONTH: raise ValueError("Unable to compute variable '{0}' for period {1}: '{0}' must be computed for a whole month. You can use the ADD option to sum '{0}' over the requested period, or change the requested period to 'period.first_month'.".format( variable.name, period )) if variable.definition_period == periods.YEAR and period.unit != periods.YEAR: raise ValueError("Unable to compute variable '{0}' for period {1}: '{0}' must be computed for a whole year. You can use the DIVIDE option to get an estimate of {0} by dividing the yearly value by 12, or change the requested period to 'period.this_year'.".format( variable.name, period )) if period.size != 1: raise ValueError("Unable to compute variable '{0}' for period {1}: '{0}' must be computed for a whole {2}. You can use the ADD option to sum '{0}' over the requested period.".format( variable.name, period, 'month' if variable.definition_period == periods.MONTH else 'year' )) def _cast_formula_result(self, value, variable): if variable.value_type == Enum and not isinstance(value, EnumArray): return variable.possible_values.encode(value) if not isinstance(value, numpy.ndarray): population = self.get_variable_population(variable.name) value = population.filled_array(value) if value.dtype != variable.dtype: return value.astype(variable.dtype) return value # ----- Handle circular dependencies in a calculation ----- # def _check_for_cycle(self, variable: str, period): """ Raise an exception in the case of a circular definition, where evaluating a variable for a given period loops around to evaluating the same variable/period pair. Also guards, as a heuristic, against "quasicircles", where the evaluation of a variable at a period involves the same variable at a different period. """ # The last frame is the current calculation, so it should be ignored from cycle detection previous_periods = [frame['period'] for frame in self.tracer.stack[:-1] if frame['name'] == variable] if period in previous_periods: raise CycleError("Circular definition detected on formula {}@{}".format(variable, period)) spiral = len(previous_periods) >= self.max_spiral_loops if spiral: self.invalidate_spiral_variables(variable) message = "Quasicircular definition detected on formula {}@{} involving {}".format(variable, period, self.tracer.stack) raise SpiralError(message, variable) def invalidate_cache_entry(self, variable: str, period): self.invalidated_caches.add(Cache(variable, period)) def invalidate_spiral_variables(self, variable: str): # Visit the stack, from the bottom (most recent) up; we know that we'll find # the variable implicated in the spiral (max_spiral_loops+1) times; we keep the # intermediate values computed (to avoid impacting performance) but we mark them # for deletion from the cache once the calculation ends. count = 0 for frame in reversed(self.tracer.stack): self.invalidate_cache_entry(str(frame['name']), frame['period']) if frame['name'] == variable: count += 1 if count > self.max_spiral_loops: break # ----- Methods to access stored values ----- #
[docs] def get_array(self, variable_name: str, period): """ Return the value of ``variable_name`` for ``period``, if this value is alreay in the cache (if it has been set as an input or previously calculated). Unlike :meth:`.calculate`, this method *does not* trigger calculations and *does not* use any formula. """ if period is not None and not isinstance(period, Period): period = periods.period(period) return self.get_holder(variable_name).get_array(period)
[docs] def get_holder(self, variable_name: str): """Get the holder associated with the variable.""" return self.get_variable_population(variable_name).get_holder(variable_name)
[docs] def get_memory_usage(self, variables = None): """ Get data about the virtual memory usage of the simulation """ result = dict( total_nb_bytes = 0, by_variable = {} ) for entity in self.populations.values(): entity_memory_usage = entity.get_memory_usage(variables = variables) result['total_nb_bytes'] += entity_memory_usage['total_nb_bytes'] result['by_variable'].update(entity_memory_usage['by_variable']) return result
# ----- Misc ----- # def delete_arrays(self, variable, period = None): """ Delete a variable's value for a given period :param variable: the variable to be set :param period: the period for which the value should be deleted Example: >>> from openfisca_country_template import CountryTaxBenefitSystem >>> simulation = Simulation(CountryTaxBenefitSystem()) >>> simulation.set_input('age', '2018-04', [12, 14]) >>> simulation.set_input('age', '2018-05', [13, 14]) >>> simulation.get_array('age', '2018-05') array([13, 14], dtype=int32) >>> simulation.delete_arrays('age', '2018-05') >>> simulation.get_array('age', '2018-04') array([12, 14], dtype=int32) >>> simulation.get_array('age', '2018-05') is None True >>> simulation.set_input('age', '2018-05', [13, 14]) >>> simulation.delete_arrays('age') >>> simulation.get_array('age', '2018-04') is None True >>> simulation.get_array('age', '2018-05') is None True """ self.get_holder(variable).delete_arrays(period) def get_known_periods(self, variable): """ Get a list variable's known period, i.e. the periods where a value has been initialized and :param variable: the variable to be set Example: >>> from openfisca_country_template import CountryTaxBenefitSystem >>> simulation = Simulation(CountryTaxBenefitSystem()) >>> simulation.set_input('age', '2018-04', [12, 14]) >>> simulation.set_input('age', '2018-05', [13, 14]) >>> simulation.get_known_periods('age') [Period((u'month', Instant((2018, 5, 1)), 1)), Period((u'month', Instant((2018, 4, 1)), 1))] """ return self.get_holder(variable).get_known_periods() def set_input(self, variable_name: str, period, value): """ Set a variable's value for a given period :param variable: the variable to be set :param value: the input value for the variable :param period: the period for which the value is setted Example: >>> from openfisca_country_template import CountryTaxBenefitSystem >>> simulation = Simulation(CountryTaxBenefitSystem()) >>> simulation.set_input('age', '2018-04', [12, 14]) >>> simulation.get_array('age', '2018-04') array([12, 14], dtype=int32) If a ``set_input`` property has been set for the variable, this method may accept inputs for periods not matching the ``definition_period`` of the variable. To read more about this, check the `documentation <https://openfisca.org/doc/coding-the-legislation/35_periods.html#automatically-process-variable-inputs-defined-for-periods-not-matching-the-definitionperiod>`_. """ variable: Optional[Variable] variable = self.tax_benefit_system.get_variable(variable_name, check_existence = True) if variable is None: raise VariableNotFoundError(variable_name, self.tax_benefit_system) period = periods.period(period) if ((variable.end is not None) and (period.start.date > variable.end)): return self.get_holder(variable_name).set_input(period, value) def get_variable_population(self, variable_name: str) -> Population: variable: Optional[Variable] variable = self.tax_benefit_system.get_variable(variable_name, check_existence = True) if variable is None: raise VariableNotFoundError(variable_name, self.tax_benefit_system) return self.populations[variable.entity.key] def get_population(self, plural: Optional[str] = None) -> Optional[Population]: return next((population for population in self.populations.values() if population.entity.plural == plural), None) def get_entity( self, plural: Optional[str] = None, ) -> Optional[Population]: population = self.get_population(plural) return population and population.entity def describe_entities(self): return {population.entity.plural: population.ids for population in self.populations.values()} def clone(self, debug = False, trace = False): """ Copy the simulation just enough to be able to run the copy without modifying the original simulation """ new = commons.empty_clone(self) new_dict = new.__dict__ for key, value in self.__dict__.items(): if key not in ('debug', 'trace', 'tracer'): new_dict[key] = value new.persons = self.persons.clone(new) setattr(new, new.persons.entity.key, new.persons) new.populations = {new.persons.entity.key: new.persons} for entity in self.tax_benefit_system.group_entities: population = self.populations[entity.key].clone(new) new.populations[entity.key] = population setattr(new, entity.key, population) # create shortcut simulation.household (for instance) new.debug = debug new.trace = trace return new
class Cache(NamedTuple): variable: str period: Period