diff --git a/firedrake/adjoint_utils/blocks/assembly.py b/firedrake/adjoint_utils/blocks/assembly.py index 969be25e9d..c13d79c4ef 100644 --- a/firedrake/adjoint_utils/blocks/assembly.py +++ b/firedrake/adjoint_utils/blocks/assembly.py @@ -1,6 +1,6 @@ import ufl import firedrake -from ufl.domain import as_domain +from ufl.domain import extract_domains from ufl.formatting.ufl2unicode import ufl2unicode from pyadjoint import Block, AdjFloat, create_overloaded_object from firedrake.adjoint_utils.checkpointing import maybe_disk_checkpoint @@ -11,14 +11,15 @@ class AssembleBlock(Block): def __init__(self, form, ad_block_tag=None): super(AssembleBlock, self).__init__(ad_block_tag=ad_block_tag) self.form = form - try: - mesh = as_domain(form) + try: # form can have multiple meshes + meshes = tuple(extract_domains(form)) except AttributeError: - mesh = None + meshes = None - if mesh and not isinstance(self.form, ufl.Interpolate): + if meshes and not isinstance(self.form, ufl.Interpolate): # Interpolation differentiation wrt spatial coordinates is currently not supported. - self.add_dependency(mesh) + for mesh in meshes: # add all meshes as dependency + self.add_dependency(mesh) for c in self.form.coefficients(): self.add_dependency(c, no_duplicates=True) @@ -104,8 +105,7 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, arity_form = len(form.arguments()) if isconstant(c): - mesh = as_domain(self.form) - space = c._ad_function_space(mesh) + space = c.function_space() elif isinstance(c, (firedrake.Function, firedrake.Cofunction)): space = c.function_space() elif isinstance(c, firedrake.MeshGeometry): @@ -162,8 +162,7 @@ def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs, c1_rep = block_variable.saved_output if isconstant(c1): - mesh = as_domain(form) - space = c1._ad_function_space(mesh) + space = c1.function_space() elif isinstance(c1, (firedrake.Function, firedrake.Cofunction)): space = c1.function_space() elif isinstance(c1, firedrake.MeshGeometry): diff --git a/firedrake/adjoint_utils/blocks/solving.py b/firedrake/adjoint_utils/blocks/solving.py index 1ace4c5222..fdf5ee1f6a 100644 --- a/firedrake/adjoint_utils/blocks/solving.py +++ b/firedrake/adjoint_utils/blocks/solving.py @@ -1,5 +1,6 @@ import numpy import ufl +from ufl.domain import extract_domains, extract_unique_domain from ufl import replace from ufl.formatting.ufl2unicode import ufl2unicode from enum import Enum @@ -74,8 +75,17 @@ def __init__(self, lhs, rhs, func, bcs, *args, **kwargs): for bc in self.bcs: self.add_dependency(bc, no_duplicates=True) - mesh = self.lhs.ufl_domain() - self.add_dependency(mesh) + try: # add all meshes as dependency + for mesh in extract_domains(self.lhs): + self.add_dependency(mesh, no_duplicates=True) + except AttributeError: + pass + + if isinstance(self.rhs, (ufl.Form, ufl.Cofunction)): + # add all meshes as dependency + for mesh in extract_domains(self.rhs): + self.add_dependency(mesh, no_duplicates=True) + self._init_solver_parameters(args, kwargs) def _init_solver_parameters(self, args, kwargs): @@ -243,9 +253,8 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, c_rep = block_variable.saved_output if isconstant(c): - mesh = F_form.ufl_domain() trial_function = firedrake.TrialFunction( - c._ad_function_space(mesh) + c.function_space() ) elif isinstance(c, (firedrake.Function, firedrake.Cofunction)): trial_function = firedrake.TrialFunction(c.function_space()) @@ -455,8 +464,7 @@ def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs, return [tmp_bc] if isconstant(c_rep): - mesh = F_form.ufl_domain() - W = c._ad_function_space(mesh) + W = c_rep.function_space() elif isinstance(c, firedrake.MeshGeometry): X = firedrake.SpatialCoordinate(c) W = c._ad_function_space() @@ -759,7 +767,10 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, if isinstance(c, (firedrake.Function, firedrake.Cofunction)): trial_function = firedrake.TrialFunction(c.function_space()) elif isinstance(c, firedrake.Constant): - mesh = F_form.ufl_domain() + try: + mesh = extract_unique_domain(F_form) + except ValueError: + raise ValueError("Expecting a single mesh") trial_function = firedrake.TrialFunction( c._ad_function_space(mesh) ) diff --git a/tests/firedrake/adjoint/test_solving.py b/tests/firedrake/adjoint/test_solving.py index 81dfdc64ad..3dcb5abfb3 100644 --- a/tests/firedrake/adjoint/test_solving.py +++ b/tests/firedrake/adjoint/test_solving.py @@ -314,6 +314,74 @@ def test_two_nonlinear_solves(): assert rf.tape.recompute_count == 5 +@pytest.mark.skipcomplex +def test_multiple_meshes(rg): + mesh1 = UnitSquareMesh(4, 4) + mesh2 = RectangleMesh(nx=4, ny=4, Lx=3, Ly=1, originX=2, originY=0) + + V1 = FunctionSpace(mesh1, "CG", 1) + V2 = FunctionSpace(mesh2, "CG", 1) + V = V1*V2 + + u = Function(V) + u1, u2 = split(u) + v1, v2 = TestFunctions(V) + + f = Function(V).assign(10.) + f1, f2 = split(f) + + a = inner(grad(u1), grad(v1))*dx(mesh1) + inner(grad(u2), grad(v2))*dx(mesh2) + L = inner(f1, v1)*dx(mesh1) + inner(f2, v2)*dx(mesh2) + + bc1 = DirichletBC(V.sub(0), 0, "on_boundary") + bc2 = DirichletBC(V.sub(1), 0, "on_boundary") + bcs = [bc1, bc2] + + solve(a - L == 0, u, bcs) + + J = assemble(u1**2*dx(mesh1) + u2**2*dx(mesh2)) + rf = ReducedFunctional(J, Control(f)) + df = rg.uniform(V) + + assert taylor_test(rf, f, df) > 1.95 + + +@pytest.mark.skipcomplex +def test_submesh(rg): + mesh = UnitSquareMesh(4, 4) + x, y = SpatialCoordinate(mesh) + + DG = FunctionSpace(mesh, "DG", 0) + ind = Function(DG).interpolate(conditional(y > 0.5, 1, 0)) + relabeled_mesh = RelabeledMesh(mesh, [ind], [10]) + submesh = Submesh(relabeled_mesh, 2, 10) + dx_sub = Measure("dx", domain=submesh, intersect_measures=(Measure("dx", relabeled_mesh),)) + + V1 = FunctionSpace(relabeled_mesh, "CG", 1) + V2 = FunctionSpace(submesh, "CG", 1) + V = V1*V2 + + u = Function(V) + u1, u2 = split(u) + v1, v2 = TestFunctions(V) + + f = Function(V1).assign(10.) + + a = inner(grad(u1), grad(v1))*dx(relabeled_mesh) + a += inner(u1 - u2, v2)*dx_sub + L = inner(f, v1)*dx(relabeled_mesh) + + bcs = [DirichletBC(V.sub(0), 0, "on_boundary")] + + solve(a - L == 0, u, bcs) + + J = assemble(u2**2*dx_sub) + rf = ReducedFunctional(J, Control(f)) + df = rg.uniform(V1) + + assert taylor_test(rf, f, df) > 1.95 + + def convergence_rates(E_values, eps_values): from numpy import log r = []