Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
selecting between the `wiggle150` and `folmsbee` reference datasets, and
refactor `run_model` to run a single batched inference call across all
conformers.
- Record JAX peak device memory (`peak_bytes_in_use`) per structure in the
scaling benchmark. Surfaced as `ScalingStructureResult.peak_memory_bytes`
and as a "Peak memory vs system size" chart on the scaling UI page. Falls
through to `None` on backends that do not expose `memory_stats()`.

## Release 0.1.2

Expand Down
27 changes: 10 additions & 17 deletions docs/source/benchmarks/general/scaling.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,23 +26,16 @@ optimization strategies for large-scale simulations.
Dataset
-------

The scaling dataset is composed of a series of protein structures, RNA fragments,
peptides and small-molecules experimental structures taken from the `PDB <https://www.rcsb.org/>`_ databank.
They have the following ids:

* 1JRS
* 1AY3
* 1UAO
* 1P79
* 5KGZ
* 7CI3
* 1AB7
* 1BIP
* 1A5E
* 1A7M
* 2BQV
* 1J7H
* 1VSQ
The scaling dataset is a size-stratified set of protein chains taken from the
`PDB <https://www.rcsb.org/>`_. Chains were sourced from a PISCES cull list
(non-redundant at 25% sequence identity, resolution ≤ 2.0 Å, no chain breaks)
and a curated small-protein list, screened to charge-neutral sequences at pH 7:

* **2JOF** chain A — Trp-cage TC10b mini-protein (284 atoms)
* **1R0R** chain I — turkey ovomucoid third domain (OMTKY3) (748 atoms)
* **3TXS** chain A — bacteriophage 44RR small terminase gp16 (1513 atoms)
* **4QMD** chain A — human envoplakin plakin-repeat domain (3018 atoms)
* **6U1V** chain A — TcsD acyl-CoA dehydrogenase from FK506 biosynthesis (5964 atoms)

Interpretation
--------------
Expand Down
75 changes: 48 additions & 27 deletions src/mlipaudit/benchmarks/scaling/scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,16 @@
from pathlib import Path
from typing import Any

import jax
from ase.io import read as ase_read
from mlip.simulation import SimulationState
from pydantic import BaseModel, ConfigDict, NonNegativeFloat, PositiveInt
from pydantic import (
BaseModel,
ConfigDict,
NonNegativeFloat,
NonNegativeInt,
PositiveInt,
)

from mlipaudit.benchmark import (
DEFAULT_CHARGE,
Expand All @@ -47,31 +54,26 @@
}
NUM_DEV_SYSTEMS = 2

# Total charge per structure (keyed by xyz file stem). The structure set is
# discovered at runtime from the data directory; any structure not listed here
# falls back to a neutral charge of 0. All systems are treated as closed-shell
# singlets (spin multiplicity = 1). Values are sequence-based estimates from
# the FASTA at pH 7 (K+R counted +1, D+E counted -1, HIS neutral, termini
# cancel per chain); cross-checked against the observed electron parity.
STRUCTURE_CHARGES: dict[str, float] = {
"71_1jrs_leupeptin": 1.0,
"121_1ay3": -1.0,
"138_1uao_chignolin": -2.0,
"168_1p79_RNA": -4.0,
"634_5kgz": -1.0,
"1061_7ci3": -1.0,
"1432_1ab7": -6.0,
"1818_1bip": 2.0,
"2301_1a5e": -5.0,
"2803_1a7m": 7.0,
"3346_2bqv": 2.0,
"5990_1j7h_atoms_removed": -3.0,
"6713_1vsq": -7.0,
}

logger = logging.getLogger("mlipaudit")


def _peak_device_bytes() -> int | None:
"""Return the JAX backend's peak device memory in bytes since process start.

Reads `peak_bytes_in_use` from the default device's `memory_stats()`. Returns
`None` when the active backend does not report memory stats (e.g. the JAX CPU
backend on some platforms), so callers can fall through without special-casing.
"""
try:
stats = jax.devices()[0].memory_stats()
except Exception:
return None
if not stats:
return None
peak = stats.get("peak_bytes_in_use")
return int(peak) if peak is not None else None


class ScalingModelOutput(ModelOutput):
"""Model output for the scaling benchmark.

Expand All @@ -84,11 +86,18 @@ class ScalingModelOutput(ModelOutput):
for each corresponding structure, excluding the first
episode to ignore the compilation time. None if the
simulation failed.
peak_memory_bytes: A list of peak device-memory readings (in bytes) taken
after each structure's simulation. The reading is the cumulative
JAX `peak_bytes_in_use` since process start, so for the size-sorted
structure sequence the value also bounds the per-system peak from
above. `None` if the simulation failed before producing a reading,
or if the backend does not expose memory stats.
"""

structure_names: list[str]
simulation_states: list[SimulationState | None]
average_episode_times: list[float | None]
peak_memory_bytes: list[int | None]

model_config = ConfigDict(arbitrary_types_allowed=True)

