Skip to content
24 changes: 9 additions & 15 deletions firedrake/adjoint_utils/blocks/assembly.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
import ufl
import firedrake
from ufl.domain import as_domain
from ufl.domain import extract_domains
from ufl.formatting.ufl2unicode import ufl2unicode
from pyadjoint import Block, AdjFloat, create_overloaded_object
from firedrake.adjoint_utils.checkpointing import maybe_disk_checkpoint
from .block_utils import isconstant


class AssembleBlock(Block):
def __init__(self, form, ad_block_tag=None):
super(AssembleBlock, self).__init__(ad_block_tag=ad_block_tag)
self.form = form
try:
mesh = as_domain(form)
try: # form can have multiple meshes
meshes = tuple(extract_domains(form))
except AttributeError:
mesh = None
meshes = None

if mesh and not isinstance(self.form, ufl.Interpolate):
if meshes and not isinstance(self.form, ufl.Interpolate):
# Interpolation differentiation wrt spatial coordinates is currently not supported.
self.add_dependency(mesh)
for mesh in meshes: # add all meshes as dependency
self.add_dependency(mesh, no_duplicates=True)

for c in self.form.coefficients():
self.add_dependency(c, no_duplicates=True)
Expand Down Expand Up @@ -103,10 +103,7 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,

arity_form = len(form.arguments())

if isconstant(c):
mesh = as_domain(self.form)
space = c._ad_function_space(mesh)
elif isinstance(c, (firedrake.Function, firedrake.Cofunction)):
if isinstance(c, (firedrake.Function, firedrake.Cofunction)):
space = c.function_space()
elif isinstance(c, firedrake.MeshGeometry):
c_rep = firedrake.SpatialCoordinate(c_rep)
Expand Down Expand Up @@ -161,10 +158,7 @@ def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs,
c1 = block_variable.output
c1_rep = block_variable.saved_output

if isconstant(c1):
mesh = as_domain(form)
space = c1._ad_function_space(mesh)
elif isinstance(c1, (firedrake.Function, firedrake.Cofunction)):
if isinstance(c1, (firedrake.Function, firedrake.Cofunction)):
space = c1.function_space()
elif isinstance(c1, firedrake.MeshGeometry):
c1_rep = firedrake.SpatialCoordinate(c1)
Expand Down
32 changes: 18 additions & 14 deletions firedrake/adjoint_utils/blocks/solving.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy
import ufl
from ufl.domain import extract_domains, extract_unique_domain
from ufl import replace
from ufl.formatting.ufl2unicode import ufl2unicode
from enum import Enum
Expand All @@ -8,7 +9,6 @@
from pyadjoint.enlisting import Enlist
import firedrake
from firedrake.adjoint_utils.checkpointing import maybe_disk_checkpoint
from .block_utils import isconstant


def extract_subfunction(u, V):
Expand Down Expand Up @@ -74,8 +74,17 @@ def __init__(self, lhs, rhs, func, bcs, *args, **kwargs):
for bc in self.bcs:
self.add_dependency(bc, no_duplicates=True)

mesh = self.lhs.ufl_domain()
self.add_dependency(mesh)
try: # add all meshes as dependency
for mesh in extract_domains(self.lhs):
self.add_dependency(mesh, no_duplicates=True)
Comment thread
connorjward marked this conversation as resolved.
except AttributeError:
pass

if isinstance(self.rhs, ufl.BaseForm):
# add all meshes as dependency
for mesh in extract_domains(self.rhs):
self.add_dependency(mesh, no_duplicates=True)

self._init_solver_parameters(args, kwargs)

def _init_solver_parameters(self, args, kwargs):
Expand Down Expand Up @@ -242,12 +251,7 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
c = block_variable.output
c_rep = block_variable.saved_output

if isconstant(c):
mesh = F_form.ufl_domain()
trial_function = firedrake.TrialFunction(
c._ad_function_space(mesh)
)
elif isinstance(c, (firedrake.Function, firedrake.Cofunction)):
if isinstance(c, (firedrake.Function, firedrake.Cofunction)):
trial_function = firedrake.TrialFunction(c.function_space())
elif isinstance(c, firedrake.DirichletBC):
tmp_bc = c.reconstruct(
Expand Down Expand Up @@ -454,10 +458,7 @@ def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs,
)
return [tmp_bc]

