Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ defaults:

jobs:
test:
name: Test on ${{ matrix.os }} (py=${{ matrix.python-version }}, gmx=${{ matrix.gmx-version }})
name: Test on ${{ matrix.os }} (py=${{ matrix.python-version }}, gmx=${{ matrix.gmx-version }}, pymbar=${{ matrix.pymbar-version }})
runs-on: ${{ matrix.os }}
timeout-minutes: 120
strategy:
Expand All @@ -35,6 +35,9 @@ jobs:
- "2019"
- "2024"
- "2025"
pymbar-version:
- "3"
- "4"

exclude:
# a 2022 build doesn't seem to exist for mac
Expand Down Expand Up @@ -81,7 +84,7 @@ jobs:
with:
environment-file: devtools/conda-envs/test_env.yaml
create-args: >- # beware the >- instead of |, we don't split on newlines but on spaces
python=${{ matrix.python-version }} gromacs=${{ matrix.gmx-version }}.*=nompi_dblprec*
python=${{ matrix.python-version }} gromacs=${{ matrix.gmx-version }}.*=nompi_dblprec* pymbar=${{ matrix.pymbar-version }}.*

- name: Set up conda environment without gromacs
if: matrix.gmx-version == '5' || matrix.gmx-version == '4'
Expand All @@ -90,7 +93,7 @@ jobs:
# Use the legacy env for Python 3.9/3.10, full env otherwise.
environment-file: ${{ (matrix.python-version == '3.9' || matrix.python-version == '3.10') && 'devtools/conda-envs/test_env_py39.yaml' || 'devtools/conda-envs/test_env.yaml' }}
create-args: >- # beware the >- instead of |, we don't split on newlines but on spaces
python=${{ matrix.python-version }}
python=${{ matrix.python-version }} pymbar=${{ matrix.pymbar-version }}.*

- name: Set up conda environment with single-precision gromacs (rejected versions)
if: matrix.gmx-version == '2022' || matrix.gmx-version == '2023'
Expand All @@ -101,13 +104,13 @@ jobs:
# Use the legacy env for Python 3.9/3.10, full env otherwise.
environment-file: ${{ (matrix.python-version == '3.9' || matrix.python-version == '3.10') && 'devtools/conda-envs/test_env_py39.yaml' || 'devtools/conda-envs/test_env.yaml' }}
create-args: >-
python=${{ matrix.python-version }} gromacs=${{ matrix.gmx-version }}
python=${{ matrix.python-version }} gromacs=${{ matrix.gmx-version }} pymbar=${{ matrix.pymbar-version }}.*

- name: Pin setuptools for Python <= 3.10
if: matrix.python-version == '3.9' || matrix.python-version == '3.10'
shell: bash -l {0}
run: pip install "setuptools<82"

- name: Additional info about the build
shell: bash
run: |
Expand Down
2 changes: 1 addition & 1 deletion devtools/conda-envs/test_env.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ dependencies:
- zlib
- swig
- future
- pymbar =3
- pymbar
- openmm >= 8
- ambertools
- ndcctools
Expand Down
2 changes: 1 addition & 1 deletion devtools/conda-envs/test_env_py39.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ dependencies:
- zlib
- swig
- future
- pymbar =3
- pymbar
- openmm >= 8
# ambertools has no Python 3.9 builds on conda-forge
- ndcctools
Expand Down
5 changes: 2 additions & 3 deletions src/lipid.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
try:
from lxml import etree
except: pass
from pymbar import pymbar
from forcebalance.liquid import _mbar_weights
import itertools
from collections import defaultdict, namedtuple, OrderedDict
import csv
Expand Down Expand Up @@ -663,8 +663,7 @@ def get(self, mvals, AGrad=True, AHess=True):
W1 = None
if len(BPoints) > 1:
logger.info("Running MBAR analysis on %i states...\n" % len(BPoints))
mbar = pymbar.MBAR(U_kln, N_k, verbose=mbar_verbose, relative_tolerance=5.0e-8)
W1 = mbar.getWeights()
W1 = _mbar_weights(U_kln, N_k, verbose=mbar_verbose)
logger.info("Done\n")
elif len(BPoints) == 1:
W1 = np.ones((BPoints*Shots,BPoints))
Expand Down
19 changes: 14 additions & 5 deletions src/liquid.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,18 @@
try:
from lxml import etree
except: pass
from pymbar import pymbar
try:
from pymbar import pymbar # pymbar 3: MBAR lives in pymbar.pymbar submodule
_MBAR_SOLVER_KW = {} # v3: self-consistent-iteration default is fine
except ImportError:
import pymbar # pymbar 4: MBAR lives at pymbar top-level
_MBAR_SOLVER_KW = {'solver_protocol': 'robust'} # v4: default hybr diverges on some data

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