Expand All @@ -105,6 +114,12 @@ class ScalingStructureResult(BaseModel):
excluding the first episode to ignore the compilation time.
average_step_time: The average step time of the simulation,
excluding the first episode to ignore the compilation time.
peak_memory_bytes: The JAX backend's `peak_bytes_in_use` after the
simulation completed (or after it failed, when a reading was still
obtainable). Cumulative since process start, so plotting against
`num_atoms` for the size-sorted structure list traces the
high-water mark as a function of system size. `None` when the
active JAX backend does not expose memory stats.
failed: Whether the simulation failed.
"""

Expand All @@ -114,6 +129,7 @@ class ScalingStructureResult(BaseModel):
num_episodes: PositiveInt
average_episode_time: NonNegativeFloat | None = None
average_step_time: NonNegativeFloat | None = None
peak_memory_bytes: NonNegativeInt | None = None

failed: bool = False

Expand Down Expand Up @@ -205,7 +221,7 @@ class ScalingBenchmark(Benchmark):
result_class = ScalingResult
model_output_class = ScalingModelOutput

required_elements = {"N", "H", "O", "S", "P", "C"}
required_elements = {"N", "H", "O", "S", "C"}

def run_model(self) -> None:
"""Runs a short MD simulation for each structure, timing each
Expand All @@ -214,15 +230,14 @@ def run_model(self) -> None:
"""
simulation_states: list[SimulationState | None] = []
average_episode_times: list[float | None] = []
peak_memory_bytes: list[int | None] = []
for structure_name in self._structure_names:
try:
timer = Timer()
atoms = ase_read(
self.data_input_dir / self.name / f"{structure_name}.xyz"
)
atoms.info["charge"] = float(
STRUCTURE_CHARGES.get(structure_name, DEFAULT_CHARGE)
)
atoms.info["charge"] = DEFAULT_CHARGE
atoms.info["spin"] = DEFAULT_SPIN
md_engine = get_simulation_engine(
atoms=atoms,
Expand All @@ -242,11 +257,14 @@ def run_model(self) -> None:
)
simulation_states.append(None)
average_episode_times.append(None)
finally:
peak_memory_bytes.append(_peak_device_bytes())

self.model_output = ScalingModelOutput(
structure_names=self._structure_names,
simulation_states=simulation_states,
average_episode_times=average_episode_times,
peak_memory_bytes=peak_memory_bytes,
)

def analyze(self) -> ScalingResult:
Expand All @@ -263,13 +281,15 @@ def analyze(self) -> ScalingResult:

structure_results = []
for i, structure_name in enumerate(self._structure_names):
peak_memory = self.model_output.peak_memory_bytes[i]
if self.model_output.average_episode_times[i] is None:
structure_results.append(
ScalingStructureResult(
structure_name=structure_name,
num_atoms=get_molecule_size_from_name(structure_name),
num_steps=self._md_kwargs["num_steps"],
num_episodes=self._md_kwargs["num_episodes"],
peak_memory_bytes=peak_memory,
failed=True,
)
)
Expand All @@ -288,6 +308,7 @@ def analyze(self) -> ScalingResult:
num_episodes=self._md_kwargs["num_episodes"],
average_episode_time=average_episode_time,
average_step_time=average_step_time,
peak_memory_bytes=peak_memory,
)
)

Expand Down
48 changes: 48 additions & 0 deletions src/mlipaudit/ui/scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ def _process_data_into_dataframe(
if structure_result.failed:
continue

peak_memory_mb = (
structure_result.peak_memory_bytes / (1024**2)
if structure_result.peak_memory_bytes is not None
else None
)
df_data.append({
"Model name": model_name,
"Structure": structure_result.structure_name,
Expand All @@ -49,6 +54,7 @@ def _process_data_into_dataframe(
"Num steps": structure_result.num_steps,
"Num episodes": structure_result.num_episodes,
"Average step time (s)": structure_result.average_step_time,
"Peak memory (MB)": peak_memory_mb,
})
return pd.DataFrame(df_data)

Expand Down Expand Up @@ -100,6 +106,40 @@ def plot_all_models_performance(df: pd.DataFrame) -> alt.Chart:
return chart


def plot_all_models_memory(df: pd.DataFrame) -> alt.Chart | None:
"""Plot peak device memory vs system size for all models together.

Args:
df: The dataframe containing per-structure rows, including a
"Peak memory (MB)" column. Rows where the column is null are
dropped (the active JAX backend does not expose memory stats).

Returns:
The Altair chart, or None if no rows have memory data.
"""
df_mem = df.dropna(subset=["Peak memory (MB)"])
if df_mem.empty:
return None

base = alt.Chart(df_mem).encode(
x=alt.X("Num atoms:Q", title="System size (number of atoms)"),
y=alt.Y("Peak memory (MB):Q", title="Peak device memory (MB)"),
color=alt.Color(
"Model name:N", title="Model", legend=alt.Legend(title="Model")
),
tooltip=[
alt.Tooltip("Model name:N", title="Model"),
alt.Tooltip("Structure:N", title="Structure"),
alt.Tooltip("Num atoms:Q", title="Number of atoms"),
alt.Tooltip("Peak memory (MB):Q", title="Peak memory (MB)", format=".1f"),
],
)

chart = base.mark_point(size=60, opacity=0.7).properties(width=800, height=500)
st.altair_chart(chart, use_container_width=True)
return chart


def scaling_page(
data_func: Callable[[], BenchmarkResultForMultipleModels],
) -> None:
Expand Down Expand Up @@ -151,6 +191,14 @@ def scaling_page(

chart = plot_all_models_performance(df) # noqa: F841

st.markdown("## Peak device memory vs system size")
mem_chart = plot_all_models_memory(df) # noqa: F841
if mem_chart is None:
st.markdown(
"*No memory readings available — the active JAX backend does not "
"expose `memory_stats()` (e.g. CPU runs).*"
)


class ScalingPageWrapper(UIPageWrapper):
"""Page wrapper for scaling benchmark."""
Expand Down
Loading
Loading