-
Notifications
You must be signed in to change notification settings - Fork 10
Expand file tree
/
Copy pathexample_data_loading.py
More file actions
121 lines (96 loc) · 5.15 KB
/
Copy pathexample_data_loading.py
File metadata and controls
121 lines (96 loc) · 5.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
"""
Minimal example showing how to use TrajectoryDatasetManager.
Usage:
python experiments/data_loading_example.py --dataset_dir /path/to/traj_data
# Save a GIF of the first trajectory:
python experiments/data_loading_example.py --dataset_dir /path/to/traj_data --gif_out out.gif
"""
import argparse
from pathlib import Path
import jax
import numpy as np
import imageio
from kinetix.data import TrajectoryDatasetManager
from kinetix.environment import EnvParams, StaticEnvParams
from kinetix.render import make_render_pixels
def static_env_params_from_batch(env_state, downscale: int = 1) -> StaticEnvParams:
"""Derive StaticEnvParams from the array shapes in a loaded batch's env_state.
Structural counts (polygons, circles, joints, thrusters, vertices) are read
from array dimensions. Params that have no shape encoding (num_static_fixated_polys,
num_motor_bindings, num_thruster_bindings) keep their defaults.
"""
return StaticEnvParams(
num_polygons=env_state.polygon.position.shape[-2],
num_circles=env_state.circle.position.shape[-2],
num_joints=env_state.joint.active.shape[-1],
num_thrusters=env_state.thruster.active.shape[-1],
max_polygon_vertices=env_state.polygon.vertices.shape[-2],
downscale=downscale,
)
def render_and_save_gifs(batch, gif_out: str, n_gifs: int) -> None:
n = min(n_gifs, batch.action.shape[0])
static_env_params = static_env_params_from_batch(batch.env_state, downscale=2)
pixel_renderer = jax.jit(make_render_pixels(EnvParams(), static_env_params))
base = Path(gif_out)
for i in range(n):
path = str(base) if n == 1 else str(base.with_stem(f"{base.stem}_{i:03d}"))
save_gif(render_trajectory(batch, i, pixel_renderer), path)
print(f"GIF saved: {path}")
def render_trajectory(batch, traj_idx: int, pixel_renderer) -> np.ndarray:
traj_state = jax.tree.map(lambda x: x[traj_idx], batch.env_state)
frames_f32 = jax.vmap(pixel_renderer)(traj_state)
frames = np.clip(np.array(frames_f32), 0, 255).astype(np.uint8)
return frames.transpose(0, 2, 1, 3)[:, ::-1] # (T, H, W, C), upright
def save_gif(frames: np.ndarray, path: str, fps: int = 10) -> None:
imageio.mimsave(path, frames, fps=fps, loop=0)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_dir", required=True)
parser.add_argument("--batch_size", type=int, default=1024)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--maximum_number_of_shards", type=int, default=-1)
parser.add_argument("--n_val_shards", type=int, default=1)
parser.add_argument("--dataset_proportions", type=float, nargs="+", default=[1.0])
parser.add_argument(
"--gif_out",
default=None,
help="Path for saved GIF(s). With --n_gifs > 1, index is inserted before the extension.",
)
parser.add_argument("--n_gifs", type=int, default=1, help="Number of trajectories to render as GIFs.")
args = parser.parse_args()
config = vars(args)
_dataset_common = dict(
dataset_dir=config["dataset_dir"],
seed=config["seed"],
maximum_number_of_shards=config["maximum_number_of_shards"],
n_val_shards=config["n_val_shards"],
)
# TrajectoryDatasetManager: batch_size is in trajectories, not timesteps.
# Each batch has shape (batch_size, T, *dims).
dataset_manager = TrajectoryDatasetManager(
batch_size=config["batch_size"],
val_batch_size=config["batch_size"],
**_dataset_common,
)
print(f"Dataset length (estimated minibatches): {dataset_manager.length}")
# ── Load a training batch ─────────────────────────────────────────────────
batch = dataset_manager.load_next_batch()
print(f"Training batch — action shape: {batch.action.shape}")
print(f"Training batch — mask shape: {batch.mask.shape}")
# ── Optional GIF rendering ────────────────────────────────────────────────
if config["gif_out"]:
render_and_save_gifs(batch, config["gif_out"], config["n_gifs"])
# ── Inspect the validation batch ─────────────────────────────────────────
if dataset_manager.validation_batch is not None:
val = dataset_manager.validation_batch
print(f"Validation batch — action shape: {val.action.shape}")
print(f"Validation batch — mask shape: {val.mask.shape}")
else:
print("No validation batch (set n_val_shards >= 1 to enable)")
# ── Iterate a few batches ─────────────────────────────────────────────────
for i in range(3):
b = dataset_manager.load_next_batch()
n_transitions = b.action.shape[0] * b.action.shape[1]
print(f"Batch {i + 1}: {n_transitions} transitions, mask fraction = {b.mask.mean():.3f}")
if __name__ == "__main__":
main()