Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 187 additions & 0 deletions tests/cli/perfetto_cat_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import io
import json
import os
import tempfile
from unittest import mock

from absl.testing import absltest
from etils import epath
from tunix.cli import perfetto_cat


def _populate_trace_dir(
tmp_dir,
*,
sealed_shards=("trace.shard_0001.binpb", "trace.shard_0002.binpb"),
pending_content=b"PENDING_BYTES",
manifest=True,
):
"""Writes a fixture trace directory with known byte payloads per file."""
for name in sealed_shards:
epath.Path(os.path.join(tmp_dir, name)).write_bytes(name.encode())
if pending_content is not None:
epath.Path(os.path.join(tmp_dir, "trace.shard_pending.binpb")).write_bytes(
pending_content
)
if manifest:
epath.Path(os.path.join(tmp_dir, "trace.manifest.json")).write_text(
json.dumps(
{
"version": 1,
"shard_steps": 1,
"sealed_shards": list(sealed_shards),
"sealed_step_count": len(sealed_shards),
"pending_file": "trace.shard_pending.binpb",
}
)
)


class ListSealedShardsTest(absltest.TestCase):

def test_uses_manifest_order_when_present(self):
with tempfile.TemporaryDirectory() as tmp_dir:
_populate_trace_dir(
tmp_dir,
sealed_shards=("trace.shard_0001.binpb", "trace.shard_0002.binpb"),
)
shards = perfetto_cat.list_sealed_shards(epath.Path(tmp_dir))
names = [p.name for p in shards]
self.assertEqual(
names, ["trace.shard_0001.binpb", "trace.shard_0002.binpb"]
)

def test_falls_back_to_glob_when_manifest_missing(self):
with tempfile.TemporaryDirectory() as tmp_dir:
_populate_trace_dir(
tmp_dir,
sealed_shards=(
"trace.shard_0002.binpb",
"trace.shard_0010.binpb",
"trace.shard_0001.binpb",
),
manifest=False,
)
shards = perfetto_cat.list_sealed_shards(epath.Path(tmp_dir))
names = [p.name for p in shards]
self.assertEqual(
names,
[
"trace.shard_0001.binpb",
"trace.shard_0002.binpb",
"trace.shard_0010.binpb",
],
)

def test_falls_back_to_glob_when_manifest_lists_missing_shard(self):
"""If the manifest references a shard that isn't on disk, glob fallback
must kick in rather than silently dropping the file from the output."""
with tempfile.TemporaryDirectory() as tmp_dir:
_populate_trace_dir(
tmp_dir,
sealed_shards=("trace.shard_0001.binpb",),
)
# Manifest claims a second shard that doesn't exist.
epath.Path(os.path.join(tmp_dir, "trace.manifest.json")).write_text(
json.dumps(
{
"version": 1,
"shard_steps": 1,
"sealed_shards": [
"trace.shard_0001.binpb",
"trace.shard_0099.binpb",
],
"sealed_step_count": 2,
"pending_file": "trace.shard_pending.binpb",
}
)
)
shards = perfetto_cat.list_sealed_shards(epath.Path(tmp_dir))
# Only the on-disk shard should appear; manifest is treated as a hint.
self.assertEqual([p.name for p in shards], ["trace.shard_0001.binpb"])


class ConcatTraceTest(absltest.TestCase):

def test_concat_includes_pending_by_default(self):
with tempfile.TemporaryDirectory() as tmp_dir:
_populate_trace_dir(tmp_dir, pending_content=b"PENDING")
payload = perfetto_cat.concat_trace(epath.Path(tmp_dir))
self.assertEqual(
payload, b"trace.shard_0001.binpb" + b"trace.shard_0002.binpb" + b"PENDING"
)

def test_concat_skips_pending_when_requested(self):
with tempfile.TemporaryDirectory() as tmp_dir:
_populate_trace_dir(tmp_dir, pending_content=b"PENDING")
payload = perfetto_cat.concat_trace(
epath.Path(tmp_dir), include_pending=False
)
self.assertEqual(
payload, b"trace.shard_0001.binpb" + b"trace.shard_0002.binpb"
)

