Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions test/test_mixed_function_space_with_mesh_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,22 @@
Measure,
Mesh,
MeshSequence,
MixedFunctionSpace,
SpatialCoordinate,
TestFunction,
TestFunctions,
TrialFunction,
TrialFunctions,
derivative,
div,
dot,
dx,
grad,
hexahedron,
inner,
prism,
pyramid,
quadrilateral,
split,
tetrahedron,
triangle,
Expand Down Expand Up @@ -352,3 +360,37 @@ def test_mixed_function_space_with_mesh_sequence_tetrahedron_triangle():
assert fd.original_form.domain_numbering()[id0.domain] == 0
assert id0.integral_coefficients == set([f])
assert id0.enabled_coefficients == [True]


@pytest.mark.parametrize(
"cells",
[
[quadrilateral, triangle],
[tetrahedron, prism],
[tetrahedron, prism, pyramid],
[tetrahedron, prism, pyramid, hexahedron],
],
)
def test_form(cells):
cells = sorted(cells)
mesh = MeshSequence(
[Mesh(FiniteElement("Lagrange", cell, 1, (3,), identity_pullback, H1)) for cell in cells]
)

elements = {
cell: FiniteElement("Lagrange", cell, 2, (), identity_pullback, H1) for cell in cells
}
V = MixedFunctionSpace(*[FunctionSpace(m, elements[m.ufl_cell()]) for m in mesh])

# Define forms for cell types separately
u_parts = TrialFunctions(V)
v_parts = TestFunctions(V)
dx_parts = [Measure("dx", domain=domain) for domain in mesh.meshes]
a = sum(inner(u_, v_) * dx_ for u_, v_, dx_ in zip(u_parts, v_parts, dx_parts))

# Check that simpler syntax leads to the same form
u = TrialFunction(V)
v = TestFunction(V)
a2 = inner(u, v) * dx

