From 1df8481873a195678cd6e13db27f8250e62a1725 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Tue, 21 Apr 2026 11:44:41 +0100 Subject: [PATCH 1/8] make forms work more nicely with MeshSequence --- ufl/algorithms/__init__.py | 2 + ufl/algorithms/replace_function_spaces.py | 61 +++++++++++++++++++++ ufl/argument.py | 18 ++++++- ufl/geometry.py | 7 +-- ufl/measure.py | 64 +++++++++++++++++------ 5 files changed, 133 insertions(+), 19 deletions(-) create mode 100644 ufl/algorithms/replace_function_spaces.py diff --git a/ufl/algorithms/__init__.py b/ufl/algorithms/__init__.py index 541abbfb5..5a3fb22e1 100644 --- a/ufl/algorithms/__init__.py +++ b/ufl/algorithms/__init__.py @@ -46,6 +46,7 @@ "preprocess_form", "read_ufl_file", "replace", + "replace_function_spaces", "replace_terminal_data", "sort_elements", "strip_terminal_data", @@ -82,6 +83,7 @@ compute_form_rhs, ) from ufl.algorithms.replace import replace +from ufl.algorithms.replace_function_spaces import replace_function_spaces from ufl.algorithms.signature import compute_form_signature from ufl.algorithms.strip_terminal_data import replace_terminal_data, strip_terminal_data from ufl.algorithms.transformer import ( diff --git a/ufl/algorithms/replace_function_spaces.py b/ufl/algorithms/replace_function_spaces.py new file mode 100644 index 000000000..2e9896685 --- /dev/null +++ b/ufl/algorithms/replace_function_spaces.py @@ -0,0 +1,61 @@ +"""Replace function spaces in all arguments.""" + +from functools import singledispatchmethod + +from ufl import Argument +from ufl.algorithms.map_integrands import map_integrands +from ufl.classes import Expr +from ufl.corealg.dag_traverser import DAGTraverser + + +class FunctionSpaceReplacer(DAGTraverser): + """Dispatcher.""" + + def __init__( + self, + replacements: dict, + part: int = 0, + compress: bool | None = True, + visited_cache: dict[tuple, Expr] | None = None, + result_cache: dict[Expr, Expr] | None = None, + ) -> None: + """Initialise.""" + super().__init__(compress=compress, visited_cache=visited_cache, result_cache=result_cache) + self._dag_traverser_cache: dict[tuple[type, Expr, Expr, Expr], DAGTraverser] = {} + self.replacements = replacements + self.part = part + + @singledispatchmethod + def process(self, o: Expr) -> Expr: + """Process ``o``. + + Args: + o: `Expr` to be processed. + + Returns: + Processed object. + + """ + return super().process(o) + + @process.register(Expr) + def _(self, o: Expr) -> Expr: + """Do nothing.""" + return self.reuse_if_untouched(o) + + @process.register(Argument) + def _(self, o: Argument) -> Argument: + """Apply to argument.""" + if o.ufl_function_space() in self.replacements: + return Argument(self.replacements[o.ufl_function_space()], o._number, self.part) + return self.reuse_if_untouched(o) + + +def replace_function_spaces(integrand, replacements: dict, offset): + """Replace all instances of function spaces in an integrand. + + replacements should be a dictionary mapping from function spaces to + what the spaces should be replaced with. + """ + dag_traverser = FunctionSpaceReplacer(replacements, offset) + return map_integrands(dag_traverser, integrand) diff --git a/ufl/argument.py b/ufl/argument.py index 043f23dc8..e744f6517 100644 --- a/ufl/argument.py +++ b/ufl/argument.py @@ -46,7 +46,23 @@ def __init__(self, function_space, number, part=None): raise ValueError("Expecting a FunctionSpace.") self._ufl_function_space = function_space - self._ufl_shape = function_space.value_shape + + if isinstance(function_space, MixedFunctionSpace): + subspaces = function_space.ufl_sub_spaces() + cells = [s.ufl_domain().ufl_cell() for s in subspaces] + if len(set(cells)) != len(cells): + raise ValueError( + "Can only use mixed function spaces where subspaces are all defined " + "on different cell types" + ) + self._ufl_shape = subspaces[0].value_shape + for s in subspaces[1:]: + if self._ufl_shape != s.value_shape: + raise ValueError( + "Cannot use mixed function space where subspaces have different shapes" + ) + else: + self._ufl_shape = function_space.value_shape if not isinstance(number, numbers.Integral): raise ValueError(f"Expecting an int for number, not {number}") diff --git a/ufl/geometry.py b/ufl/geometry.py index 120d6ac42..716579cb4 100644 --- a/ufl/geometry.py +++ b/ufl/geometry.py @@ -117,9 +117,10 @@ def __init__(self, domain): """Initialise.""" Terminal.__init__(self) if isinstance(domain, MeshSequence) and len(set(domain)) > 1: - # Can not make GeometricQuantity if multiple domains exist. - raise TypeError(f"Can not create a GeometricQuantity on {domain}") - self._domain = as_domain(domain) + self._domain = domain + # raise TypeError(f"Can not create a GeometricQuantity on {domain}") + else: + self._domain = as_domain(domain) def ufl_domains(self): """Get the UFL domains.""" diff --git a/ufl/measure.py b/ufl/measure.py index 6726be200..98501bb2f 100644 --- a/ufl/measure.py +++ b/ufl/measure.py @@ -483,23 +483,57 @@ def __rmul__(self, integrand): elif len(domains) == 0: raise ValueError("This integral is missing an integration domain.") else: - raise ValueError( - "Multiple domains found, making the choice of integration domain ambiguous." + assert len(set(d.ufl_cell() for d in domains)) == len(domains) + + if domain is None: + from ufl.algorithms import replace_function_spaces + from ufl.algorithms.traversal import iter_expressions + from ufl.corealg.traversal import unique_pre_traversal + + mixed_spaces = { + o.ufl_function_space() + for e in iter_expressions(integrand) + for o in unique_pre_traversal(e) + if hasattr(o, "ufl_function_space") + } + cells = {e.cell_type for space in mixed_spaces for e in space.ufl_elements()} + + integrals = [] + for i, cell in enumerate(cells): + cell_domain = next(d for d in domains if d.ufl_cell().cellname == cell.name) + replacements = { + m: next(s for s in m.ufl_sub_spaces() if s.ufl_element().cell_type == cell) + for m in mixed_spaces + } + integrals.append( + Integral( + integrand=replace_function_spaces(integrand, replacements, i), + integral_type=self.integral_type(), + domain=cell_domain, + subdomain_id=subdomain_id, + metadata=self.metadata(), + subdomain_data=self.subdomain_data(), + extra_domain_integral_type_map={ + m.ufl_domain(): m.integral_type() for m in self.intersect_measures() + }, + ) ) - # Otherwise create and return a one-integral form - integral = Integral( - integrand=integrand, - integral_type=self.integral_type(), - domain=domain, - subdomain_id=subdomain_id, - metadata=self.metadata(), - subdomain_data=self.subdomain_data(), - extra_domain_integral_type_map={ - m.ufl_domain(): m.integral_type() for m in self.intersect_measures() - }, - ) - return Form([integral]) + return sum(Form([integral]) for integral in integrals) + else: + # Create and return a one-integral form + integral = Integral( + integrand=integrand, + integral_type=self.integral_type(), + domain=domain, + subdomain_id=subdomain_id, + metadata=self.metadata(), + subdomain_data=self.subdomain_data(), + extra_domain_integral_type_map={ + m.ufl_domain(): m.integral_type() for m in self.intersect_measures() + }, + ) + return Form([integral]) class MeasureSum: From 523d53e90fbc481534d4b479ac4268971046cdf0 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Tue, 21 Apr 2026 12:03:48 +0100 Subject: [PATCH 2/8] typing --- ufl/algorithms/replace_function_spaces.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ufl/algorithms/replace_function_spaces.py b/ufl/algorithms/replace_function_spaces.py index 2e9896685..a9330dc5b 100644 --- a/ufl/algorithms/replace_function_spaces.py +++ b/ufl/algorithms/replace_function_spaces.py @@ -44,7 +44,7 @@ def _(self, o: Expr) -> Expr: return self.reuse_if_untouched(o) @process.register(Argument) - def _(self, o: Argument) -> Argument: + def _(self, o: Argument) -> Expr: """Apply to argument.""" if o.ufl_function_space() in self.replacements: return Argument(self.replacements[o.ufl_function_space()], o._number, self.part) From 31fd0fc212de749b946b1f973856b588cc286d6c Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Tue, 21 Apr 2026 12:06:02 +0100 Subject: [PATCH 3/8] move imports up --- ufl/measure.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/ufl/measure.py b/ufl/measure.py index 98501bb2f..4ab4b878a 100644 --- a/ufl/measure.py +++ b/ufl/measure.py @@ -12,9 +12,11 @@ import numbers from itertools import chain +from ufl.algorithms.traversal import iter_expressions from ufl.checks import is_true_ufl_scalar from ufl.constantvalue import as_ufl from ufl.core.expr import Expr +from ufl.corealg.traversal import unique_pre_traversal from ufl.domain import AbstractDomain, as_domain, extract_domains from ufl.protocols import id_or_none @@ -430,6 +432,7 @@ def __rmul__(self, integrand): Integration properties are taken from this Measure object. """ # Avoid circular imports + from ufl.algorithms import replace_function_spaces from ufl.form import Form from ufl.integral import Integral @@ -486,10 +489,6 @@ def __rmul__(self, integrand): assert len(set(d.ufl_cell() for d in domains)) == len(domains) if domain is None: - from ufl.algorithms import replace_function_spaces - from ufl.algorithms.traversal import iter_expressions - from ufl.corealg.traversal import unique_pre_traversal - mixed_spaces = { o.ufl_function_space() for e in iter_expressions(integrand) From 99f248a6da855e42939b55b947827430ec19de4b Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Tue, 21 Apr 2026 12:08:53 +0100 Subject: [PATCH 4/8] use filter --- ufl/measure.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ufl/measure.py b/ufl/measure.py index 4ab4b878a..4f45f59c1 100644 --- a/ufl/measure.py +++ b/ufl/measure.py @@ -499,9 +499,9 @@ def __rmul__(self, integrand): integrals = [] for i, cell in enumerate(cells): - cell_domain = next(d for d in domains if d.ufl_cell().cellname == cell.name) + cell_domain = next(filter(lambda d: d.ufl_cell().cellname == cell.name, domains)) replacements = { - m: next(s for s in m.ufl_sub_spaces() if s.ufl_element().cell_type == cell) + m: next(filter(lambda s: s.ufl_element().cell_type == cell, m.ufl_sub_spaces())) for m in mixed_spaces } integrals.append( From ce2d2d85f0d746db1c7796a60d0d6f0db30ae4fc Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Tue, 21 Apr 2026 13:19:34 +0100 Subject: [PATCH 5/8] skip baffling mypy error --- ufl/formoperators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ufl/formoperators.py b/ufl/formoperators.py index 59d6ded47..657afed3b 100644 --- a/ufl/formoperators.py +++ b/ufl/formoperators.py @@ -13,7 +13,7 @@ from ufl.action import Action from ufl.adjoint import Adjoint -from ufl.algorithms import ( +from ufl.algorithms import ( # type: ignore compute_energy_norm, compute_form_action, compute_form_adjoint, From 8b26f4c7534215ecf4d97647d78ffb3e6f59052b Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Tue, 21 Apr 2026 13:19:50 +0100 Subject: [PATCH 6/8] docs --- ufl/algorithms/replace_function_spaces.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/ufl/algorithms/replace_function_spaces.py b/ufl/algorithms/replace_function_spaces.py index a9330dc5b..881d5ab01 100644 --- a/ufl/algorithms/replace_function_spaces.py +++ b/ufl/algorithms/replace_function_spaces.py @@ -6,6 +6,7 @@ from ufl.algorithms.map_integrands import map_integrands from ufl.classes import Expr from ufl.corealg.dag_traverser import DAGTraverser +from ufl.functionspace import AbstractFunctionSpace class FunctionSpaceReplacer(DAGTraverser): @@ -13,8 +14,8 @@ class FunctionSpaceReplacer(DAGTraverser): def __init__( self, - replacements: dict, - part: int = 0, + replacements: dict[AbstractFunctionSpace, AbstractFunctionSpace], + part: int, compress: bool | None = True, visited_cache: dict[tuple, Expr] | None = None, result_cache: dict[Expr, Expr] | None = None, @@ -51,11 +52,18 @@ def _(self, o: Argument) -> Expr: return self.reuse_if_untouched(o) -def replace_function_spaces(integrand, replacements: dict, offset): +def replace_function_spaces( + integrand: Expr, + replacements: dict[AbstractFunctionSpace, AbstractFunctionSpace], + part: int = 0, +) -> Expr: """Replace all instances of function spaces in an integrand. - replacements should be a dictionary mapping from function spaces to - what the spaces should be replaced with. + Args: + integrand: The integrand to do the replacements in. + replacements: A dictionary mapping function spaces to + the spaces they should be replaced with. + part: The part to use in the replacement arguments. """ - dag_traverser = FunctionSpaceReplacer(replacements, offset) + dag_traverser = FunctionSpaceReplacer(replacements, part) return map_integrands(dag_traverser, integrand) From a4e4e8d46f65f6e08a2aa3adb3a13db08afa594d Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Tue, 21 Apr 2026 13:34:09 +0100 Subject: [PATCH 7/8] .argument --- ufl/algorithms/replace_function_spaces.py | 2 +- ufl/formoperators.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ufl/algorithms/replace_function_spaces.py b/ufl/algorithms/replace_function_spaces.py index 881d5ab01..ae35bbf10 100644 --- a/ufl/algorithms/replace_function_spaces.py +++ b/ufl/algorithms/replace_function_spaces.py @@ -2,7 +2,7 @@ from functools import singledispatchmethod -from ufl import Argument +from ufl.argument import Argument from ufl.algorithms.map_integrands import map_integrands from ufl.classes import Expr from ufl.corealg.dag_traverser import DAGTraverser diff --git a/ufl/formoperators.py b/ufl/formoperators.py index 657afed3b..59d6ded47 100644 --- a/ufl/formoperators.py +++ b/ufl/formoperators.py @@ -13,7 +13,7 @@ from ufl.action import Action from ufl.adjoint import Adjoint -from ufl.algorithms import ( # type: ignore +from ufl.algorithms import ( compute_energy_norm, compute_form_action, compute_form_adjoint, From 56f6f43f61b0298b9f6c3fec7aa7b4e78737d1af Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Tue, 21 Apr 2026 16:53:47 +0100 Subject: [PATCH 8/8] add test that long and short versions of form are equal --- ...mixed_function_space_with_mesh_sequence.py | 42 +++++++++++++++++++ ufl/algorithms/replace_function_spaces.py | 2 +- ufl/measure.py | 7 ++-- 3 files changed, 47 insertions(+), 4 deletions(-) diff --git a/test/test_mixed_function_space_with_mesh_sequence.py b/test/test_mixed_function_space_with_mesh_sequence.py index 6f39bb841..2d8ef6ee6 100644 --- a/test/test_mixed_function_space_with_mesh_sequence.py +++ b/test/test_mixed_function_space_with_mesh_sequence.py @@ -10,14 +10,22 @@ Measure, Mesh, MeshSequence, + MixedFunctionSpace, SpatialCoordinate, TestFunction, + TestFunctions, TrialFunction, + TrialFunctions, derivative, div, dot, + dx, grad, + hexahedron, inner, + prism, + pyramid, + quadrilateral, split, tetrahedron, triangle, @@ -352,3 +360,37 @@ def test_mixed_function_space_with_mesh_sequence_tetrahedron_triangle(): assert fd.original_form.domain_numbering()[id0.domain] == 0 assert id0.integral_coefficients == set([f]) assert id0.enabled_coefficients == [True] + + +@pytest.mark.parametrize( + "cells", + [ + [quadrilateral, triangle], + [tetrahedron, prism], + [tetrahedron, prism, pyramid], + [tetrahedron, prism, pyramid, hexahedron], + ], +) +def test_form(cells): + cells = sorted(cells) + mesh = MeshSequence( + [Mesh(FiniteElement("Lagrange", cell, 1, (3,), identity_pullback, H1)) for cell in cells] + ) + + elements = { + cell: FiniteElement("Lagrange", cell, 2, (), identity_pullback, H1) for cell in cells + } + V = MixedFunctionSpace(*[FunctionSpace(m, elements[m.ufl_cell()]) for m in mesh]) + + # Define forms for cell types separately + u_parts = TrialFunctions(V) + v_parts = TestFunctions(V) + dx_parts = [Measure("dx", domain=domain) for domain in mesh.meshes] + a = sum(inner(u_, v_) * dx_ for u_, v_, dx_ in zip(u_parts, v_parts, dx_parts)) + + # Check that simpler syntax leads to the same form + u = TrialFunction(V) + v = TestFunction(V) + a2 = inner(u, v) * dx + + assert a == a2 diff --git a/ufl/algorithms/replace_function_spaces.py b/ufl/algorithms/replace_function_spaces.py index ae35bbf10..ed43c2457 100644 --- a/ufl/algorithms/replace_function_spaces.py +++ b/ufl/algorithms/replace_function_spaces.py @@ -2,8 +2,8 @@ from functools import singledispatchmethod -from ufl.argument import Argument from ufl.algorithms.map_integrands import map_integrands +from ufl.argument import Argument from ufl.classes import Expr from ufl.corealg.dag_traverser import DAGTraverser from ufl.functionspace import AbstractFunctionSpace diff --git a/ufl/measure.py b/ufl/measure.py index 4f45f59c1..70db42498 100644 --- a/ufl/measure.py +++ b/ufl/measure.py @@ -495,13 +495,14 @@ def __rmul__(self, integrand): for o in unique_pre_traversal(e) if hasattr(o, "ufl_function_space") } - cells = {e.cell_type for space in mixed_spaces for e in space.ufl_elements()} + cells = {e.cell for space in mixed_spaces for e in space.ufl_elements()} + cells = sorted(list(cells)) integrals = [] for i, cell in enumerate(cells): - cell_domain = next(filter(lambda d: d.ufl_cell().cellname == cell.name, domains)) + cell_domain = next(filter(lambda d: d.ufl_cell() == cell, domains)) replacements = { - m: next(filter(lambda s: s.ufl_element().cell_type == cell, m.ufl_sub_spaces())) + m: next(filter(lambda s: s.ufl_element().cell == cell, m.ufl_sub_spaces())) for m in mixed_spaces } integrals.append(