Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
967dadc
feat(python): add flatten_union_request_bodies config flag (off by de…
patrickthornton Apr 28, 2026
cccb50a
test(python): cover flatten_union_request_bodies config parsing
patrickthornton Apr 28, 2026
5b37cc3
feat(python): accept flatten_union_request_bodies in python-v2 config…
patrickthornton Apr 28, 2026
1f43750
feat(python): add union-flattening helper for request bodies
patrickthornton Apr 28, 2026
981c72a
feat(python): route discriminated-union request bodies through flatte…
patrickthornton Apr 28, 2026
e5a5b3d
feat(python): unwrap optional + dedup hints in union flattening helper
patrickthornton Apr 28, 2026
a514a89
test(python): add seed fixture for flatten_union_request_bodies
patrickthornton Apr 28, 2026
a9a0f30
fix(python): include union-level extends in flattened parameters
patrickthornton Apr 28, 2026
fe28692
chore(python): add changelog entry for flatten_union_request_bodies
patrickthornton Apr 28, 2026
8755dfc
fix(python): annotate Set[str] in union flattening helper for mypy
patrickthornton Apr 28, 2026
3ae71be
chore(python): apply ruff-format to union flattening helper
patrickthornton Apr 28, 2026
a074eb0
Merge branch 'main' into patrick/python/flatten-union-request-bodies
patrickthornton Apr 28, 2026
328bf44
fix(python): emit flat kwargs in examples for flatten_union_request_b…
patrickthornton Apr 29, 2026
9be77a4
refactor(python): simplify flatten_union_request_bodies snippet helper
patrickthornton Apr 29, 2026
cd9f1f3
refactor(python): tighten union flattening type resolution
patrickthornton May 5, 2026
494be3a
chore(python): satisfy pre-commit hooks
patrickthornton May 5, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ export const BasePythonCustomConfigSchema = z.object({
flat_layout: z.boolean().optional(),
include_legacy_wire_tests: z.boolean().optional(),
inline_request_params: z.boolean().optional(),
flatten_union_request_bodies: z.boolean().optional(),
use_api_name_in_package: z.boolean().optional()
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -633,8 +633,23 @@ export class EndpointSnippetGenerator {
switch (named.type) {
case "alias":
return this.getBodyRequestArgsForTypeReference({ typeReference: named.typeReference, value });
case "enum":
case "discriminatedUnion":
if (this.context.customConfig.flatten_union_request_bodies === true) {
const flatKwargs = this.context.dynamicTypeLiteralMapper.convertDiscriminatedUnionToFlatKwargs({
discriminatedUnion: named,
value
});
if (flatKwargs != null) {
return flatKwargs;
}
}
return [
{
name: REQUEST_BODY_ARG_NAME,
value: this.context.dynamicTypeLiteralMapper.convert({ typeReference, value })
}
];
case "enum":
case "undiscriminatedUnion":
return [
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,35 @@ export class DynamicTypeLiteralMapper {
);
}

public convertDiscriminatedUnionToFlatKwargs({
discriminatedUnion,
value
}: {
discriminatedUnion: FernIr.dynamic.DiscriminatedUnionType;
value: unknown;
}): python.NamedValue[] | undefined {
const discriminatedUnionTypeInstance = this.context.resolveDiscriminatedUnionTypeInstance({
discriminatedUnion,
value
});
if (discriminatedUnionTypeInstance == null) {
return undefined;
}
const unionVariant = discriminatedUnionTypeInstance.singleDiscriminatedUnionType;
const unionProperties = this.convertDiscriminatedUnionProperties({
discriminatedUnionTypeInstance,
unionVariant
});
if (unionProperties == null) {
return undefined;
}
const discriminantEntry: python.NamedValue = {
name: this.context.getPropertyName(discriminatedUnion.discriminant.name),
value: python.TypeInstantiation.str(discriminatedUnionTypeInstance.discriminantValue.wireValue)
};
return [...unionProperties, discriminantEntry];
}

private convertDiscriminatedUnionProperties({
discriminatedUnionTypeInstance,
unionVariant
Expand Down
1 change: 1 addition & 0 deletions generators/python-v2/sdk/src/SdkCustomConfig.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ export const SdkCustomConfigSchema = z.object({
client: ClientConfigSchema.optional(),
client_class_name: z.string().optional(),
inline_request_params: z.boolean().optional(),
flatten_union_request_bodies: z.boolean().optional(),
wire_tests: WireTestsConfigSchema.optional(),
custom_readme_sections: z.array(CustomReadmeSectionSchema).optional(),
/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# yaml-language-server: $schema=../../../../../fern-changes-yml.schema.json

- summary: |
New `flatten_union_request_bodies` config flag (default `false`). When
enabled, endpoints whose referenced request body is a discriminated
`oneOf`/union expose all variants' fields as flat kwargs (Stripe-style),
rather than a single `request: Union[...]` parameter. The discriminator
becomes a `Union[Literal[...], ...]` of every variant value; per-variant
fields with the same wire name are merged (with conflicting types
unioned). Existing customers see no change unless they opt in.
type: feat
Original file line number Diff line number Diff line change
Expand Up @@ -2186,17 +2186,52 @@ def _get_snippet_for_request_reference(
is_optional: bool,
request_parameter_names: Dict[Union[str, ir_types.Name], str],
) -> List[AST.Expression]:
if self.context.custom_config.inline_request_params and not is_optional:
flatten_union = self.context.custom_config.flatten_union_request_bodies
flatten_object = self.context.custom_config.inline_request_params and not is_optional
if flatten_union or flatten_object:
union = example_type_reference.shape.get_as_union()
if union.type == "named":
shape = union.shape.get_as_union()
if shape.type == "alias":
return self._get_snippet_for_request_reference(shape.value, is_optional, request_parameter_names)
if shape.type == "object":
if flatten_union and shape.type == "union":
return self._get_snippet_for_request_reference_flattened_union(shape, request_parameter_names)
if flatten_object and shape.type == "object":
return self._get_snippet_for_request_reference_flattened(shape, request_parameter_names)
return self._get_snippet_for_request_reference_default(example_type_reference)
else:
return self._get_snippet_for_request_reference_default(example_type_reference)
return self._get_snippet_for_request_reference_default(example_type_reference)

def _get_snippet_for_request_reference_flattened_union(
self,
example_union: ir_types.ExampleUnionType,
request_parameter_names: Dict[Union[str, ir_types.Name], str],
) -> List[AST.Expression]:
sut = example_union.single_union_type
discriminator_param = (
request_parameter_names.get(get_name_from_wire_value(example_union.discriminant))
or resolve_name(get_name_from_wire_value(example_union.discriminant)).snake_case.safe_name
)
discriminator_arg = self.snippet_writer.get_snippet_for_named_parameter(
parameter_name=discriminator_param,
value=AST.Expression(f'"{get_wire_value(sut.wire_discriminant_value)}"'),
)

# single_property variants need the field name from the union type declaration,
# which isn't available on the example alone — emit only the discriminator until
# we plumb the declaration in. Customers using single_property unions still get a
# callable example, just without the variant payload.
variant_args: List[AST.Expression] = sut.shape.visit(
same_properties_as_object=lambda example_object: self.snippet_writer.get_snippet_for_object_properties(
example_object.object,
request_parameter_names,
as_request=True,
use_typeddict_request=self.context.custom_config.pydantic_config.use_typeddict_requests,
in_typeddict=False,
),
single_property=lambda _example_type_reference: [],
no_properties=lambda: [],
)

return [*variant_args, discriminator_arg]

def _get_request_parameter_name(self) -> str:
if self.endpoint.sdk_request is None:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Tuple, Union

from ...context.sdk_generator_context import SdkGeneratorContext
from ..constants import DEFAULT_BODY_PARAMETER_VALUE
from .abstract_request_body_parameters import AbstractRequestBodyParameters
from .flattened_request_body_parameter_utils import get_json_body_for_inlined_request
from .union_flattening_utils import build_flattened_union_parameters
from fern_python.codegen import AST
from fern_python.codegen.ast.nodes.declarations.function.named_function_parameter import (
NamedFunctionParameter,
Expand All @@ -24,38 +25,68 @@ def __init__(
self._endpoint = endpoint
self._request_body = request_body
self._context = context
self._type_id = self._get_type_id_from_type_reference(self._request_body.request_body_type)
resolved = self._resolve_type_reference(self._request_body.request_body_type)
self._type_id: Optional[ir_types.TypeId] = resolved[0] if resolved else None
self._is_union_request_body: bool = resolved[1] if resolved else False

self.should_inline_request_parameters = (
context.custom_config.inline_request_params and self._type_id is not None
)
self._are_any_properties_optional = self.should_inline_request_parameters
self.parameter_name_rewrites: Dict[Union[str, ir_types.Name], str] = {}

def _get_type_id_from_type_reference(self, type_reference: ir_types.TypeReference) -> Optional[ir_types.TypeId]:
def _resolve_type_reference(self, type_reference: ir_types.TypeReference) -> Optional[Tuple[ir_types.TypeId, bool]]:
"""Resolve to ``(type_id, is_union)`` if the body inlines into kwargs, else ``None``.

Walks aliases and only reports a type_id for shapes we know how to inline:
``object`` always, ``union`` only when ``flatten_union_request_bodies`` is on.
"""
return type_reference.visit(
container=lambda _: None,
named=lambda t: self._get_type_id_from_type(t.type_id),
named=lambda t: self._resolve_type_id(t.type_id),
primitive=lambda _: None,
unknown=lambda: None,
)

def _get_type_id_from_type(self, type_id: ir_types.TypeId) -> Optional[ir_types.TypeId]:
def _resolve_type_id(self, type_id: ir_types.TypeId) -> Optional[Tuple[ir_types.TypeId, bool]]:
declaration = self._context.pydantic_generator_context.get_declaration_for_type_id(type_id)
return declaration.shape.visit(
alias=lambda atd: self._get_type_id_from_type_reference(atd.alias_of),
alias=lambda atd: self._resolve_type_reference(atd.alias_of),
enum=lambda _: None,
object=lambda o: type_id,
object=lambda _: (type_id, False),
undiscriminated_union=lambda _: None,
union=lambda _: None,
union=lambda _: (type_id, True) if self._context.custom_config.flatten_union_request_bodies else None,
)

def get_parameters(self, names_to_deconflict: Optional[List[str]] = None) -> List[AST.NamedFunctionParameter]:
if self.should_inline_request_parameters:
if self._is_union_request_body:
return self._get_flattened_union_parameters(names_to_deconflict)
return self._get_inlined_request_parameters(names_to_deconflict)
else:
return self._get_default_referenced_parameters()

def _get_flattened_union_parameters(
self, names_to_deconflict: Optional[List[str]]
) -> List[AST.NamedFunctionParameter]:
if self._type_id is None:
raise RuntimeError("Request body type is not defined, this should never happen.")
declaration = self._context.pydantic_generator_context.get_declaration_for_type_id(self._type_id)
union_decl = declaration.shape.visit(
alias=lambda _: None,
enum=lambda _: None,
object=lambda _: None,
undiscriminated_union=lambda _: None,
union=lambda u: u,
)
if union_decl is None:
raise RuntimeError("Expected union shape for flattened union request body.")
return build_flattened_union_parameters(
union_decl=union_decl,
context=self._context,
names_to_deconflict=names_to_deconflict,
)

def _is_type_literal(self, type_reference: ir_types.TypeReference) -> bool:
return self._context.get_literal_value(reference=type_reference) is not None

Expand Down Expand Up @@ -97,6 +128,11 @@ def _get_inlined_request_parameters(
return parameters

def _get_non_parameter_properties(self) -> List[AST.NamedFunctionParameter]:
# Union request bodies emit every wire field via the flattening helper —
# nothing extra goes in the JSON body beyond what get_parameters returned.
if self._is_union_request_body:
return []

non_param_properties = []

parameters: List[AST.NamedFunctionParameter] = self.get_parameters()
Expand Down
Loading