assert a == a2
2 changes: 2 additions & 0 deletions ufl/algorithms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
"preprocess_form",
"read_ufl_file",
"replace",
"replace_function_spaces",
"replace_terminal_data",
"sort_elements",
"strip_terminal_data",
Expand Down Expand Up @@ -82,6 +83,7 @@
compute_form_rhs,
)
from ufl.algorithms.replace import replace
from ufl.algorithms.replace_function_spaces import replace_function_spaces
from ufl.algorithms.signature import compute_form_signature
from ufl.algorithms.strip_terminal_data import replace_terminal_data, strip_terminal_data
from ufl.algorithms.transformer import (
Expand Down
69 changes: 69 additions & 0 deletions ufl/algorithms/replace_function_spaces.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
"""Replace function spaces in all arguments."""

from functools import singledispatchmethod

from ufl.algorithms.map_integrands import map_integrands
from ufl.argument import Argument
from ufl.classes import Expr
from ufl.corealg.dag_traverser import DAGTraverser
from ufl.functionspace import AbstractFunctionSpace


class FunctionSpaceReplacer(DAGTraverser):
"""Dispatcher."""

def __init__(
self,
replacements: dict[AbstractFunctionSpace, AbstractFunctionSpace],
part: int,
compress: bool | None = True,
visited_cache: dict[tuple, Expr] | None = None,
result_cache: dict[Expr, Expr] | None = None,
) -> None:
"""Initialise."""
super().__init__(compress=compress, visited_cache=visited_cache, result_cache=result_cache)
self._dag_traverser_cache: dict[tuple[type, Expr, Expr, Expr], DAGTraverser] = {}
self.replacements = replacements
self.part = part

@singledispatchmethod
def process(self, o: Expr) -> Expr:
"""Process ``o``.

Args:
o: `Expr` to be processed.

Returns:
Processed object.

"""
return super().process(o)

@process.register(Expr)
def _(self, o: Expr) -> Expr:
"""Do nothing."""
return self.reuse_if_untouched(o)

@process.register(Argument)
def _(self, o: Argument) -> Expr:
"""Apply to argument."""
if o.ufl_function_space() in self.replacements:
return Argument(self.replacements[o.ufl_function_space()], o._number, self.part)
return self.reuse_if_untouched(o)


def replace_function_spaces(
integrand: Expr,
replacements: dict[AbstractFunctionSpace, AbstractFunctionSpace],
part: int = 0,
) -> Expr:
"""Replace all instances of function spaces in an integrand.

Args:
integrand: The integrand to do the replacements in.
replacements: A dictionary mapping function spaces to
the spaces they should be replaced with.
part: The part to use in the replacement arguments.
"""
dag_traverser = FunctionSpaceReplacer(replacements, part)
return map_integrands(dag_traverser, integrand)
18 changes: 17 additions & 1 deletion ufl/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,23 @@ def __init__(self, function_space, number, part=None):
raise ValueError("Expecting a FunctionSpace.")

self._ufl_function_space = function_space
self._ufl_shape = function_space.value_shape

if isinstance(function_space, MixedFunctionSpace):
subspaces = function_space.ufl_sub_spaces()
cells = [s.ufl_domain().ufl_cell() for s in subspaces]
if len(set(cells)) != len(cells):
raise ValueError(
"Can only use mixed function spaces where subspaces are all defined "
"on different cell types"
)
self._ufl_shape = subspaces[0].value_shape
for s in subspaces[1:]:
if self._ufl_shape != s.value_shape:
raise ValueError(
"Cannot use mixed function space where subspaces have different shapes"
)
else:
self._ufl_shape = function_space.value_shape

if not isinstance(number, numbers.Integral):
raise ValueError(f"Expecting an int for number, not {number}")
Expand Down
7 changes: 4 additions & 3 deletions ufl/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,10 @@ def __init__(self, domain):
"""Initialise."""
Terminal.__init__(self)
if isinstance(domain, MeshSequence) and len(set(domain)) > 1:
# Can not make GeometricQuantity if multiple domains exist.
raise TypeError(f"Can not create a GeometricQuantity on {domain}")
self._domain = as_domain(domain)
self._domain = domain
# raise TypeError(f"Can not create a GeometricQuantity on {domain}")
else:
self._domain = as_domain(domain)

def ufl_domains(self):
"""Get the UFL domains."""
Expand Down
64 changes: 49 additions & 15 deletions ufl/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
import numbers
from itertools import chain

from ufl.algorithms.traversal import iter_expressions
from ufl.checks import is_true_ufl_scalar
from ufl.constantvalue import as_ufl
from ufl.core.expr import Expr
from ufl.corealg.traversal import unique_pre_traversal
from ufl.domain import AbstractDomain, as_domain, extract_domains
from ufl.protocols import id_or_none

Expand Down Expand Up @@ -430,6 +432,7 @@ def __rmul__(self, integrand):
Integration properties are taken from this Measure object.
"""
# Avoid circular imports
from ufl.algorithms import replace_function_spaces
from ufl.form import Form
from ufl.integral import Integral

Expand Down Expand Up @@ -483,23 +486,54 @@ def __rmul__(self, integrand):
elif len(domains) == 0:
raise ValueError("This integral is missing an integration domain.")
else:
raise ValueError(
"Multiple domains found, making the choice of integration domain ambiguous."
assert len(set(d.ufl_cell() for d in domains)) == len(domains)

if domain is None:
mixed_spaces = {
o.ufl_function_space()
for e in iter_expressions(integrand)
for o in unique_pre_traversal(e)
if hasattr(o, "ufl_function_space")
}
cells = {e.cell for space in mixed_spaces for e in space.ufl_elements()}
cells = sorted(list(cells))

integrals = []
for i, cell in enumerate(cells):
cell_domain = next(filter(lambda d: d.ufl_cell() == cell, domains))
replacements = {
m: next(filter(lambda s: s.ufl_element().cell == cell, m.ufl_sub_spaces()))
for m in mixed_spaces
}
integrals.append(
Integral(
integrand=replace_function_spaces(integrand, replacements, i),
integral_type=self.integral_type(),
domain=cell_domain,
subdomain_id=subdomain_id,
metadata=self.metadata(),
subdomain_data=self.subdomain_data(),
extra_domain_integral_type_map={
m.ufl_domain(): m.integral_type() for m in self.intersect_measures()
},
)
)

# Otherwise create and return a one-integral form
integral = Integral(
integrand=integrand,
integral_type=self.integral_type(),
domain=domain,
subdomain_id=subdomain_id,
metadata=self.metadata(),
subdomain_data=self.subdomain_data(),
extra_domain_integral_type_map={
m.ufl_domain(): m.integral_type() for m in self.intersect_measures()
},
)
return Form([integral])
return sum(Form([integral]) for integral in integrals)
else:
# Create and return a one-integral form
integral = Integral(
integrand=integrand,
integral_type=self.integral_type(),
domain=domain,
subdomain_id=subdomain_id,
metadata=self.metadata(),
subdomain_data=self.subdomain_data(),
extra_domain_integral_type_map={
m.ufl_domain(): m.integral_type() for m in self.intersect_measures()
},
)
return Form([integral])


class MeasureSum:
Expand Down