Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/firebase_functions/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def _quote_if_string(literal: _T) -> _T:
return _obj_cel_name(literal) if not isinstance(literal, str) else f'"{literal}"'


_params: dict[str, Expression] = {}
_params: dict[str, "Param[_typing.Any] | SecretParam"] = {}


@_dataclasses.dataclass(frozen=True)
Expand Down
70 changes: 45 additions & 25 deletions src/firebase_functions/private/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,26 @@

import dataclasses as _dataclasses
import typing as _typing
from collections.abc import Mapping as _Mapping
from collections.abc import Sequence as _Sequence
from enum import Enum as _Enum
from zoneinfo import ZoneInfo as _ZoneInfo

import typing_extensions as _typing_extensions

import firebase_functions.params as _params
import firebase_functions.private.util as _util

ManifestParamBase = _params.Param | _params.SecretParam

SpecValue: _typing.TypeAlias = (
str | int | float | bool | _util.Sentinel | list["SpecValue"] | dict[str, "SpecValue"] | None
)


class _DataclassInstance(_typing.Protocol):
__dataclass_fields__: _typing.ClassVar[dict[str, _dataclasses.Field[object]]]


class SecretEnvironmentVariable(_typing.TypedDict):
key: _typing_extensions.Required[str]
Expand Down Expand Up @@ -148,7 +161,7 @@ class ManifestEndpoint:
"""A definition of a function as appears in the Manifest."""

entryPoint: str | None = None
region: list[str] | None = _dataclasses.field(default_factory=list[str])
region: list[str] | None = _dataclasses.field(default_factory=list)
platform: str | None = "gcfv2"
availableMemoryMb: int | _params.Expression[int] | _util.Sentinel | None = None
maxInstances: int | _params.Expression[int] | _util.Sentinel | None = None
Expand All @@ -161,7 +174,7 @@ class ManifestEndpoint:
labels: dict[str, str] | None = None
ingressSettings: str | None | _util.Sentinel = None
secretEnvironmentVariables: list[SecretEnvironmentVariable] | _util.Sentinel | None = (
_dataclasses.field(default_factory=list[SecretEnvironmentVariable])
_dataclasses.field(default_factory=list)
)
httpsTrigger: HttpsTrigger | None = None
callableTrigger: CallableTrigger | None = None
Expand All @@ -180,18 +193,16 @@ class ManifestRequiredApi(_typing.TypedDict):
class ManifestStack:
endpoints: dict[str, ManifestEndpoint]
specVersion: str = "v1alpha1"
params: list[_typing.Any] | None = _dataclasses.field(default_factory=list[_typing.Any])
requiredAPIs: list[ManifestRequiredApi] = _dataclasses.field(
default_factory=list[ManifestRequiredApi]
)
params: _Sequence[ManifestParamBase] | None = _dataclasses.field(default_factory=list)
requiredAPIs: list[ManifestRequiredApi] = _dataclasses.field(default_factory=list)


