diff --git a/AUTHORS.rst b/AUTHORS.rst index 9fc0ddbbe..4147e51fe 100644 --- a/AUTHORS.rst +++ b/AUTHORS.rst @@ -159,3 +159,4 @@ Contributors (chronological) - Stephen Rosen `@sirosen `_ - Vladimir Mikhaylov `@vemikhaylov `_ - Stephen Eaton `@madeinoz67 `_ +- Dor Meiri `@dormeiri `_ diff --git a/src/marshmallow/fields.py b/src/marshmallow/fields.py index 6b7b5e221..1022273d2 100644 --- a/src/marshmallow/fields.py +++ b/src/marshmallow/fields.py @@ -164,6 +164,7 @@ def __init__( dump_only: bool = False, error_messages: typing.Optional[typing.Dict[str, str]] = None, metadata: typing.Optional[typing.Mapping[str, typing.Any]] = None, + take_default_on=None, **additional_metadata ) -> None: self.default = default @@ -191,6 +192,7 @@ def __init__( raise ValueError("'missing' must not be set for required fields.") self.required = required self.missing = missing + self.take_default_on = tuple() if take_default_on is None else take_default_on metadata = metadata or {} self.metadata = {**metadata, **additional_metadata} @@ -218,7 +220,8 @@ def __repr__(self) -> str: "validate={self.validate}, required={self.required}, " "load_only={self.load_only}, dump_only={self.dump_only}, " "missing={self.missing}, allow_none={self.allow_none}, " - "error_messages={self.error_messages})>".format( + "error_messages={self.error_messages}, " + "take_default_on={self.take_default_on})>".format( ClassName=self.__class__.__name__, self=self ) ) @@ -323,10 +326,10 @@ def serialize( """ if self._CHECK_ATTRIBUTE: value = self.get_value(obj, attr, accessor=accessor) - if value is missing_ and hasattr(self, "default"): + if self._should_take_default(value) and hasattr(self, "default"): default = self.default value = default() if callable(default) else default - if value is missing_: + if self._should_take_default(value): return value else: value = None @@ -416,6 +419,9 @@ def _deserialize( """ return value + def _should_take_default(self, value): + return value is missing_ or value in self.take_default_on + # Properties @property diff --git a/tests/test_fields.py b/tests/test_fields.py index 4e5e4cc01..a0b62ad9e 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -36,7 +36,7 @@ def test_repr(self): "validate=None, required=False, " "load_only=False, dump_only=False, " "missing={missing}, allow_none=False, " - "error_messages={error_messages})>".format( + "error_messages={error_messages}, take_default_on=())>".format( default, missing=missing, error_messages=field.error_messages ) ) diff --git a/tests/test_schema.py b/tests/test_schema.py index 4f5b871cb..8455d5ea2 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -2748,6 +2748,51 @@ class MySchema(Schema): assert errors["allow_none_field"][0] == "" +class TestTakeDefaultOn: + class MySchema(Schema): + int_take_default_on_none = fields.Int(default=42, take_default_on=[None]) + int_take_default_on_value = fields.Int(default=42, take_default_on=[0, 1]) + str_take_default_on_none = fields.Str(default="foo", take_default_on=[None]) + str_take_default_on_value = fields.Str( + default="foo", take_default_on=["", "bar"] + ) + + @pytest.fixture() + def schema(self): + return self.MySchema() + + @pytest.fixture() + def default_values(self): + return dict( + int_take_default_on_none=42, + int_take_default_on_value=42, + str_take_default_on_none="foo", + str_take_default_on_value="foo", + ) + + def test_default_taken_not_missing(self, schema, default_values): + data = dict( + int_take_default_on_none=None, + int_take_default_on_value=0, + str_take_default_on_none=None, + str_take_default_on_value="bar", + ) + assert schema.dump(data) == default_values + + def test_default_taken_missing(self, schema, default_values): + data = dict() + assert schema.dump(data) == default_values + + def test_default_not_taken(self, schema, default_values): + data = dict( + int_take_default_on_none=-1, + int_take_default_on_value=-1, + str_take_default_on_none="baz", + str_take_default_on_value="baz", + ) + assert schema.dump(data) == data + + class TestDefaults: class MySchema(Schema): int_no_default = fields.Int(allow_none=True) diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 7824dc45d..f87cec35a 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -135,6 +135,17 @@ def test_integer_field_default_set_to_none(self, user): field = fields.Integer(default=None) assert field.serialize("age", user) is None + def test_integer_field_take_default_on(self, user): + field = fields.Integer(default=0, take_default_on=[None, 1]) + user.age = 42 + assert field.serialize("age", user) == 42 + del user.age + assert field.serialize("age", user) == 0 + user.age = None + assert field.serialize("age", user) == 0 + user.age = 1 + assert field.serialize("age", user) == 0 + def test_uuid_field(self, user): user.uuid1 = uuid.UUID("12345678123456781234567812345678") user.uuid2 = None