Source code for openfisca_core.simulations.simulation_builder

import copy
import dpath.util
import typing

import numpy

from openfisca_core import periods
from openfisca_core.entities import Entity
from openfisca_core.errors import PeriodMismatchError, SituationParsingError, VariableNotFoundError
from openfisca_core.populations import Population
from openfisca_core.simulations import helpers, Simulation
from openfisca_core.variables import Variable


[docs]class SimulationBuilder: def __init__(self): self.default_period = None # Simulation period used for variables when no period is defined self.persons_plural = None # Plural name for person entity in current tax and benefits system # JSON input - Memory of known input values. Indexed by variable or axis name. self.input_buffer: typing.Dict[Variable.name, typing.Dict[str(periods.period), numpy.array]] = {} self.populations: typing.Dict[Entity.key, Population] = {} # JSON input - Number of items of each entity type. Indexed by entities plural names. Should be consistent with ``entity_ids``, including axes. self.entity_counts: typing.Dict[Entity.plural, int] = {} # JSON input - typing.List of items of each entity type. Indexed by entities plural names. Should be consistent with ``entity_counts``. self.entity_ids: typing.Dict[Entity.plural, typing.List[int]] = {} # Links entities with persons. For each person index in persons ids list, set entity index in entity ids id. E.g.: self.memberships[entity.plural][person_index] = entity_ids.index(instance_id) self.memberships: typing.Dict[Entity.plural, typing.List[int]] = {} self.roles: typing.Dict[Entity.plural, typing.List[int]] = {} self.variable_entities: typing.Dict[Variable.name, Entity] = {} self.axes = [[]] self.axes_entity_counts: typing.Dict[Entity.plural, int] = {} self.axes_entity_ids: typing.Dict[Entity.plural, typing.List[int]] = {} self.axes_memberships: typing.Dict[Entity.plural, typing.List[int]] = {} self.axes_roles: typing.Dict[Entity.plural, typing.List[int]] = {}
[docs] def build_from_dict(self, tax_benefit_system, input_dict): """ Build a simulation from ``input_dict`` This method uses :any:`build_from_entities` if entities are fully specified, or :any:`build_from_variables` if not. :param dict input_dict: A dict represeting the input of the simulation :return: A :any:`Simulation` """ input_dict = self.explicit_singular_entities(tax_benefit_system, input_dict) if any(key in tax_benefit_system.entities_plural() for key in input_dict.keys()): return self.build_from_entities(tax_benefit_system, input_dict) else: return self.build_from_variables(tax_benefit_system, input_dict)
[docs] def build_from_entities(self, tax_benefit_system, input_dict): """ Build a simulation from a Python dict ``input_dict`` fully specifying entities. Examples: >>> simulation_builder.build_from_entities({ 'persons': {'Javier': { 'salary': {'2018-11': 2000}}}, 'households': {'household': {'parents': ['Javier']}} }) """ input_dict = copy.deepcopy(input_dict) simulation = Simulation(tax_benefit_system, tax_benefit_system.instantiate_entities()) # Register variables so get_variable_entity can find them for (variable_name, _variable) in tax_benefit_system.variables.items(): self.register_variable(variable_name, simulation.get_variable_population(variable_name).entity) helpers.check_type(input_dict, dict, ['error']) axes = input_dict.pop('axes', None) unexpected_entities = [entity for entity in input_dict if entity not in tax_benefit_system.entities_plural()] if unexpected_entities: unexpected_entity = unexpected_entities[0] raise SituationParsingError([unexpected_entity], ''.join([ "Some entities in the situation are not defined in the loaded tax and benefit system.", "These entities are not found: {0}.", "The defined entities are: {1}."] ) .format( ', '.join(unexpected_entities), ', '.join(tax_benefit_system.entities_plural()) ) ) persons_json = input_dict.get(tax_benefit_system.person_entity.plural, None) if not persons_json: raise SituationParsingError([tax_benefit_system.person_entity.plural], 'No {0} found. At least one {0} must be defined to run a simulation.'.format(tax_benefit_system.person_entity.key)) persons_ids = self.add_person_entity(simulation.persons.entity, persons_json) for entity_class in tax_benefit_system.group_entities: instances_json = input_dict.get(entity_class.plural) if instances_json is not None: self.add_group_entity(self.persons_plural, persons_ids, entity_class, instances_json) else: self.add_default_group_entity(persons_ids, entity_class) if axes: self.axes = axes self.expand_axes() try: self.finalize_variables_init(simulation.persons) except PeriodMismatchError as e: self.raise_period_mismatch(simulation.persons.entity, persons_json, e) for entity_class in tax_benefit_system.group_entities: try: population = simulation.populations[entity_class.key] self.finalize_variables_init(population) except PeriodMismatchError as e: self.raise_period_mismatch(population.entity, instances_json, e) return simulation
[docs] def build_from_variables(self, tax_benefit_system, input_dict): """ Build a simulation from a Python dict ``input_dict`` describing variables values without expliciting entities. This method uses :any:`build_default_simulation` to infer an entity structure Example: >>> simulation_builder.build_from_variables( {'salary': {'2016-10': 12000}} ) """ count = helpers._get_person_count(input_dict) simulation = self.build_default_simulation(tax_benefit_system, count) for variable, value in input_dict.items(): if not isinstance(value, dict): if self.default_period is None: raise SituationParsingError([variable], "Can't deal with type: expected object. Input variables should be set for specific periods. For instance: {'salary': {'2017-01': 2000, '2017-02': 2500}}, or {'birth_date': {'ETERNITY': '1980-01-01'}}.") simulation.set_input(variable, self.default_period, value) else: for period_str, dated_value in value.items(): simulation.set_input(variable, period_str, dated_value) return simulation
[docs] def build_default_simulation(self, tax_benefit_system, count = 1): """ Build a simulation where: - There are ``count`` persons - There are ``count`` instances of each group entity, containing one person - Every person has, in each entity, the first role """ simulation = Simulation(tax_benefit_system, tax_benefit_system.instantiate_entities()) for population in simulation.populations.values(): population.count = count population.ids = numpy.array(range(count)) if not population.entity.is_person: population.members_entity_id = population.ids # Each person is its own group entity return simulation
def create_entities(self, tax_benefit_system): self.populations = tax_benefit_system.instantiate_entities() def declare_person_entity(self, person_singular, persons_ids: typing.Iterable): person_instance = self.populations[person_singular] person_instance.ids = numpy.array(list(persons_ids)) person_instance.count = len(person_instance.ids) self.persons_plural = person_instance.entity.plural def declare_entity(self, entity_singular, entity_ids: typing.Iterable): entity_instance = self.populations[entity_singular] entity_instance.ids = numpy.array(list(entity_ids)) entity_instance.count = len(entity_instance.ids) return entity_instance def nb_persons(self, entity_singular, role = None): return self.populations[entity_singular].nb_persons(role = role) def join_with_persons(self, group_population, persons_group_assignment, roles: typing.Iterable[str]): # Maps group's identifiers to a 0-based integer range, for indexing into members_roles (see PR#876) group_sorted_indices = numpy.unique(persons_group_assignment, return_inverse = True)[1] group_population.members_entity_id = numpy.argsort(group_population.ids)[group_sorted_indices] flattened_roles = group_population.entity.flattened_roles roles_array = numpy.array(roles) if numpy.issubdtype(roles_array.dtype, numpy.integer): group_population.members_role = numpy.array(flattened_roles)[roles_array] else: if len(flattened_roles) == 0: group_population.members_role = numpy.int64(0) else: group_population.members_role = numpy.select([roles_array == role.key for role in flattened_roles], flattened_roles) def build(self, tax_benefit_system): return Simulation(tax_benefit_system, self.populations)
[docs] def explicit_singular_entities(self, tax_benefit_system, input_dict): """ Preprocess ``input_dict`` to explicit entities defined using the single-entity shortcut Example: >>> simulation_builder.explicit_singular_entities( {'persons': {'Javier': {}, }, 'household': {'parents': ['Javier']}} ) >>> {'persons': {'Javier': {}}, 'households': {'household': {'parents': ['Javier']}} """ singular_keys = set(input_dict).intersection(tax_benefit_system.entities_by_singular()) if not singular_keys: return input_dict result = { entity_id: entity_description for (entity_id, entity_description) in input_dict.items() if entity_id in tax_benefit_system.entities_plural() } # filter out the singular entities for singular in singular_keys: plural = tax_benefit_system.entities_by_singular()[singular].plural result[plural] = {singular: input_dict[singular]} return result
def add_person_entity(self, entity, instances_json): """ Add the simulation's instances of the persons entity as described in ``instances_json``. """ helpers.check_type(instances_json, dict, [entity.plural]) entity_ids = list(map(str, instances_json.keys())) self.persons_plural = entity.plural self.entity_ids[self.persons_plural] = entity_ids self.entity_counts[self.persons_plural] = len(entity_ids) for instance_id, instance_object in instances_json.items(): helpers.check_type(instance_object, dict, [entity.plural, instance_id]) self.init_variable_values(entity, instance_object, str(instance_id)) return self.get_ids(entity.plural) def add_default_group_entity(self, persons_ids, entity): persons_count = len(persons_ids) self.entity_ids[entity.plural] = persons_ids self.entity_counts[entity.plural] = persons_count self.memberships[entity.plural] = numpy.arange(0, persons_count, dtype = numpy.int32) self.roles[entity.plural] = numpy.repeat(entity.flattened_roles[0], persons_count) def add_group_entity(self, persons_plural, persons_ids, entity, instances_json): """ Add all instances of one of the model's entities as described in ``instances_json``. """ helpers.check_type(instances_json, dict, [entity.plural]) entity_ids = list(map(str, instances_json.keys())) self.entity_ids[entity.plural] = entity_ids self.entity_counts[entity.plural] = len(entity_ids) persons_count = len(persons_ids) persons_to_allocate = set(persons_ids) self.memberships[entity.plural] = numpy.empty(persons_count, dtype = numpy.int32) self.roles[entity.plural] = numpy.empty(persons_count, dtype = object) self.entity_ids[entity.plural] = entity_ids self.entity_counts[entity.plural] = len(entity_ids) for instance_id, instance_object in instances_json.items(): helpers.check_type(instance_object, dict, [entity.plural, instance_id]) variables_json = instance_object.copy() # Don't mutate function input roles_json = { role.plural or role.key: helpers.transform_to_strict_syntax(variables_json.pop(role.plural or role.key, [])) for role in entity.roles } for role_id, role_definition in roles_json.items(): helpers.check_type(role_definition, list, [entity.plural, instance_id, role_id]) for index, person_id in enumerate(role_definition): entity_plural = entity.plural self.check_persons_to_allocate(persons_plural, entity_plural, persons_ids, person_id, instance_id, role_id, persons_to_allocate, index) persons_to_allocate.discard(person_id) entity_index = entity_ids.index(instance_id) role_by_plural = {role.plural or role.key: role for role in entity.roles} for role_plural, persons_with_role in roles_json.items(): role = role_by_plural[role_plural] if role.max is not None and len(persons_with_role) > role.max: raise SituationParsingError([entity.plural, instance_id, role_plural], f"There can be at most {role.max} {role_plural} in a {entity.key}. {len(persons_with_role)} were declared in '{instance_id}'.") for index_within_role, person_id in enumerate(persons_with_role): person_index = persons_ids.index(person_id) self.memberships[entity.plural][person_index] = entity_index person_role = role.subroles[index_within_role] if role.subroles else role self.roles[entity.plural][person_index] = person_role self.init_variable_values(entity, variables_json, instance_id) if persons_to_allocate: entity_ids = entity_ids + list(persons_to_allocate) for person_id in persons_to_allocate: person_index = persons_ids.index(person_id) self.memberships[entity.plural][person_index] = entity_ids.index(person_id) self.roles[entity.plural][person_index] = entity.flattened_roles[0] # Adjust previously computed ids and counts self.entity_ids[entity.plural] = entity_ids self.entity_counts[entity.plural] = len(entity_ids) # Convert back to Python array self.roles[entity.plural] = self.roles[entity.plural].tolist() self.memberships[entity.plural] = self.memberships[entity.plural].tolist() def set_default_period(self, period_str): if period_str: self.default_period = str(periods.period(period_str)) def get_input(self, variable, period_str): if variable not in self.input_buffer: self.input_buffer[variable] = {} return self.input_buffer[variable].get(period_str) def check_persons_to_allocate(self, persons_plural, entity_plural, persons_ids, person_id, entity_id, role_id, persons_to_allocate, index): helpers.check_type(person_id, str, [entity_plural, entity_id, role_id, str(index)]) if person_id not in persons_ids: raise SituationParsingError([entity_plural, entity_id, role_id], "Unexpected value: {0}. {0} has been declared in {1} {2}, but has not been declared in {3}.".format( person_id, entity_id, role_id, persons_plural) ) if person_id not in persons_to_allocate: raise SituationParsingError([entity_plural, entity_id, role_id], "{} has been declared more than once in {}".format( person_id, entity_plural) ) def init_variable_values(self, entity, instance_object, instance_id): for variable_name, variable_values in instance_object.items(): path_in_json = [entity.plural, instance_id, variable_name] try: entity.check_variable_defined_for_entity(variable_name) except ValueError as e: # The variable is defined for another entity raise SituationParsingError(path_in_json, e.args[0]) except VariableNotFoundError as e: # The variable doesn't exist raise SituationParsingError(path_in_json, str(e), code = 404) instance_index = self.get_ids(entity.plural).index(instance_id) if not isinstance(variable_values, dict): if self.default_period is None: raise SituationParsingError(path_in_json, "Can't deal with type: expected object. Input variables should be set for specific periods. For instance: {'salary': {'2017-01': 2000, '2017-02': 2500}}, or {'birth_date': {'ETERNITY': '1980-01-01'}}.") variable_values = {self.default_period: variable_values} for period_str, value in variable_values.items(): try: periods.period(period_str) except ValueError as e: raise SituationParsingError(path_in_json, e.args[0]) variable = entity.get_variable(variable_name) self.add_variable_value(entity, variable, instance_index, instance_id, period_str, value) def add_variable_value(self, entity, variable, instance_index, instance_id, period_str, value): path_in_json = [entity.plural, instance_id, variable.name, period_str] if value is None: return array = self.get_input(variable.name, str(period_str)) if array is None: array_size = self.get_count(entity.plural) array = variable.default_array(array_size) try: value = variable.check_set_value(value) except ValueError as error: raise SituationParsingError(path_in_json, *error.args) array[instance_index] = value self.input_buffer[variable.name][str(periods.period(period_str))] = array def finalize_variables_init(self, population): # Due to set_input mechanism, we must bufferize all inputs, then actually set them, # so that the months are set first and the years last. plural_key = population.entity.plural if plural_key in self.entity_counts: population.count = self.get_count(plural_key) population.ids = self.get_ids(plural_key) if plural_key in self.memberships: population.members_entity_id = numpy.array(self.get_memberships(plural_key)) population.members_role = numpy.array(self.get_roles(plural_key)) for variable_name in self.input_buffer.keys(): try: holder = population.get_holder(variable_name) except ValueError: # Wrong entity, we can just ignore that continue buffer = self.input_buffer[variable_name] unsorted_periods = [periods.period(period_str) for period_str in self.input_buffer[variable_name].keys()] # We need to handle small periods first for set_input to work sorted_periods = sorted(unsorted_periods, key = periods.key_period_size) for period_value in sorted_periods: values = buffer[str(period_value)] # Hack to replicate the values in the persons entity # when we have an axis along a group entity but not persons array = numpy.tile(values, population.count // len(values)) variable = holder.variable # TODO - this duplicates the check in Simulation.set_input, but # fixing that requires improving Simulation's handling of entities if (variable.end is None) or (period_value.start.date <= variable.end): holder.set_input(period_value, array) def raise_period_mismatch(self, entity, json, e): # This error happens when we try to set a variable value for a period that doesn't match its definition period # It is only raised when we consume the buffer. We thus don't know which exact key caused the error. # We do a basic research to find the culprit path culprit_path = next( dpath.util.search(json, "*/{}/{}".format(e.variable_name, str(e.period)), yielded = True), None) if culprit_path: path = [entity.plural] + culprit_path[0].split('/') else: path = [entity.plural] # Fallback: if we can't find the culprit, just set the error at the entities level raise SituationParsingError(path, e.message) # Returns the total number of instances of this entity, including when there is replication along axes def get_count(self, entity_name): return self.axes_entity_counts.get(entity_name, self.entity_counts[entity_name]) # Returns the ids of instances of this entity, including when there is replication along axes def get_ids(self, entity_name): return self.axes_entity_ids.get(entity_name, self.entity_ids[entity_name]) # Returns the memberships of individuals in this entity, including when there is replication along axes def get_memberships(self, entity_name): # Return empty array for the "persons" entity return self.axes_memberships.get(entity_name, self.memberships.get(entity_name, [])) # Returns the roles of individuals in this entity, including when there is replication along axes def get_roles(self, entity_name): # Return empty array for the "persons" entity return self.axes_roles.get(entity_name, self.roles.get(entity_name, [])) def add_parallel_axis(self, axis): # All parallel axes have the same count and entity. # Search for a compatible axis, if none exists, error out self.axes[0].append(axis) def add_perpendicular_axis(self, axis): # This adds an axis perpendicular to all previous dimensions self.axes.append([axis]) def expand_axes(self): # This method should be idempotent & allow change in axes perpendicular_dimensions = self.axes cell_count = 1 for parallel_axes in perpendicular_dimensions: first_axis = parallel_axes[0] axis_count = first_axis['count'] cell_count *= axis_count # Scale the "prototype" situation, repeating it cell_count times for entity_name in self.entity_counts.keys(): # Adjust counts self.axes_entity_counts[entity_name] = self.get_count(entity_name) * cell_count # Adjust ids original_ids = self.get_ids(entity_name) * cell_count indices = numpy.arange(0, cell_count * self.entity_counts[entity_name]) adjusted_ids = [id + str(ix) for id, ix in zip(original_ids, indices)] self.axes_entity_ids[entity_name] = adjusted_ids # Adjust roles original_roles = self.get_roles(entity_name) adjusted_roles = original_roles * cell_count self.axes_roles[entity_name] = adjusted_roles # Adjust memberships, for group entities only if entity_name != self.persons_plural: original_memberships = self.get_memberships(entity_name) repeated_memberships = original_memberships * cell_count indices = numpy.repeat(numpy.arange(0, cell_count), len(original_memberships)) * self.entity_counts[entity_name] adjusted_memberships = (numpy.array(repeated_memberships) + indices).tolist() self.axes_memberships[entity_name] = adjusted_memberships # Now generate input values along the specified axes # TODO - factor out the common logic here if len(self.axes) == 1 and len(self.axes[0]): parallel_axes = self.axes[0] first_axis = parallel_axes[0] axis_count: int = first_axis['count'] axis_entity = self.get_variable_entity(first_axis['name']) axis_entity_step_size = self.entity_counts[axis_entity.plural] # Distribute values along axes for axis in parallel_axes: axis_index = axis.get('index', 0) axis_period = axis.get('period', self.default_period) axis_name = axis['name'] variable = axis_entity.get_variable(axis_name) array = self.get_input(axis_name, str(axis_period)) if array is None: array = variable.default_array(axis_count * axis_entity_step_size) elif array.size == axis_entity_step_size: array = numpy.tile(array, axis_count) array[axis_index:: axis_entity_step_size] = numpy.linspace( axis['min'], axis['max'], num = axis_count, ) # Set input self.input_buffer[axis_name][str(axis_period)] = array else: first_axes_count: typing.List[int] = ( parallel_axes[0]["count"] for parallel_axes in self.axes ) axes_linspaces = [ numpy.linspace(0, axis_count - 1, num = axis_count) for axis_count in first_axes_count ] axes_meshes = numpy.meshgrid(*axes_linspaces) for parallel_axes, mesh in zip(self.axes, axes_meshes): first_axis = parallel_axes[0] axis_count = first_axis['count'] axis_entity = self.get_variable_entity(first_axis['name']) axis_entity_step_size = self.entity_counts[axis_entity.plural] # Distribute values along the grid for axis in parallel_axes: axis_index = axis.get('index', 0) axis_period = axis['period'] or self.default_period axis_name = axis['name'] variable = axis_entity.get_variable(axis_name) array = self.get_input(axis_name, str(axis_period)) if array is None: array = variable.default_array(cell_count * axis_entity_step_size) elif array.size == axis_entity_step_size: array = numpy.tile(array, cell_count) array[axis_index:: axis_entity_step_size] = axis['min'] \ + mesh.reshape(cell_count) * (axis['max'] - axis['min']) / (axis_count - 1) self.input_buffer[axis_name][str(axis_period)] = array def get_variable_entity(self, variable_name): return self.variable_entities[variable_name] def register_variable(self, variable_name, entity): self.variable_entities[variable_name] = entity