Skip to content
Closed
8 changes: 5 additions & 3 deletions src/imitation/algorithms/preference_comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -1158,6 +1158,7 @@ def __init__(
# for keeping track of the global iteration, in case train() is called
# multiple times
self._iteration = 0
self.trajectory_generator_num_steps = 0

self.model = reward_model

Expand Down Expand Up @@ -1202,7 +1203,7 @@ def train(
self,
total_timesteps: int,
total_comparisons: int,
callback: Optional[Callable[[int], None]] = None,
callback: Optional[Callable[[int, int], None]] = None,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably should add in the docstring what the callback type signature represents.

) -> Mapping[str, Any]:
"""Train the reward model and the policy if applicable.

Expand Down Expand Up @@ -1286,14 +1287,15 @@ def train(
with self.logger.accumulate_means("agent"):
self.logger.log(f"Training agent for {num_steps} timesteps")
self.trajectory_generator.train(steps=num_steps)
self.trajectory_generator_num_steps += num_steps

self.logger.dump(self._iteration)

########################
# Additional Callbacks #
########################
if callback:
callback(self._iteration)
self._iteration += 1
if callback:
callback(self._iteration, self.trajectory_generator_num_steps)

return {"reward_loss": reward_loss, "reward_accuracy": reward_accuracy}
4 changes: 2 additions & 2 deletions src/imitation/scripts/common/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
import os
from typing import Any, Mapping, Sequence, Tuple, Union
from typing import Any, Mapping, Optional, Sequence, Tuple, Union

import sacred
from stable_baselines3.common import vec_env
Expand Down Expand Up @@ -131,7 +131,7 @@ def make_venv(
env_name: str,
num_vec: int,
parallel: bool,
log_dir: str,
log_dir: Optional[str],
max_episode_steps: int,
env_make_kwargs: Mapping[str, Any],
**kwargs,
Expand Down
4 changes: 4 additions & 0 deletions src/imitation/scripts/common/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ def config():
# Evaluation
n_episodes_eval = 50 # Num of episodes for final mean ground truth return

# Visualization
Comment thread
yawen-d marked this conversation as resolved.
videos = False # save video files
video_kwargs = {} # arguments to VideoWrapper

locals() # quieten flake8


Expand Down
3 changes: 2 additions & 1 deletion src/imitation/scripts/config/train_preference_comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def train_defaults():
fragment_length = 100 # timesteps per fragment used for comparisons
total_timesteps = int(1e6) # total number of environment timesteps
total_comparisons = 5000 # total number of comparisons to elicit
num_iterations = 5 # Arbitrary, should be tuned for the task
num_iterations = 50 # Arbitrary, should be tuned for the task
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies if this has been discussed, but why are you doing this?

comparison_queue_size = None
# factor by which to oversample transitions before creating fragments
transition_oversampling = 1
Expand All @@ -39,6 +39,7 @@ def train_defaults():
cross_entropy_loss_kwargs = {}
reward_trainer_kwargs = {
"epochs": 3,
"weight_decay": 0.0,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll have to remember changing this as I have a PR that replaces weight decay with a general regularization API (#481). @AdamGleave what do you think, should we merge my PR or this one first?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably best to merge your PR first, though really depends which one is ready earlier.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#481 is ready and passing all the tests AFAIK.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for proposing #481. #481 seems to be the feature wanted. I'll make changes accordingly.

}
save_preferences = False # save preference dataset at the end?
agent_path = None # path to a (partially) trained agent to load at the beginning
Expand Down
32 changes: 25 additions & 7 deletions src/imitation/scripts/train_adversarial.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import sacred.commands
import torch as th
from sacred.observers import FileStorageObserver
from stable_baselines3.common import vec_env

from imitation.algorithms.adversarial import airl as airl_algo
from imitation.algorithms.adversarial import common
Expand All @@ -18,21 +19,33 @@
from imitation.scripts.common import common as common_config
from imitation.scripts.common import demonstrations, reward, rl, train
from imitation.scripts.config.train_adversarial import train_adversarial_ex
from imitation.util import video_wrapper

logger = logging.getLogger("imitation.scripts.train_adversarial")


def save(trainer, save_path):
def save(
_config: Mapping[str, Any],
trainer: common.AdversarialTrainer,
save_path: str,
eval_venv: vec_env.VecEnv,
) -> None:
"""Save discriminator and generator."""
# We implement this here and not in Trainer since we do not want to actually
# serialize the whole Trainer (including e.g. expert demonstrations).
os.makedirs(save_path, exist_ok=True)
th.save(trainer.reward_train, os.path.join(save_path, "reward_train.pt"))
th.save(trainer.reward_test, os.path.join(save_path, "reward_test.pt"))
serialize.save_stable_model(
os.path.join(save_path, "gen_policy"),
trainer.gen_algo,
)
policy_path = os.path.join(save_path, "gen_policy")
serialize.save_stable_model(policy_path, trainer.gen_algo)
if _config["train"]["videos"]:
video_wrapper.record_and_save_video(
output_dir=policy_path,
policy=trainer.gen_algo.policy,
eval_venv=eval_venv,
video_kwargs=_config["train"]["video_kwargs"],
logger=trainer.logger,
)


def _add_hook(ingredient: sacred.Ingredient) -> None:
Expand Down Expand Up @@ -68,6 +81,7 @@ def dummy_config():
def train_adversarial(
_run,
_seed: int,
_config: Mapping[str, Any],
show_config: bool,
algo_cls: Type[common.AdversarialTrainer],
algorithm_kwargs: Mapping[str, Any],
Expand All @@ -85,6 +99,7 @@ def train_adversarial(

Args:
_seed: Random seed.
_config: Sacred configuration dict.
show_config: Print the merged config before starting training. This is
analogous to the print_config command, but will show config after
rather than before merging `algorithm_specific` arguments.
Expand Down Expand Up @@ -117,6 +132,7 @@ def train_adversarial(
expert_trajs = demonstrations.load_expert_trajs()

venv = common_config.make_venv()
eval_venv = common_config.make_venv(log_dir=None)

reward_net = reward.make_reward_net(venv)
relabel_reward_fn = functools.partial(
Expand Down Expand Up @@ -153,13 +169,15 @@ def train_adversarial(

def callback(round_num):
if checkpoint_interval > 0 and round_num % checkpoint_interval == 0:
save(trainer, os.path.join(log_dir, "checkpoints", f"{round_num:05d}"))
save_path = os.path.join(log_dir, "checkpoints", f"{round_num:05d}")
save(_config, trainer, save_path, eval_venv)

trainer.train(total_timesteps, callback)

# Save final artifacts.
if checkpoint_interval >= 0:
save(trainer, os.path.join(log_dir, "checkpoints", "final"))
save_path = os.path.join(log_dir, "checkpoints", "final")
save(_config, trainer, save_path, eval_venv)

return {
"imit_stats": train.eval_policy(trainer.policy, trainer.venv_train),
Expand Down
12 changes: 12 additions & 0 deletions src/imitation/scripts/train_imitation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from imitation.policies import serialize
from imitation.scripts.common import common, demonstrations, train
from imitation.scripts.config.train_imitation import train_imitation_ex
from imitation.util import video_wrapper

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -98,6 +99,7 @@ def load_expert_policy(
@train_imitation_ex.capture
def train_imitation(
_run,
_config: Mapping[str, Any],
bc_kwargs: Mapping[str, Any],
bc_train_kwargs: Mapping[str, Any],
dagger: Mapping[str, Any],
Expand All @@ -120,6 +122,7 @@ def train_imitation(
"""
custom_logger, log_dir = common.setup_logging()
venv = common.make_venv()
eval_venv = common.make_venv(log_dir=None)
imit_policy = make_policy(venv, agent_path=agent_path)

expert_trajs = None
Expand Down Expand Up @@ -163,6 +166,15 @@ def train_imitation(
# TODO(adam): add checkpointing to BC?
bc_trainer.save_policy(policy_path=osp.join(log_dir, "final.th"))

if _config["train"]["videos"]:
video_wrapper.record_and_save_video(
output_dir=log_dir,
policy=imit_policy,
eval_venv=eval_venv,
video_kwargs=_config["train"]["video_kwargs"],
logger=custom_logger,
)

return {
"imit_stats": train.eval_policy(imit_policy, venv),
"expert_stats": rollout.rollout_stats(
Expand Down
57 changes: 37 additions & 20 deletions src/imitation/scripts/train_preference_comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@

import functools
import os
import os.path as osp
from typing import Any, Mapping, Optional, Type, Union

import torch as th
from sacred.observers import FileStorageObserver
from stable_baselines3.common import type_aliases
from stable_baselines3.common import type_aliases, vec_env

from imitation.algorithms import preference_comparisons
from imitation.data import types
Expand All @@ -21,33 +22,37 @@
from imitation.scripts.config.train_preference_comparisons import (
train_preference_comparisons_ex,
)


def save_model(
agent_trainer: preference_comparisons.AgentTrainer,
save_path: str,
):
"""Save the model as model.pkl."""
serialize.save_stable_model(
output_dir=os.path.join(save_path, "policy"),
model=agent_trainer.algorithm,
)
from imitation.util import video_wrapper


def save_checkpoint(
_config: Mapping[str, Any],
trainer: preference_comparisons.PreferenceComparisons,
save_path: str,
allow_save_policy: Optional[bool],
):
eval_venv: vec_env.VecEnv,
) -> None:
"""Save reward model and optionally policy."""
os.makedirs(save_path, exist_ok=True)
th.save(trainer.model, os.path.join(save_path, "reward_net.pt"))
th.save(trainer.model, osp.join(save_path, "reward_net.pt"))
if allow_save_policy:
# Note: We should only save the model as model.pkl if `trajectory_generator`
# contains one. Specifically we check if the `trajectory_generator` contains an
# `algorithm` attribute.
assert hasattr(trainer.trajectory_generator, "algorithm")
save_model(trainer.trajectory_generator, save_path)
policy_dir = osp.join(save_path, "policy")
serialize.save_stable_model(
output_dir=policy_dir,
model=trainer.trajectory_generator.algorithm,
)
if _config["train"]["videos"]:
video_wrapper.record_and_save_video(
output_dir=policy_dir,
policy=trainer.trajectory_generator.algorithm.policy,
eval_venv=eval_venv,
video_kwargs=_config["train"]["video_kwargs"],
logger=trainer.logger,
)
else:
trainer.logger.warn(
"trainer.trajectory_generator doesn't contain a policy to save.",
Expand All @@ -57,6 +62,7 @@ def save_checkpoint(
@train_preference_comparisons_ex.main
def train_preference_comparisons(
_seed: int,
_config: Mapping[str, Any],
total_timesteps: int,
total_comparisons: int,
num_iterations: int,
Expand All @@ -82,6 +88,7 @@ def train_preference_comparisons(

Args:
_seed: Random seed.
_config: Sacred configuration dict.
total_timesteps: number of environment interaction steps
total_comparisons: number of preferences to gather in total
num_iterations: number of times to train the agent against the reward model
Expand Down Expand Up @@ -140,6 +147,7 @@ def train_preference_comparisons(
"""
custom_logger, log_dir = common.setup_logging()
venv = common.make_venv()
eval_venv = common.make_venv(log_dir=None)

reward_net = reward.make_reward_net(venv)
relabel_reward_fn = functools.partial(
Expand Down Expand Up @@ -220,12 +228,19 @@ def train_preference_comparisons(
query_schedule=query_schedule,
)

def save_callback(iteration_num):
def save_callback(iteration_num, traj_generator_num_steps):
if checkpoint_interval > 0 and iteration_num % checkpoint_interval == 0:
save_path = osp.join(
log_dir,
"checkpoints",
f"iter_{iteration_num:04d}_step_{traj_generator_num_steps:08d}",
)
save_checkpoint(
_config,
trainer=main_trainer,
save_path=os.path.join(log_dir, "checkpoints", f"{iteration_num:04d}"),
save_path=save_path,
allow_save_policy=bool(trajectory_path is None),
eval_venv=eval_venv,
)

results = main_trainer.train(
Expand All @@ -235,14 +250,16 @@ def save_callback(iteration_num):
)

if save_preferences:
main_trainer.dataset.save(os.path.join(log_dir, "preferences.pkl"))
main_trainer.dataset.save(osp.join(log_dir, "preferences.pkl"))

# Save final artifacts.
if checkpoint_interval >= 0:
save_checkpoint(
_config,
trainer=main_trainer,
save_path=os.path.join(log_dir, "checkpoints", "final"),
save_path=osp.join(log_dir, "checkpoints", "final"),
allow_save_policy=bool(trajectory_path is None),
eval_venv=eval_venv,
)

# Storing and evaluating policy only useful if we actually generate trajectory data
Expand All @@ -255,7 +272,7 @@ def save_callback(iteration_num):

def main_console():
observer = FileStorageObserver(
os.path.join("output", "sacred", "train_preference_comparisons"),
osp.join("output", "sacred", "train_preference_comparisons"),
)
train_preference_comparisons_ex.observers.append(observer)
train_preference_comparisons_ex.run_commandline()
Expand Down
Loading