diff --git a/src/firebase_functions/params.py b/src/firebase_functions/params.py index 32853f0..dcf0e70 100644 --- a/src/firebase_functions/params.py +++ b/src/firebase_functions/params.py @@ -54,7 +54,7 @@ def _quote_if_string(literal: _T) -> _T: return _obj_cel_name(literal) if not isinstance(literal, str) else f'"{literal}"' -_params: dict[str, Expression] = {} +_params: dict[str, "Param[_typing.Any] | SecretParam"] = {} @_dataclasses.dataclass(frozen=True) diff --git a/src/firebase_functions/private/manifest.py b/src/firebase_functions/private/manifest.py index 7672a9f..5ab1543 100644 --- a/src/firebase_functions/private/manifest.py +++ b/src/firebase_functions/private/manifest.py @@ -19,13 +19,26 @@ import dataclasses as _dataclasses import typing as _typing +from collections.abc import Mapping as _Mapping +from collections.abc import Sequence as _Sequence from enum import Enum as _Enum +from zoneinfo import ZoneInfo as _ZoneInfo import typing_extensions as _typing_extensions import firebase_functions.params as _params import firebase_functions.private.util as _util +ManifestParamBase = _params.Param | _params.SecretParam + +SpecValue: _typing.TypeAlias = ( + str | int | float | bool | _util.Sentinel | list["SpecValue"] | dict[str, "SpecValue"] | None +) + + +class _DataclassInstance(_typing.Protocol): + __dataclass_fields__: _typing.ClassVar[dict[str, _dataclasses.Field[object]]] + class SecretEnvironmentVariable(_typing.TypedDict): key: _typing_extensions.Required[str] @@ -148,7 +161,7 @@ class ManifestEndpoint: """A definition of a function as appears in the Manifest.""" entryPoint: str | None = None - region: list[str] | None = _dataclasses.field(default_factory=list[str]) + region: list[str] | None = _dataclasses.field(default_factory=list) platform: str | None = "gcfv2" availableMemoryMb: int | _params.Expression[int] | _util.Sentinel | None = None maxInstances: int | _params.Expression[int] | _util.Sentinel | None = None @@ -161,7 +174,7 @@ class ManifestEndpoint: labels: dict[str, str] | None = None ingressSettings: str | None | _util.Sentinel = None secretEnvironmentVariables: list[SecretEnvironmentVariable] | _util.Sentinel | None = ( - _dataclasses.field(default_factory=list[SecretEnvironmentVariable]) + _dataclasses.field(default_factory=list) ) httpsTrigger: HttpsTrigger | None = None callableTrigger: CallableTrigger | None = None @@ -180,10 +193,8 @@ class ManifestRequiredApi(_typing.TypedDict): class ManifestStack: endpoints: dict[str, ManifestEndpoint] specVersion: str = "v1alpha1" - params: list[_typing.Any] | None = _dataclasses.field(default_factory=list[_typing.Any]) - requiredAPIs: list[ManifestRequiredApi] = _dataclasses.field( - default_factory=list[ManifestRequiredApi] - ) + params: _Sequence[ManifestParamBase] | None = _dataclasses.field(default_factory=list) + requiredAPIs: list[ManifestRequiredApi] = _dataclasses.field(default_factory=list) def _param_input_to_spec( @@ -191,7 +202,7 @@ def _param_input_to_spec( | _params.ResourceInput | _params.SelectInput | _params.MultiSelectInput, -) -> dict[str, _typing.Any]: +) -> dict[str, SpecValue]: if isinstance(param_input, _params.TextInput): return { "text": { @@ -233,8 +244,8 @@ def _param_input_to_spec( return {} -def _param_to_spec(param: _params.Param | _params.SecretParam) -> dict[str, _typing.Any]: - spec_dict: dict[str, _typing.Any] = { +def _param_to_spec(param: ManifestParamBase) -> dict[str, SpecValue]: + spec_dict: dict[str, SpecValue] = { "name": param.name, "label": param.label, "description": param.description, @@ -266,31 +277,40 @@ def _param_to_spec(param: _params.Param | _params.SecretParam) -> dict[str, _typ return _dict_to_spec(spec_dict) -def _object_to_spec(data) -> object: +def _object_to_spec(data: object) -> SpecValue: if isinstance(data, _Enum): - return data.value + result: SpecValue = data.value elif isinstance(data, _params.Expression): - return f"{data}" + result = f"{data}" + elif isinstance(data, _ZoneInfo): + result = data.key elif _dataclasses.is_dataclass(data): - return _dataclass_to_spec(data) - elif isinstance(data, list): - return list(map(_object_to_spec, data)) - elif isinstance(data, dict): - return _dict_to_spec(data) + result = _dataclass_to_spec(_typing.cast(_DataclassInstance, data)) + elif isinstance(data, _Mapping): + result = _dict_to_spec(data) + elif isinstance(data, _Sequence) and not isinstance(data, str | bytes | bytearray): + result = list(map(_object_to_spec, data)) + elif data is None: + result = None + elif isinstance(data, _util.Sentinel): + result = data + elif isinstance(data, str | int | float | bool): + result = data else: - return data + raise TypeError(f"Unsupported manifest spec value: {type(data)!r}") + return result -def _dict_factory(data: list[tuple[str, _typing.Any]]) -> dict: - out: dict = {} +def _dict_factory(data: list[tuple[str, object]]) -> dict[str, SpecValue]: + out: dict[str, SpecValue] = {} for key, value in data: if value is not None: out[key] = _object_to_spec(value) return out -def _dataclass_to_spec(data) -> dict: - out: dict = {} +def _dataclass_to_spec(data: _DataclassInstance) -> dict[str, SpecValue]: + out: dict[str, SpecValue] = {} for field in _dataclasses.fields(data): value = _object_to_spec(getattr(data, field.name)) if value is not None: @@ -298,13 +318,13 @@ def _dataclass_to_spec(data) -> dict: return out -def _dict_to_spec(data: dict) -> dict: +def _dict_to_spec(data: _Mapping[str, object]) -> dict[str, SpecValue]: return _dict_factory(list(data.items())) -def manifest_to_spec_dict(manifest: ManifestStack) -> dict: +def manifest_to_spec_dict(manifest: ManifestStack) -> dict[str, SpecValue]: params = manifest.params - out: dict = _dataclass_to_spec(manifest) + out: dict[str, SpecValue] = _dataclass_to_spec(manifest) if params is not None: out["params"] = list(map(_param_to_spec, params)) return out diff --git a/tests/test_manifest.py b/tests/test_manifest.py index 681d90f..2dfd385 100644 --- a/tests/test_manifest.py +++ b/tests/test_manifest.py @@ -13,9 +13,33 @@ # limitations under the License. """Manifest unit tests.""" +from collections.abc import Mapping as _Mapping +from zoneinfo import ZoneInfo + +from pytest import raises + import firebase_functions.params as _params import firebase_functions.private.manifest as _manifest + +class _CustomMapping(_Mapping): + def __init__(self, data): + self._data = data + + def __getitem__(self, key): + return self._data[key] + + def __iter__(self): + return iter(self._data) + + def __len__(self): + return len(self._data) + + +class _UnsupportedManifestValue: + pass + + full_endpoint = _manifest.ManifestEndpoint( platform="gcfv2", region=["us-west1"], @@ -160,3 +184,23 @@ def test_endpoint_nones(self): assert expressions_actual_dict == expressions_expected_dict, ( "Generated endpoint spec dict does not match expected dict." ) + + def test_object_to_spec_converts_tuple_to_list(self): + """Check tuple values are converted to manifest lists.""" + actual = _manifest._object_to_spec(("hello", 1, True)) + assert actual == ["hello", 1, True] + + def test_object_to_spec_converts_custom_mapping_to_dict(self): + """Check Mapping implementations are converted via dict serialization.""" + actual = _manifest._object_to_spec(_CustomMapping({"hello": "world"})) + assert actual == {"hello": "world"} + + def test_object_to_spec_converts_zoneinfo_to_key(self): + """Check ZoneInfo values serialize to their key.""" + actual = _manifest._object_to_spec(ZoneInfo("America/Los_Angeles")) + assert actual == "America/Los_Angeles" + + def test_object_to_spec_raises_for_unsupported_value(self): + """Check unsupported values fail fast.""" + with raises(TypeError, match="Unsupported manifest spec value"): + _manifest._object_to_spec(_UnsupportedManifestValue()) diff --git a/tests/test_options.py b/tests/test_options.py index 3e9cb52..51fed90 100644 --- a/tests/test_options.py +++ b/tests/test_options.py @@ -15,8 +15,11 @@ Options unit tests. """ -from pytest import raises +import typing as _typing +from pytest import mark, raises + +import firebase_functions.private.manifest as _manifest from firebase_functions import alerts_fn, https_fn, options, params from firebase_functions.alerts import ( app_distribution_fn, @@ -31,6 +34,30 @@ ALERT_SECRET = params.SecretParam("GITLAB_PERSONAL_ACCESS_TOKEN") +class _UnsupportedManifestValue: + pass + + +class _FakePattern: + def __init__(self, value, has_wildcards=False): + self.value = value + self.has_wildcards = has_wildcards + + +def _assert_endpoint_manifest_type_error(endpoint): + with raises(TypeError, match="Unsupported manifest spec value"): + _manifest.manifest_to_spec_dict(_manifest.ManifestStack(endpoints={"test": endpoint})) + + +def _assert_option_builder_type_error(builder): + try: + endpoint = builder() + except TypeError as exc: + assert "Unsupported manifest spec value" in str(exc) + else: + _assert_endpoint_manifest_type_error(endpoint) + + @https_fn.on_call() def asamplefunction(_): return "hello world" @@ -305,3 +332,120 @@ def sample(_event): "crashlytics.newFatalIssue", expect_app_id="app-123", ) + + +@mark.parametrize( + ("builder"), + [ + lambda: options.RuntimeOptions( + region=_typing.cast(str, _UnsupportedManifestValue()) + )._endpoint(func_name="test"), + lambda: options.EventHandlerOptions( + retry=_typing.cast(bool, _UnsupportedManifestValue()) + )._endpoint( + func_name="test", + event_filters={}, + event_type="google.cloud.pubsub.topic.v1.messagePublished", + ), + lambda: options.TaskQueueOptions( + retry_config=options.RetryConfig( + max_attempts=_typing.cast(int, _UnsupportedManifestValue()) + ) + )._endpoint(func_name="test"), + lambda: options.PubSubOptions( + topic=_typing.cast(str, _UnsupportedManifestValue()) + )._endpoint(func_name="test"), + lambda: options.FirebaseAlertOptions( + alert_type=_typing.cast(str, _UnsupportedManifestValue()) + )._endpoint(func_name="test"), + lambda: options.AppDistributionOptions( + app_id=_typing.cast(str, _UnsupportedManifestValue()) + )._endpoint( + func_name="test", + alert_type=options.AlertType.APP_DISTRIBUTION_NEW_TESTER_IOS_DEVICE, + ), + lambda: options.PerformanceOptions( + app_id=_typing.cast(str, _UnsupportedManifestValue()) + )._endpoint( + func_name="test", + alert_type=options.AlertType.PERFORMANCE_THRESHOLD, + ), + lambda: options.CrashlyticsOptions( + app_id=_typing.cast(str, _UnsupportedManifestValue()) + )._endpoint( + func_name="test", + alert_type=options.AlertType.CRASHLYTICS_NEW_FATAL_ISSUE, + ), + lambda: options.BillingOptions()._endpoint( + func_name="test", + alert_type=_typing.cast(str, _UnsupportedManifestValue()), + ), + lambda: options.EventarcTriggerOptions( + event_type="firebase.extensions.storage-resize-images.v1.complete", + filters={"subject": _typing.cast(str, _UnsupportedManifestValue())}, + )._endpoint(func_name="test"), + lambda: options.ScheduleOptions( + schedule=_typing.cast(str, _UnsupportedManifestValue()) + )._endpoint(func_name="test"), + lambda: options.StorageOptions( + bucket=_typing.cast(str, _UnsupportedManifestValue()) + )._endpoint( + func_name="test", + event_type="google.cloud.storage.object.v1.finalized", + ), + lambda: options.DatabaseOptions(reference="/foo/{bar}")._endpoint( + func_name="test", + event_type="google.firebase.database.ref.v1.written", + instance_pattern=_FakePattern(_UnsupportedManifestValue()), + ), + lambda: options.BlockingOptions( + id_token=_typing.cast(bool, _UnsupportedManifestValue()) + )._endpoint( + func_name="test", + event_type="providers/cloud.auth/eventTypes/user.beforeSignIn", + ), + lambda: options.FirestoreOptions(document="foo/{bar}")._endpoint( + func_name="test", + event_type="google.cloud.firestore.document.v1.written", + document_pattern=_FakePattern(_UnsupportedManifestValue(), has_wildcards=True), + ), + lambda: options.HttpsOptions( + invoker=[_typing.cast(str, _UnsupportedManifestValue())] + )._endpoint(func_name="test"), + lambda: options.HttpsOptions( + labels={"broken": _typing.cast(str, _UnsupportedManifestValue())} + )._endpoint( + func_name="test", + callable=True, + ), + lambda: options.DataConnectOptions(service="service")._endpoint( + func_name="test", + event_type="google.firebase.dataconnect.connector.v1.mutationExecuted", + service_pattern=_FakePattern(_UnsupportedManifestValue()), + connector_pattern=_FakePattern("connector"), + operation_pattern=_FakePattern("operation"), + ), + ], + ids=[ + "runtime", + "event_handler", + "task_queue", + "pubsub", + "firebase_alert", + "app_distribution", + "performance", + "crashlytics", + "billing", + "eventarc", + "schedule", + "storage", + "database", + "blocking", + "firestore", + "https", + "callable_https", + "dataconnect", + ], +) +def test_manifest_to_spec_rejects_unsupported_values_across_option_types(builder): + _assert_option_builder_type_error(builder)