-
Notifications
You must be signed in to change notification settings - Fork 75
Fix vsite fits #315
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix vsite fits #315
Changes from 71 commits
d901d44
38b5d02
6977614
9abc5b4
f210213
1a8db42
f84cf66
c0fcaa4
bc464d3
276aa2e
040c763
99f616b
1c34afa
b8c2d35
bdff323
7057664
5b648aa
94a5ab2
2f1f4ab
0d527b4
9abf676
c324a8b
c054c25
8de0b00
c947c34
3d769c8
ba27332
23dc054
c3595f2
b794c1b
ba45a74
5db7e84
d7c3f42
b5feb5e
000b73f
8fe3107
a582e7a
ab34285
f174c38
ed807d9
689f59d
f2156b7
9bfd725
d95bcb7
44db401
ed76618
cb95c39
34daa23
db21dc9
45b8e45
471b403
ed5840b
c79515a
02ccda0
e70857c
606534c
2e57aba
8c1d0e1
b8b6086
7d56b3e
b74a0a6
751443d
3ddee47
083d81f
167b67d
0e9a701
12309cd
36c3ba0
a770175
9ac70c8
c991c2a
8b19797
def6686
0c637fd
a0937d9
913a63d
536146b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,30 @@ | ||
| name: forcebalance-test | ||
| channels: | ||
| - conda-forge | ||
| - openeye | ||
| dependencies: | ||
| # Base depends | ||
| - python | ||
| - pip | ||
| # Testing | ||
| - pytest | ||
| - pytest-cov | ||
| - codecov | ||
| - numpy | ||
| - scipy | ||
| - lxml | ||
| - networkx | ||
| - zlib | ||
| - swig | ||
| - future | ||
| - pymbar =3 | ||
| - openmm >= 8 | ||
| # ambertools has no Python 3.9 builds on conda-forge | ||
| - ndcctools | ||
| - geometric | ||
| # - gromacs =2019.1 | ||
| # openff packages require Python >= 3.11; tests are skipped on 3.9 and 3.10 | ||
| # - openff-toolkit-base | ||
| # - openff-evaluator-base | ||
| # - openff-recharge | ||
| # - openeye-toolkits (Don't have a license file to use with GH Actions.) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,6 +16,7 @@ | |
| import numpy as np | ||
| from forcebalance.nifty import warn_once, printcool, printcool_dictionary | ||
| from forcebalance.output import getLogger | ||
| from forcebalance.smirnoffio import select_virtual_site_parameter | ||
| from forcebalance.target import Target | ||
|
|
||
| try: | ||
|
|
@@ -308,20 +309,40 @@ def _parameter_value_from_gradient_key(self, gradient_key): | |
| bool | ||
| Returns True if the parameter is a cosmetic one. | ||
| """ | ||
| # try: | ||
| # import openmm.unit as simtk_unit | ||
| # except ImportError: | ||
| # import simtk.unit as simtk_unit | ||
| from openff.units import unit as openff_unit | ||
|
|
||
|
|
||
| parameter_handler = self.FF.openff_forcefield.get_parameter_handler( | ||
| gradient_key.tag | ||
| ) | ||
| parameter = ( | ||
| parameter_handler if gradient_key.smirks is None | ||
| else parameter_handler.parameters[gradient_key.smirks] | ||
| ) | ||
|
|
||
| if gradient_key.smirks is None: | ||
| parameter = parameter_handler | ||
| elif gradient_key.tag != "VirtualSites": | ||
| parameter = parameter_handler.parameters[gradient_key.smirks] | ||
| else: | ||
| # VirtualSite parameters are not uniquely identifiable by SMIRKS alone. | ||
| # Require explicit type/name/match metadata in every VirtualSites key. | ||
| if gradient_key.virtual_site_type is None: | ||
| raise KeyError( | ||
| f"Gradient key {gradient_key} is missing required virtual_site_type" | ||
| ) | ||
| if gradient_key.virtual_site_name is None: | ||
| raise KeyError( | ||
| f"Gradient key {gradient_key} is missing required virtual_site_name" | ||
| ) | ||
| if gradient_key.virtual_site_match is None: | ||
| raise KeyError( | ||
| f"Gradient key {gradient_key} is missing required virtual_site_match" | ||
| ) | ||
|
|
||
| parameter = select_virtual_site_parameter( | ||
| parameters=parameter_handler.parameters, | ||
| smirks=gradient_key.smirks, | ||
| virtual_site_type=gradient_key.virtual_site_type, | ||
| virtual_site_name=gradient_key.virtual_site_name, | ||
| virtual_site_match=gradient_key.virtual_site_match, | ||
| error_context=f"gradient key {gradient_key}", | ||
| ) | ||
|
|
||
| attribute_split = re.split(r"(\d+)", gradient_key.attribute) | ||
| attribute_split = list(filter(None, attribute_split)) | ||
|
|
@@ -474,14 +495,27 @@ def submit_jobs(self, mvals, AGrad=True, AHess=True): | |
| string_key = field_list[0] | ||
| key_split = string_key.split("/") | ||
|
|
||
| virtual_site_kwargs = {} | ||
|
|
||
| if len(key_split) == 3 and key_split[0] == "": | ||
| parameter_tag = key_split[1].strip() | ||
| parameter_smirks = None | ||
| parameter_attribute = key_split[2].strip() | ||
| elif len(key_split) == 4: | ||
| elif len(key_split) >= 4: | ||
| parameter_tag = key_split[0].strip() | ||
| parameter_smirks = key_split[3].strip() | ||
| parameter_attribute = key_split[2].strip() | ||
|
|
||
| if parameter_tag == "VirtualSites": | ||
| # VirtualSites keys must include positional identity metadata: | ||
| # VirtualSites/<tag>/<attribute>/<smirks>/<type>/<name>/<match> | ||
| if len(key_split) != 7: | ||
| raise KeyError( | ||
| f"VirtualSites parameter key must include type/name/match: {string_key}" | ||
| ) | ||
| virtual_site_kwargs["virtual_site_type"] = key_split[4].strip() | ||
| virtual_site_kwargs["virtual_site_name"] = key_split[5].strip() | ||
| virtual_site_kwargs["virtual_site_match"] = key_split[6].strip() | ||
| else: | ||
| raise NotImplementedError() | ||
|
|
||
|
|
@@ -490,6 +524,7 @@ def submit_jobs(self, mvals, AGrad=True, AHess=True): | |
| tag=parameter_tag, | ||
| smirks=parameter_smirks, | ||
| attribute=parameter_attribute, | ||
| **virtual_site_kwargs, | ||
| ) | ||
|
|
||
| # Find the unit of the gradient parameter. | ||
|
|
@@ -525,9 +560,7 @@ def submit_jobs(self, mvals, AGrad=True, AHess=True): | |
| ) | ||
|
|
||
| if ( | ||
| self._pending_estimate_request.results( | ||
| True, polling_interval=self._options.polling_interval | ||
| )[0] is None | ||
| self._pending_estimate_request.results(False)[0] is None | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Having this blocking request here originally meant that other targets had to wait until Evaluator was done computing to even start -- which can be a while ... |
||
| ): | ||
|
|
||
| raise RuntimeError( | ||
|
|
@@ -738,6 +771,19 @@ def get(self, mvals, AGrad=True, AHess=True): | |
| AGrad = bool(AGrad) | ||
| AHess = bool(AHess) | ||
|
|
||
| # Block until the Evaluator computation is finished. This is intentionally | ||
| # placed here (not in submit_jobs) so that Work Queue tasks submitted by other | ||
| # targets can run concurrently while we wait. | ||
| estimation_results, _ = self._pending_estimate_request.results( | ||
| True, polling_interval=self._options.polling_interval | ||
| ) | ||
| if estimation_results is None: | ||
| raise RuntimeError( | ||
| "No `EvaluatorServer` could be found to retrieve results from. " | ||
| "Please double check that a server is running, and that the connection " | ||
| "settings specified in the input script are correct." | ||
| ) | ||
|
|
||
|
Comment on lines
+774
to
+786
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I moved (and added the same runtimeerror check) this from the submit_jobs function above. |
||
| # Extract the properties estimated using the unperturbed parameters. | ||
| estimated_data_set, estimated_gradients = self._extract_property_data( | ||
| self._pending_estimate_request, mvals, AGrad | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -747,7 +747,7 @@ def addff_xml(self, ffname): | |
| res = re.search(r'^[-+]?[0-9]*\.?[0-9]*([eEdD][-+]?[0-9]+)?', quantity_str) | ||
| value_str, unit_str = quantity_str[:res.end()], quantity_str[res.end():] | ||
| # LPW 2023-01-23: Behavior of parameter unit string for "evaluated" parameter is undefined. | ||
| unit_str = "" | ||
| # unit_str = "" | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixes parameter_eval |
||
| quantity_str = e.get(parameter_name) | ||
| self.offxml_unit_strs[dest] = unit_str | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -270,6 +270,11 @@ def GetVirtualSiteParameters(system): | |
| vsprm.append(_openmm.OutOfPlaneSite_getWeight12(vs)) | ||
| vsprm.append(_openmm.OutOfPlaneSite_getWeight13(vs)) | ||
| vsprm.append(_openmm.OutOfPlaneSite_getWeightCross(vs)) | ||
| elif vstype == 'LocalCoordinatesSite': | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OpenFF vsites are convereted to localcoordinatesites |
||
| vsprm.extend(_openmm.LocalCoordinatesSite_getOriginWeights(vs)) | ||
| vsprm.extend(_openmm.LocalCoordinatesSite_getXWeights(vs)) | ||
| vsprm.extend(_openmm.LocalCoordinatesSite_getYWeights(vs)) | ||
| vsprm.extend(_openmm.LocalCoordinatesSite_getLocalPosition(vs)) | ||
| return np.array(vsprm) | ||
|
|
||
| def GetDrudeParameters(system): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We will aim to get this merged and into a release ASAP, but this branch needed to be developed in tandem with the Evaluator one for testing to work (hence the same name)