v3: used the adaptive solver by default. (https://pymbar.readthedocs.io/en/3.1.1/mbar.html#pymbar.MBAR)
v4: 'robust' uses adaptive, followed by L-BFGS (https://deepwiki.com/choderalab/pymbar/2.1-mbar-implementation) .

I'm not sure why, but plain swapping this out for the plain adaptive protocol results in tests failing, so maybe the fallback to L-BFGS is necessary. @mattwthompson I saw you managed to swap it out for adaptive in your PR, did you notice any differences?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately I don't recall, I was probably trying each option until one worked without really understanding the underlying behavior


def _mbar_weights(U_kln, N_k, verbose=False):
"""Run MBAR and return weight matrix W[n, m] (shape (sum(N_k), K))."""
mbar = pymbar.MBAR(U_kln, N_k, verbose=verbose, relative_tolerance=5.0e-8, **_MBAR_SOLVER_KW)
return mbar.weights() if hasattr(mbar, 'weights') else mbar.getWeights()

import itertools
from forcebalance.optimizer import Counter
from collections import defaultdict, namedtuple, OrderedDict
Expand Down Expand Up @@ -894,8 +905,7 @@ def get_normal(self, mvals, AGrad=True, AHess=True):
W1 = None
if len(BPoints) > 1:
logger.info("Running MBAR analysis on %i states...\n" % len(BPoints))
mbar = pymbar.MBAR(U_kln, N_k, verbose=mbar_verbose, relative_tolerance=5.0e-8)
W1 = mbar.getWeights()
W1 = _mbar_weights(U_kln, N_k, verbose=mbar_verbose)
logger.info("Done\n")
elif len(BPoints) == 1:
W1 = np.ones((Shots,1))
Expand Down Expand Up @@ -935,8 +945,7 @@ def fill_weights(weights, phase_points, mbar_points, snapshots):
mU_kln[k, m, :] = mE[mE_idx]
mU_kln[k, m, :] *= beta
if np.abs(np.std(mE)) > 1e-6 and mBSims > 1:
mmbar = pymbar.MBAR(mU_kln, mN_k, verbose=False, relative_tolerance=5.0e-8, method='self-consistent-iteration')
mW1 = mmbar.getWeights()
mW1 = _mbar_weights(mU_kln, mN_k)
elif len(mBPoints) == 1:
mW1 = np.ones((mShots,1))
mW1 /= mShots
Expand Down
Binary file added src/tests/files/test_liquid/N_k.npy
Binary file not shown.
Binary file added src/tests/files/test_liquid/U_kln.npy
Binary file not shown.
Binary file added src/tests/files/test_liquid/mbar_weights_ref.npy
Binary file not shown.
50 changes: 50 additions & 0 deletions src/tests/test_liquid.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,63 @@
import os
import sys
import shutil
import numpy as np
import pytest
from forcebalance.parser import parse_inputs
from forcebalance.forcefield import FF
from forcebalance.objective import Objective
from forcebalance.optimizer import Optimizer
from forcebalance.liquid import _mbar_weights
from .__init__ import ForceBalanceTestCase, check_for_openmm

FIXTURE_DIR = os.path.join(os.path.dirname(__file__), 'files', 'test_liquid')

# U_kln.npy — shape (6, 6, 801), float64
# Reduced potential energy matrix U_kln[k, m, n] = (E_k[n] + P_m * V_k[n] * pvkj) * beta_m
# where k = source simulation index, m = evaluation state index, n = snapshot index.
# Built from the six npt_result.p pickles in files/test_liquid/single.tmp/Liquid/iter_0000/
# (water at 249.15 K/1 atm, 273.15 K/1 atm, 298.15 K/1 atm, 373.15 K/1 atm,
# 298.15 K/20 bar, 298.15 K/2000 bar; 801 snapshots each).
# Regenerate with: conda run -n <env> python tools/mbar_mre.py --save-ref
U_KLN_PATH = os.path.join(FIXTURE_DIR, 'U_kln.npy')

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From here downwards: added tests comparing direct pymbar execution. The script referenced here is attached, not sure if I should add it to the repo or if it's clutter.

mbar_mre.py


# N_k.npy — shape (6,), int64, all entries = 801
# Number of uncorrelated snapshots per simulation state; matches the first axis of U_kln.
N_K_PATH = os.path.join(FIXTURE_DIR, 'N_k.npy')

# mbar_weights_ref.npy — shape (4806, 6), float64
# MBAR weight matrix W[n, m] produced by pymbar 3.0.5 on U_kln above.
# Used as the reference for cross-version agreement checks (atol=1e-4).
# Regenerate with: conda run -n <env> python tools/mbar_mre.py --save-ref
REF_PATH = os.path.join(FIXTURE_DIR, 'mbar_weights_ref.npy')


@pytest.fixture(scope='module')
def mbar_weights():
if not os.path.exists(U_KLN_PATH) or not os.path.exists(N_K_PATH):
pytest.skip(f"MBAR fixtures not found in {FIXTURE_DIR}; run tools/mbar_mre.py --save-ref")
return _mbar_weights(np.load(U_KLN_PATH), np.load(N_K_PATH))


def test_mbar_weights_normalized(mbar_weights):
"""MBAR weight matrix columns must each sum to 1 (pymbar v3 and v4)."""
col_sums = mbar_weights.sum(axis=0)
assert np.allclose(col_sums, 1.0, atol=1e-6), (
f"MBAR weight columns do not sum to 1: {col_sums}"
)


def test_mbar_weights_match_reference(mbar_weights):
"""MBAR weights must agree with the pymbar-v3 reference within 1e-4."""
if not os.path.exists(REF_PATH):
pytest.skip(f"Reference weights not found at {REF_PATH}; run tools/mbar_mre.py --save-ref")
W_ref = np.load(REF_PATH)
assert mbar_weights.shape == W_ref.shape, (
f"Weight matrix shape mismatch: got {mbar_weights.shape}, expected {W_ref.shape}"
)
np.testing.assert_allclose(mbar_weights, W_ref, atol=1e-7, rtol=1e-2,
err_msg="MBAR weights differ from v3 reference beyond rtol=1e-2")

class TestWaterTutorial(ForceBalanceTestCase):
def setup_method(self, method):
if not check_for_openmm(): pytest.skip("No OpenMM modules found.")
Expand Down
Loading