Source code for openfisca_core.indexed_enums.enum_array
from __future__ import annotations
from typing import NoReturn
from typing_extensions import Self
import numpy
from . import types as t
[docs]
class EnumArray(t.EnumArray):
"""A subclass of :class:`~numpy.ndarray` of :class:`.Enum`.
:class:`.Enum` arrays are encoded as :class:`int` to improve performance.
Note:
Subclassing :class:`~numpy.ndarray` is a little tricky™. To read more
about the :meth:`.__new__` and :meth:`.__array_finalize__` methods
below, see `Subclassing ndarray`_.
Examples:
>>> import numpy
>>> from openfisca_core import indexed_enums as enum, variables
>>> class Housing(enum.Enum):
... OWNER = "Owner"
... TENANT = "Tenant"
... FREE_LODGER = "Free lodger"
... HOMELESS = "Homeless"
>>> array = numpy.array([1], dtype=numpy.int16)
>>> enum_array = enum.EnumArray(array, Housing)
>>> repr(enum.EnumArray)
"<class 'openfisca_core.indexed_enums.enum_array.EnumArray'>"
>>> repr(enum_array)
'EnumArray([Housing.TENANT])'
>>> str(enum_array)
"['TENANT']"
>>> list(map(int, enum_array))
[1]
>>> int(enum_array[0])
1
>>> enum_array[0] in enum_array
True
>>> len(enum_array)
1
>>> enum_array = enum.EnumArray(list(Housing), Housing)
Traceback (most recent call last):
AttributeError: 'list' object has no attribute 'view'
>>> class OccupancyStatus(variables.Variable):
... value_type = enum.Enum
... possible_values = Housing
>>> enum.EnumArray(array, OccupancyStatus.possible_values)
EnumArray([Housing.TENANT])
.. _Subclassing ndarray:
https://numpy.org/doc/stable/user/basics.subclassing.html
"""
#: Enum type of the array items.
possible_values: None | type[t.Enum]
[docs]
def __new__(
cls,
input_array: t.IndexArray,
possible_values: type[t.Enum],
) -> Self:
"""See comment above."""
obj = input_array.view(cls)
obj.possible_values = possible_values
return obj
[docs]
def __array_finalize__(self, obj: None | t.EnumArray | t.VarArray) -> None:
"""See comment above."""
if obj is None:
return
self.possible_values = getattr(obj, "possible_values", None)
[docs]
def __eq__(self, other: object) -> t.BoolArray: # type: ignore[override]
"""Compare equality with the item's :attr:`~.Enum.index`.
When comparing to an item of :attr:`.possible_values`, use the
item's :attr:`~.Enum.index`. to speed up the comparison.
Whenever possible, use :any:`numpy.ndarray.view` so that the result is
a classic :class:`~numpy.ndarray`, not an :obj:`.EnumArray`.
Args:
other: Another :class:`object` to compare to.
Returns:
bool: When ???
ndarray[bool_]: When ???
Examples:
>>> import numpy
>>> from openfisca_core import indexed_enums as enum
>>> class Housing(enum.Enum):
... OWNER = "Owner"
... TENANT = "Tenant"
>>> array = numpy.array([1])
>>> enum_array = enum.EnumArray(array, Housing)
>>> enum_array == Housing
array([False, True])
>>> enum_array == Housing.TENANT
array([ True])
>>> enum_array == 1
array([ True])
>>> enum_array == [1]
array([ True])
>>> enum_array == [2]
array([False])
>>> enum_array == "1"
array([False])
>>> enum_array is None
False
>>> enum_array == enum.EnumArray(numpy.array([1]), Housing)
array([ True])
Note:
This breaks the `Liskov substitution principle`_.
.. _Liskov substitution principle:
https://en.wikipedia.org/wiki/Liskov_substitution_principle
"""
result: t.BoolArray
if self.possible_values is None:
return NotImplemented
if other is None:
return NotImplemented
if (
isinstance(other, type(t.Enum))
and other.__name__ is self.possible_values.__name__
):
result = (
self.view(numpy.ndarray)
== self.possible_values.indices[
self.possible_values.indices <= max(self)
]
)
return result
if (
isinstance(other, t.Enum)
and other.__class__.__name__ is self.possible_values.__name__
):
result = self.view(numpy.ndarray) == other.index
return result
# For NumPy >=1.26.x.
if isinstance(is_equal := self.view(numpy.ndarray) == other, numpy.ndarray):
return is_equal
# For NumPy <1.26.x.
return numpy.array([is_equal], dtype=t.BoolDType)
[docs]
def __ne__(self, other: object) -> t.BoolArray: # type: ignore[override]
"""Inequality.
Args:
other: Another :class:`object` to compare to.
Returns:
bool: When ???
ndarray[bool_]: When ???
Examples:
>>> import numpy
>>> from openfisca_core import indexed_enums as enum
>>> class Housing(enum.Enum):
... OWNER = "Owner"
... TENANT = "Tenant"
>>> array = numpy.array([1])
>>> enum_array = enum.EnumArray(array, Housing)
>>> enum_array != Housing
array([ True, False])
>>> enum_array != Housing.TENANT
array([False])
>>> enum_array != 1
array([False])
>>> enum_array != [1]
array([False])
>>> enum_array != [2]
array([ True])
>>> enum_array != "1"
array([ True])
>>> enum_array is not None
True
Note:
This breaks the `Liskov substitution principle`_.
.. _Liskov substitution principle:
https://en.wikipedia.org/wiki/Liskov_substitution_principle
"""
return numpy.logical_not(self == other)
@staticmethod
def _forbidden_operation(*__args: object, **__kwds: object) -> NoReturn:
msg = (
"Forbidden operation. The only operations allowed on EnumArrays "
"are '==' and '!='."
)
raise TypeError(msg)
__add__ = _forbidden_operation
__mul__ = _forbidden_operation
__lt__ = _forbidden_operation
__le__ = _forbidden_operation
__gt__ = _forbidden_operation
__ge__ = _forbidden_operation
__and__ = _forbidden_operation
__or__ = _forbidden_operation
[docs]
def decode(self) -> t.ObjArray:
"""Decode itself to a normal array.
Returns:
ndarray[Enum]: The items of the :obj:`.EnumArray`.
Raises:
TypeError: When the :attr:`.possible_values` is not defined.
Examples:
>>> import numpy
>>> from openfisca_core import indexed_enums as enum
>>> class Housing(enum.Enum):
... OWNER = "Owner"
... TENANT = "Tenant"
>>> array = numpy.array([1])
>>> enum_array = enum.EnumArray(array, Housing)
>>> enum_array.decode()
array([Housing.TENANT], dtype=object)
"""
result: t.ObjArray
if self.possible_values is None:
msg = (
f"The possible values of the {self.__class__.__name__} are "
f"not defined."
)
raise TypeError(msg)
array = self.reshape(1).astype(t.EnumDType) if self.ndim == 0 else self
result = self.possible_values.enums[array]
return result
[docs]
def decode_to_str(self) -> t.StrArray:
"""Decode itself to an array of strings.
Returns:
ndarray[str_]: The string values of the :obj:`.EnumArray`.
Raises:
TypeError: When the :attr:`.possible_values` is not defined.
Examples:
>>> import numpy
>>> from openfisca_core import indexed_enums as enum
>>> class Housing(enum.Enum):
... OWNER = "Owner"
... TENANT = "Tenant"
>>> array = numpy.array([1])
>>> enum_array = enum.EnumArray(array, Housing)
>>> enum_array.decode_to_str()
array(['TENANT'], dtype='<U6')
"""
result: t.StrArray
if self.possible_values is None:
msg = (
f"The possible values of the {self.__class__.__name__} are "
f"not defined."
)
raise TypeError(msg)
array = self.reshape(1).astype(t.EnumDType) if self.ndim == 0 else self
result = self.possible_values.names[array]
return result
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.decode()!s})"
def __str__(self) -> str:
return str(self.decode_to_str())
__all__ = ["EnumArray"]