-
Notifications
You must be signed in to change notification settings - Fork 189
Adjoint + submesh #5081
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Adjoint + submesh #5081
Changes from 3 commits
4bbefa7
8b196b5
a2ae02e
734b4f0
cef0895
183a1a3
319b120
b65c230
eecbdd0
afebcca
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,6 @@ | ||
| import ufl | ||
| import firedrake | ||
| from ufl.domain import as_domain | ||
| from ufl.domain import extract_domains | ||
| from ufl.formatting.ufl2unicode import ufl2unicode | ||
| from pyadjoint import Block, AdjFloat, create_overloaded_object | ||
| from firedrake.adjoint_utils.checkpointing import maybe_disk_checkpoint | ||
|
|
@@ -11,14 +11,15 @@ class AssembleBlock(Block): | |
| def __init__(self, form, ad_block_tag=None): | ||
| super(AssembleBlock, self).__init__(ad_block_tag=ad_block_tag) | ||
| self.form = form | ||
| try: | ||
| mesh = as_domain(form) | ||
| try: # form can have multiple meshes | ||
| meshes = tuple(extract_domains(form)) | ||
| except AttributeError: | ||
| mesh = None | ||
| meshes = None | ||
|
|
||
| if mesh and not isinstance(self.form, ufl.Interpolate): | ||
| if meshes and not isinstance(self.form, ufl.Interpolate): | ||
| # Interpolation differentiation wrt spatial coordinates is currently not supported. | ||
| self.add_dependency(mesh) | ||
| for mesh in meshes: # add all meshes as dependency | ||
| self.add_dependency(mesh, no_duplicates=True) | ||
|
|
||
| for c in self.form.coefficients(): | ||
| self.add_dependency(c, no_duplicates=True) | ||
|
|
@@ -104,8 +105,7 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, | |
| arity_form = len(form.arguments()) | ||
|
|
||
| if isconstant(c): | ||
| mesh = as_domain(self.form) | ||
| space = c._ad_function_space(mesh) | ||
| space = c.function_space() | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How come this needs to change?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if this is a holder from previous Firedrake/FEniCS compatibility. But I have no idea really.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If it is then we should address that in a separate PR
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In which case someone should open an issue. This is a non-core dev submission and saying "leave that for another PR" makes it very likely to never be done.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's two questions here:
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can keep using |
||
| elif isinstance(c, (firedrake.Function, firedrake.Cofunction)): | ||
| space = c.function_space() | ||
| elif isinstance(c, firedrake.MeshGeometry): | ||
|
|
@@ -162,8 +162,7 @@ def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs, | |
| c1_rep = block_variable.saved_output | ||
|
|
||
| if isconstant(c1): | ||
| mesh = as_domain(form) | ||
| space = c1._ad_function_space(mesh) | ||
| space = c1.function_space() | ||
| elif isinstance(c1, (firedrake.Function, firedrake.Cofunction)): | ||
| space = c1.function_space() | ||
| elif isinstance(c1, firedrake.MeshGeometry): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,6 @@ | ||
| import numpy | ||
| import ufl | ||
| from ufl.domain import extract_domains | ||
| from ufl import replace | ||
| from ufl.formatting.ufl2unicode import ufl2unicode | ||
| from enum import Enum | ||
|
|
@@ -74,8 +75,17 @@ def __init__(self, lhs, rhs, func, bcs, *args, **kwargs): | |
| for bc in self.bcs: | ||
| self.add_dependency(bc, no_duplicates=True) | ||
|
|
||
| mesh = self.lhs.ufl_domain() | ||
| self.add_dependency(mesh) | ||
| try: # add all meshes as dependency | ||
| for mesh in tuple(extract_domains(self.lhs)): | ||
|
KarsKnook marked this conversation as resolved.
Outdated
|
||
| self.add_dependency(mesh, no_duplicates=True) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment about
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Some test fails if you don't include no_duplicates. I don't understand why we wouldn't want to use no_duplicates for meshes if we use it for each other dependency |
||
| except AttributeError: | ||
| pass | ||
|
|
||
| if isinstance(self.rhs, (ufl.Form, ufl.Cofunction)): | ||
|
KarsKnook marked this conversation as resolved.
|
||
| # add all meshes as dependency | ||
| for mesh in tuple(extract_domains(self.rhs)): | ||
| self.add_dependency(mesh, no_duplicates=True) | ||
|
|
||
| self._init_solver_parameters(args, kwargs) | ||
|
|
||
| def _init_solver_parameters(self, args, kwargs): | ||
|
|
@@ -243,9 +253,8 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, | |
| c_rep = block_variable.saved_output | ||
|
|
||
| if isconstant(c): | ||
| mesh = F_form.ufl_domain() | ||
| trial_function = firedrake.TrialFunction( | ||
| c._ad_function_space(mesh) | ||
| c.function_space() | ||
|
connorjward marked this conversation as resolved.
|
||
| ) | ||
| elif isinstance(c, (firedrake.Function, firedrake.Cofunction)): | ||
| trial_function = firedrake.TrialFunction(c.function_space()) | ||
|
|
@@ -455,8 +464,7 @@ def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs, | |
| return [tmp_bc] | ||
|
|
||
| if isconstant(c_rep): | ||
| mesh = F_form.ufl_domain() | ||
| W = c._ad_function_space(mesh) | ||
| W = c_rep.function_space() | ||
| elif isinstance(c, firedrake.MeshGeometry): | ||
| X = firedrake.SpatialCoordinate(c) | ||
| W = c._ad_function_space() | ||
|
|
@@ -759,7 +767,7 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, | |
| if isinstance(c, (firedrake.Function, firedrake.Cofunction)): | ||
| trial_function = firedrake.TrialFunction(c.function_space()) | ||
| elif isinstance(c, firedrake.Constant): | ||
| mesh = F_form.ufl_domain() | ||
| mesh = extract_domains(F_form)[0] # I don't like this | ||
|
KarsKnook marked this conversation as resolved.
Outdated
|
||
| trial_function = firedrake.TrialFunction( | ||
| c._ad_function_space(mesh) | ||
| ) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.