diff --git a/pcs/common/interface/dto.py b/pcs/common/interface/dto.py index 75c92f726..7661e48a7 100644 --- a/pcs/common/interface/dto.py +++ b/pcs/common/interface/dto.py @@ -1,18 +1,6 @@ -from dataclasses import asdict, fields, is_dataclass -from enum import Enum, EnumType -from types import NoneType, UnionType -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Iterable, - NewType, - TypeVar, - Union, - get_type_hints, -) -from typing import get_args as get_type_args -from typing import get_origin as get_type_origin +from dataclasses import asdict +from enum import Enum +from typing import TYPE_CHECKING, Any, Callable, Iterable, TypeVar, Union import dacite @@ -35,23 +23,18 @@ class DataclassInstance: PrimitiveType, DtoPayload, Iterable["SerializableType"] ] -T = TypeVar("T") -E = TypeVar("E", bound=Enum) - -ToDictMetaKey = NewType("ToDictMetaKey", str) -META_NAME = ToDictMetaKey("META_NAME") - class PayloadConversionError(Exception): pass -class _UnionNotAllowed(Exception): +class DataTransferObject(DataclassInstance): pass -class DataTransferObject(DataclassInstance): - pass +T = TypeVar("T") +E = TypeVar("E", bound=Enum) +DTOTYPE = TypeVar("DTOTYPE", bound=DataTransferObject) def _safe_enum_cast(enum_class: type[E]) -> Callable[[Any], E]: @@ -109,186 +92,12 @@ def _cast_value(value: Any) -> E: } -def meta(name: str) -> dict[str, str]: - metadata: dict[str, str] = {} - if name: - metadata[META_NAME] = name - return metadata - - -# _type is Any - in reality, it is either one of: -# * type -# * enum.EnumType -# * something defined in typing module, e.g. typing._GenericAlias, typing.Union -# Especially the typing module changes with new Python versions. -# Properly typing (rather metatyping, since its input and output are types) -# this function doesn't bring any benefits. -def _extract_type_from_optional(_type: Any) -> Any: - # Dataclass fields may be typed as 'Optional[some_type]' or - # 'Union[some_type, None]' or 'some_type | None'. This function extracts - # the inner type from an Optional, and thus allows to properly detect types - # of such dataclass fields. It raises an exception if a Union contains more - # than one type other than None, because in that case it is unclear which - # one is the correct type. However, such a field should never be defined in - # a dataclass, because field type must be unambiguous. - - # Internal representation of Union and Optional is different in Python 3.12 - # and 3.14. To be able to handle the differences, typing.get_origin is - # used. It transforms all the representations to Union or UnionType. - # https://docs.python.org/3/library/typing.html#typing.Union - _type_origin = get_type_origin(_type) - if not (_type_origin is Union or _type_origin is UnionType): - return _type - - inner_types_without_none = [ - inner_type - for inner_type in get_type_args(_type) - if inner_type is not NoneType - ] - if len(inner_types_without_none) == 1: - return inner_types_without_none[0] - raise _UnionNotAllowed() - - -# _type is Any - in reality, it is either one of: -# * type -# * enum.EnumType -# * something defined in typing module, e.g. typing._GenericAlias, typing.Union -# Especially the typing module changes with new Python versions. -# Properly typing (rather metatyping, since its input and output are types) -# this function doesn't bring any benefits. -def _is_compatible_type(_type: Any, arg_index: int) -> bool: - return ( - hasattr(_type, "__args__") - and len(_type.__args__) >= arg_index - and is_dataclass(_type.__args__[arg_index]) - ) - - -# _type is Any - in reality, it is either one of: -# * type -# * enum.EnumType -# * something defined in typing module, e.g. typing._GenericAlias, typing.Union -# Especially the typing module changes with new Python versions. -# Properly typing (rather metatyping, since its input and output are types) -# this function doesn't bring any benefits. -def _is_enum_type(_type: Any, arg_index: int) -> bool: - return ( - hasattr(_type, "__args__") - and len(_type.__args__) >= arg_index - and type(_type.__args__[arg_index]) is EnumType - ) - - -# returns Any as the type of enum value can be anything and it can be different -# for each Enum -def _convert_enum(value: Enum) -> Any: - return value.value - - -def _convert_dict( - klass: type[DataTransferObject], obj_dict: DtoPayload -) -> DtoPayload: - new_dict = {} - # resolve forward references in type hints, because type-detecting - # functions do not work with forward references - type_hints = get_type_hints(klass) - for _field in fields(klass): - try: - _type = _extract_type_from_optional(type_hints[_field.name]) - except _UnionNotAllowed as e: - raise AssertionError( - f"Field '{_field.name}' in class '{klass}' is a Union: " - f"{_field.type}. " - "Dataclass fields cannot be Unions, unless they are a Union of " - "one type and None (which is equal to Optional)." - ) from e - value = obj_dict[_field.name] - - new_value: SerializableType - if value is None: - # None must be handled here, other checks fail if they get None - new_value = value - elif is_dataclass(_type): - new_value = _convert_dict(_type, value) # type: ignore - elif isinstance(value, list) and _is_compatible_type(_type, 0): - new_value = [ - _convert_dict(_type.__args__[0], item) for item in value - ] - elif isinstance(value, list) and _is_enum_type(_type, 0): - new_value = [_convert_enum(item) for item in value] - elif isinstance(value, dict) and _is_compatible_type(_type, 1): - new_value = { - item_key: _convert_dict(_type.__args__[1], item_val) # type: ignore[arg-type] - for item_key, item_val in value.items() - } - elif isinstance(value, Enum): - new_value = _convert_enum(value) - else: - new_value = value - new_dict[_field.metadata.get(META_NAME, _field.name)] = new_value - return new_dict - - -def to_dict(obj: DataTransferObject) -> DtoPayload: - return _convert_dict(obj.__class__, asdict(obj)) - - -DTOTYPE = TypeVar("DTOTYPE", bound=DataTransferObject) - - -def _convert_payload(klass: type[DTOTYPE], data: DtoPayload) -> DtoPayload: - try: - new_dict = dict(data) - except ValueError as e: - raise PayloadConversionError() from e - # resolve forward references in type hints, because type-detecting - # functions do not work with forward references - type_hints = get_type_hints(klass) - for _field in fields(klass): - new_name = _field.metadata.get(META_NAME, _field.name) - if new_name not in data: - continue - - try: - _type = _extract_type_from_optional(type_hints[_field.name]) - except _UnionNotAllowed as e: - raise AssertionError( - f"Field '{_field.name}' in class '{klass}' is a Union: " - f"{_field.type}. " - "Dataclass fields cannot be Unions, unless they are a Union of " - "one type and None (which is equal to Optional)." - ) from e - value = data[new_name] - - new_value: SerializableType - if value is None: - # None must be handled here, other checks fail if they get None - new_value = value - elif is_dataclass(_type): - new_value = _convert_payload(_type, value) # type: ignore - elif isinstance(value, list) and _is_compatible_type(_type, 0): - new_value = [ - _convert_payload(_type.__args__[0], item) for item in value - ] - elif isinstance(value, dict) and _is_compatible_type(_type, 1): - new_value = { - item_key: _convert_payload(_type.__args__[1], item_val) # type: ignore[arg-type] - for item_key, item_val in value.items() - } - else: - new_value = value - del new_dict[new_name] - new_dict[_field.name] = new_value - return new_dict - - def from_dict( cls: type[DTOTYPE], data: DtoPayload, strict: bool = False ) -> DTOTYPE: return dacite.from_dict( data_class=cls, - data=_convert_payload(cls, data), + data=data, config=dacite.Config( type_hooks=DTO_TYPE_HOOKS_MAP, strict=strict, @@ -296,6 +105,10 @@ def from_dict( ) +def to_dict(obj: DataTransferObject) -> DtoPayload: + return asdict(obj) + + class ImplementsToDto: def to_dto(self) -> Any: raise NotImplementedError() diff --git a/pcs/common/permissions/types.py b/pcs/common/permissions/types.py index 5d726a647..8cc4f8b7c 100644 --- a/pcs/common/permissions/types.py +++ b/pcs/common/permissions/types.py @@ -1,12 +1,12 @@ -from enum import Enum +from enum import StrEnum -class PermissionTargetType(str, Enum): +class PermissionTargetType(StrEnum): USER = "user" GROUP = "group" -class PermissionGrantedType(str, Enum): +class PermissionGrantedType(StrEnum): READ = "read" WRITE = "write" GRANT = "grant" diff --git a/pcs/common/resource_agent/dto.py b/pcs/common/resource_agent/dto.py index 7f1a67587..60ea1cc51 100644 --- a/pcs/common/resource_agent/dto.py +++ b/pcs/common/resource_agent/dto.py @@ -1,16 +1,7 @@ -from dataclasses import ( - dataclass, - field, -) -from typing import ( - List, - Optional, -) +from dataclasses import dataclass +from typing import Optional -from pcs.common.interface.dto import ( - DataTransferObject, - meta, -) +from pcs.common.interface.dto import DataTransferObject @dataclass(frozen=True) @@ -30,7 +21,7 @@ def get_resource_agent_full_name(agent_name: ResourceAgentNameDto) -> str: @dataclass(frozen=True) class ListResourceAgentNameDto(DataTransferObject): - names: List[ResourceAgentNameDto] + names: list[ResourceAgentNameDto] @dataclass(frozen=True) @@ -47,7 +38,7 @@ class ResourceAgentActionDto(DataTransferObject): # not allowed by OCF 1.0, defined in OCF 1.0 agents anyway role: Optional[str] # OCF name: 'start-delay', optional by both OCF 1.0 and 1.1 - start_delay: Optional[str] = field(metadata=meta(name="start-delay")) + start_delay: Optional[str] # optional by both OCF 1.0 and 1.1 depth: Optional[str] # not allowed by any OCF, defined in OCF 1.0 agents anyway @@ -71,7 +62,7 @@ class ResourceAgentParameterDto(DataTransferObject): # default value of the parameter default: Optional[str] # allowed values, only defined if type == 'select' - enum_values: Optional[List[str]] + enum_values: Optional[list[str]] # True if it is a required parameter, False otherwise required: bool # True if the parameter is meant for advanced users @@ -79,7 +70,7 @@ class ResourceAgentParameterDto(DataTransferObject): # True if the parameter is deprecated, False otherwise deprecated: bool # list of parameters deprecating this one - deprecated_by: List[str] + deprecated_by: list[str] # text describing / explaining the deprecation deprecated_desc: Optional[str] # should the parameter's value be unique across same agent resources? @@ -93,8 +84,8 @@ class ResourceAgentMetadataDto(DataTransferObject): name: ResourceAgentNameDto shortdesc: Optional[str] longdesc: Optional[str] - parameters: List[ResourceAgentParameterDto] - actions: List[ResourceAgentActionDto] + parameters: list[ResourceAgentParameterDto] + actions: list[ResourceAgentActionDto] @dataclass(frozen=True) diff --git a/pcs/common/types.py b/pcs/common/types.py index 41f14ff32..deac999db 100644 --- a/pcs/common/types.py +++ b/pcs/common/types.py @@ -1,46 +1,21 @@ from collections.abc import Set -from enum import ( - Enum, - auto, -) -from typing import ( - Generator, - Literal, - MutableSequence, - Optional, - Type, - TypeVar, - Union, -) +from enum import StrEnum, auto +from typing import Generator, Literal, MutableSequence, Union StringSequence = Union[MutableSequence[str], tuple[str, ...]] StringCollection = Union[StringSequence, Set[str]] StringIterable = Union[StringCollection, Generator[str, None, None]] -class AutoNameEnum(str, Enum): +class AutoNameEnum(StrEnum): @staticmethod def _generate_next_value_( - name: str, - start: int, - count: int, - last_values: list[int], + name: str, start: int, count: int, last_values: list[str] ) -> str: del start, count, last_values return name -T = TypeVar("T", bound=AutoNameEnum) - - -def str_to_enum(enum_type: Type[T], value: Optional[str]) -> Optional[T]: - if value: - value = value.upper() - if value in {item.value for item in enum_type}: - return enum_type(value) - return None - - PcmkScore = Union[int, Literal["INFINITY", "+INFINITY", "-INFINITY"]] @@ -95,7 +70,7 @@ def from_str(cls, transport: str) -> "CorosyncTransportType": raise UnknownCorosyncTransportTypeException(transport) from None -class CorosyncNodeAddressType(str, Enum): +class CorosyncNodeAddressType(StrEnum): IPV4 = "IPv4" IPV6 = "IPv6" FQDN = "FQDN" diff --git a/pcs/lib/permissions/types.py b/pcs/lib/permissions/types.py index e0a6d266a..db415b219 100644 --- a/pcs/lib/permissions/types.py +++ b/pcs/lib/permissions/types.py @@ -1,7 +1,7 @@ -from enum import Enum +from enum import StrEnum -class PermissionRequiredType(str, Enum): +class PermissionRequiredType(StrEnum): NONE = "none" READ = "read" WRITE = "write" diff --git a/pcs_test/tier0/common/interface/test_dto.py b/pcs_test/tier0/common/interface/test_dto.py index 35718ca6b..5a80e5dc8 100644 --- a/pcs_test/tier0/common/interface/test_dto.py +++ b/pcs_test/tier0/common/interface/test_dto.py @@ -1,7 +1,7 @@ import importlib import pkgutil from collections.abc import Sequence -from dataclasses import dataclass, field, is_dataclass +from dataclasses import dataclass, is_dataclass from typing import Any, Optional from unittest import TestCase @@ -12,7 +12,6 @@ DataTransferObject, PayloadConversionError, from_dict, - meta, to_dict, ) from pcs.common.types import CorosyncNodeAddressType @@ -44,22 +43,21 @@ def test_has_all_subclasses_are_dataclasses(self): @dataclass class MyDto1(DataTransferObject): field_a: int - field_b: int = field(metadata=meta(name="field-b")) - field_c: int + field_b: int @dataclass class MyDto2(DataTransferObject): field_d: int - field_e: MyDto1 = field(metadata=meta(name="field-e")) + field_e: MyDto1 field_f: CorosyncNodeAddressType # tests converting an Enum class @dataclass class MyDto3(DataTransferObject): - field_g: MyDto2 = field(metadata=meta(name="field-g")) + field_g: MyDto2 field_h: list[MyDto2] - field_i: int = field(metadata=meta(name="field-i")) + field_i: int @dataclass @@ -71,41 +69,35 @@ class TypeHooksDto(DataTransferObject): class DictName(TestCase): maxDiff = None - simple_dto = MyDto1(1, 2, 3) - simple_dict = {"field_a": 1, "field-b": 2, "field_c": 3} + simple_dto = MyDto1(1, 2) + simple_dict = {"field_a": 1, "field_b": 2} nested_dto = MyDto3( - MyDto2(0, MyDto1(1, 2, 3), CorosyncNodeAddressType.IPV4), + MyDto2(0, MyDto1(1, 2), CorosyncNodeAddressType.IPV4), [ - MyDto2(5, MyDto1(6, 7, 8), CorosyncNodeAddressType.FQDN), - MyDto2( - 10, MyDto1(11, 12, 13), CorosyncNodeAddressType.UNRESOLVABLE - ), + MyDto2(3, MyDto1(4, 5), CorosyncNodeAddressType.FQDN), + MyDto2(6, MyDto1(7, 8), CorosyncNodeAddressType.UNRESOLVABLE), ], - 15, + 9, ) nested_dict = { - "field-g": { + "field_g": { "field_d": 0, - "field-e": {"field_a": 1, "field-b": 2, "field_c": 3}, + "field_e": {"field_a": 1, "field_b": 2}, "field_f": "IPv4", }, "field_h": [ { - "field_d": 5, - "field-e": {"field_a": 6, "field-b": 7, "field_c": 8}, + "field_d": 3, + "field_e": {"field_a": 4, "field_b": 5}, "field_f": "FQDN", }, { - "field_d": 10, - "field-e": { - "field_a": 11, - "field-b": 12, - "field_c": 13, - }, + "field_d": 6, + "field_e": {"field_a": 7, "field_b": 8}, "field_f": "unresolvable", }, ], - "field-i": 15, + "field_i": 9, } def test_simple_to_dict(self): @@ -195,26 +187,27 @@ class EnumDto(DataTransferObject): field_d: Optional[CorosyncNodeAddressType] -class FromDictEnumConversion(TestCase): +class EnumConversion(TestCase): + _DTO = EnumDto( + field_a=CorosyncNodeAddressType.IPV4, + field_b=[CorosyncNodeAddressType.IPV6, CorosyncNodeAddressType.FQDN], + field_c={"foo": CorosyncNodeAddressType.UNRESOLVABLE}, + field_d=CorosyncNodeAddressType.IPV4, + ) _VALID_PAYLOAD = dict( field_a="IPv4", field_b=["IPv6", "FQDN"], - field_c=dict(foo="unresolvable"), + field_c={"foo": "unresolvable"}, field_d="IPv4", ) - def test_success_from_raw_values(self): - self.assertEqual( - EnumDto( - CorosyncNodeAddressType.IPV4, - [CorosyncNodeAddressType.IPV6, CorosyncNodeAddressType.FQDN], - {"foo": CorosyncNodeAddressType.UNRESOLVABLE}, - CorosyncNodeAddressType.IPV4, - ), - from_dict(EnumDto, self._VALID_PAYLOAD), - ) + def test_success_from_dict(self): + self.assertEqual(self._DTO, from_dict(EnumDto, self._VALID_PAYLOAD)) + + def test_success_to_dict(self): + self.assertEqual(self._VALID_PAYLOAD, to_dict(self._DTO)) - def test_error_bad_value(self): + def test_from_dict_error_bad_value(self): bad_values = dict( field_a="bad value", field_b=["IPv6", "bad value"], diff --git a/pcs_test/tier0/daemon/async_tasks/test_command_mapping.py b/pcs_test/tier0/daemon/async_tasks/test_command_mapping.py index 66aa77925..f2b66f37e 100644 --- a/pcs_test/tier0/daemon/async_tasks/test_command_mapping.py +++ b/pcs_test/tier0/daemon/async_tasks/test_command_mapping.py @@ -1,18 +1,12 @@ +import importlib import inspect -from dataclasses import ( - fields, - is_dataclass, -) +import pkgutil +from dataclasses import fields, is_dataclass from enum import EnumType -from typing import ( - Container, - Iterable, - get_args, - get_origin, - get_type_hints, -) +from typing import Container, Iterable, get_args, get_origin, get_type_hints from unittest import TestCase +import pcs.lib.commands as lib_command_package from pcs.common.interface.dto import DTO_TYPE_HOOKS_MAP from pcs.daemon.async_tasks.worker.command_mapping import COMMAND_MAP @@ -39,10 +33,6 @@ def prohibited_types_used(_type, prohibited_types): return False -def _get_generic(annotation): - return getattr(annotation, "__origin__", None) - - def _find_disallowed_types(_type, allowed_types, _seen=None): if _seen is None: _seen = set() @@ -121,3 +111,80 @@ def test_check_type_hooks_map_types_in_commands(self): "and update FromDictConversion tests in " "test_dto.py accordingly.", ) + + +def _find_disallowed_return_type_enums(_type, allowed_enum_bases, _seen=None): + disallowed = set() + + if _seen is None: + _seen = set() + type_id = id(_type) + if type_id in _seen: + return disallowed + _seen.add(type_id) + + generic = get_origin(_type) + if generic is None: + if isinstance(_type, EnumType) and not any( + issubclass(_type, base) for base in allowed_enum_bases + ): + disallowed.add(_type) + else: + for arg in get_args(_type): + disallowed.update( + _find_disallowed_return_type_enums( + arg, allowed_enum_bases, _seen + ) + ) + + if is_dataclass(_type): + # resolve forward references in type hints, because type-detecting + # functions do not work with forward references + type_hints = get_type_hints(_type) + for field in fields(_type): + disallowed.update( + _find_disallowed_return_type_enums( + type_hints[field.name], allowed_enum_bases, _seen + ) + ) + + return disallowed + + +class ReturnTypeCompatibilityTest(TestCase): + def test_return_value_enums(self): + allowed_enum_bases = (int, str, float) + + for _, module_name, _ in pkgutil.walk_packages( + lib_command_package.__path__, lib_command_package.__name__ + "." + ): + try: + module = importlib.import_module(module_name) + except ImportError: + continue + + for cmd_name, cmd in inspect.getmembers(module, inspect.isfunction): + if cmd_name.startswith("_"): + continue + + return_type = inspect.signature(cmd).return_annotation + if ( + return_type == inspect.Parameter.empty + or return_type is None + ): + continue + + with self.subTest(value=cmd_name): + disallowed = _find_disallowed_return_type_enums( + return_type, allowed_enum_bases + ) + if disallowed: + raise AssertionError( + f"Type(s) {disallowed} in return type of command: {cmd_name}\n" + f"All Enums must also be subclasses of {allowed_enum_bases} " + "to allow for easy serialization.\n" + "Either use 'pcs.common.types.AutoNameEnum' or make sure " + "your enum is also a subclass of one of the allowed " + "types: use enum.StrEnum for strings, or " + "MyEnum(, Enum) for other types" + )