def test_concat_handles_missing_pending(self):
with tempfile.TemporaryDirectory() as tmp_dir:
_populate_trace_dir(tmp_dir, pending_content=None)
payload = perfetto_cat.concat_trace(epath.Path(tmp_dir))
self.assertEqual(
payload, b"trace.shard_0001.binpb" + b"trace.shard_0002.binpb"
)


class MainTest(absltest.TestCase):

def test_main_writes_to_output_file(self):
with tempfile.TemporaryDirectory() as tmp_dir:
_populate_trace_dir(tmp_dir, pending_content=b"P")
out_path = os.path.join(tmp_dir, "combined.binpb")
rc = perfetto_cat.main([tmp_dir, "-o", out_path])
self.assertEqual(rc, 0)
self.assertEqual(
epath.Path(out_path).read_bytes(),
b"trace.shard_0001.binpb" + b"trace.shard_0002.binpb" + b"P",
)

def test_main_writes_to_stdout(self):
with tempfile.TemporaryDirectory() as tmp_dir:
_populate_trace_dir(tmp_dir, pending_content=b"P")
fake_stdout = io.BytesIO()
fake_wrapper = mock.MagicMock()
fake_wrapper.buffer = fake_stdout
with mock.patch("sys.stdout", fake_wrapper):
rc = perfetto_cat.main([tmp_dir])
self.assertEqual(rc, 0)
self.assertEqual(
fake_stdout.getvalue(),
b"trace.shard_0001.binpb" + b"trace.shard_0002.binpb" + b"P",
)

def test_main_missing_directory_returns_error(self):
with tempfile.TemporaryDirectory() as tmp_dir:
missing = os.path.join(tmp_dir, "does-not-exist")
rc = perfetto_cat.main([missing, "-o", "/dev/null"])
self.assertEqual(rc, 1)

def test_main_empty_directory_returns_error(self):
with tempfile.TemporaryDirectory() as tmp_dir:
rc = perfetto_cat.main([tmp_dir, "-o", os.path.join(tmp_dir, "out.bin")])
self.assertEqual(rc, 1)


if __name__ == "__main__":
absltest.main()
45 changes: 38 additions & 7 deletions tests/perf/experimental/export_v2_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
class ExportTest(parameterized.TestCase):

def test_perf_metrics_export(self):
# Backward compatibility check
with tempfile.TemporaryDirectory() as tmp_dir:
with export.PerfMetricsExport(trace_dir=tmp_dir) as exporter:
# Create dummy timeline
Expand All @@ -36,9 +35,14 @@ def test_perf_metrics_export(self):
time.sleep(0.001)
t.export()

files = os.listdir(tmp_dir)
self.assertLen(files, 1)
self.assertStartsWith(files[0], "perfetto_trace_v2_")
# With a single committed step and the default shard size, no shard
# has been sealed yet; the in-flight pending file holds the only data.
files = set(os.listdir(tmp_dir))
self.assertIn("trace.shard_pending.binpb", files)
self.assertGreater(
os.path.getsize(os.path.join(tmp_dir, "trace.shard_pending.binpb")),
0,
)

def test_basic_metrics_export(self):
with self.assertLogs(level="INFO") as logs:
Expand Down Expand Up @@ -76,7 +80,9 @@ def test_perf_metrics_export_initialization_with_trace_writer_enabled(
with export.PerfMetricsExport(
enable_trace_writer=True, trace_dir=trace_dir
) as exporter:
mock_writer_cls.assert_called_once_with(expected_dir, role_to_devices=None)
mock_writer_cls.assert_called_once_with(
expected_dir, role_to_devices=None, shard_steps=None
)
# export_metrics shouldn't crash
exporter.export_metrics({})

Expand Down Expand Up @@ -148,10 +154,33 @@ def test_from_cluster_config(self, mock_writer_cls):
"rollout": ["tpu4", "tpu5", "tpu6", "tpu7"],
}
mock_writer_cls.assert_called_once_with(
"/test/dir", role_to_devices=expected_role_to_devices
"/test/dir",
role_to_devices=expected_role_to_devices,
shard_steps=None,
)
self.assertIs(exporter._writer, mock_writer_cls.return_value)

