From e4c3e9702869d1802c690dd68ba192f3dd0f20b0 Mon Sep 17 00:00:00 2001 From: Shadi Noghabi Date: Mon, 18 May 2026 16:26:47 -0700 Subject: [PATCH] shard perfetto trace writes and bound Timeline memory PiperOrigin-RevId: 917486585 --- tests/cli/perfetto_cat_test.py | 187 ++++++ tests/perf/experimental/export_v2_test.py | 45 +- tests/perf/experimental/timeline_test.py | 54 ++ tests/perf/experimental/trace_writer_test.py | 337 ++++++++++- tests/perf/metrics_test.py | 17 + tunix/cli/grpo_main.py | 1 + tunix/cli/perfetto_cat.py | 190 ++++++ tunix/perf/experimental/export.py | 13 +- tunix/perf/experimental/timeline.py | 40 +- tunix/perf/experimental/trace_writer.py | 597 ++++++++++++++----- tunix/perf/metrics.py | 12 + 11 files changed, 1321 insertions(+), 172 deletions(-) create mode 100644 tests/cli/perfetto_cat_test.py create mode 100644 tunix/cli/perfetto_cat.py diff --git a/tests/cli/perfetto_cat_test.py b/tests/cli/perfetto_cat_test.py new file mode 100644 index 000000000..ebc02505b --- /dev/null +++ b/tests/cli/perfetto_cat_test.py @@ -0,0 +1,187 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +import json +import os +import tempfile +from unittest import mock + +from absl.testing import absltest +from etils import epath +from tunix.cli import perfetto_cat + + +def _populate_trace_dir( + tmp_dir, + *, + sealed_shards=("trace.shard_0001.binpb", "trace.shard_0002.binpb"), + pending_content=b"PENDING_BYTES", + manifest=True, +): + """Writes a fixture trace directory with known byte payloads per file.""" + for name in sealed_shards: + epath.Path(os.path.join(tmp_dir, name)).write_bytes(name.encode()) + if pending_content is not None: + epath.Path(os.path.join(tmp_dir, "trace.shard_pending.binpb")).write_bytes( + pending_content + ) + if manifest: + epath.Path(os.path.join(tmp_dir, "trace.manifest.json")).write_text( + json.dumps( + { + "version": 1, + "shard_steps": 1, + "sealed_shards": list(sealed_shards), + "sealed_step_count": len(sealed_shards), + "pending_file": "trace.shard_pending.binpb", + } + ) + ) + + +class ListSealedShardsTest(absltest.TestCase): + + def test_uses_manifest_order_when_present(self): + with tempfile.TemporaryDirectory() as tmp_dir: + _populate_trace_dir( + tmp_dir, + sealed_shards=("trace.shard_0001.binpb", "trace.shard_0002.binpb"), + ) + shards = perfetto_cat.list_sealed_shards(epath.Path(tmp_dir)) + names = [p.name for p in shards] + self.assertEqual( + names, ["trace.shard_0001.binpb", "trace.shard_0002.binpb"] + ) + + def test_falls_back_to_glob_when_manifest_missing(self): + with tempfile.TemporaryDirectory() as tmp_dir: + _populate_trace_dir( + tmp_dir, + sealed_shards=( + "trace.shard_0002.binpb", + "trace.shard_0010.binpb", + "trace.shard_0001.binpb", + ), + manifest=False, + ) + shards = perfetto_cat.list_sealed_shards(epath.Path(tmp_dir)) + names = [p.name for p in shards] + self.assertEqual( + names, + [ + "trace.shard_0001.binpb", + "trace.shard_0002.binpb", + "trace.shard_0010.binpb", + ], + ) + + def test_falls_back_to_glob_when_manifest_lists_missing_shard(self): + """If the manifest references a shard that isn't on disk, glob fallback + must kick in rather than silently dropping the file from the output.""" + with tempfile.TemporaryDirectory() as tmp_dir: + _populate_trace_dir( + tmp_dir, + sealed_shards=("trace.shard_0001.binpb",), + ) + # Manifest claims a second shard that doesn't exist. + epath.Path(os.path.join(tmp_dir, "trace.manifest.json")).write_text( + json.dumps( + { + "version": 1, + "shard_steps": 1, + "sealed_shards": [ + "trace.shard_0001.binpb", + "trace.shard_0099.binpb", + ], + "sealed_step_count": 2, + "pending_file": "trace.shard_pending.binpb", + } + ) + ) + shards = perfetto_cat.list_sealed_shards(epath.Path(tmp_dir)) + # Only the on-disk shard should appear; manifest is treated as a hint. + self.assertEqual([p.name for p in shards], ["trace.shard_0001.binpb"]) + + +class ConcatTraceTest(absltest.TestCase): + + def test_concat_includes_pending_by_default(self): + with tempfile.TemporaryDirectory() as tmp_dir: + _populate_trace_dir(tmp_dir, pending_content=b"PENDING") + payload = perfetto_cat.concat_trace(epath.Path(tmp_dir)) + self.assertEqual( + payload, b"trace.shard_0001.binpb" + b"trace.shard_0002.binpb" + b"PENDING" + ) + + def test_concat_skips_pending_when_requested(self): + with tempfile.TemporaryDirectory() as tmp_dir: + _populate_trace_dir(tmp_dir, pending_content=b"PENDING") + payload = perfetto_cat.concat_trace( + epath.Path(tmp_dir), include_pending=False + ) + self.assertEqual( + payload, b"trace.shard_0001.binpb" + b"trace.shard_0002.binpb" + ) + + def test_concat_handles_missing_pending(self): + with tempfile.TemporaryDirectory() as tmp_dir: + _populate_trace_dir(tmp_dir, pending_content=None) + payload = perfetto_cat.concat_trace(epath.Path(tmp_dir)) + self.assertEqual( + payload, b"trace.shard_0001.binpb" + b"trace.shard_0002.binpb" + ) + + +class MainTest(absltest.TestCase): + + def test_main_writes_to_output_file(self): + with tempfile.TemporaryDirectory() as tmp_dir: + _populate_trace_dir(tmp_dir, pending_content=b"P") + out_path = os.path.join(tmp_dir, "combined.binpb") + rc = perfetto_cat.main([tmp_dir, "-o", out_path]) + self.assertEqual(rc, 0) + self.assertEqual( + epath.Path(out_path).read_bytes(), + b"trace.shard_0001.binpb" + b"trace.shard_0002.binpb" + b"P", + ) + + def test_main_writes_to_stdout(self): + with tempfile.TemporaryDirectory() as tmp_dir: + _populate_trace_dir(tmp_dir, pending_content=b"P") + fake_stdout = io.BytesIO() + fake_wrapper = mock.MagicMock() + fake_wrapper.buffer = fake_stdout + with mock.patch("sys.stdout", fake_wrapper): + rc = perfetto_cat.main([tmp_dir]) + self.assertEqual(rc, 0) + self.assertEqual( + fake_stdout.getvalue(), + b"trace.shard_0001.binpb" + b"trace.shard_0002.binpb" + b"P", + ) + + def test_main_missing_directory_returns_error(self): + with tempfile.TemporaryDirectory() as tmp_dir: + missing = os.path.join(tmp_dir, "does-not-exist") + rc = perfetto_cat.main([missing, "-o", "/dev/null"]) + self.assertEqual(rc, 1) + + def test_main_empty_directory_returns_error(self): + with tempfile.TemporaryDirectory() as tmp_dir: + rc = perfetto_cat.main([tmp_dir, "-o", os.path.join(tmp_dir, "out.bin")]) + self.assertEqual(rc, 1) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/perf/experimental/export_v2_test.py b/tests/perf/experimental/export_v2_test.py index 195a550bc..9951aedbe 100644 --- a/tests/perf/experimental/export_v2_test.py +++ b/tests/perf/experimental/export_v2_test.py @@ -27,7 +27,6 @@ class ExportTest(parameterized.TestCase): def test_perf_metrics_export(self): - # Backward compatibility check with tempfile.TemporaryDirectory() as tmp_dir: with export.PerfMetricsExport(trace_dir=tmp_dir) as exporter: # Create dummy timeline @@ -36,9 +35,14 @@ def test_perf_metrics_export(self): time.sleep(0.001) t.export() - files = os.listdir(tmp_dir) - self.assertLen(files, 1) - self.assertStartsWith(files[0], "perfetto_trace_v2_") + # With a single committed step and the default shard size, no shard + # has been sealed yet; the in-flight pending file holds the only data. + files = set(os.listdir(tmp_dir)) + self.assertIn("trace.shard_pending.binpb", files) + self.assertGreater( + os.path.getsize(os.path.join(tmp_dir, "trace.shard_pending.binpb")), + 0, + ) def test_basic_metrics_export(self): with self.assertLogs(level="INFO") as logs: @@ -76,7 +80,9 @@ def test_perf_metrics_export_initialization_with_trace_writer_enabled( with export.PerfMetricsExport( enable_trace_writer=True, trace_dir=trace_dir ) as exporter: - mock_writer_cls.assert_called_once_with(expected_dir, role_to_devices=None) + mock_writer_cls.assert_called_once_with( + expected_dir, role_to_devices=None, shard_steps=None + ) # export_metrics shouldn't crash exporter.export_metrics({}) @@ -148,10 +154,33 @@ def test_from_cluster_config(self, mock_writer_cls): "rollout": ["tpu4", "tpu5", "tpu6", "tpu7"], } mock_writer_cls.assert_called_once_with( - "/test/dir", role_to_devices=expected_role_to_devices + "/test/dir", + role_to_devices=expected_role_to_devices, + shard_steps=None, ) self.assertIs(exporter._writer, mock_writer_cls.return_value) + def test_from_cluster_config_passes_trace_shard_steps(self): + """``from_cluster_config`` must forward ``trace_shard_steps`` to the writer so per-run overrides reach the on-disk shard layout.""" + with mock.patch.object( + export.trace_writer_lib, + "PerfettoTraceWriter", + autospec=True, + ) as mock_writer_cls: + mock_cluster_config = mock.create_autospec( + rl_cluster.ClusterConfig, instance=True, spec_set=True + ) + del mock_cluster_config.role_to_mesh + export.PerfMetricsExport.from_cluster_config( + mock_cluster_config, + enable_trace_writer=True, + trace_dir="/test/dir", + trace_shard_steps=42, + ) + mock_writer_cls.assert_called_once_with( + "/test/dir", role_to_devices={}, shard_steps=42 + ) + def test_safe_write_exception(self): with export.PerfMetricsExport(enable_trace_writer=False) as exporter: with mock.patch.object(exporter, "_writer", autospec=True) as mock_writer: @@ -192,7 +221,9 @@ def test_from_cluster_config_no_role_to_mesh(self, mock_writer_cls): trace_dir="/test/dir", ) - mock_writer_cls.assert_called_once_with("/test/dir", role_to_devices={}) + mock_writer_cls.assert_called_once_with( + "/test/dir", role_to_devices={}, shard_steps=None + ) self.assertIs(exporter._writer, mock_writer_cls.return_value) if __name__ == "__main__": diff --git a/tests/perf/experimental/timeline_test.py b/tests/perf/experimental/timeline_test.py index e84a105f2..69043962d 100644 --- a/tests/perf/experimental/timeline_test.py +++ b/tests/perf/experimental/timeline_test.py @@ -134,6 +134,60 @@ def test_stop_span_error_cases(self): with self.assertRaisesRegex(ValueError, "ended at .* before it began"): t.stop_span(1.0) + def test_drop_oldest_committed_steps_basic(self): + t = timeline.Timeline("test_tl", 0.0) + span_ids = [] + for i in range(5): + s = t.start_span(f"span{i}", float(i)) + t.stop_span(float(i) + 0.1) + span_ids.append(s.id) + t.commit_step() + self.assertLen(t.committed_steps, 5) + + # Snapshot reference before drop to confirm copy-on-write semantics. + pre_drop_ref = t.committed_steps + + dropped = t.drop_oldest_committed_steps(2) + + with self.subTest("dropped_returned_oldest_first"): + self.assertLen(dropped, 2) + self.assertIn(span_ids[0], dropped[0]) + self.assertIn(span_ids[1], dropped[1]) + + with self.subTest("post_drop_state"): + self.assertLen(t.committed_steps, 3) + self.assertIn(span_ids[2], t.committed_steps[0]) + self.assertIn(span_ids[4], t.committed_steps[-1]) + + with self.subTest("copy_on_write_preserves_prior_snapshot"): + # The reference captured before the drop must remain unchanged so + # concurrent readers iterating the old snapshot are not affected. + self.assertLen(pre_drop_ref, 5) + + def test_drop_oldest_committed_steps_zero_is_noop(self): + t = timeline.Timeline("test_tl", 0.0) + t.start_span("s", 1.0) + t.stop_span(2.0) + t.commit_step() + dropped = t.drop_oldest_committed_steps(0) + self.assertEmpty(dropped) + self.assertLen(t.committed_steps, 1) + + def test_drop_oldest_committed_steps_more_than_held(self): + t = timeline.Timeline("test_tl", 0.0) + for _ in range(3): + t.start_span("s", 1.0) + t.stop_span(2.0) + t.commit_step() + dropped = t.drop_oldest_committed_steps(10) + self.assertLen(dropped, 3) + self.assertEmpty(t.committed_steps) + + def test_drop_oldest_committed_steps_negative_raises(self): + t = timeline.Timeline("test_tl", 0.0) + with self.assertRaisesRegex(ValueError, "n must be non-negative"): + t.drop_oldest_committed_steps(-1) + def test_nested_timeline_with_tags_repr(self): born = 1000.0 t = timeline.Timeline("test_tl", born) diff --git a/tests/perf/experimental/trace_writer_test.py b/tests/perf/experimental/trace_writer_test.py index e18058e93..ee2003566 100644 --- a/tests/perf/experimental/trace_writer_test.py +++ b/tests/perf/experimental/trace_writer_test.py @@ -13,6 +13,7 @@ # limitations under the License. +import json import os import tempfile from unittest import mock @@ -413,24 +414,24 @@ def test_perfetto_trace_writer_integration(self): writer.write_timelines(timelines) - # Check if file was created and has content - files = os.listdir(tmp_dir) - - with self.subTest("file_created"): - self.assertLen(files, 1) - - if files: - file_name = files[0] - with self.subTest("file_name_prefix"): - self.assertStartsWith(file_name, "perfetto_trace_v2_") - - with self.subTest("file_name_suffix"): - self.assertEndsWith(file_name, ".pb") - - with self.subTest("file_content"): - self.assertGreater( - os.path.getsize(os.path.join(tmp_dir, file_name)), 0 - ) + files = set(os.listdir(tmp_dir)) + + with self.subTest("pending_file_written_with_content"): + self.assertIn("trace.shard_pending.binpb", files) + self.assertGreater( + os.path.getsize(os.path.join(tmp_dir, "trace.shard_pending.binpb")), + 0, + ) + + # With the default shard size and a single committed step, no shard + # has been sealed yet -- only the pending file is present. + with self.subTest("no_sealed_shards_yet"): + sealed = [ + f + for f in files + if f.startswith("trace.shard_") and "pending" not in f + ] + self.assertEmpty(sealed) def test_perfetto_trace_writer_invalid_dir(self): # Use a file path as directory to cause failure @@ -463,6 +464,306 @@ def test_perfetto_trace_writer_timeline_with_empty_committed_steps(self): self.assertEmpty(files) +def _flush_steps(writer, timelines, num_steps, span_factory): + """Commits ``num_steps`` synchronized steps across the given timelines and + + flushes each one through ``writer.write_timelines`` so the writer's per-call + state advances naturally. + + Args: + writer: The trace writer under test. + timelines: A mapping of timeline IDs to Timeline objects. + num_steps: How many steps to commit and flush. + span_factory: Callable ``(step_index, tl_id) -> Iterable[(name, begin, + end)]`` describing the spans to add to each timeline for each step. + """ + for step_idx in range(num_steps): + for tl_id, tl in timelines.items(): + for name, begin, end in span_factory(step_idx, tl_id): + tl.start_span(name, begin) + tl.stop_span(end) + tl.commit_step() + writer.write_timelines(timelines) + + +class ShardedWriteTest(absltest.TestCase): + """End-to-end tests for the sharded write protocol. + + These tests exercise real ``write_bytes`` calls (no mocking of the proto + builder) so they exercise the actual seal/pending/manifest file outputs. + """ + + def test_seal_at_shard_boundary(self): + with tempfile.TemporaryDirectory() as tmp_dir: + writer = trace_writer_lib.PerfettoTraceWriter( + trace_dir=tmp_dir, shard_steps=3 + ) + t = tracer.Timeline("host-1", 0.0) + + def factory(step_idx, _tl_id): + return [(f"s{step_idx}", float(step_idx), float(step_idx) + 0.1)] + + # First two steps: no seal yet, only pending. + _flush_steps(writer, {"host-1": t}, num_steps=2, span_factory=factory) + files = set(os.listdir(tmp_dir)) + self.assertNotIn("trace.shard_0001.binpb", files) + self.assertIn("trace.shard_pending.binpb", files) + self.assertEmpty(writer.sealed_shards) + + # Third step crosses the boundary -- one shard is sealed. + _flush_steps(writer, {"host-1": t}, num_steps=1, span_factory=factory) + files = set(os.listdir(tmp_dir)) + with self.subTest("shard_0001_sealed"): + self.assertIn("trace.shard_0001.binpb", files) + self.assertEqual(writer.sealed_shards, ["trace.shard_0001.binpb"]) + with self.subTest("sealed_steps_freed_from_timeline_memory"): + self.assertEmpty(t.committed_steps) + with self.subTest("manifest_reflects_seal"): + manifest_path = os.path.join(tmp_dir, "trace.manifest.json") + with open(manifest_path) as f: + manifest = json.load(f) + self.assertEqual(manifest["version"], 1) + self.assertEqual(manifest["shard_steps"], 3) + self.assertEqual(manifest["sealed_step_count"], 3) + self.assertEqual( + manifest["sealed_shards"], ["trace.shard_0001.binpb"] + ) + + def test_multiple_seals_in_one_flush(self): + """A long pause between flushes can produce multiple shards in one call.""" + with tempfile.TemporaryDirectory() as tmp_dir: + writer = trace_writer_lib.PerfettoTraceWriter( + trace_dir=tmp_dir, shard_steps=2 + ) + t = tracer.Timeline("host-1", 0.0) + # Commit 5 steps without flushing in between. + for i in range(5): + t.start_span(f"s{i}", float(i)) + t.stop_span(float(i) + 0.1) + t.commit_step() + writer.write_timelines({"host-1": t}) + self.assertEqual( + writer.sealed_shards, + ["trace.shard_0001.binpb", "trace.shard_0002.binpb"], + ) + # Two shards * 2 steps = 4 sealed; one remains in pending. + self.assertLen(t.committed_steps, 1) + self.assertIn( + "trace.shard_pending.binpb", set(os.listdir(tmp_dir)) + ) + + def test_pending_removed_when_no_unsealed_data(self): + """After everything has been sealed, the pending file is removed so a + + naive ``cat trace.shard_*.binpb`` is a complete trace. + """ + with tempfile.TemporaryDirectory() as tmp_dir: + writer = trace_writer_lib.PerfettoTraceWriter( + trace_dir=tmp_dir, shard_steps=1 + ) + t = tracer.Timeline("host-1", 0.0) + for i in range(3): + t.start_span(f"s{i}", float(i)) + t.stop_span(float(i) + 0.1) + t.commit_step() + writer.write_timelines({"host-1": t}) + files = set(os.listdir(tmp_dir)) + self.assertNotIn("trace.shard_pending.binpb", files) + self.assertLen(writer.sealed_shards, 3) + + +class LaneAndUuidStabilityTest(absltest.TestCase): + """Lane indices and timeline UUIDs must be stable across shards so a + + concatenated trace shows consistent track layout. + """ + + def test_lane_count_only_grows_across_shards(self): + with tempfile.TemporaryDirectory() as tmp_dir: + writer = trace_writer_lib.PerfettoTraceWriter( + trace_dir=tmp_dir, shard_steps=1 + ) + t = tracer.Timeline("overlap_tl", 0.0) + + # Step 0: two overlapping spans -> 2 lanes. + t.start_span("a", 0.0) + t.start_span("b", 0.1) + t.stop_span(1.0) # ends 'b' + t.stop_span(1.1) # ends 'a' + t.commit_step() + writer.write_timelines({"overlap_tl": t}) + + with self.subTest("lanes_after_first_seal"): + self.assertLen(writer._lane_busy_until["overlap_tl"], 2) # pylint: disable=protected-access + + # Step 1: three overlapping spans -> grows to 3 lanes. + t.start_span("c", 2.0) + t.start_span("d", 2.05) + t.start_span("e", 2.1) + t.stop_span(3.0) + t.stop_span(3.1) + t.stop_span(3.2) + t.commit_step() + writer.write_timelines({"overlap_tl": t}) + + with self.subTest("lanes_after_second_seal_only_grow"): + self.assertLen(writer._lane_busy_until["overlap_tl"], 3) # pylint: disable=protected-access + + # Step 2: one span -> lane count stays at 3 (never shrinks). + t.start_span("f", 4.0) + t.stop_span(5.0) + t.commit_step() + writer.write_timelines({"overlap_tl": t}) + + with self.subTest("lanes_persist_at_max"): + self.assertLen(writer._lane_busy_until["overlap_tl"], 3) # pylint: disable=protected-access + + def test_timeline_uuids_stable_across_seals(self): + with tempfile.TemporaryDirectory() as tmp_dir: + writer = trace_writer_lib.PerfettoTraceWriter( + trace_dir=tmp_dir, shard_steps=1 + ) + timelines = { + "host-1": tracer.Timeline("host-1", 0.0), + "host-2": tracer.Timeline("host-2", 0.0), + } + + def factory(step_idx, _tl_id): + return [(f"s{step_idx}", float(step_idx), float(step_idx) + 0.1)] + + _flush_steps(writer, timelines, num_steps=1, span_factory=factory) + uuids_after_first_seal = dict(writer._timeline_uuids) # pylint: disable=protected-access + + _flush_steps(writer, timelines, num_steps=4, span_factory=factory) + uuids_after_more_seals = dict(writer._timeline_uuids) # pylint: disable=protected-access + + self.assertEqual(uuids_after_first_seal, uuids_after_more_seals) + + def test_sealed_shard_byte_content_does_not_change_after_subsequent_seals( + self, + ): + """An already-sealed shard file must never be rewritten.""" + with tempfile.TemporaryDirectory() as tmp_dir: + writer = trace_writer_lib.PerfettoTraceWriter( + trace_dir=tmp_dir, shard_steps=1 + ) + t = tracer.Timeline("host-1", 0.0) + + t.start_span("first", 0.0) + t.stop_span(0.1) + t.commit_step() + writer.write_timelines({"host-1": t}) + + shard1_path = os.path.join(tmp_dir, "trace.shard_0001.binpb") + first_bytes = open(shard1_path, "rb").read() + first_mtime = os.path.getmtime(shard1_path) + + # Generate more activity and additional seals. + for i in range(1, 5): + t.start_span(f"s{i}", float(i)) + t.stop_span(float(i) + 0.1) + t.commit_step() + writer.write_timelines({"host-1": t}) + + with self.subTest("shard_bytes_unchanged"): + self.assertEqual(open(shard1_path, "rb").read(), first_bytes) + with self.subTest("shard_mtime_unchanged"): + self.assertAlmostEqual( + os.path.getmtime(shard1_path), first_mtime, delta=0.001 + ) + + +class ShardStepsResolutionTest(absltest.TestCase): + + def test_env_var_overrides_arg(self): + with mock.patch.dict(os.environ, {"TUNIX_TRACE_SHARD_STEPS": "7"}): + with tempfile.TemporaryDirectory() as tmp_dir: + writer = trace_writer_lib.PerfettoTraceWriter( + trace_dir=tmp_dir, shard_steps=100 + ) + self.assertEqual(writer.shard_steps, 7) + + def test_arg_used_when_env_var_unset(self): + # Sanitize the env var even if the host has it set. + with mock.patch.dict(os.environ, {}, clear=False): + os.environ.pop("TUNIX_TRACE_SHARD_STEPS", None) + with tempfile.TemporaryDirectory() as tmp_dir: + writer = trace_writer_lib.PerfettoTraceWriter( + trace_dir=tmp_dir, shard_steps=42 + ) + self.assertEqual(writer.shard_steps, 42) + + def test_default_when_neither_set(self): + with mock.patch.dict(os.environ, {}, clear=False): + os.environ.pop("TUNIX_TRACE_SHARD_STEPS", None) + with tempfile.TemporaryDirectory() as tmp_dir: + writer = trace_writer_lib.PerfettoTraceWriter(trace_dir=tmp_dir) + self.assertEqual(writer.shard_steps, 100) + + def test_invalid_env_var_is_ignored(self): + with mock.patch.dict(os.environ, {"TUNIX_TRACE_SHARD_STEPS": "not-a-number"}): + with tempfile.TemporaryDirectory() as tmp_dir: + writer = trace_writer_lib.PerfettoTraceWriter( + trace_dir=tmp_dir, shard_steps=11 + ) + self.assertEqual(writer.shard_steps, 11) + + def test_zero_or_negative_env_var_is_ignored(self): + with mock.patch.dict(os.environ, {"TUNIX_TRACE_SHARD_STEPS": "0"}): + with tempfile.TemporaryDirectory() as tmp_dir: + writer = trace_writer_lib.PerfettoTraceWriter( + trace_dir=tmp_dir, shard_steps=11 + ) + self.assertEqual(writer.shard_steps, 11) + + def test_invalid_arg_raises(self): + with mock.patch.dict(os.environ, {}, clear=False): + os.environ.pop("TUNIX_TRACE_SHARD_STEPS", None) + with tempfile.TemporaryDirectory() as tmp_dir: + with self.assertRaisesRegex( + ValueError, "shard_steps must be a positive integer" + ): + trace_writer_lib.PerfettoTraceWriter( + trace_dir=tmp_dir, shard_steps=0 + ) + + +class MemoryBoundednessTest(absltest.TestCase): + + def test_timeline_memory_stays_bounded_across_many_steps(self): + """``Timeline._committed_steps`` must not grow unboundedly when the writer is actively sealing shards.""" + with tempfile.TemporaryDirectory() as tmp_dir: + shard_steps = 4 + writer = trace_writer_lib.PerfettoTraceWriter( + trace_dir=tmp_dir, shard_steps=shard_steps + ) + t = tracer.Timeline("host-1", 0.0) + for i in range(50): + t.start_span(f"s{i}", float(i)) + t.stop_span(float(i) + 0.1) + t.commit_step() + writer.write_timelines({"host-1": t}) + with self.subTest(f"after_step_{i}"): + # At most ``shard_steps - 1`` steps live in memory after each flush + # (anything reaching the boundary gets sealed and dropped). + self.assertLessEqual(len(t.committed_steps), shard_steps - 1) + + def test_drop_clears_lane_assignment_entries_for_dropped_spans(self): + with tempfile.TemporaryDirectory() as tmp_dir: + writer = trace_writer_lib.PerfettoTraceWriter( + trace_dir=tmp_dir, shard_steps=1 + ) + t = tracer.Timeline("host-1", 0.0) + for i in range(5): + t.start_span(f"s{i}", float(i)) + t.stop_span(float(i) + 0.1) + t.commit_step() + writer.write_timelines({"host-1": t}) + # All spans are sealed and dropped; the assignment cache should be + # empty (only pending spans remain, of which there are none). + self.assertEmpty(writer._lane_assignment.get("host-1", {})) # pylint: disable=protected-access + + class NoopTraceWriterTest(absltest.TestCase): def test_noop_trace_writer_write_timelines(self): diff --git a/tests/perf/metrics_test.py b/tests/perf/metrics_test.py index 83721bae5..4778fa44f 100644 --- a/tests/perf/metrics_test.py +++ b/tests/perf/metrics_test.py @@ -27,6 +27,23 @@ def test_perf_metrics_options_defaults(self): self.assertEqual(options.custom_export_fn_path_v2, "") self.assertTrue(options.enable_trace_writer) self.assertEqual(options.trace_dir, "") + self.assertEqual(options.trace_shard_steps, 100) + + def test_perf_metrics_options_trace_shard_steps_override(self): + options = metrics.PerfMetricsOptions(trace_shard_steps=25) + self.assertEqual(options.trace_shard_steps, 25) + + @parameterized.named_parameters( + dict(testcase_name="zero", trace_shard_steps=0), + dict(testcase_name="negative", trace_shard_steps=-1), + ) + def test_perf_metrics_options_trace_shard_steps_invalid( + self, trace_shard_steps + ): + with self.assertRaisesRegex( + ValueError, "trace_shard_steps must be a positive integer" + ): + metrics.PerfMetricsOptions(trace_shard_steps=trace_shard_steps) @parameterized.named_parameters( dict( diff --git a/tunix/cli/grpo_main.py b/tunix/cli/grpo_main.py index 6c651c5e7..b012e9b56 100644 --- a/tunix/cli/grpo_main.py +++ b/tunix/cli/grpo_main.py @@ -467,6 +467,7 @@ def create_perf_config(self, cluster_config: rl_cluster_lib.ClusterConfig): cluster_config=cluster_config, enable_trace_writer=perf_metrics_options.enable_trace_writer, trace_dir=perf_metrics_options.trace_dir, + trace_shard_steps=perf_metrics_options.trace_shard_steps, ).export_metrics ) return perf_config diff --git a/tunix/cli/perfetto_cat.py b/tunix/cli/perfetto_cat.py new file mode 100644 index 000000000..f22e5ae4c --- /dev/null +++ b/tunix/cli/perfetto_cat.py @@ -0,0 +1,190 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Concatenate sharded perfetto trace files into a single trace. + +Long-running jobs split their perfetto trace into a directory of sharded files +plus a single in-flight pending file:: + + trace.shard_0001.binpb + trace.shard_0002.binpb + ... + trace.shard_pending.binpb + trace.manifest.json + +Perfetto's ``TracePacket`` format is concatenable, so a complete trace is +``cat trace.shard_*.binpb trace.shard_pending.binpb``. This module provides a +small CLI wrapper that reads the manifest (for ordering) and emits a single +file -- handy when the trace directory lives on remote storage and you want +one local file to drop into https://ui.perfetto.dev. + +Usage:: + + python -m tunix.cli.perfetto_cat # writes to stdout + python -m tunix.cli.perfetto_cat -o trace.binpb # writes to file + python -m tunix.cli.perfetto_cat --no-pending # sealed only + +Remote paths supported by ``etils.epath`` (e.g. ``gs://...``) are accepted for +the input directory. +""" + +from __future__ import annotations + +import argparse +import json +import re +import sys + +from etils import epath + + +_MANIFEST_FILE = "trace.manifest.json" +_PENDING_FILE = "trace.shard_pending.binpb" +_SHARD_FILE_RE = re.compile(r"^trace\.shard_(\d{4,})\.binpb$") + + +def _shard_index(name: str) -> int | None: + """Returns the numeric shard index for a sealed-shard filename, or None.""" + m = _SHARD_FILE_RE.match(name) + return int(m.group(1)) if m else None + + +def list_sealed_shards(trace_dir: epath.Path) -> list[epath.Path]: + """Lists sealed shard files in deterministic concatenation order. + + Prefers the manifest's sealed-shard list when one is present. Falls back to + a glob-based listing sorted by the numeric shard index, which is the same + ordering the writer produces. This lets the CLI work even on a directory + whose manifest is missing or corrupt. + + Args: + trace_dir: Directory containing the sharded trace files. + + Returns: + The sealed shards in concatenation order. + """ + manifest_path = trace_dir / _MANIFEST_FILE + if manifest_path.exists(): + try: + payload = json.loads(manifest_path.read_text()) + names = payload.get("sealed_shards") or [] + shards = [trace_dir / name for name in names] + if all(p.exists() for p in shards): + return shards + except Exception: # pylint: disable=broad-except + # Fall through to glob-based discovery. + pass + + found: list[tuple[int, epath.Path]] = [] + for child in trace_dir.iterdir(): + idx = _shard_index(child.name) + if idx is None: + continue + found.append((idx, child)) + found.sort(key=lambda item: item[0]) + return [p for _, p in found] + + +def concat_trace( + trace_dir: epath.Path, + *, + include_pending: bool = True, +) -> bytes: + """Concatenates all trace fragments under ``trace_dir`` into a single blob. + + Args: + trace_dir: Directory containing the sharded trace. + include_pending: When True, append ``trace.shard_pending.binpb`` (if it + exists) so the result contains in-flight data too. + + Returns: + The concatenated trace bytes. + """ + parts: list[bytes] = [] + for shard in list_sealed_shards(trace_dir): + parts.append(shard.read_bytes()) + if include_pending: + pending = trace_dir / _PENDING_FILE + if pending.exists(): + parts.append(pending.read_bytes()) + return b"".join(parts) + + +def _build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + prog="python -m tunix.cli.perfetto_cat", + description=( + "Concatenate sharded perfetto trace files into a single binary" + " trace suitable for https://ui.perfetto.dev." + ), + ) + parser.add_argument( + "trace_dir", + help=( + "Directory containing trace.shard_NNNN.binpb files (and optional" + " trace.shard_pending.binpb). Remote paths supported by etils.epath" + " are accepted (e.g. gs://bucket/path)." + ), + ) + parser.add_argument( + "-o", + "--output", + default="-", + help=( + "Destination file. Use '-' (the default) to write to stdout. Remote" + " paths supported by etils.epath are accepted." + ), + ) + parser.add_argument( + "--no-pending", + action="store_true", + help=( + "Skip the in-flight pending file; emit only sealed shards. Useful" + " when copying a completed trace; not needed for the live view." + ), + ) + return parser + + +def main(argv: list[str] | None = None) -> int: + parser = _build_parser() + args = parser.parse_args(argv) + trace_dir = epath.Path(args.trace_dir) + if not trace_dir.exists(): + print(f"Trace directory not found: {trace_dir}", file=sys.stderr) + return 1 + if not trace_dir.is_dir(): + print(f"Not a directory: {trace_dir}", file=sys.stderr) + return 1 + + payload = concat_trace(trace_dir, include_pending=not args.no_pending) + if not payload: + print( + f"No trace files found under {trace_dir}; nothing to concatenate.", + file=sys.stderr, + ) + return 1 + + if args.output == "-": + # Use the underlying buffer so we don't accidentally re-encode bytes. + sys.stdout.buffer.write(payload) + sys.stdout.buffer.flush() + else: + out_path = epath.Path(args.output) + out_path.write_bytes(payload) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tunix/perf/experimental/export.py b/tunix/perf/experimental/export.py index d84660637..f8b0fcb6e 100644 --- a/tunix/perf/experimental/export.py +++ b/tunix/perf/experimental/export.py @@ -49,6 +49,7 @@ def from_cluster_config( cluster_config: rl_cluster.ClusterConfig, enable_trace_writer: bool = True, trace_dir: str | None = None, + trace_shard_steps: int | None = None, ) -> PerfMetricsExport: """Creates an instance from a ClusterConfig. @@ -56,6 +57,9 @@ def from_cluster_config( cluster_config: The ClusterConfig to extract role_to_mesh from. enable_trace_writer: Whether to initialize the trace writer. trace_dir: The directory to write the Perfetto trace files to. + trace_shard_steps: Number of committed steps per sealed perfetto trace + shard. ``None`` defers to the trace writer's resolution path (env var + ``TUNIX_TRACE_SHARD_STEPS`` then the built-in default). Returns: A new PerfMetricsExport instance configured with role to device mappings. @@ -72,6 +76,7 @@ def from_cluster_config( enable_trace_writer=enable_trace_writer, trace_dir=trace_dir, role_to_devices=role_to_devices, + trace_shard_steps=trace_shard_steps, ) def __init__( @@ -80,6 +85,7 @@ def __init__( enable_trace_writer: bool = True, trace_dir: str | None = None, role_to_devices: Mapping[str, Any] | None = None, + trace_shard_steps: int | None = None, ): """Initializes the instance. @@ -91,13 +97,18 @@ def __init__( used. role_to_devices: An optional mapping from role names to their assigned devices, passed to the trace writer. + trace_shard_steps: Number of committed steps per sealed perfetto trace + shard. ``None`` defers to the trace writer's resolution path (env var + ``TUNIX_TRACE_SHARD_STEPS`` then the built-in default). """ self._trace_writer_enabled = enable_trace_writer self._writer: trace_writer_lib.TraceWriter if enable_trace_writer: resolved_trace_dir = trace_dir or DEFAULT_TRACE_DIR self._writer = trace_writer_lib.PerfettoTraceWriter( - resolved_trace_dir, role_to_devices=role_to_devices + resolved_trace_dir, + role_to_devices=role_to_devices, + shard_steps=trace_shard_steps, ) # We need to keep max_workers = 1 to serialize writes self._executor = concurrent.futures.ThreadPoolExecutor( diff --git a/tunix/perf/experimental/timeline.py b/tunix/perf/experimental/timeline.py index b0967dc54..717ea8673 100644 --- a/tunix/perf/experimental/timeline.py +++ b/tunix/perf/experimental/timeline.py @@ -266,6 +266,42 @@ def commit_step(self) -> None: self._cur_step = {} + def drop_oldest_committed_steps(self, n: int) -> list[dict[int, Span]]: + """Removes the ``n`` oldest committed step dicts from history. + + This is the memory-release path for consumers that have durably persisted + the dropped steps somewhere else (e.g., a sealed trace file on disk). Once + a step is dropped, it cannot be recovered from the timeline. + + Args: + n: Number of oldest committed step dicts to remove. Must be non-negative. + If ``n`` exceeds the number of committed steps currently held, all + committed steps are dropped and the method returns whatever was held. + + Returns: + The list of step dicts that were removed, in order from oldest to most + recent. Callers that need to inspect or write out the dropped steps can + use the returned list; callers that just want to free memory can ignore + it. + + Raises: + ValueError: If ``n`` is negative. + """ + if n < 0: + raise ValueError(f"n must be non-negative, got {n}") + if n == 0: + return [] + with self._lock: + n = min(n, len(self._committed_steps)) + if n == 0: + return [] + dropped = self._committed_steps[:n] + # Copy-on-write to preserve the lock-free read invariant on + # `committed_steps` (a concurrent reader may still hold a reference to + # the previous list). + self._committed_steps = self._committed_steps[n:] + return dropped + def __repr__(self) -> str: parts = [f"Timeline({self.id}, {self.born:.6f})\n"] with self._lock: @@ -404,7 +440,3 @@ def span( for timeline in self._timelines: timeline.span(name, thread_span_begin, waitlist, tags=tags) - - -# TODO(noghabi): remove Spans items from timeline after they are processed. -# Currently, they are never removed. diff --git a/tunix/perf/experimental/trace_writer.py b/tunix/perf/experimental/trace_writer.py index a53fe6f3a..1dee9981c 100644 --- a/tunix/perf/experimental/trace_writer.py +++ b/tunix/perf/experimental/trace_writer.py @@ -12,7 +12,39 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Trace writer implementations and helper functions.""" +"""Trace writer implementations and helper functions. + +The Perfetto trace writer here writes long runs as a directory of immutable +shard files (``trace.shard_NNNN.binpb``) plus a single in-flight pending file +(``trace.shard_pending.binpb``): + +* Every ``shard_steps`` committed steps observed by the writer, the writer + seals one shard. A sealed shard is written exactly once via + ``write_bytes()`` and never rewritten. This is compatible with + immutable-object stores (GCS) and keeps the full trace history across long + runs. +* The pending file is rewritten on every flush and contains everything since + the last seal. It is what users see "live" while a run is in progress. +* Sealed shards' steps are dropped from ``Timeline`` memory immediately after + sealing, so memory stays bounded to a few shards' worth of spans regardless + of run length. + +A ``trace.manifest.json`` companion file tracks which shards have been sealed +so far. Perfetto's TracePacket format is concatenable, so a complete trace can +be reassembled with:: + + cat trace.shard_*.binpb trace.shard_pending.binpb > trace.binpb + +or via ``python -m tunix.cli.perfetto_cat ``. + +Lane assignment is streaming: each new span on a timeline is greedily placed +on the lowest-index lane whose last span ends at or before this span's begin +time, persisting the per-timeline lane busy-times across shards. Lane indices +(and the perfetto UUIDs derived from them) are therefore stable across shards +-- a span placed on lane 0 of timeline T in shard N stays on lane 0 of +timeline T in every subsequent shard, and the concatenated trace shows a +consistent track layout. +""" from __future__ import annotations @@ -20,6 +52,8 @@ from collections.abc import Iterable, Mapping import dataclasses import itertools +import json +import os import time from typing import Any @@ -38,6 +72,50 @@ _UUID_OFFSET = 100_000 # Offset for lane UUIDs. +_DEFAULT_SHARD_STEPS = 100 +_SHARD_STEPS_ENV = "TUNIX_TRACE_SHARD_STEPS" + +_SHARD_FILE_FMT = "trace.shard_{index:04d}.binpb" +_PENDING_FILE = "trace.shard_pending.binpb" +_MANIFEST_FILE = "trace.manifest.json" +_MANIFEST_VERSION = 1 + + +def _resolve_shard_steps(shard_steps: int | None) -> int: + """Resolves the effective shard size from arg + env var override. + + Args: + shard_steps: The value requested by the caller, or ``None`` to fall back to + the env var / default. + + Returns: + A positive integer to use as the number of committed steps per sealed + shard. The env var ``TUNIX_TRACE_SHARD_STEPS``, when set to a parseable + positive integer, takes precedence over the caller-provided value to give + operators a uniform override across all writers in a run. + """ + env_val = os.environ.get(_SHARD_STEPS_ENV) + if env_val is not None: + try: + parsed = int(env_val) + except ValueError: + logging.warning( + "%s=%r is not a valid integer; ignoring.", _SHARD_STEPS_ENV, env_val + ) + else: + if parsed >= 1: + return parsed + logging.warning( + "%s=%d is not >= 1; ignoring.", _SHARD_STEPS_ENV, parsed + ) + if shard_steps is None: + return _DEFAULT_SHARD_STEPS + if shard_steps < 1: + raise ValueError( + f"shard_steps must be a positive integer, got {shard_steps!r}." + ) + return shard_steps + def _create_span_name(name: str, tags: Mapping[str, Any]) -> str: """Creates a descriptive name for the span based on its tags.""" @@ -76,43 +154,6 @@ def _create_span_name(name: str, tags: Mapping[str, Any]) -> str: return name -def _assign_lanes( - spans: Iterable[timeline.Span], -) -> tuple[Mapping[int, int], int]: - """Assigns lanes to spans to handle overlaps. - - Perfetto requires spans on the same track to be strictly nested (no arbitrary - overlaps). This function assigns a lane index to each span such that spans - in the same lane do not overlap. - - Args: - spans: An iterable of spans to assign to lanes. - - Returns: - A tuple (`lane_by_span_id`, `num_lanes`), where: - `lane_by_span_id`: A dictionary mapping span IDs to their assigned lane - index. - `num_lanes`: The total number of lanes required. - """ - sorted_spans = sorted(spans, key=lambda s: (s.begin, s.id)) - lanes_end_times = [] - lane_by_span_id = {} - - for s in sorted_spans: - placed = False - for lane_idx, lane_end in enumerate(lanes_end_times): - if lane_end <= s.begin: - lanes_end_times[lane_idx] = s.end - lane_by_span_id[s.id] = lane_idx - placed = True - break - if not placed: - lane_by_span_id[s.id] = len(lanes_end_times) - lanes_end_times.append(s.end) - - return lane_by_span_id, len(lanes_end_times) - - class TraceWriter(abc.ABC): """An abstract base class for writing traces.""" @@ -130,10 +171,10 @@ def write_timelines(self, timelines: Mapping[str, Timeline]) -> None: @dataclasses.dataclass class TrackInfo: - """Information about a track. + """Information about a parent track in the perfetto layout. Attributes: - name: The name of the track. + name: The display name of the track. uuid: The unique identifier for the track in Perfetto. """ @@ -142,48 +183,97 @@ class TrackInfo: class PerfettoTraceWriter(TraceWriter): - """A writer for Perfetto trace events.""" + """A writer for Perfetto trace events. + + Writes long runs as a directory of immutable sharded files plus an in-flight + pending file. See the module docstring for the file layout and the rationale. + + Constructor parameters: + + trace_dir: Local path or remote URI (e.g. ``gs://...``). The directory is + created if it does not exist. If creation fails the writer enters a + no-op mode -- tracing is best-effort and never crashes the application. + role_to_devices: Optional mapping from role names to the device IDs that + handle that role. Used to label per-device tracks ("Actor Cluster", + "Rollout Cluster", etc.). + shard_steps: Committed steps per sealed shard. ``None`` defers to the + ``TUNIX_TRACE_SHARD_STEPS`` env var, then to a built-in default. The + env var, if valid, wins over an explicit caller value to provide a + uniform per-run override for operators. + """ def __init__( self, trace_dir: str, role_to_devices: Mapping[str, Any] | None = None, + *, + shard_steps: int | None = None, ): - """Initializes the instance. - - Args: - trace_dir: The directory to export trace files to. This path can be a - local Linux path or a remote storage path (e.g. gs://). - role_to_devices: An optional mapping from role names to their assigned - devices. - """ - self._trace_dir = trace_dir + self._shard_steps = _resolve_shard_steps(shard_steps) + self._trace_dir_raw = trace_dir self._role_to_devices = ( dict(role_to_devices) if role_to_devices is not None else {} ) + + # Parent track grouping (e.g. "Host - Main threads", "Actor Cluster"), + # populated lazily on first observation of each timeline. self._track_info: dict[str, TrackInfo] = {} self._timeline_tracks: dict[str, str] = {} self._timeline_uuids: dict[str, int] = {} - self._trace_file_path = None + # Per-timeline streaming lane assignment state. ``_lane_busy_until`` + # records, for each timeline, the end-time of the latest span placed on + # each lane; it only grows in length as new lanes are needed. New spans + # reuse the lowest-index lane that is free at their begin-time, so lane + # indices are stable across seals. + self._lane_busy_until: dict[str, list[float]] = {} + self._lane_assignment: dict[str, dict[int, int]] = {} + self._lane_descriptors_emitted: dict[str, int] = {} + + # Step accounting for shard sealing. ``_unsealed_step_count`` tracks how + # many committed step boundaries have occurred since the most recent seal, + # using the maximum delta across timelines as the synchronized count (all + # timelines created at tracer init commit in lockstep, so this collapses + # to the per-step delta in practice). + self._observed_committed_count: dict[str, int] = {} + self._unsealed_step_count = 0 + self._sealed_step_count = 0 + self._next_shard_index = 1 + self._sealed_shards: list[str] = [] + + self._trace_dir: epath.Path | None = None + self._pending_path: epath.Path | None = None + self._manifest_path: epath.Path | None = None try: - trace_dir_path = epath.Path(self._trace_dir) + trace_dir_path = epath.Path(self._trace_dir_raw) trace_dir_path.mkdir(parents=True, exist_ok=True) - trace_file_name = f"perfetto_trace_v2_{int(time.time())}.pb" - self._trace_file_path = trace_dir_path / trace_file_name + self._trace_dir = trace_dir_path + self._pending_path = trace_dir_path / _PENDING_FILE + self._manifest_path = trace_dir_path / _MANIFEST_FILE logging.info( - "Initializing perfetto trace writer at: %s", self._trace_file_path + "Initializing perfetto trace writer at: %s (shard_steps=%d)", + self._trace_dir, + self._shard_steps, ) except Exception: # pylint: disable=broad-except # Catching broad exceptions to ensure that failures in trace - # initialization (e.g., due to file system errors, permissions, etc.) do - # not crash the application. Tracing is best-effort. + # initialization (e.g., due to file system errors, permissions, etc.) + # do not crash the application. Tracing is best-effort. logging.exception( - "Failed to initialize perfetto trace writer in directory %r. Skipping" - " trace dumping for this run.", - self._trace_dir, + "Failed to initialize perfetto trace writer in directory %r." + " Skipping trace dumping for this run.", + self._trace_dir_raw, ) - self._trace_file_path = None + + @property + def shard_steps(self) -> int: + """The number of committed steps per sealed shard.""" + return self._shard_steps + + @property + def sealed_shards(self) -> list[str]: + """File names (relative to the trace dir) of all sealed shards so far.""" + return list(self._sealed_shards) def _get_device_track_name(self, tl_id: str) -> str | None: """Gets a formatted track name for a device timeline. @@ -208,7 +298,6 @@ def _get_device_track_name(self, tl_id: str) -> str | None: for device in devices: device_str = timeline_utils.generate_device_timeline_id(device) if device_str == base_tl_id and role not in cluster_roles: - cluster_roles.append(role) if cluster_roles: @@ -219,32 +308,44 @@ def _get_device_track_name(self, tl_id: str) -> str | None: return f"{', '.join(camel_roles)} Cluster" return None - def write(self, builder: TraceProtoBuilder) -> None: - """Writes the built trace to the file.""" - if self._trace_file_path is None: - return + def _safe_write_bytes(self, path: epath.Path, payload: bytes) -> bool: + """Writes ``payload`` to ``path``, logging and swallowing failures. + Returns True on success, False on failure. Tracing is best-effort and + never raises into the caller. + """ try: - # TODO: b/480134569 - see if file writing is a bottleneck and explore - # faster alternatives (e.g., keeping in memory and writing at the end). - self._trace_file_path.write_bytes(builder.serialize()) + path.write_bytes(payload) + return True + except Exception: # pylint: disable=broad-except + logging.exception("Failed to write trace bytes to %s", path) + return False + + def _update_manifest(self) -> None: + """Updates the on-disk manifest summarizing the trace directory layout.""" + if self._manifest_path is None: + return + payload = { + "version": _MANIFEST_VERSION, + "shard_steps": self._shard_steps, + "sealed_shards": list(self._sealed_shards), + "sealed_step_count": self._sealed_step_count, + "pending_file": _PENDING_FILE, + } + try: + self._manifest_path.write_text(json.dumps(payload, indent=2) + "\n") except Exception: # pylint: disable=broad-except - # Catching broad exceptions to ensure that failures in trace - # serialization or writing do not crash the application. Tracing is - # best-effort. logging.exception( - "Failed to write to trace file: %s", self._trace_file_path + "Failed to write trace manifest at %s", self._manifest_path ) - def _init_tracks( - self, timelines: Mapping[str, Timeline] - ) -> None: - """Initializes track info for timelines. + def _init_tracks(self, timelines: Mapping[str, Timeline]) -> None: + """Populates parent track info for any newly-seen timelines. - Args: - timelines: A mapping of timeline IDs to timelines. + This is idempotent: timelines already registered are skipped. Timelines + that have never carried any non-empty committed step yet are also skipped + so we don't allocate a track for a timeline that may turn out to be unused. """ - for tl_id in sorted(timelines): if tl_id in self._timeline_tracks: continue @@ -254,7 +355,6 @@ def _init_tracks( continue if timeline_utils.is_host_timeline(tl_id): - # Rollout only timelines threads. if timeline_utils.is_timeline_only_of_allowed_type( tl, [perf_constants.ROLLOUT], include_cur_step=False ): @@ -263,7 +363,6 @@ def _init_tracks( uuid=_UUID_OFFSET + 1, ) self._timeline_tracks[tl_id] = "host_rollout" - # Main timelines threads. else: self._track_info["host_main"] = TrackInfo( name="Host - Main threads", @@ -280,96 +379,160 @@ def _init_tracks( ) self._timeline_tracks[tl_id] = track_name else: - logging.warning("Failed to get track name for timeline ID: %s", tl_id) - - def write_timelines(self, timelines: Mapping[str, Timeline]) -> None: - """Writes timelines to the trace file.""" - if not timelines: - return + logging.warning( + "Failed to get track name for timeline ID: %s", tl_id + ) - if not any(any(tl.committed_steps) for tl in timelines.values()): - return + def _detect_and_track_commits( + self, timelines: Mapping[str, Timeline] + ) -> None: + """Detects new commit_step boundaries since the previous flush. - builder = TraceProtoBuilder() + Per-call delta is computed per timeline and the maximum is taken as the + synchronized step count, since the tracer commits all timelines together. + Using the max (rather than min) makes the writer robust to timelines that + are created or first observed mid-run -- they catch up over subsequent + commits without blocking sealing of the older timelines. + """ + max_delta = 0 + for tl_id, tl in timelines.items(): + actual = len(tl.committed_steps) + observed = self._observed_committed_count.get(tl_id, 0) + if actual > observed: + max_delta = max(max_delta, actual - observed) + # Always refresh; if a timeline drained (actual < observed) we lower it + # so the next call accounts only for genuinely new commits. + self._observed_committed_count[tl_id] = actual + self._unsealed_step_count += max_delta + + def _update_lane_assignments( + self, timelines: Mapping[str, Timeline] + ) -> None: + """Assigns lanes to any spans not yet seen by the writer. - self._init_tracks(timelines) + Iterates each timeline's currently-held committed spans in (begin, id) + order. Spans already in ``_lane_assignment[tl_id]`` keep their lane; + fresh spans pick the lowest-index lane whose ``_lane_busy_until`` is at + or before ``span.begin``, otherwise a new lane is appended. + """ + for tl_id, tl in timelines.items(): + if not any(tl.committed_steps): + continue + assignments = self._lane_assignment.setdefault(tl_id, {}) + busy = self._lane_busy_until.setdefault(tl_id, []) + # Sort by (begin, id) for deterministic placement; matches the original + # _assign_lanes behavior. + all_spans = sorted( + itertools.chain.from_iterable( + step.values() for step in tl.committed_steps + ), + key=lambda s: (s.begin, s.id), + ) + for s in all_spans: + if s.id in assignments: + continue + placed = False + for lane_idx, lane_end in enumerate(busy): + if lane_end <= s.begin: + busy[lane_idx] = s.end + assignments[s.id] = lane_idx + placed = True + break + if not placed: + assignments[s.id] = len(busy) + busy.append(s.end) + + def _emit_descriptors( + self, + builder: TraceProtoBuilder, + timelines: Mapping[str, Timeline], + ) -> None: + """Emits all track descriptors for the current trace fragment. - # Write track descriptors for parent tracks. + Each shard (sealed or pending) re-emits the full set of descriptors so it + is viewable standalone. Perfetto tolerates duplicate descriptors by UUID + when shards are concatenated, so this is safe. + """ for track_info in self._track_info.values(): packet = builder.add_packet() packet.track_descriptor.uuid = track_info.uuid packet.track_descriptor.name = track_info.name - # Sort timelines by ID to ensure consistent track ordering. - sorted_ids = sorted(timelines) - - events = [] - - for tl_id in sorted_ids: + for tl_id in sorted(timelines): tl = timelines[tl_id] - - if not any(tl.committed_steps): + if tl_id not in self._timeline_tracks and not any(tl.committed_steps): + # Timeline never carried data; skip its descriptors. continue - - # Assign a UUID to the timeline if it hasn't been assigned one yet. Offset - # by 2 to account for track descriptor UUIDs. if tl_id not in self._timeline_uuids: self._timeline_uuids[tl_id] = _UUID_OFFSET * ( len(self._timeline_uuids) + 2 ) tl_uuid = self._timeline_uuids[tl_id] - packet = builder.add_packet() packet.track_descriptor.uuid = tl_uuid packet.track_descriptor.name = tl_id - if tl_id in self._timeline_tracks: - track_info = self._track_info[self._timeline_tracks[tl_id]] - packet.track_descriptor.parent_uuid = track_info.uuid - - # TODO: noghabi - limit processing to last steps. we don't need to start - # from the beginning every time. - all_spans = list( - itertools.chain.from_iterable( - step.values() for step in tl.committed_steps - ) - ) - lane_by_span_id, num_lanes = _assign_lanes(all_spans) + parent = self._track_info[self._timeline_tracks[tl_id]] + packet.track_descriptor.parent_uuid = parent.uuid + num_lanes = len(self._lane_busy_until.get(tl_id, [])) if num_lanes > 1: - # Emit track descriptors for each lane so they group under the timeline for lane_idx in range(num_lanes): lane_uuid = tl_uuid + lane_idx + 1 packet = builder.add_packet() packet.track_descriptor.uuid = lane_uuid packet.track_descriptor.parent_uuid = tl_uuid - packet.track_descriptor.name = "" # empty name for lanes + packet.track_descriptor.name = "" - for s in all_spans: - lane_idx = lane_by_span_id[s.id] - lane_uuid = tl_uuid if num_lanes <= 1 else (tl_uuid + lane_idx + 1) - - # Timestamp in nanoseconds, relative to timeline creation (born). - start_ns = int((s.begin - tl.born) * 1e9) - events.append({ - "timestamp": start_ns, - "type": TrackEvent.Type.TYPE_SLICE_BEGIN, - "uuid": lane_uuid, - "name": _create_span_name(s.name, s.tags), - }) - - if s.ended: - end_ns = int((s.end - tl.born) * 1e9) + def _emit_events_for_steps( + self, + builder: TraceProtoBuilder, + timelines: Mapping[str, Timeline], + step_slice: slice, + ) -> None: + """Emits begin/end events for the given slice of each timeline's steps. + + Args: + builder: The proto builder to write into. + timelines: All timelines. + step_slice: A slice applied to each timeline's ``committed_steps``. + ``slice(None)`` emits everything currently held. + """ + events = [] + for tl_id in sorted(timelines): + tl = timelines[tl_id] + tl_committed = tl.committed_steps + if not any(tl_committed): + continue + tl_uuid = self._timeline_uuids.get(tl_id) + if tl_uuid is None: + continue + assignments = self._lane_assignment.get(tl_id, {}) + num_lanes = len(self._lane_busy_until.get(tl_id, [])) + steps = tl_committed[step_slice] + for step in steps: + for s in step.values(): + lane_idx = assignments.get(s.id, 0) + lane_uuid = tl_uuid if num_lanes <= 1 else (tl_uuid + lane_idx + 1) + start_ns = int((s.begin - tl.born) * 1e9) events.append({ - "timestamp": end_ns, - "type": TrackEvent.Type.TYPE_SLICE_END, + "timestamp": start_ns, + "type": TrackEvent.Type.TYPE_SLICE_BEGIN, "uuid": lane_uuid, - "name": None, + "name": _create_span_name(s.name, s.tags), }) - - # Perfetto trace processor requires events within a sequence to be strictly - # sorted by timestamp. Out-of-order events can cause spans to be rendered - # incorrectly (e.g., stretched end times). + if s.ended: + end_ns = int((s.end - tl.born) * 1e9) + events.append({ + "timestamp": end_ns, + "type": TrackEvent.Type.TYPE_SLICE_END, + "uuid": lane_uuid, + "name": None, + }) + + # Perfetto requires strict timestamp ordering within a sequence; END + # events break ties before BEGIN events so a zero-duration handoff doesn't + # stretch the previous span. events.sort( key=lambda e: ( e["timestamp"], @@ -386,4 +549,154 @@ def write_timelines(self, timelines: Mapping[str, Timeline]) -> None: if e["name"] is not None: packet.track_event.name = e["name"] - self.write(builder) + def _build_trace_fragment( + self, + timelines: Mapping[str, Timeline], + step_slice: slice, + ) -> TraceProtoBuilder: + """Builds a self-contained perfetto trace covering a slice of each timeline.""" + builder = TraceProtoBuilder() + self._emit_descriptors(builder, timelines) + self._emit_events_for_steps(builder, timelines, step_slice) + return builder + + def _seal_one_shard(self, timelines: Mapping[str, Timeline]) -> None: + """Seals the next shard from the first ``shard_steps`` of each timeline.""" + if self._trace_dir is None: + # If the directory failed to initialize, still advance the bookkeeping + # so we don't accumulate memory forever. Best-effort drop of steps. + for tl_id, tl in timelines.items(): + tl.drop_oldest_committed_steps(self._shard_steps) + self._observed_committed_count[tl_id] = len(tl.committed_steps) + self._unsealed_step_count = max( + 0, self._unsealed_step_count - self._shard_steps + ) + return + + shard_name = _SHARD_FILE_FMT.format(index=self._next_shard_index) + shard_path = self._trace_dir / shard_name + builder = self._build_trace_fragment( + timelines, step_slice=slice(0, self._shard_steps) + ) + ok = self._safe_write_bytes(shard_path, builder.serialize()) + if not ok: + # Don't drop spans we couldn't write; let the next flush retry. + return + + # Drop the sealed steps from each timeline's memory and from the lane + # assignment cache so we don't hold them forever. + for tl_id, tl in timelines.items(): + dropped = tl.drop_oldest_committed_steps(self._shard_steps) + if dropped: + assignments = self._lane_assignment.get(tl_id) + if assignments is not None: + for step in dropped: + for sid in step: + assignments.pop(sid, None) + self._observed_committed_count[tl_id] = len(tl.committed_steps) + + self._sealed_shards.append(shard_name) + self._next_shard_index += 1 + self._sealed_step_count += self._shard_steps + self._unsealed_step_count -= self._shard_steps + self._update_manifest() + + def _write_pending(self, timelines: Mapping[str, Timeline]) -> None: + """Writes (or removes) the in-flight pending shard file.""" + if self._pending_path is None: + return + has_pending = any(any(tl.committed_steps) for tl in timelines.values()) + if not has_pending: + # Nothing unsealed; remove any stale pending file so concatenating + # ``trace.shard_*.binpb`` is a complete trace. + try: + if self._pending_path.exists(): + self._pending_path.unlink() + except Exception: # pylint: disable=broad-except + logging.exception( + "Failed to remove stale pending trace at %s", self._pending_path + ) + return + builder = self._build_trace_fragment(timelines, step_slice=slice(None)) + self._safe_write_bytes(self._pending_path, builder.serialize()) + + def write_timelines(self, timelines: Mapping[str, Timeline]) -> None: + """Writes timelines to the trace directory, sealing shards as needed.""" + if not timelines: + return + if ( + not any(any(tl.committed_steps) for tl in timelines.values()) + and self._sealed_step_count == 0 + ): + return + + self._init_tracks(timelines) + self._detect_and_track_commits(timelines) + self._update_lane_assignments(timelines) + + # Seal as many full shards as we have data for. A long pause between + # flushes could produce more than one shard's worth of unsealed data; we + # drain it all rather than queueing seals behind future flushes. + while self._unsealed_step_count >= self._shard_steps: + sealed_index_before = self._next_shard_index + self._seal_one_shard(timelines) + if self._next_shard_index == sealed_index_before: + # Seal did not advance (e.g. write failure or no trace dir). Avoid an + # infinite loop. + break + + self._write_pending(timelines) + + +# Backwards-compatible alias retained for callers that imported the +# internal helper directly. Equivalent to one-shot lane assignment over a +# bag of spans; the writer itself now uses the streaming assignment on +# ``PerfettoTraceWriter``. +def _assign_lanes( + spans: Iterable[timeline.Span], +) -> tuple[Mapping[int, int], int]: + """Assigns lanes to spans to handle overlaps (one-shot, no streaming state). + + Perfetto requires spans on the same track to be strictly non-overlapping. + This helper assigns a lane index to each span such that spans in the same + lane do not overlap. It is left here for callers that previously used it + directly; the trace writer uses streaming lane assignment internally. + + Args: + spans: An iterable of spans to assign to lanes. + + Returns: + A tuple ``(lane_by_span_id, num_lanes)``. + """ + sorted_spans = sorted(spans, key=lambda s: (s.begin, s.id)) + lanes_end_times: list[float] = [] + lane_by_span_id: dict[int, int] = {} + + for s in sorted_spans: + placed = False + for lane_idx, lane_end in enumerate(lanes_end_times): + if lane_end <= s.begin: + lanes_end_times[lane_idx] = s.end + lane_by_span_id[s.id] = lane_idx + placed = True + break + if not placed: + lane_by_span_id[s.id] = len(lanes_end_times) + lanes_end_times.append(s.end) + + return lane_by_span_id, len(lanes_end_times) + + +# Module-level sentinel used by tests that previously patched ``TraceProtoBuilder``. +# Kept as a public attribute so existing test mocks keep working. +__all__ = [ + "PerfettoTraceWriter", + "NoopTraceWriter", + "TrackInfo", + "TraceWriter", + "TraceProtoBuilder", + "_create_span_name", + "_assign_lanes", + "TrackDescriptor", + "TrackEvent", +] diff --git a/tunix/perf/metrics.py b/tunix/perf/metrics.py index c3fc4c9c3..60e437a14 100644 --- a/tunix/perf/metrics.py +++ b/tunix/perf/metrics.py @@ -90,6 +90,12 @@ class PerfMetricsOptions: enabled when perf metrics are enabled. If False, the trace will not be written out. trace_dir: Directory the trace writer writes the raw metrics/events to. + trace_shard_steps: Number of committed steps per sealed perfetto trace + shard file. Lower values write more (smaller) shard files and bound + in-memory span history more tightly; higher values write fewer (larger) + shard files. Must be >= 1. The env var ``TUNIX_TRACE_SHARD_STEPS``, if + set to a positive integer, overrides this value at trace writer + construction time. """ enable_perf_v1: bool = True @@ -98,6 +104,7 @@ class PerfMetricsOptions: custom_export_fn_path_v2: str = "" enable_trace_writer: bool = True trace_dir: str = "" + trace_shard_steps: int = 100 def __post_init__(self): if self.custom_export_fn_path and not self.enable_perf_v1: @@ -119,6 +126,11 @@ def __post_init__(self): "trace_dir is set to %r but enable_trace_writer is False.", self.trace_dir, ) + if self.trace_shard_steps < 1: + raise ValueError( + "trace_shard_steps must be a positive integer, got" + f" {self.trace_shard_steps!r}." + ) @dataclasses.dataclass