if isconstant(c_rep):
mesh = F_form.ufl_domain()
W = c._ad_function_space(mesh)
elif isinstance(c, firedrake.MeshGeometry):
if isinstance(c, firedrake.MeshGeometry):
X = firedrake.SpatialCoordinate(c)
W = c._ad_function_space()
else:
Expand Down Expand Up @@ -759,7 +760,10 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
if isinstance(c, (firedrake.Function, firedrake.Cofunction)):
trial_function = firedrake.TrialFunction(c.function_space())
elif isinstance(c, firedrake.Constant):
mesh = F_form.ufl_domain()
try:
mesh = extract_unique_domain(F_form)
except ValueError:
raise ValueError("Expecting a single mesh")
trial_function = firedrake.TrialFunction(
c._ad_function_space(mesh)
)
Expand Down
76 changes: 76 additions & 0 deletions tests/firedrake/adjoint/test_solving.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,82 @@ def test_two_nonlinear_solves():
assert rf.tape.recompute_count == 5


@pytest.mark.skipcomplex
def test_multiple_meshes(rg):
mesh1 = UnitSquareMesh(4, 4)
mesh2 = RectangleMesh(nx=4, ny=4, Lx=3, Ly=1, originX=2, originY=0)

V1 = FunctionSpace(mesh1, "CG", 1)
V2 = FunctionSpace(mesh2, "CG", 1)
V = V1*V2

u = Function(V)
u1, u2 = split(u)
v1, v2 = TestFunctions(V)

f = Function(V).assign(10.)
f1, f2 = split(f)

a = inner(grad(u1), grad(v1))*dx(mesh1) + inner(grad(u2), grad(v2))*dx(mesh2)
L = inner(f1, v1)*dx(mesh1) + inner(f2, v2)*dx(mesh2)

bc1 = DirichletBC(V.sub(0), 0, "on_boundary")
bc2 = DirichletBC(V.sub(1), 0, "on_boundary")
bcs = [bc1, bc2]

solve(a - L == 0, u, bcs)

J = assemble(u1**4*dx(mesh1) + u2**4*dx(mesh2))
rf = ReducedFunctional(J, Control(f))
df = rg.uniform(V)

taylor = taylor_to_dict(rf, f, df)

assert min(taylor['R0']['Rate']) > 0.95, taylor['R0']
assert min(taylor['R1']['Rate']) > 1.95, taylor['R1']
assert min(taylor['R2']['Rate']) > 2.95, taylor['R2']


@pytest.mark.skipcomplex
def test_submesh(rg):
mesh = UnitSquareMesh(4, 4)
x, y = SpatialCoordinate(mesh)

DG = FunctionSpace(mesh, "DG", 0)
ind = Function(DG).interpolate(conditional(y > 0.5, 1, 0))
relabeled_mesh = RelabeledMesh(mesh, [ind], [10])
submesh = Submesh(relabeled_mesh, 2, 10)
dx_sub = Measure("dx", domain=submesh, intersect_measures=(Measure("dx", relabeled_mesh),))

V1 = FunctionSpace(relabeled_mesh, "CG", 1)
V2 = FunctionSpace(submesh, "CG", 1)
V = V1*V2

u = Function(V)
u1, u2 = split(u)
v1, v2 = TestFunctions(V)

f = Function(V1).assign(10.)

a = inner(grad(u1), grad(v1))*dx(relabeled_mesh)
a += inner(u1 - u2, v2)*dx_sub
L = inner(f, v1)*dx(relabeled_mesh)

bcs = [DirichletBC(V.sub(0), 0, "on_boundary")]

solve(a - L == 0, u, bcs)

J = assemble(u2**4*dx_sub)
rf = ReducedFunctional(J, Control(f))
df = rg.uniform(V1)

taylor = taylor_to_dict(rf, f, df)

assert min(taylor['R0']['Rate']) > 0.95, taylor['R0']
assert min(taylor['R1']['Rate']) > 1.95, taylor['R1']
assert min(taylor['R2']['Rate']) > 2.95, taylor['R2']


def convergence_rates(E_values, eps_values):
from numpy import log
r = []
Expand Down
Loading