def _param_input_to_spec(
param_input: _params.TextInput
| _params.ResourceInput
| _params.SelectInput
| _params.MultiSelectInput,
) -> dict[str, _typing.Any]:
) -> dict[str, SpecValue]:
if isinstance(param_input, _params.TextInput):
return {
"text": {
Expand Down Expand Up @@ -233,8 +244,8 @@ def _param_input_to_spec(
return {}


def _param_to_spec(param: _params.Param | _params.SecretParam) -> dict[str, _typing.Any]:
spec_dict: dict[str, _typing.Any] = {
def _param_to_spec(param: ManifestParamBase) -> dict[str, SpecValue]:
spec_dict: dict[str, SpecValue] = {
"name": param.name,
"label": param.label,
"description": param.description,
Expand Down Expand Up @@ -266,45 +277,54 @@ def _param_to_spec(param: _params.Param | _params.SecretParam) -> dict[str, _typ
return _dict_to_spec(spec_dict)


def _object_to_spec(data) -> object:
def _object_to_spec(data: object) -> SpecValue:
if isinstance(data, _Enum):
return data.value
result: SpecValue = data.value
elif isinstance(data, _params.Expression):
return f"{data}"
result = f"{data}"
elif isinstance(data, _ZoneInfo):
result = data.key
elif _dataclasses.is_dataclass(data):
return _dataclass_to_spec(data)
elif isinstance(data, list):
return list(map(_object_to_spec, data))
elif isinstance(data, dict):
return _dict_to_spec(data)
result = _dataclass_to_spec(_typing.cast(_DataclassInstance, data))
elif isinstance(data, _Mapping):
result = _dict_to_spec(data)
elif isinstance(data, _Sequence) and not isinstance(data, str | bytes | bytearray):
result = list(map(_object_to_spec, data))
elif data is None:
result = None
elif isinstance(data, _util.Sentinel):
result = data
elif isinstance(data, str | int | float | bool):
result = data
else:
return data
raise TypeError(f"Unsupported manifest spec value: {type(data)!r}")
Comment thread
IzaakGough marked this conversation as resolved.
return result


def _dict_factory(data: list[tuple[str, _typing.Any]]) -> dict:
out: dict = {}
def _dict_factory(data: list[tuple[str, object]]) -> dict[str, SpecValue]:
out: dict[str, SpecValue] = {}
for key, value in data:
if value is not None:
out[key] = _object_to_spec(value)
return out


def _dataclass_to_spec(data) -> dict:
out: dict = {}
def _dataclass_to_spec(data: _DataclassInstance) -> dict[str, SpecValue]:
out: dict[str, SpecValue] = {}
for field in _dataclasses.fields(data):
value = _object_to_spec(getattr(data, field.name))
if value is not None:
out[field.name] = value
return out


def _dict_to_spec(data: dict) -> dict:
def _dict_to_spec(data: _Mapping[str, object]) -> dict[str, SpecValue]:
return _dict_factory(list(data.items()))


def manifest_to_spec_dict(manifest: ManifestStack) -> dict:
def manifest_to_spec_dict(manifest: ManifestStack) -> dict[str, SpecValue]:
params = manifest.params
out: dict = _dataclass_to_spec(manifest)
out: dict[str, SpecValue] = _dataclass_to_spec(manifest)
if params is not None:
out["params"] = list(map(_param_to_spec, params))
return out
44 changes: 44 additions & 0 deletions tests/test_manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,33 @@
# limitations under the License.
"""Manifest unit tests."""

from collections.abc import Mapping as _Mapping
from zoneinfo import ZoneInfo

from pytest import raises

import firebase_functions.params as _params
import firebase_functions.private.manifest as _manifest


class _CustomMapping(_Mapping):
def __init__(self, data):
self._data = data

def __getitem__(self, key):
return self._data[key]

def __iter__(self):
return iter(self._data)

def __len__(self):
return len(self._data)


class _UnsupportedManifestValue:
pass


full_endpoint = _manifest.ManifestEndpoint(
platform="gcfv2",
region=["us-west1"],
Expand Down Expand Up @@ -160,3 +184,23 @@ def test_endpoint_nones(self):
assert expressions_actual_dict == expressions_expected_dict, (
"Generated endpoint spec dict does not match expected dict."
)

def test_object_to_spec_converts_tuple_to_list(self):
"""Check tuple values are converted to manifest lists."""
actual = _manifest._object_to_spec(("hello", 1, True))
assert actual == ["hello", 1, True]

def test_object_to_spec_converts_custom_mapping_to_dict(self):
"""Check Mapping implementations are converted via dict serialization."""
actual = _manifest._object_to_spec(_CustomMapping({"hello": "world"}))
assert actual == {"hello": "world"}

def test_object_to_spec_converts_zoneinfo_to_key(self):
"""Check ZoneInfo values serialize to their key."""
actual = _manifest._object_to_spec(ZoneInfo("America/Los_Angeles"))
assert actual == "America/Los_Angeles"

def test_object_to_spec_raises_for_unsupported_value(self):
"""Check unsupported values fail fast."""
with raises(TypeError, match="Unsupported manifest spec value"):
_manifest._object_to_spec(_UnsupportedManifestValue())
146 changes: 145 additions & 1 deletion tests/test_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@
Options unit tests.
"""

from pytest import raises
import typing as _typing

from pytest import mark, raises

import firebase_functions.private.manifest as _manifest
from firebase_functions import alerts_fn, https_fn, options, params
from firebase_functions.alerts import (
app_distribution_fn,
Expand All @@ -31,6 +34,30 @@
ALERT_SECRET = params.SecretParam("GITLAB_PERSONAL_ACCESS_TOKEN")


class _UnsupportedManifestValue:
pass


class _FakePattern:
def __init__(self, value, has_wildcards=False):
self.value = value
self.has_wildcards = has_wildcards


def _assert_endpoint_manifest_type_error(endpoint):
with raises(TypeError, match="Unsupported manifest spec value"):
_manifest.manifest_to_spec_dict(_manifest.ManifestStack(endpoints={"test": endpoint}))


def _assert_option_builder_type_error(builder):
try:
endpoint = builder()
except TypeError as exc:
assert "Unsupported manifest spec value" in str(exc)
else:
_assert_endpoint_manifest_type_error(endpoint)


@https_fn.on_call()
def asamplefunction(_):
return "hello world"
Expand Down Expand Up @@ -305,3 +332,120 @@ def sample(_event):
"crashlytics.newFatalIssue",
expect_app_id="app-123",
)


@mark.parametrize(
("builder"),
[
lambda: options.RuntimeOptions(
region=_typing.cast(str, _UnsupportedManifestValue())
)._endpoint(func_name="test"),
lambda: options.EventHandlerOptions(
retry=_typing.cast(bool, _UnsupportedManifestValue())
)._endpoint(
func_name="test",
event_filters={},
event_type="google.cloud.pubsub.topic.v1.messagePublished",
),
lambda: options.TaskQueueOptions(
retry_config=options.RetryConfig(
max_attempts=_typing.cast(int, _UnsupportedManifestValue())
)
)._endpoint(func_name="test"),
lambda: options.PubSubOptions(
topic=_typing.cast(str, _UnsupportedManifestValue())
)._endpoint(func_name="test"),
lambda: options.FirebaseAlertOptions(
alert_type=_typing.cast(str, _UnsupportedManifestValue())
)._endpoint(func_name="test"),
lambda: options.AppDistributionOptions(
app_id=_typing.cast(str, _UnsupportedManifestValue())
)._endpoint(
func_name="test",
alert_type=options.AlertType.APP_DISTRIBUTION_NEW_TESTER_IOS_DEVICE,
),
lambda: options.PerformanceOptions(
app_id=_typing.cast(str, _UnsupportedManifestValue())
)._endpoint(
func_name="test",
alert_type=options.AlertType.PERFORMANCE_THRESHOLD,
),
lambda: options.CrashlyticsOptions(
app_id=_typing.cast(str, _UnsupportedManifestValue())
)._endpoint(
func_name="test",
alert_type=options.AlertType.CRASHLYTICS_NEW_FATAL_ISSUE,
),
lambda: options.BillingOptions()._endpoint(
func_name="test",
alert_type=_typing.cast(str, _UnsupportedManifestValue()),
),
lambda: options.EventarcTriggerOptions(
event_type="firebase.extensions.storage-resize-images.v1.complete",
filters={"subject": _typing.cast(str, _UnsupportedManifestValue())},
)._endpoint(func_name="test"),
lambda: options.ScheduleOptions(
schedule=_typing.cast(str, _UnsupportedManifestValue())
)._endpoint(func_name="test"),
lambda: options.StorageOptions(
bucket=_typing.cast(str, _UnsupportedManifestValue())
)._endpoint(
func_name="test",
event_type="google.cloud.storage.object.v1.finalized",
),
lambda: options.DatabaseOptions(reference="/foo/{bar}")._endpoint(
func_name="test",
event_type="google.firebase.database.ref.v1.written",
instance_pattern=_FakePattern(_UnsupportedManifestValue()),
),
lambda: options.BlockingOptions(
id_token=_typing.cast(bool, _UnsupportedManifestValue())
)._endpoint(
func_name="test",
event_type="providers/cloud.auth/eventTypes/user.beforeSignIn",
),
lambda: options.FirestoreOptions(document="foo/{bar}")._endpoint(
func_name="test",
event_type="google.cloud.firestore.document.v1.written",
document_pattern=_FakePattern(_UnsupportedManifestValue(), has_wildcards=True),
),
lambda: options.HttpsOptions(
invoker=[_typing.cast(str, _UnsupportedManifestValue())]
)._endpoint(func_name="test"),
lambda: options.HttpsOptions(
labels={"broken": _typing.cast(str, _UnsupportedManifestValue())}
)._endpoint(
func_name="test",
callable=True,
),
lambda: options.DataConnectOptions(service="service")._endpoint(
func_name="test",
event_type="google.firebase.dataconnect.connector.v1.mutationExecuted",
service_pattern=_FakePattern(_UnsupportedManifestValue()),
connector_pattern=_FakePattern("connector"),
operation_pattern=_FakePattern("operation"),
),
],
ids=[
"runtime",
"event_handler",
"task_queue",
"pubsub",
"firebase_alert",
"app_distribution",
"performance",
"crashlytics",
"billing",
"eventarc",
"schedule",
"storage",
"database",
"blocking",
"firestore",
"https",
"callable_https",
"dataconnect",
],
)
def test_manifest_to_spec_rejects_unsupported_values_across_option_types(builder):
_assert_option_builder_type_error(builder)
Loading