def test_from_cluster_config_passes_trace_shard_steps(self):
"""``from_cluster_config`` must forward ``trace_shard_steps`` to the writer so per-run overrides reach the on-disk shard layout."""
with mock.patch.object(
export.trace_writer_lib,
"PerfettoTraceWriter",
autospec=True,
) as mock_writer_cls:
mock_cluster_config = mock.create_autospec(
rl_cluster.ClusterConfig, instance=True, spec_set=True
)
del mock_cluster_config.role_to_mesh
export.PerfMetricsExport.from_cluster_config(
mock_cluster_config,
enable_trace_writer=True,
trace_dir="/test/dir",
trace_shard_steps=42,
)
mock_writer_cls.assert_called_once_with(
"/test/dir", role_to_devices={}, shard_steps=42
)

def test_safe_write_exception(self):
with export.PerfMetricsExport(enable_trace_writer=False) as exporter:
with mock.patch.object(exporter, "_writer", autospec=True) as mock_writer:
Expand Down Expand Up @@ -192,7 +221,9 @@ def test_from_cluster_config_no_role_to_mesh(self, mock_writer_cls):
trace_dir="/test/dir",
)

mock_writer_cls.assert_called_once_with("/test/dir", role_to_devices={})
mock_writer_cls.assert_called_once_with(
"/test/dir", role_to_devices={}, shard_steps=None
)
self.assertIs(exporter._writer, mock_writer_cls.return_value)

if __name__ == "__main__":
Expand Down
54 changes: 54 additions & 0 deletions tests/perf/experimental/timeline_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,60 @@ def test_stop_span_error_cases(self):
with self.assertRaisesRegex(ValueError, "ended at .* before it began"):
t.stop_span(1.0)

def test_drop_oldest_committed_steps_basic(self):
t = timeline.Timeline("test_tl", 0.0)
span_ids = []
for i in range(5):
s = t.start_span(f"span{i}", float(i))
t.stop_span(float(i) + 0.1)
span_ids.append(s.id)
t.commit_step()
self.assertLen(t.committed_steps, 5)

# Snapshot reference before drop to confirm copy-on-write semantics.
pre_drop_ref = t.committed_steps

dropped = t.drop_oldest_committed_steps(2)

with self.subTest("dropped_returned_oldest_first"):
self.assertLen(dropped, 2)
self.assertIn(span_ids[0], dropped[0])
self.assertIn(span_ids[1], dropped[1])

with self.subTest("post_drop_state"):
self.assertLen(t.committed_steps, 3)
self.assertIn(span_ids[2], t.committed_steps[0])
self.assertIn(span_ids[4], t.committed_steps[-1])

with self.subTest("copy_on_write_preserves_prior_snapshot"):
# The reference captured before the drop must remain unchanged so
# concurrent readers iterating the old snapshot are not affected.
self.assertLen(pre_drop_ref, 5)

def test_drop_oldest_committed_steps_zero_is_noop(self):
t = timeline.Timeline("test_tl", 0.0)
t.start_span("s", 1.0)
t.stop_span(2.0)
t.commit_step()
dropped = t.drop_oldest_committed_steps(0)
self.assertEmpty(dropped)
self.assertLen(t.committed_steps, 1)

def test_drop_oldest_committed_steps_more_than_held(self):
t = timeline.Timeline("test_tl", 0.0)
for _ in range(3):
t.start_span("s", 1.0)
t.stop_span(2.0)
t.commit_step()
dropped = t.drop_oldest_committed_steps(10)
self.assertLen(dropped, 3)
self.assertEmpty(t.committed_steps)

def test_drop_oldest_committed_steps_negative_raises(self):
t = timeline.Timeline("test_tl", 0.0)
with self.assertRaisesRegex(ValueError, "n must be non-negative"):
t.drop_oldest_committed_steps(-1)

def test_nested_timeline_with_tags_repr(self):
born = 1000.0
t = timeline.Timeline("test_tl", born)
Expand Down
Loading
Loading