Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 9 additions & 10 deletions firedrake/adjoint_utils/blocks/assembly.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import ufl
import firedrake
from ufl.domain import as_domain
from ufl.domain import extract_domains
from ufl.formatting.ufl2unicode import ufl2unicode
from pyadjoint import Block, AdjFloat, create_overloaded_object
from firedrake.adjoint_utils.checkpointing import maybe_disk_checkpoint
Expand All @@ -11,14 +11,15 @@ class AssembleBlock(Block):
def __init__(self, form, ad_block_tag=None):
super(AssembleBlock, self).__init__(ad_block_tag=ad_block_tag)
self.form = form
try:
mesh = as_domain(form)
try: # form can have multiple meshes
meshes = tuple(extract_domains(form))
except AttributeError:
mesh = None
meshes = None

if mesh and not isinstance(self.form, ufl.Interpolate):
if meshes and not isinstance(self.form, ufl.Interpolate):
# Interpolation differentiation wrt spatial coordinates is currently not supported.
self.add_dependency(mesh)
for mesh in meshes: # add all meshes as dependency
self.add_dependency(mesh)

for c in self.form.coefficients():
self.add_dependency(c, no_duplicates=True)
Expand Down Expand Up @@ -104,8 +105,7 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
arity_form = len(form.arguments())

if isconstant(c):
mesh = as_domain(self.form)
space = c._ad_function_space(mesh)
space = c.function_space()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How come this needs to change? c.function_space() returns a Firedrake type but c._ad_function_space returns the ufl_function_space.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this is a holder from previous Firedrake/FEniCS compatibility. But I have no idea really.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it is then we should address that in a separate PR

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In which case someone should open an issue. This is a non-core dev submission and saying "leave that for another PR" makes it very likely to never be done.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's two questions here:

  1. Does it need changing for the fix in this PR to work? @KarsKnook does it work if you don't change how you get the function space? If not then you can revert back to the original.
  2. Is this a holdover that we could update? I don't know either, but I'm just saying that unless its essential for this PR then this isn't the place to start changing it, lets keep this one as simple as possible.

elif isinstance(c, (firedrake.Function, firedrake.Cofunction)):
space = c.function_space()
elif isinstance(c, firedrake.MeshGeometry):
Expand Down Expand Up @@ -162,8 +162,7 @@ def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs,
c1_rep = block_variable.saved_output

if isconstant(c1):
mesh = as_domain(form)
space = c1._ad_function_space(mesh)
space = c1.function_space()
elif isinstance(c1, (firedrake.Function, firedrake.Cofunction)):
space = c1.function_space()
elif isinstance(c1, firedrake.MeshGeometry):
Expand Down
25 changes: 18 additions & 7 deletions firedrake/adjoint_utils/blocks/solving.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy
import ufl
from ufl.domain import extract_domains, extract_unique_domain
from ufl import replace
from ufl.formatting.ufl2unicode import ufl2unicode
from enum import Enum
Expand Down Expand Up @@ -74,8 +75,17 @@ def __init__(self, lhs, rhs, func, bcs, *args, **kwargs):
for bc in self.bcs:
self.add_dependency(bc, no_duplicates=True)

mesh = self.lhs.ufl_domain()
self.add_dependency(mesh)
try: # add all meshes as dependency
for mesh in tuple(extract_domains(self.lhs)):
Comment thread
KarsKnook marked this conversation as resolved.
Outdated
self.add_dependency(mesh, no_duplicates=True)
except AttributeError:
pass

if isinstance(self.rhs, (ufl.Form, ufl.Cofunction)):
# add all meshes as dependency
for mesh in tuple(extract_domains(self.rhs)):
self.add_dependency(mesh, no_duplicates=True)

self._init_solver_parameters(args, kwargs)

def _init_solver_parameters(self, args, kwargs):
Expand Down Expand Up @@ -243,9 +253,8 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
c_rep = block_variable.saved_output

if isconstant(c):
mesh = F_form.ufl_domain()
trial_function = firedrake.TrialFunction(
c._ad_function_space(mesh)
c.function_space()
Comment thread
connorjward marked this conversation as resolved.
)
elif isinstance(c, (firedrake.Function, firedrake.Cofunction)):
trial_function = firedrake.TrialFunction(c.function_space())
Expand Down Expand Up @@ -455,8 +464,7 @@ def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs,
return [tmp_bc]

if isconstant(c_rep):
mesh = F_form.ufl_domain()
W = c._ad_function_space(mesh)
W = c_rep.function_space()
elif isinstance(c, firedrake.MeshGeometry):
X = firedrake.SpatialCoordinate(c)
W = c._ad_function_space()
Expand Down Expand Up @@ -759,7 +767,10 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
if isinstance(c, (firedrake.Function, firedrake.Cofunction)):
trial_function = firedrake.TrialFunction(c.function_space())
elif isinstance(c, firedrake.Constant):
mesh = F_form.ufl_domain()
try:
mesh = extract_unique_domain(F_form)
except ValueError:
raise ValueError("Expecting a single mesh")
trial_function = firedrake.TrialFunction(
c._ad_function_space(mesh)
)
Expand Down
Loading