diff --git a/configs/config.yaml b/configs/config.yaml new file mode 100644 index 000000000..2446ff153 --- /dev/null +++ b/configs/config.yaml @@ -0,0 +1,12 @@ +defaults: + - _self_ + - env: extended_bin_pack # [bin_pack, cleaner, connector, cvrp, game_2048, graph_coloring, job_shop, knapsack, maze, minesweeper, mmst, multi_cvrp, robot_warehouse, rubiks_cube, snake, sudoku, tetris, tsp] + +agent: a2c # [random, a2c] + +seed: 0 + +logger: + type: terminal # [neptune, tensorboard, terminal] + save_checkpoint: false # [false, true] + name: ${agent}_${env.name} diff --git a/configs/env/bin_pack.yaml b/configs/env/bin_pack.yaml new file mode 100644 index 000000000..198f22080 --- /dev/null +++ b/configs/env/bin_pack.yaml @@ -0,0 +1,37 @@ +name: bin_pack +registered_version: BinPackValueBased-v0 + +env_settings: + reward_fn: ValueBasedDenseReward + generator: RandomValueProblemGenerator + +generator_settings: + max_num_items: 20 + max_num_ems: 40 + split_num_same_items: 2 + + +network: + num_transformer_layers: 2 + transformer_num_heads: 8 + transformer_key_size: 16 + transformer_mlp_units: [512] + +training: + num_epochs: 5 + num_learner_steps_per_epoch: 5 + n_steps: 30 + total_batch_size: 2 + +evaluation: + eval_total_batch_size: 2 + greedy_eval_total_batch_size: 2 + +a2c: + normalize_advantage: False + discount_factor: 1.0 + bootstrapping_factor: 0.95 + l_pg: 1.0 + l_td: 1.0 + l_en: 0.005 + learning_rate: 1e-4 diff --git a/configs/env/constrained_bin_pack.yaml b/configs/env/constrained_bin_pack.yaml new file mode 100644 index 000000000..6a3c8debd --- /dev/null +++ b/configs/env/constrained_bin_pack.yaml @@ -0,0 +1,26 @@ +name: constrained_bin_pack +registered_version: ConstrainedBinPack-v0 + +network: + num_transformer_layers: 2 + transformer_num_heads: 8 + transformer_key_size: 16 + transformer_mlp_units: [512] + +training: + num_epochs: 5 + num_learner_steps_per_epoch: 2 + n_steps: 3 + total_batch_size: 4 + +evaluation: + eval_total_batch_size: 0 + greedy_eval_total_batch_size: 0 +a2c: + normalize_advantage: False + discount_factor: 1.0 + bootstrapping_factor: 0.95 + l_pg: 1.0 + l_td: 1.0 + l_en: 0.005 + learning_rate: 1e-4 diff --git a/configs/env/extended_bin_pack.yaml b/configs/env/extended_bin_pack.yaml new file mode 100644 index 000000000..3922563b5 --- /dev/null +++ b/configs/env/extended_bin_pack.yaml @@ -0,0 +1,42 @@ +name: extended_bin_pack +registered_version: ExtendedBinPack-v0 + +env_settings: + reward_fn: ValueBasedDenseReward + generator: ExtendedTrainingGenerator + is_value_based: True + is_rotation_allowed: True + normalize_dimensions: True + +generator_settings: + max_num_items: 480 + max_num_ems: 80 + mean_item_value: 0 + std_item_value: 1 + min_target_volume: 2 + max_target_volume: 30 + +network: + num_transformer_layers: 2 + transformer_num_heads: 8 + transformer_key_size: 16 + transformer_mlp_units: [512] + +training: + num_epochs: 550 + num_learner_steps_per_epoch: 100 + n_steps: 30 + total_batch_size: 64 + +evaluation: + eval_total_batch_size: 10000 + greedy_eval_total_batch_size: 10000 + +a2c: + normalize_advantage: False + discount_factor: 1.0 + bootstrapping_factor: 0.95 + l_pg: 1.0 + l_td: 1.0 + l_en: 0.005 + learning_rate: 1e-4 diff --git a/jumanji/environments/__init__.py b/jumanji/environments/__init__.py index d69fbbf8e..5462c217c 100644 --- a/jumanji/environments/__init__.py +++ b/jumanji/environments/__init__.py @@ -29,7 +29,7 @@ from jumanji.environments.logic.sliding_tile_puzzle.env import SlidingTilePuzzle from jumanji.environments.logic.sudoku.env import Sudoku from jumanji.environments.packing import bin_pack, flat_pack, job_shop, knapsack, tetris -from jumanji.environments.packing.bin_pack.env import BinPack +from jumanji.environments.packing.bin_pack.env import BinPack, ExtendedBinPack from jumanji.environments.packing.flat_pack.env import FlatPack from jumanji.environments.packing.job_shop.env import JobShop from jumanji.environments.packing.knapsack.env import Knapsack diff --git a/jumanji/environments/packing/bin_pack/__init__.py b/jumanji/environments/packing/bin_pack/__init__.py index cbd2b8d67..738ec0741 100644 --- a/jumanji/environments/packing/bin_pack/__init__.py +++ b/jumanji/environments/packing/bin_pack/__init__.py @@ -12,5 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -from jumanji.environments.packing.bin_pack.env import BinPack +from jumanji.environments.packing.bin_pack.env import BinPack, ExtendedBinPack from jumanji.environments.packing.bin_pack.types import Observation, State diff --git a/jumanji/environments/packing/bin_pack/conftest.py b/jumanji/environments/packing/bin_pack/conftest.py index 33ca6c437..ddb9911eb 100644 --- a/jumanji/environments/packing/bin_pack/conftest.py +++ b/jumanji/environments/packing/bin_pack/conftest.py @@ -18,17 +18,29 @@ import pytest from jumanji import specs -from jumanji.environments.packing.bin_pack.env import BinPack +from jumanji.environments.packing.bin_pack.env import BinPack, ExtendedBinPack from jumanji.environments.packing.bin_pack.generator import ( TWENTY_FOOT_DIMS, + ExtendedRandomGenerator, + ExtendedToyGenerator, Generator, RandomGenerator, ToyGenerator, make_container, ) -from jumanji.environments.packing.bin_pack.reward import DenseReward, SparseReward +from jumanji.environments.packing.bin_pack.reward import ( + DenseReward, + SparseReward, + ValueBasedDenseReward, + ValueBasedSparseReward, +) from jumanji.environments.packing.bin_pack.space import Space -from jumanji.environments.packing.bin_pack.types import Item, Location, State +from jumanji.environments.packing.bin_pack.types import ( + Item, + Location, + State, + ValuedItem, +) class DummyGenerator(Generator): @@ -67,6 +79,7 @@ def __call__(self, key: chex.PRNGKey) -> State: y_len=jnp.array([700, 700, 500], jnp.int32), z_len=jnp.array([900, 900, 600], jnp.int32), ), + nb_items=3, items_mask=jnp.array([True, True, True], bool), items_placed=jnp.array([False, False, False], bool), items_location=jax.tree_util.tree_map( @@ -74,6 +87,361 @@ def __call__(self, key: chex.PRNGKey) -> State: ), action_mask=None, sorted_ems_indexes=jnp.arange(self.max_num_ems, dtype=jnp.int32), + # For non value based optimisation set these to dummy values by default + instance_max_item_value_magnitude=0.0, + instance_total_value=0.0, + # For deterministic instance generators we always set the key to 0. + key=jax.random.PRNGKey(0), + ) + + +class DummyValueGenerator(Generator): + """Dummy instance generator used for testing. It outputs a constant instance with a 20-ft + container and 3 items: two identical items and a different third one to be able to + test item aggregation. + """ + + def __init__(self) -> None: + """Instantiate a dummy `Generator` with 3 items and 10 EMSs maximum.""" + super(DummyValueGenerator, self).__init__( + max_num_items=3, max_num_ems=10, container_dims=TWENTY_FOOT_DIMS + ) + + def __call__(self, key: chex.PRNGKey) -> State: + """Returns a fixed instance with 3 items, 10 EMSs and a 20-ft container. + + Args: + key: random key not used here but kept for consistency with parent signature. + + Returns: + State. + """ + del key + container = make_container(TWENTY_FOOT_DIMS) + return State( + container=container, + ems=jax.tree_util.tree_map( + lambda x: jnp.array([x] + (self.max_num_ems - 1) * [0], jnp.int32), + container, + ), + ems_mask=jnp.array([True] + (self.max_num_ems - 1) * [False], bool), + items=ValuedItem( + # The 1st and 2nd items have the same shape and value. + x_len=jnp.array([1000, 1000, 500], jnp.int32), + y_len=jnp.array([700, 700, 500], jnp.int32), + z_len=jnp.array([900, 900, 600], jnp.int32), + value=jnp.array([2.0, 2.0, 1.5], jnp.float32), + ), + items_mask=jnp.array([True, True, True], bool), + items_placed=jnp.array([False, False, False], bool), + items_location=jax.tree_util.tree_map( + lambda x: jnp.array(3 * [x], jnp.int32), Location(x=0, y=0, z=0) + ), + action_mask=None, + sorted_ems_indexes=jnp.arange(self.max_num_ems, dtype=jnp.int32), + # For non value based optimisation set these to dummy values by default + instance_max_item_value_magnitude=2.0, + instance_total_value=5.5, + # For deterministic instance generators we always set the key to 0. + key=jax.random.PRNGKey(0), + nb_items=3, + ) + + +class DummyExtendedGenerator(DummyGenerator): + """Dummy instance generator used for testing. It outputs a constant instance with a 20-ft + container and 3 items that can take all 6 possible orientations: two identical items and a + different third one. + """ + + def __init__(self) -> None: + """Instantiate a dummy `Generator` with 3 items and 10 EMSs maximum.""" + super(DummyGenerator, self).__init__( + max_num_items=3, max_num_ems=10, container_dims=TWENTY_FOOT_DIMS + ) + + def __call__(self, key: chex.PRNGKey) -> State: + """Returns a fixed instance with 3 items, 10 EMSs and a 20-ft container. + + Args: + key: random key not used here but kept for consistency with parent signature. + + Returns: + State. + """ + del key + container = make_container(TWENTY_FOOT_DIMS) + return State( + container=container, + ems=jax.tree_util.tree_map( + lambda x: jnp.array([x] + (self.max_num_ems - 1) * [0], jnp.int32), + container, + ), + ems_mask=jnp.array([True] + (self.max_num_ems - 1) * [False], bool), + items=ValuedItem( + # The 1st and 2nd items have the same shape. + x_len=jnp.array( + [ + [1000, 1000, 500], + [1000, 1000, 500], + [700, 700, 500], + [700, 700, 500], + [900, 900, 600], + [900, 900, 600], + ], + jnp.int32, + ), + y_len=jnp.array( + [ + [700, 700, 500], + [900, 900, 600], + [1000, 1000, 500], + [900, 900, 600], + [700, 700, 500], + [1000, 1000, 500], + ], + jnp.int32, + ), + z_len=jnp.array( + [ + [900, 900, 600], + [700, 700, 500], + [900, 900, 600], + [1000, 1000, 500], + [1000, 1000, 500], + [700, 700, 500], + ], + jnp.int32, + ), + value=jnp.array( + [ + [2.0, 2.0, 1.5], + [2.0, 2.0, 1.5], + [2.0, 2.0, 1.5], + [2.0, 2.0, 1.5], + [2.0, 2.0, 1.5], + [2.0, 2.0, 1.5], + ], + jnp.float32, + ), + ), + items_mask=jnp.array( + [ + [True, True, True], + [True, True, True], + [True, True, True], + [True, True, True], + [True, True, True], + [True, True, True], + ], + bool, + ), + items_placed=jnp.array( + [ + [False, False, False], + [False, False, False], + [False, False, False], + [False, False, False], + [False, False, False], + [False, False, False], + ], + bool, + ), + items_location=jax.tree_util.tree_map( + lambda x: jnp.array(3 * [x], jnp.int32), Location(x=0, y=0, z=0) + ), + instance_max_item_value_magnitude=2.0, + instance_total_value=5.5, + action_mask=None, + sorted_ems_indexes=jnp.arange(self.max_num_ems, dtype=jnp.int32), + # For deterministic instance generators we always set the key to 0. + key=jax.random.PRNGKey(0), + nb_items=3, + ) + + +class DummyRotationGenerator(DummyGenerator): + """Dummy instance generator used for testing. It outputs a constant instance with a 20-ft + container and 3 items: two identical items and a different third one to be able to + test item aggregation. + """ + + def __init__(self) -> None: + """Instantiate a dummy `Generator` with 3 items and 10 EMSs maximum.""" + super(DummyGenerator, self).__init__( + max_num_items=3, max_num_ems=10, container_dims=TWENTY_FOOT_DIMS + ) + + def __call__(self, key: chex.PRNGKey) -> State: + """Returns a fixed instance with 3 items, 10 EMSs and a 20-ft container. + + Args: + key: random key not used here but kept for consistency with parent signature. + + Returns: + State. + """ + del key + container = make_container(TWENTY_FOOT_DIMS) + return State( + container=container, + ems=jax.tree_util.tree_map( + lambda x: jnp.array([x] + (self.max_num_ems - 1) * [0], jnp.int32), + container, + ), + ems_mask=jnp.array([True] + (self.max_num_ems - 1) * [False], bool), + items=Item( + # The 1st and 2nd items have the same shape. + x_len=jnp.array( + [ + [1000, 1000, 500], + [1000, 1000, 500], + [700, 700, 500], + [700, 700, 500], + [900, 900, 600], + [900, 900, 600], + ], + jnp.int32, + ), + y_len=jnp.array( + [ + [700, 700, 500], + [900, 900, 600], + [1000, 1000, 500], + [900, 900, 600], + [700, 700, 500], + [1000, 1000, 500], + ], + jnp.int32, + ), + z_len=jnp.array( + [ + [900, 900, 600], + [700, 700, 500], + [900, 900, 600], + [1000, 1000, 500], + [1000, 1000, 500], + [700, 700, 500], + ], + jnp.int32, + ), + ), + items_mask=jnp.array( + [ + [True, True, True], + [True, True, True], + [True, True, True], + [True, True, True], + [True, True, True], + [True, True, True], + ], + bool, + ), + items_placed=jnp.array( + [ + [False, False, False], + [False, False, False], + [False, False, False], + [False, False, False], + [False, False, False], + [False, False, False], + ], + bool, + ), + items_location=jax.tree_util.tree_map( + lambda x: jnp.array(3 * [x], jnp.int32), Location(x=0, y=0, z=0) + ), + instance_max_item_value_magnitude=0, + instance_total_value=0, + action_mask=None, + sorted_ems_indexes=jnp.arange(self.max_num_ems, dtype=jnp.int32), + # For deterministic instance generators we always set the key to 0. + key=jax.random.PRNGKey(0), + nb_items=3, + ) + + +class FullSupportDummyGenerator(Generator): + """Dummy instance generator used for testing. It outputs a constant instance with a 20-ft + container and 11 items: 10 identical small items of size + (container_length/10, container_width, 300) and one big item of size + (container_length, container_width, 1900). + This instance is used to test the full support constraint by forcing the agent to start by + placing one of the small items. Using this instance allows us to test both that the agent isn't + able to place items if they're not fully supported, and make sure that + the merger of ems is correct. + """ + + def __init__(self) -> None: + """Instantiate a dummy `Generator` with 3 items and 10 EMSs maximum.""" + super(FullSupportDummyGenerator, self).__init__( + max_num_items=11, max_num_ems=40, container_dims=TWENTY_FOOT_DIMS + ) + + def __call__(self, key: chex.PRNGKey) -> State: + """Returns a fixed instance with 3 items, 10 EMSs and a 20-ft container. + + Args: + key: random key not used here but kept for consistency with parent signature. + + Returns: + State. + """ + del key + container = make_container(TWENTY_FOOT_DIMS) + + return State( + container=container, + instance_max_item_value_magnitude=0, + instance_total_value=0, + ems=jax.tree_util.tree_map( + lambda x: jnp.array([x] + (self.max_num_ems - 1) * [0], jnp.int32), + container, + ), + ems_mask=jnp.array([True] + (self.max_num_ems - 1) * [False], bool), + items=Item( + x_len=jnp.array([container.x2] + 10 * [container.x2 / 10], jnp.int32), + y_len=jnp.array(11 * [container.y2], jnp.int32), + z_len=jnp.array([1900] + 10 * [300], jnp.int32), + ), + items_mask=jnp.array( + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + ], + bool, + ), + items_placed=jnp.array( + [ + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + bool, + ), + items_location=jax.tree_util.tree_map( + lambda loc: jnp.array(11 * [loc], jnp.int32), Location(x=0, y=0, z=0) + ), + nb_items=11, + action_mask=None, + sorted_ems_indexes=jnp.arange(self.max_num_ems, dtype=jnp.int32), # For deterministic instance generators we always set the key to 0. key=jax.random.PRNGKey(0), ) @@ -84,6 +452,16 @@ def dummy_generator() -> DummyGenerator: return DummyGenerator() +@pytest.fixture +def dummy_rotation_generator() -> DummyRotationGenerator: + return DummyRotationGenerator() + + +@pytest.fixture +def dummy_extended_generator() -> DummyExtendedGenerator: + return DummyExtendedGenerator() + + @pytest.fixture def toy_generator() -> ToyGenerator: return ToyGenerator() @@ -95,6 +473,19 @@ def random_generator() -> RandomGenerator: return RandomGenerator(max_num_items=20, max_num_ems=80) +@pytest.fixture +def rotation_toy_generator() -> ExtendedToyGenerator: + return ExtendedToyGenerator() + + +@pytest.fixture +def rotation_random_generator() -> ExtendedRandomGenerator: + """Returns a `RandomGenerator` with up to 20 items and that can handle 80 EMSs.""" + return ExtendedRandomGenerator( + max_num_items=20, max_num_ems=80, is_rotation_allowed=True, is_value_based=False + ) + + @pytest.fixture def dummy_state(dummy_generator: DummyGenerator) -> State: state = dummy_generator(key=jax.random.PRNGKey(0)) @@ -104,11 +495,52 @@ def dummy_state(dummy_generator: DummyGenerator) -> State: return state +@pytest.fixture +def dummy_rotation_state( + dummy_rotation_generator: DummyRotationGenerator, +) -> State: + state = dummy_rotation_generator(key=jax.random.PRNGKey(0)) + num_ems = dummy_rotation_generator.max_num_ems + num_items = dummy_rotation_generator.max_num_items + state.action_mask = jnp.ones((6, num_ems, num_items), bool) + return state + + @pytest.fixture def bin_pack(dummy_generator: DummyGenerator) -> BinPack: return BinPack(generator=dummy_generator, obs_num_ems=5) +@pytest.fixture() +def rotation_bin_pack( + dummy_rotation_generator: DummyRotationGenerator, +) -> ExtendedBinPack: + """ + Bin pack environment where the items can be rotated. + """ + return ExtendedBinPack( + generator=dummy_rotation_generator, + obs_num_ems=5, + is_rotation_allowed=True, + is_value_based=False, + ) + + +@pytest.fixture() +def extended_bin_pack( + dummy_extended_generator: DummyExtendedGenerator, +) -> ExtendedBinPack: + """ + Bin pack environment where the items have a value and can be rotated. + """ + return ExtendedBinPack( + generator=dummy_extended_generator, + obs_num_ems=5, + is_rotation_allowed=True, + is_value_based=True, + ) + + @pytest.fixture def obs_spec(bin_pack: BinPack) -> specs.Spec: return bin_pack.observation_spec @@ -135,6 +567,17 @@ def bin_pack_dense_reward( ) +@pytest.fixture +def bin_pack_dense_value_reward() -> BinPack: + return ExtendedBinPack( + generator=DummyValueGenerator(), + obs_num_ems=5, + reward_fn=ValueBasedDenseReward(), + is_value_based=True, + is_rotation_allowed=False, + ) + + @pytest.fixture def sparse_reward() -> SparseReward: return SparseReward() @@ -149,3 +592,28 @@ def bin_pack_sparse_reward( obs_num_ems=5, reward_fn=sparse_reward, ) + + +@pytest.fixture +def bin_pack_sparse_value_reward() -> BinPack: + return ExtendedBinPack( + generator=DummyValueGenerator(), + obs_num_ems=5, + reward_fn=ValueBasedSparseReward(), + is_rotation_allowed=False, + is_value_based=True, + ) + + +@pytest.fixture +def full_support_dummy_generator() -> FullSupportDummyGenerator: + return FullSupportDummyGenerator() + + +@pytest.fixture +def full_support_bin_pack( + full_support_dummy_generator: FullSupportDummyGenerator, +) -> BinPack: + return BinPack( + generator=full_support_dummy_generator, full_support=True, debug=True + ) diff --git a/jumanji/environments/packing/bin_pack/env.py b/jumanji/environments/packing/bin_pack/env.py index 4f410af62..1d21e06f0 100644 --- a/jumanji/environments/packing/bin_pack/env.py +++ b/jumanji/environments/packing/bin_pack/env.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import functools import itertools from functools import cached_property -from typing import Dict, Optional, Sequence, Tuple +from typing import Any, Dict, Optional, Sequence, Tuple, cast import chex import jax @@ -24,21 +25,31 @@ from jumanji import specs from jumanji.env import Environment -from jumanji.environments.packing.bin_pack.generator import Generator, RandomGenerator +from jumanji.environments.packing.bin_pack.generator import ( + ExtendedRandomGenerator, + Generator, + RandomGenerator, +) from jumanji.environments.packing.bin_pack.reward import DenseReward, RewardFn from jumanji.environments.packing.bin_pack.space import Space from jumanji.environments.packing.bin_pack.types import ( EMS, Item, + ItemType, Location, Observation, State, + ValuedItem, item_fits_in_item, item_from_space, item_volume, space_from_item_and_location, + valued_item_from_space_and_max_value, +) +from jumanji.environments.packing.bin_pack.viewer import ( + BinPackViewer, + ExtendedBinPackViewer, ) -from jumanji.environments.packing.bin_pack.viewer import BinPackViewer from jumanji.tree_utils import tree_add_element, tree_slice from jumanji.types import TimeStep, restart, termination, transition from jumanji.viewer import Viewer @@ -62,16 +73,16 @@ class BinPack(Environment[State, specs.MultiDiscreteArray, Observation]): - ems_mask: jax array (bool) of shape (obs_num_ems,) indicates the EMSs that are valid. - items: `Item` tree of jax arrays (float if `normalize_dimensions` else int32) each of - shape (max_num_items,), + shape (max_num_items, 6), characteristics of all items for this instance. - - items_mask: jax array (bool) of shape (max_num_items,) + - items_mask: jax array (bool) of shape (max_num_items, 6) indicates the items that are valid. - - items_placed: jax array (bool) of shape (max_num_items,) + - items_placed: jax array (bool) of shape (max_num_items, 6) indicates the items that have been placed so far. - - action_mask: jax array (bool) of shape (obs_num_ems, max_num_items) + - action_mask: jax array (bool) of shape (obs_num_ems, max_num_items, 6) mask of the joint action space: `True` if the action (ems_id, item_id) is valid. - - action: `MultiDiscreteArray` (int32) of shape (obs_num_ems, max_num_items). + - action: `MultiDiscreteArray` (int32) of shape (obs_num_ems, max_num_items, 6). - ems_id: int between 0 and obs_num_ems - 1 (included). - item_id: int between 0 and max_num_items - 1 (included). @@ -121,6 +132,7 @@ def __init__( normalize_dimensions: bool = True, debug: bool = False, viewer: Optional[Viewer[State]] = None, + full_support: Optional[bool] = False, ): """Instantiates a `BinPack` environment. @@ -146,6 +158,9 @@ def __init__( this metric slows down the environment. Default to False. viewer: `Viewer` used for rendering. Defaults to `BinPackViewer` with "human" render mode. + full_support: if full_support is true a box can only be placed on top of a set of boxes + only if the bottom face of the box being placed is fully supported by the top face + of the set of supporting boxes. """ self.generator = generator or RandomGenerator( max_num_items=20, @@ -158,6 +173,7 @@ def __init__( super().__init__() self._viewer = viewer or BinPackViewer("BinPack", render_mode="human") self.debug = debug + self.full_support = full_support def __repr__(self) -> str: return "\n".join( @@ -189,7 +205,8 @@ def observation_spec(self) -> specs.Spec[Observation]: - if normalize_dimensions: tree of BoundedArray (float) of shape (max_num_items,). - else: - tree of BoundedArray (int32) of shape (max_num_items,). + tree of BoundedArray (int32) of shape (max_num_items,) though float is used for + values if they are valued items. - items_mask: BoundedArray (bool) of shape (max_num_items,). - items_placed: BoundedArray (bool) of shape (max_num_items,). - action_mask: BoundedArray (bool) of shape (obs_num_ems, max_num_items). @@ -226,7 +243,7 @@ def observation_spec(self) -> specs.Spec[Observation]: ) for axis in ["x_len", "y_len", "z_len"] } - items = specs.Spec(Item, "ItemsSpec", **items_dict) + items = specs.Spec(Item, "ItemsSpec", **items_dict) # type: ignore items_mask = specs.BoundedArray( (max_num_items,), bool, False, True, "items_mask" ) @@ -251,6 +268,55 @@ def observation_spec(self) -> specs.Spec[Observation]: action_mask=action_mask, ) + def _items_dict_for_valued_items(self, max_num_items: int, max_dim: int) -> Dict: + """Set the items_dict specs to the correct bounded array for valued items depending + on whether features are to be normalized or not. + + Args: + max_num_items: the maximum number of items that can be in an instance. + max_dim: The maximum dimension in this given instance. + + Returns: + A dictionary with string keys of the item features and specs BoundedArray as values. + """ + items_dict = self._items_dict_for_non_valued_items(max_num_items, max_dim) + if self.normalize_dimensions: + items_dict["value"] = specs.BoundedArray( + (self.generator.max_num_items,), float, -1.0, 1.0, "value" + ) + else: + items_dict["value"] = specs.BoundedArray( + (self.generator.max_num_items,), float, -jnp.inf, jnp.inf, "value" + ) + return items_dict + + def _items_dict_for_non_valued_items( + self, max_num_items: int, max_dim: int + ) -> Dict: + """Set the items_dict specs to the correct bounded array for non valued items depending + on whether dimensions are to be normalized or not. + + Args: + max_num_items: the maximum number of items that can be in an instance. + max_dim: The maximum dimension in this given instance. + + Returns: + A dictionary with string keys of the item features and specs BoundedArray as values. + """ + if self.normalize_dimensions: + return { + f"{axis}": specs.BoundedArray( + (self.generator.max_num_items,), float, 0.0, 1.0, axis + ) + for axis in ["x_len", "y_len", "z_len"] + } + return { + f"{axis}": specs.BoundedArray( + (self.generator.max_num_items,), jnp.int32, 0, max_dim, axis + ) + for axis in ["x_len", "y_len", "z_len"] + } + @cached_property def action_spec(self) -> specs.MultiDiscreteArray: """Specifications of the action expected by the `BinPack` environment. @@ -336,7 +402,6 @@ def step( # Make the observation. next_state, observation, extras = self._make_observation_and_extras(next_state) - done = ~jnp.any(next_state.action_mask) | ~action_is_valid reward = self.reward_fn(state, action, next_state, action_is_valid, done) @@ -448,7 +513,7 @@ def _get_extras(self, state: State) -> Dict: items_volume = jnp.sum(item_volume(state.items) * state.items_placed) volume_utilization = items_volume / state.container.volume() packed_items = jnp.sum(state.items_placed) - ratio_packed_items = packed_items / jnp.sum(state.items_mask) + ratio_packed_items = packed_items / state.nb_items active_ems = jnp.sum(state.ems_mask) extras = { "volume_utilization": volume_utilization, @@ -459,7 +524,7 @@ def _get_extras(self, state: State) -> Dict: return extras def _normalize_ems_and_items( - self, state: State, obs_ems: EMS, items: Item + self, state: State, obs_ems: EMS, items: ItemType ) -> Tuple[EMS, Item]: """Normalize the EMSs and items in the observation. Each dimension is divided by the container length so that they are all between 0.0 and 1.0. Hence, the ratio is not kept. @@ -506,7 +571,7 @@ def _get_action_mask( self, obs_ems: EMS, obs_ems_mask: chex.Array, - items: Item, + items: ItemType, items_mask: chex.Array, items_placed: chex.Array, ) -> chex.Array: @@ -526,7 +591,7 @@ def _get_action_mask( def is_action_allowed( ems: EMS, ems_mask: chex.Array, - item: Item, + item: ItemType, item_mask: chex.Array, item_placed: chex.Array, ) -> chex.Array: @@ -577,6 +642,8 @@ def _update_ems(self, state: State, item_id: chex.Numeric) -> State: state.ems = new_ems state.ems_mask = new_ems_mask + if self.full_support: + self.merge_same_height_ems(state) return state def _get_intersections_dict( @@ -587,9 +654,9 @@ def _get_intersections_dict( """ # Create new EMSs from EMSs that intersect the new item intersections_ems_dict: Dict[str, Space] = { - f"{axis}_{direction}": item_space.hyperplane(axis, direction).intersection( - state.ems - ) + f"{axis}_{direction}": item_space.hyperplane( + axis, direction, self.full_support + ).intersection(state.ems) for axis, direction in itertools.product( ["x", "y", "z"], ["lower", "upper"] ) @@ -664,6 +731,282 @@ def _get_intersections_dict( intersections_mask_dict[direction] &= ~to_remove return intersections_ems_dict, intersections_mask_dict + def merge_same_height_ems(self, state: State) -> None: + """ + Function that takes the state as input and merges all the EMS that are contiguous and + have the same height into one bigger EMS. + + """ + zero_vol_ems = Space(x1=0, x2=0, y1=0, y2=0, z1=0, z2=0) + max_nb_ems = len(state.ems.x1) + + def merge(direction: int, space1: "Space", space2: "Space") -> "Space": + """ + Function that takes two spaces and a direction and returns a merged space from + those two spaces along that direction. + + Args: + direction: 1: merge along y axis, 2: merge along the x axis. + Returns: + Space: Space obtained from the merger of the two input spaces. + """ + if direction == 1: + x1 = space1.x1 + x2 = space1.x2 + y1 = jnp.minimum(space1.y1, space2.y1) + y2 = jnp.maximum(space1.y2, space2.y2) + elif direction == 2: + x1 = jnp.minimum(space1.x1, space2.x1) + x2 = jnp.maximum(space1.x2, space2.x2) + y1 = space1.y1 + y2 = space1.y2 + return Space(x1=x1, x2=x2, y1=y1, y2=y2, z1=space1.z1, z2=space1.z2) + + def compute_merge_mask(args: Tuple) -> Tuple[chex.Array, chex.Array]: + """Computes a boolean matrix where element i, j is True if ems i and j can be merged. + + Two EMS can be merged if they start at the same z1, either have the same width, start + at the same y1 and they're continuous on the x axis of the container, or have the same + length, start at the same x1 and they're continuous on the y axis of the container. + Note, we do not verify the z2 values since ems merging in this way is only done in the + case of a full support constraint and z2 is always equal to container.height in that + case, ie. there are no overhanging items. + + Args: + ems_arr: Array EMS representing all the current EMSs of the state. + ems_mask: Mask over all the EMSs of the state. + + + Returns: + mask: 2D boolean array where mask[i,j] is True if tree_slice(EMS,i) and + tree_slice(EMS, j) have the same height and can be merged. + same_x : 2D boolean array where same_x[i,j] is True if + tree_slice(EMS,i).x1 == tree_slice(EMS,j).x1 and + tree_slice(EMS,i).x2 == tree_slice(EMS,j).x2. + If mask[i,j]==True and same_x[i,j] == True then we can merge these EMSs along + the y axis of the container and if False we can merge them along the x axis of + the container. + """ + + def isclose_matrix(a: chex.Array, b: chex.Array) -> chex.Array: + """ + This function takes two 1D vectors and returns a 2D boolean matrix where + (i,j) = True if a[i] is close to b[j]. + """ + return jnp.isclose( + jnp.expand_dims(a, -1) - jnp.expand_dims(b, -1).transpose(), 0 + ) + + ems_arr, ems_mask = args + + same_y = isclose_matrix(ems_arr.y1, ems_arr.y1) & isclose_matrix( + ems_arr.y2, ems_arr.y2 + ) + same_x = isclose_matrix(ems_arr.x1, ems_arr.x1) & isclose_matrix( + ems_arr.x2, ems_arr.x2 + ) + side_by_side_x = isclose_matrix(ems_arr.x1, ems_arr.x2) | isclose_matrix( + ems_arr.x2, ems_arr.x1 + ) + side_by_side_y = isclose_matrix(ems_arr.y1, ems_arr.y2) | isclose_matrix( + ems_arr.y2, ems_arr.y1 + ) + + # The ems have the same z1 and the emss exist. + mask = jnp.triu( + isclose_matrix(ems_arr.z1, ems_arr.z1) # [nb_ems, nb_ems] + & ems_mask # [nb_ems, ] + ) + # Can be merged along the y or x axis. + mask = mask & ( + same_x & side_by_side_y | same_y & side_by_side_x + ) # [nb_ems, nb_ems] (but only use the upper triangular part of the matrix). + return mask, same_x + + def merge_if_possible( + ems_arr: EMS, + flat_can_merge_ems: chex.Array, + same_x: chex.Array, + is_merged_ems: chex.Array, + mask_ind: chex.Array, + ) -> Tuple[chex.Array, chex.Array]: + """ + Function that merges two EMS if the merge_mask allows it. + Args: + ems_arr : initial array of EMSs. + flat_can_merge_ems: 1D array obtained from flattening a 2D array returned by the + compute_merge_mask function. flat_can_merge_ems[i] = True means that the EMSs + i//max_nb_ems and i%max_nb_ems can be merged. + same_x: 1D boolean array where same_x[i] is True if ( + tree_slice(EMS,i//max_nb_ems).x1 == tree_slice(EMS,i%max_nb_ems).x1 and + tree_slice(EMS,i//max_nb_ems).x2 == tree_slice(EMS,i%max_nb_ems).x2). + If flat_can_merge_ems[i]==True and same_x[i] == True then we can merge these + EMSs along the y axis of the container and if False we can merge them along the + x axis of the container. + is_merged_ems: 2D triangular boolean matrix where is_merged_ems[i,j] is False + if the i-th EMS and the j-th EMS are already merged. + mask_ind: the index of the element of the mask that the function will examine. + Returns: + Triangular boolean matrix is_merged_ems, and Space resulting from the merger of + the EMS at mask_ind// max_nb_ems and + mask_ind% max_nb_ems. + """ + + # Get the indeces of the 2 emss corresponding to mask_ind in the the non-flattened + # can_merge_ems matrix. + row = mask_ind // max_nb_ems + column = mask_ind % max_nb_ems + merged_ems = jax.lax.cond( + # make sure we can merge the two EMS located at mask_ind// max_nb_ems and + # mask_ind% max_nb_ems, and make sure that we haven't merged the ems located at + # row before this. + flat_can_merge_ems[mask_ind] & ~jnp.any(is_merged_ems[row]), + lambda _: jax.lax.cond( + same_x[mask_ind], + functools.partial(merge, 1), + functools.partial(merge, 2), + *( + tree_slice(ems_arr, row), + tree_slice(ems_arr, column), + ), + ), + lambda *_: zero_vol_ems, + (), + ) + is_merged_ems = is_merged_ems.at[row, column].set(~merged_ems.is_empty()) + return is_merged_ems, merged_ems + + def delete_merged_ems_and_add_new_ems( + flat_is_merged_ems: chex.Array, + flat_merged_ems: chex.Array, + merged_ems_indices: Tuple[chex.Array, chex.Array], + new_ems_and_mask: Tuple[EMS, chex.Array], + merged_ems_ind: Tuple[chex.Array], + ) -> Tuple[EMS, Any]: + """ + Function that takes the a list of EMS and the indices of two merged EMS + and removes those EMS from that list and puts in that list the EMS resulting + from merging those two EMS. + Args: + flat_is_merged_ems: 1D array obtained from the flattening of the 2D array + is_merged_ems. flat_is_merged_ems[i] == True if the two ems at + i//max_nb_ems and i%max_nb_ems have been merged. + flat_merged_ems is a Tree of Spaces of length max_nb_ems**2 and contains the + newly created EMSs from merging the initial EMS. + if flat_is_merged_ems[i] then flat_merged_ems[i] = EMS obtained from merging + the two EMSs at i//max_nb_ems and i%max_nb_ems. + merged_ems_indices: Tuple of Arrays, where merged_ems_indices[:,merged_ems_ind] + contains the indices in the initial ems array of the two ems that were + merged in order to create the EMS at flat_merged_ems[ + merged_ems_indices[0][merged_ems_ind] * max_nb_ems + + merged_ems_indices[1][merged_ems_ind] + ] + new_ems_and_mask: Tuple containing the Initial Array of EMSs and the mask + associated to it. + merged_ems_ind: indices of two merged EMS. + + Returns: + List of EMS where the merged EMS were deleted and the + newly created EMS added. + """ + ems_arr, ems_mask = new_ems_and_mask + new_ems_and_mask = jax.lax.cond( + # If the current EMS is the product of merging two EMSs together. + flat_is_merged_ems[merged_ems_ind], + lambda ems_arr, ems_mask: ( + tree_add_element( + # Add the new EMS at the place of the first EMS used to merge. + tree_add_element( + ems_arr, + merged_ems_indices[0][merged_ems_ind], + tree_slice( + flat_merged_ems, + # The Resulting EMS is at this index because the merged_ems + # array is flat and has a shape of max_nb_ems ** 2 + # (but only the first max_nb_ems make sense). + merged_ems_indices[0][merged_ems_ind] * max_nb_ems + + merged_ems_indices[1][merged_ems_ind], + ), + ), + # Add an empty EMS at the place of the second EMS used to merge. + merged_ems_indices[1][merged_ems_ind], + zero_vol_ems, + ), + # Set the mask to True at the index of the newly added merged EMS + # and to False at the index of the empty EMS. + ems_mask.at[merged_ems_indices[0][merged_ems_ind]] + .set(True) + .at[merged_ems_indices[1][merged_ems_ind]] + .set(False), + ), + lambda *_: _, + ems_arr, + ems_mask, + ) + return new_ems_and_mask, None + + def merge_ems(args: Tuple[EMS, chex.Array]) -> Tuple[EMS, chex.Array]: + """Function that merges all the ems it can. + + Args: + args: Tuple containing the EMS and the EMS mask arrays. + + Returns: + Updated EMS and EMS mask arrays after merging all the contiguous EMS having the same + z1. + """ + ems_arr, ems_mask = args + # can_merge_ems = True if emss i and j can be merged. + can_merge_ems, same_x = compute_merge_mask((ems_arr, ems_mask)) + + is_merged_ems = jnp.full_like(can_merge_ems, False) + + flat_same_x = same_x.flatten() + flat_can_merge_ems = can_merge_ems.flatten() + # Construct new emss from merging previous ones. + # - is_merged_ems[i,j] = True if the two ems at i and j have been merged. + # - flat_merged_ems is a Tree of Spaces of length max_nb_ems**2 and contains the newly + # created EMSs from merging the initial EMS. + is_merged_ems, flat_merged_ems = jax.lax.scan( + functools.partial( + merge_if_possible, ems_arr, flat_can_merge_ems, flat_same_x + ), + is_merged_ems, + jnp.arange(len(flat_can_merge_ems)), + ) + flat_is_merged_ems = is_merged_ems.flatten() + # Make sure that the indices of the merged EMS are at the top of this list. + keys = flat_is_merged_ems.argsort()[::-1] + flat_is_merged_ems = flat_is_merged_ems.sort()[::-1] + # Create a list of tuples that are used to access the list of merged EMS. + _, merged_ems_indices = jax.lax.scan( + lambda _, key: (_, (key // max_nb_ems, key % max_nb_ems)), + jnp.arange(max_nb_ems), + keys, + ) + # Go through the first max_nb_ems elements of the list of merged EMS since we know that + # it can contain at most number of initial ems given to this function. + # Loop through the original ems tree and delete all the EMSs that were merged and add + # the newly created EMSs. + (new_ems, new_ems_mask), _ = jax.lax.scan( + functools.partial( + delete_merged_ems_and_add_new_ems, + flat_is_merged_ems, + flat_merged_ems, + merged_ems_indices, + ), + (ems_arr, ems_mask), + jnp.arange(max_nb_ems), + ) + + return new_ems, new_ems_mask + + state.ems, state.ems_mask = jax.lax.while_loop( + lambda ems_and_mask: jnp.any(compute_merge_mask(ems_and_mask)[0]), + merge_ems, + (state.ems, state.ems_mask), + ) + def _add_ems( self, intersection_ems: EMS, @@ -703,3 +1046,553 @@ def inclusion_check(ems: EMS, ems_mask: chex.Array) -> chex.Array: add_one_ems, (ems, ems_mask), (intersection_ems, intersection_mask) ) return ems, ems_mask + + +class ExtendedBinPack(BinPack): + def __init__( + self, + is_rotation_allowed: bool, + is_value_based: bool, + generator: Optional[Generator] = None, + obs_num_ems: int = 40, + reward_fn: Optional[RewardFn] = None, + normalize_dimensions: bool = True, + debug: bool = False, + viewer: Optional[Viewer[State]] = None, + mean_item_value: Optional[float] = None, + std_item_value: Optional[float] = None, + full_support: Optional[bool] = False, + ): + generator = generator or ExtendedRandomGenerator( + is_rotation_allowed=is_rotation_allowed, + is_value_based=is_value_based, + max_num_items=20, + max_num_ems=40, + mean_item_value=mean_item_value, + std_item_value=std_item_value, + ) + viewer = viewer or ExtendedBinPackViewer( + "ExtendedBinPack", + is_rotation_allowed=is_rotation_allowed, + render_mode="human", + ) + super().__init__( + generator, + obs_num_ems, + reward_fn, + normalize_dimensions, + debug, + viewer, + full_support, + ) + self.is_value_based = is_value_based + self.is_rotation_allowed = is_rotation_allowed + + def observation_spec(self) -> specs.Spec[Observation]: + """Specifications of the observation of the `BinPack` environment. + + Returns: + Spec for the `Observation` whose fields are: + - ems: + - if normalize_dimensions: + tree of BoundedArray (float) of shape (obs_num_ems,). + - else: + tree of BoundedArray (int32) of shape (obs_num_ems,). + - ems_mask: BoundedArray (bool) of shape (obs_num_ems,). + - items: + - if normalize_dimensions: + tree of BoundedArray (float) of shape (max_num_items,). + - else: + tree of BoundedArray (int32) of shape (max_num_items,). + - items_mask: BoundedArray (bool) of shape (max_num_items,). + - items_placed: BoundedArray (bool) of shape (max_num_items,). + - action_mask: BoundedArray (bool) of shape (obs_num_ems, max_num_items). + """ + obs_num_ems = self.obs_num_ems + max_num_items = self.generator.max_num_items + max_dim = max(self.generator.container_dims) + + if self.is_value_based: + if self.is_rotation_allowed: + items_dict = self._items_dict_for_rotated_valued_items( + max_num_items, max_dim + ) + else: + items_dict = self._items_dict_for_valued_items(max_num_items, max_dim) + else: + if self.is_rotation_allowed: + items_dict = self._items_dict_for_rotated_items(max_num_items, max_dim) + else: + items_dict = self._items_dict_for_non_valued_items( + max_num_items, max_dim + ) + + items = specs.Spec( + ValuedItem if self.is_value_based else Item, "ItemsSpec", **items_dict + ) + nb_orientations = 1 + 5 * self.is_rotation_allowed + items_mask = specs.BoundedArray( + (nb_orientations * max_num_items,), bool, False, True, "items_mask" + ) + items_placed = specs.BoundedArray( + (nb_orientations * max_num_items,), bool, False, True, "items_placed" + ) + action_mask = specs.BoundedArray( + ( + obs_num_ems, + nb_orientations * max_num_items, + ), + bool, + False, + True, + "action_mask", + ) + return ( + super() + .observation_spec() + .replace( + items=items, + items_mask=items_mask, + items_placed=items_placed, + action_mask=action_mask, + ) + ) + + def _items_dict_for_valued_items(self, max_num_items: int, max_dim: int) -> Dict: + """Set the items_dict specs to the correct bounded array for valued items depending + on whether features are to be normalized or not. + + Args: + max_num_items: the maximum number of items that can be in an instance. + max_dim: The maximum dimension in this given instance. + + Returns: + A dictionary with string keys of the item features and specs BoundedArray as values. + """ + items_dict = self._items_dict_for_non_valued_items(max_num_items, max_dim) + if self.normalize_dimensions: + items_dict["value"] = specs.BoundedArray( + (max_num_items,), float, -1.0, 1.0, "value" + ) + else: + items_dict["value"] = specs.BoundedArray( + (max_num_items,), float, -jnp.inf, jnp.inf, "value" + ) + return items_dict + + def _items_dict_for_non_valued_items( + self, max_num_items: int, max_dim: int + ) -> Dict: + """Set the items_dict specs to the correct bounded array for non valued items depending + on whether dimensions are to be normalized or not. + + Args: + max_num_items: the maximum number of items that can be in an instance. + max_dim: The maximum dimension in this given instance. + + Returns: + A dictionary with string keys of the item features and specs BoundedArray as values. + """ + if self.normalize_dimensions: + return { + f"{axis}": specs.BoundedArray((max_num_items,), float, 0.0, 1.0, axis) + for axis in ["x_len", "y_len", "z_len"] + } + return { + f"{axis}": specs.BoundedArray((max_num_items,), jnp.int32, 0, max_dim, axis) + for axis in ["x_len", "y_len", "z_len"] + } + + def _items_dict_for_rotated_items(self, max_num_items: int, max_dim: int) -> Dict: + """Set the items_dict specs to the correct bounded array for items depending + on whether features are to be normalized or not. + + Args: + max_num_items: the maximum number of items that can be in an instance. + max_dim: The maximum dimension in this given instance. + + Returns: + A dictionary with string keys of the item features and specs BoundedArray as values. + """ + if self.normalize_dimensions: + items_dict = { + f"{axis}": specs.BoundedArray( + (6 * max_num_items,), float, 0.0, 1.0, axis + ) + for axis in ["x_len", "y_len", "z_len"] + } + + else: + items_dict = { + f"{axis}": specs.BoundedArray( + (6 * max_num_items,), jnp.int32, 0, max_dim, axis + ) + for axis in ["x_len", "y_len", "z_len"] + } + return items_dict + + def _items_dict_for_rotated_valued_items( + self, max_num_items: int, max_dim: int + ) -> Dict: + """Set the items_dict specs to the correct bounded array for valued items depending + on whether features are to be normalized or not. + + Args: + max_num_items: the maximum number of items that can be in an instance. + max_dim: The maximum dimension in this given instance. + + Returns: + A dictionary with string keys of the item features and specs BoundedArray as values. + """ + items_dict = self._items_dict_for_rotated_items(max_num_items, max_dim) + if self.normalize_dimensions: + items_dict["value"] = specs.BoundedArray( + (6 * max_num_items,), float, -1.0, 1.0, "value" + ) + else: + items_dict["value"] = specs.BoundedArray( + (6 * max_num_items,), float, -jnp.inf, jnp.inf, "value" + ) + + return items_dict + + def action_spec(self) -> specs.MultiDiscreteArray: + """Specifications of the action expected by the `BinPack` environment. + + Returns: + MultiDiscreteArray (int32) of shape (obs_num_ems, max_num_items). + - ems_id: int between 0 and obs_num_ems - 1 (included). + - item_id: int between 0 and max_num_items - 1 (included). + """ + if self.is_rotation_allowed: + num_values = jnp.array( + [6, self.obs_num_ems, self.generator.max_num_items], jnp.int32 + ) + else: + num_values = jnp.array( + [self.obs_num_ems, self.generator.max_num_items], jnp.int32 + ) + return specs.MultiDiscreteArray(num_values=num_values, name="action") + + def step( + self, state: State, action: chex.Array + ) -> Tuple[State, TimeStep[Observation]]: + """Run one timestep of the environment's dynamics. If the action is invalid, the state + is not updated, i.e. the action is not taken, and the episode terminates. + + Args: + state: `State` object containing the data of the current instance. + action: jax array (int32) of shape (2,): (ems_id, item_id). This means placing the given + item at the location of the given EMS. If the action is not valid, the flag + `invalid_action` will be set to True in `timestep.extras` and the episode + terminates. + + Returns: + state: `State` object corresponding to the next state of the environment. + timestep: `TimeStep` object corresponding to the timestep returned by the environment. + Also contains metrics in the `extras` field: + - volume_utilization: utilization (in [0, 1]) of the container. + - packed_items: number of items that are packed in the container. + - ratio_packed_items: ratio (in [0, 1]) of items that are packed in the container. + - active_ems: number of EMSs in the current instance. + - invalid_action: True if the action that was just taken was invalid. + - invalid_ems_from_env (optional): True if the environment produced an EMS that was + invalid. Only available in debug mode. + """ + action_is_valid = state.action_mask[tuple(action)] # type: ignore + orientation, obs_ems_id, item_id = None, None, None + if self.is_rotation_allowed: + orientation, obs_ems_id, item_id = action + else: + obs_ems_id, item_id = action + ems_id = state.sorted_ems_indexes[obs_ems_id] + + # Pack the item if the provided action is valid. + next_state = jax.lax.cond( + action_is_valid, + lambda s: self._pack_item(s, ems_id, item_id, orientation), + lambda s: s, + state, + ) + # Make the observation. + next_state, observation, extras = self._make_observation_and_extras(next_state) + + done = ~jnp.any(next_state.action_mask) | ~action_is_valid + reward = self.reward_fn(state, action, next_state, action_is_valid, done) + + extras.update(invalid_action=~action_is_valid) + + if self.debug: + ems_are_all_valid = self._ems_are_all_valid(next_state) + extras.update(invalid_ems_from_env=~ems_are_all_valid) + timestep = jax.lax.cond( + done, + lambda: termination( + reward=reward, + observation=observation, + extras=extras, + ), + lambda: transition( + reward=reward, + observation=observation, + extras=extras, + ), + ) + + return next_state, timestep + + def _make_observation_and_extras( + self, state: State + ) -> Tuple[State, Observation, Dict]: + """Computes the observation and the environment metrics to include in `timestep.extras`. + Also updates the `action_mask` and `sorted_ems_indexes` in the state. The observation is + obtained by selecting a subset of all EMSs, namely the `obs_num_ems` largest ones. + + Args: + state: a state of the ExtendedBinPack environment. + + Returns: + - State with the updated EMSs and EMS mask. + - Observation with the valid action_mask, item_placed and items_mask. + - extra metrics that can be used to measure the performance of the agent. + + """ + + def flatten_observation(observation: Observation) -> Observation: + """In the case where item rotation is allowed, this function is used to + + Args: + observation: Initial observation with items, items_placed and items_mask arrays of + shape (6, max_nb_items) and an action mask of shape (6, max_nb_ems, max_nb_items) + + Returns: + Observation where the items, items_placed and items_mask array have a shape + (6 * max_nb_items) and the action_mask has a shape of (max_nb_ems, 6 * max_nb_items) + """ + flattened_items_mask = observation.items_mask.flatten() + flattened_items_placed = observation.items_placed.flatten() + flattened_action_mask = observation.action_mask.reshape( + observation.action_mask.shape[1], + -1, + ) + if self.is_value_based: + items = cast( + ValuedItem, + observation.items, + ) + return Observation( + ems=observation.ems, + ems_mask=observation.ems_mask, + items=ValuedItem( + items.x_len.flatten(), + items.y_len.flatten(), + items.z_len.flatten(), + items.value.flatten(), + ), + items_mask=flattened_items_mask, + items_placed=flattened_items_placed, + action_mask=flattened_action_mask, + ) + + else: + return Observation( + ems=observation.ems, + ems_mask=observation.ems_mask, + items=Item( + observation.items.x_len.flatten(), + observation.items.y_len.flatten(), + observation.items.z_len.flatten(), + ), + items_mask=flattened_items_mask, + items_placed=flattened_items_placed, + action_mask=flattened_action_mask, + ) + + state, observation, extra = super()._make_observation_and_extras(state) + flat_obs = observation + if self.is_rotation_allowed: + flat_obs = flatten_observation(flat_obs) + return state, flat_obs, extra + + def _normalize_ems_and_items( + self, state: State, obs_ems: EMS, items: ItemType # type: ignore + ) -> Tuple[EMS, Item]: + """Normalize the EMSs and items in the observation. Each dimension is divided by the + container length so that they are all between 0.0 and 1.0. Hence, the ratio is not kept. + """ + # If items have the extra feature: value (for cases where we want to maximize the value + # packed into a container instead of the volume) we normalise by the largest valued item + # (observed better performances than normalising with respect to the total value of all + # items). + container_item: ItemType + if isinstance(items, ValuedItem): + items = cast(ValuedItem, items) + state.items = cast(ValuedItem, state.items) + ( + x_len, + y_len, + z_len, + _, + ) = container_item = valued_item_from_space_and_max_value( + state.container, state.instance_max_item_value_magnitude + ) + else: + items = cast(Item, items) + x_len, y_len, z_len = container_item = item_from_space(state.container) + + norm_space = Space(x1=x_len, x2=x_len, y1=y_len, y2=y_len, z1=z_len, z2=z_len) + obs_ems = jax.tree_util.tree_map( + lambda ems, normalization_ems: ems / normalization_ems, + obs_ems, + norm_space, + ) + items = jax.tree_util.tree_map( + lambda item, normalization_items: item / normalization_items, + items, + container_item, + ) + return obs_ems, items + + def _pack_item( # type: ignore + self, + state: State, + ems_id: int, + item_id: chex.Numeric, + item_orientation: Optional[int] = None, + ) -> State: + """This method assumes that the item can be placed correctly, i.e. the action is valid.""" + # Place the item in the bottom left corner of the EMS. + ems = tree_slice(state.ems, ems_id) + state.items_location = tree_add_element( + state.items_location, item_id, Location(ems.x1, ems.y1, ems.z1) + ) + if item_orientation is not None: + state.items_mask = state.items_mask.at[:, item_id].set(False) + state.items_placed = state.items_placed.at[item_orientation, item_id].set( + True + ) + else: + state.items_mask = state.items_mask.at[item_id].set(False) + state.items_placed = state.items_placed.at[item_id].set(True) + + state = self._update_ems(state, item_id, item_orientation) + return state + + def _get_action_mask( + self, + obs_ems: EMS, + obs_ems_mask: chex.Array, + items: ItemType, + items_mask: chex.Array, + items_placed: chex.Array, + ) -> chex.Array: + """Compute the mask of valid actions. + + Args: + obs_ems: tree of EMSs from the observation. + obs_ems_mask: mask of EMSs. + items: all items. + items_mask: mask of items. + items_placed: placing mask of items. + + Returns: + action_mask: jax array (bool) of shape (obs_num_ems, max_num_items,). + """ + + def is_action_allowed( + ems: EMS, + ems_mask: chex.Array, + item: Item, + item_mask: chex.Array, + item_placed: chex.Array, + ) -> chex.Array: + item_fits_in_ems = item_fits_in_item(item, item_from_space(ems)) + return ~item_placed & item_mask & ems_mask & item_fits_in_ems + + if self.is_rotation_allowed: + expanded_obs_state = jax.tree_util.tree_map( + functools.partial(jnp.expand_dims, axis=0), obs_ems + ) + expanded_obs_ems_mask = jax.tree_util.tree_map( + functools.partial(jnp.expand_dims, axis=0), obs_ems_mask + ) + action_mask = jax.vmap( + jax.vmap(is_action_allowed, in_axes=(None, None, 1, 1, 1)), + in_axes=(1, 1, None, None, None), + )( + expanded_obs_state, + expanded_obs_ems_mask, + items, + items_mask, + items_placed, + ) + action_mask = jnp.moveaxis(action_mask, -1, 0) + return jnp.asarray(action_mask, bool) + else: + return super()._get_action_mask( + obs_ems, obs_ems_mask, items, items_mask, items_placed + ) + + def _ems_are_all_valid(self, state: State) -> chex.Array: + """Checks if all EMSs are valid, i.e. they don't intersect items and do not stick out of the + container. This check is only done in debug mode. + """ + ems_intersection_with_items = jnp.zeros((state.ems_mask.shape), bool) + if self.is_rotation_allowed: + for o in range(6): + tmp_items = Item( + state.items[:][0][o], state.items[:][1][o], state.items[:][2][o] + ) + item_spaces = space_from_item_and_location( + tmp_items, state.items_location + ) + ems_intersect_items = jax.vmap(Space.intersect, in_axes=(0, None))( + state.ems, item_spaces + ) + ems_intersect_items &= jnp.outer(state.ems_mask, state.items_placed[o]) + ems_intersection_with_items |= jnp.any(ems_intersect_items) + ems_outside_container = jnp.any( + state.ems_mask & ~state.ems.is_included(state.container) + ) + return ~ems_intersection_with_items & ~ems_outside_container + else: + return super()._ems_are_all_valid(state) + + def _update_ems( # type: ignore + self, state: State, item_id: chex.Numeric, item_orientation + ) -> State: + """Update the EMSs after packing the item.""" + if item_orientation is not None: + item_space = space_from_item_and_location( + tree_slice(tree_slice(state.items, item_orientation), item_id), + tree_slice(state.items_location, item_id), + ) + else: + item_space = space_from_item_and_location( + tree_slice(state.items, item_id), + tree_slice(state.items_location, item_id), + ) + # Delete EMSs that intersect the new item. + ems_mask_after_intersect = ~item_space.intersect(state.ems) & state.ems_mask + + # Get the EMSs created by splitting the intersected EMSs. + intersections_ems_dict, intersections_mask_dict = self._get_intersections_dict( + state, item_space, ems_mask_after_intersect + ) + + # Loop over intersection EMSs from all directions to add them to the current set of EMSs. + new_ems = state.ems + new_ems_mask = ems_mask_after_intersect + for intersection_ems, intersection_mask in zip( + intersections_ems_dict.values(), intersections_mask_dict.values() + ): + new_ems, new_ems_mask = self._add_ems( + intersection_ems, intersection_mask, new_ems, new_ems_mask + ) + + state.ems = new_ems + state.ems_mask = new_ems_mask + if self.full_support: + self.merge_same_height_ems(state) + + return state diff --git a/jumanji/environments/packing/bin_pack/env_test.py b/jumanji/environments/packing/bin_pack/env_test.py index 967a56538..217339f87 100644 --- a/jumanji/environments/packing/bin_pack/env_test.py +++ b/jumanji/environments/packing/bin_pack/env_test.py @@ -12,17 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable +from typing import Callable, cast import chex import jax import jax.numpy as jnp import numpy as np import pytest +from pytest import FixtureRequest from jumanji import tree_utils -from jumanji.environments.packing.bin_pack.env import BinPack +from jumanji.environments.packing.bin_pack.env import BinPack, ExtendedBinPack from jumanji.environments.packing.bin_pack.generator import ( + ExtendedRandomGenerator, + ExtendedToyGenerator, RandomGenerator, ToyGenerator, ) @@ -30,6 +33,7 @@ from jumanji.environments.packing.bin_pack.types import ( Observation, State, + ValuedItem, item_from_space, location_from_space, ) @@ -42,6 +46,50 @@ from jumanji.types import TimeStep +def assert_type_bin_pack_state(state: State) -> None: + """Assert that all spaces or items are integers while all masks are boolean in the state.""" + jax.tree_util.tree_map( + lambda leaf: chex.assert_type(leaf, jnp.int32), + ( + state.container, + state.ems, + state.items, + state.items_location, + state.sorted_ems_indexes, + ), + ) + jax.tree_util.tree_map( + lambda leaf: chex.assert_type(leaf, bool), + (state.ems_mask, state.items_mask, state.items_placed, state.action_mask), + ) + + +def assert_type_extended_bin_pack_state(state: State) -> None: + """Assert that all spaces or items are integers while all masks are boolean in the state.""" + jax.tree_util.tree_map( + lambda leaf: chex.assert_type(leaf, jnp.int32), + ( + state.container, + state.ems, + state.items.x_len, + state.items.y_len, + state.items.z_len, + state.items_location, + state.sorted_ems_indexes, + ), + ) + if len(state.items) == 4: + state.items = cast(ValuedItem, state.items) + jax.tree_util.tree_map( + lambda leaf: chex.assert_type(leaf, float), + (state.items.value,), + ) + jax.tree_util.tree_map( + lambda leaf: chex.assert_type(leaf, bool), + (state.ems_mask, state.items_mask, state.items_placed, state.action_mask), + ) + + @pytest.fixture def bin_pack_random_select_action(bin_pack: BinPack) -> SelectActionFn: num_ems, num_items = np.asarray(bin_pack.action_spec.num_values) @@ -106,24 +154,6 @@ def select_action( # noqa: CCR001 return select_action -def assert_type_bin_pack_state(state: State) -> None: - """Assert that all spaces or items are integers while all masks are boolean in the state.""" - jax.tree_util.tree_map( - lambda leaf: chex.assert_type(leaf, jnp.int32), - ( - state.container, - state.ems, - state.items, - state.items_location, - state.sorted_ems_indexes, - ), - ) - jax.tree_util.tree_map( - lambda leaf: chex.assert_type(leaf, bool), - (state.ems_mask, state.items_mask, state.items_placed, state.action_mask), - ) - - def test_bin_pack__reset(bin_pack: BinPack) -> None: """Validates the jitted reset of the environment.""" chex.clear_trace_counter() @@ -265,3 +295,590 @@ def test_bin_pack__optimal_policy_random_instance( assert not timestep.extras["invalid_action"] assert not timestep.extras["invalid_ems_from_env"] assert jnp.array_equal(state.items_placed, solution.items_placed) + + +def test_full_support_bin_pack(full_support_bin_pack: BinPack) -> None: + """ + This test checks that no unsupported item can be placed by the agent + and that the merging of EMS with the same height works correctly. + """ + step_fn = jax.jit(full_support_bin_pack.step) + state, timestep = jax.jit(full_support_bin_pack.reset)(0) + # Start by forcing the agent to place one of the small items. + state, timestep = step_fn(state, jnp.array([0, 1])) + nb_remaning_items = full_support_bin_pack.generator.max_num_items - 1 + while not timestep.last() and nb_remaning_items > 1: + # Trick to select the EMS whose bottom is the floor of the container (when the number of + # items is less than 6 the biggest ems becomes the EMS whose bottom is the top face of + # the set of placed items). + # At first int(nb_remaning_items < 6) = 0 since the number of reamining items is bigger than + # 6 and so the EMS selected would be the biggest EMS which is the one whose bottom side is + # the floor. When the number of remaining items is less than 6 int(nb_remaning_items < 6) + # would be equal to 1 and so the selected EMS would be the second largest EMS which is + # the one with the bottom side on the floor. + action = jnp.array([int(nb_remaning_items < 6), nb_remaning_items]) + assert timestep.observation.action_mask[tuple(action)] + # Make sure that the big item can't be placed because it won't be supported + assert jnp.all(~timestep.observation.action_mask[:, 0]) + state, timestep = step_fn(state, action) + # Make sure that big piece isn't placeable because it can't be fully supported. + assert not timestep.extras["invalid_action"] + assert not timestep.extras["invalid_ems_from_env"] + nb_remaning_items -= 1 + action = jnp.array([0, 0]) + assert timestep.observation.action_mask[tuple(action)] + state, timestep = step_fn(state, action) + + # Make sure that all the items were placed and that the container was filled to 100% + assert jnp.array_equal(state.items_placed, jnp.array(11 * [True])) + assert jnp.isclose(timestep.extras["volume_utilization"], 1) + + +@pytest.mark.parametrize( + "normalize_dimensions, max_num_items, max_num_ems, obs_num_ems", + [ + (False, 5, 20, 10), + (True, 5, 20, 10), + (False, 20, 80, 50), + (True, 20, 80, 50), + ], +) +def test_full_support_bin_pack__optimal_policy_random_instance( + normalize_dimensions: bool, + bin_pack_optimal_policy_select_action: Callable[[Observation, State], chex.Array], + max_num_items: int, + max_num_ems: int, + obs_num_ems: int, +) -> None: + """Functional test to check that random instances can be optimally packed with an optimal + policy. Checks for both options: normalizing dimensions and not normalizing, and checks for + two different sizes: 5 items and 20 items, with respectively 20 and 80 max number of EMSs. + """ + num_trial_episodes = 3 + random_bin_pack = BinPack( + generator=RandomGenerator(max_num_items, max_num_ems), + obs_num_ems=obs_num_ems, + normalize_dimensions=normalize_dimensions, + debug=True, + full_support=True, + ) + reset_fn = jax.jit(random_bin_pack.reset) + generate_solution_fn = jax.jit(random_bin_pack.generator.generate_solution) + step_fn = jax.jit(random_bin_pack.step) + for key in jax.random.split(jax.random.PRNGKey(0), num_trial_episodes): + state, timestep = reset_fn(key) + solution = generate_solution_fn(key) + + while not timestep.last(): + action = bin_pack_optimal_policy_select_action( + timestep.observation, solution + ) + assert timestep.observation.action_mask[tuple(action)] + state, timestep = step_fn(state, action) + assert not timestep.extras["invalid_action"] + assert not timestep.extras["invalid_ems_from_env"] + assert jnp.array_equal(state.items_placed, solution.items_placed) + + +class TestExtendedBinPackRotationNoValue: + + """ + Class Used to test the Extended Bin pack environment when Items are allowed to take all + possible orientations but have no value associated to them. + """ + + @pytest.fixture + def rotation_bin_pack_random_select_action( + self, rotation_bin_pack: ExtendedBinPack + ) -> SelectActionFn: + num_orientations, num_ems, num_items = np.asarray( + rotation_bin_pack.action_spec().num_values + ) + + def select_action(key: chex.PRNGKey, observation: Observation) -> chex.Array: + """Randomly sample valid actions, as determined by `observation.action_mask`.""" + orientation_ems_item_id = jax.random.choice( + key=key, + a=num_orientations * num_ems * num_items, + p=observation.action_mask.flatten(), + ) + orientation_ems_id, item_id = jnp.divmod(orientation_ems_item_id, num_items) + orientation, ems_id = jnp.divmod(orientation_ems_id, num_ems) + action = jnp.array([orientation, ems_id, item_id], jnp.int32) + return action + + return jax.jit(select_action) # type: ignore + + @pytest.fixture # noqa: CCR001 + def rotation_bin_pack_optimal_policy_select_action( # noqa: CCR001 + self, + request: FixtureRequest, + ) -> Callable[[Observation, State], chex.Array]: + """Optimal policy for the BinPack environment. + WARNING: Requires `normalize_dimensions` from the BinPack environment. + """ + normalize_dimensions = request.param + + def unnormalize_obs_ems(obs_ems: Space, solution: State) -> Space: + x_len, y_len, z_len = item_from_space(solution.container) + norm_space = Space( + x1=x_len, x2=x_len, y1=y_len, y2=y_len, z1=z_len, z2=z_len + ) + obs_ems: Space = jax.tree_util.tree_map( + lambda x, c: jnp.round(x * c).astype(jnp.int32), + obs_ems, + norm_space, + ) + return obs_ems + + def select_action( # noqa: CCR001 + observation: Observation, solution: State + ) -> chex.Array: + """Outputs the best action to fully pack the container.""" + reshaped_action_mask = observation.action_mask.reshape( + 6, observation.action_mask.shape[0], -1 + ) + for obs_ems_id, obs_ems_action_mask in enumerate(reshaped_action_mask[0]): + if not obs_ems_action_mask.any(): + continue + obs_ems = tree_utils.tree_slice(observation.ems, obs_ems_id) + if normalize_dimensions: + obs_ems = unnormalize_obs_ems(obs_ems, solution) + obs_ems_location = location_from_space(obs_ems) + for item_id, action_feasible in enumerate(obs_ems_action_mask): + if not action_feasible: + continue + item_location = tree_utils.tree_slice( + solution.items_location, item_id + ) + if item_location == obs_ems_location: + return jnp.array([0, obs_ems_id, item_id], jnp.int32) + raise LookupError("Could not find the optimal action.") + + return select_action + + def test__rotation_bin_pack__reset( + self, rotation_bin_pack: ExtendedBinPack + ) -> None: + """Validates the jitted reset of the environment.""" + chex.clear_trace_counter() + reset_fn = jax.jit(chex.assert_max_traces(rotation_bin_pack.reset, n=1)) + + key = jax.random.PRNGKey(0) + _ = reset_fn(key) + # Call again to check it does not compile twice. + state, timestep = reset_fn(key) + assert isinstance(timestep, TimeStep) + assert isinstance(state, State) + # Check that the state is made of DeviceArrays, this is false for the non-jitted + # reset function since unpacking random.split returns numpy arrays and not device arrays. + assert_is_jax_array_tree(state) + assert_type_bin_pack_state(state) + assert state.ems_mask.any() + assert jnp.any(state.action_mask) + assert state.items_mask.any() + + def test_rotation_bin_pack_step__jit(self, rotation_bin_pack: BinPack) -> None: + """Validates jitting the environment step function.""" + chex.clear_trace_counter() + step_fn = jax.jit(chex.assert_max_traces(rotation_bin_pack.step, n=1)) + + key = jax.random.PRNGKey(0) + state, timestep = rotation_bin_pack.reset(key) + + action = rotation_bin_pack.action_spec().generate_value() + _ = step_fn(state, action) + # Call again to check it does not compile twice. + state, timestep = step_fn(state, action) + assert_type_bin_pack_state(state) + + def test_bin_pack__render_does_not_smoke( + self, rotation_bin_pack: ExtendedBinPack, dummy_rotation_state: State + ) -> None: + rotation_bin_pack.render(dummy_rotation_state) + rotation_bin_pack.close() + + def test_bin_pack__does_not_smoke( + self, + rotation_bin_pack: ExtendedBinPack, + rotation_bin_pack_random_select_action: SelectActionFn, + ) -> None: + """Test that we can run an episode without any errors.""" + check_env_does_not_smoke( + rotation_bin_pack, rotation_bin_pack_random_select_action + ) + + def test_bin_pack__pack_all_items_dummy_instance( + self, + rotation_bin_pack: ExtendedBinPack, + rotation_bin_pack_random_select_action: SelectActionFn, + ) -> None: + """Functional test to check that the dummy instance can be completed with a random agent.""" + step_fn = jax.jit(rotation_bin_pack.step) + key = jax.random.PRNGKey(0) + state, timestep = rotation_bin_pack.reset(key) + + while not timestep.last(): + action_key, key = jax.random.split(key) + action = rotation_bin_pack_random_select_action( + action_key, timestep.observation + ) + state, timestep = step_fn(state, action) + + assert jnp.array_equal(jnp.sum(state.items_placed), state.nb_items) + + @pytest.mark.parametrize( + "rotation_bin_pack_optimal_policy_select_action, normalize_dimensions", + [(False, False), (True, True)], + indirect=["rotation_bin_pack_optimal_policy_select_action"], + ) + def test_bin_pack__optimal_policy_toy_instance( + self, + rotation_bin_pack_optimal_policy_select_action: Callable[ + [Observation, State], chex.Array + ], + rotation_toy_generator: ExtendedToyGenerator, + normalize_dimensions: bool, + ) -> None: + """Functional test to check that the toy instance can be optimally packed with an optimal + policy. Checks for both options: normalizing dimensions and not normalizing. + """ + toy_bin_pack = ExtendedBinPack( + generator=rotation_toy_generator, + obs_num_ems=40, + normalize_dimensions=normalize_dimensions, + debug=True, + is_rotation_allowed=True, + is_value_based=False, + ) + key = jax.random.PRNGKey(0) + step_fn = jax.jit(toy_bin_pack.step) + state, timestep = toy_bin_pack.reset(key) + solution = toy_bin_pack.generator.generate_solution(key) + + while not timestep.last(): + action = rotation_bin_pack_optimal_policy_select_action( + timestep.observation, solution + ) + state, timestep = step_fn(state, action) + assert isinstance(timestep.extras, dict) + # This is not true anymore since there are items that can't + # fit in all the possible orientations + # assert not timestep.extras["invalid_action"] + assert not jnp.any(timestep.extras["invalid_ems_from_env"]) + if timestep.extras is not None: + assert timestep.extras["volume_utilization"] == 1 + assert timestep.extras["ratio_packed_items"] == 1 + + @pytest.mark.parametrize( + "rotation_bin_pack_optimal_policy_select_action, \ + normalize_dimensions, max_num_items, max_num_ems, obs_num_ems", + [ + (False, False, 5, 20, 10), + (True, True, 5, 20, 10), + (False, False, 20, 80, 50), + (True, True, 20, 80, 50), + ], + indirect=["rotation_bin_pack_optimal_policy_select_action"], + ) + def test_bin_pack__optimal_policy_random_instance( + self, + rotation_bin_pack_optimal_policy_select_action: Callable[ + [Observation, State], chex.Array + ], + normalize_dimensions: bool, + max_num_items: int, + max_num_ems: int, + obs_num_ems: int, + ) -> None: + """Functional test to check that random instances can be optimally packed with an optimal + policy. Checks for both options: normalizing dimensions and not normalizing, and checks for + two different sizes: 5 items and 20 items, with respectively 20 and 80 max number of EMSs. + """ + num_trial_episodes = 3 + random_bin_pack = ExtendedBinPack( + generator=ExtendedRandomGenerator( + max_num_items, + max_num_ems, + is_rotation_allowed=True, + is_value_based=False, + ), + obs_num_ems=obs_num_ems, + normalize_dimensions=normalize_dimensions, + debug=True, + is_rotation_allowed=True, + is_value_based=False, + ) + reset_fn = jax.jit(random_bin_pack.reset) + generate_solution_fn = jax.jit(random_bin_pack.generator.generate_solution) + step_fn = jax.jit(random_bin_pack.step) + for key in jax.random.split(jax.random.PRNGKey(0), num_trial_episodes): + state, timestep = reset_fn(key) + solution = generate_solution_fn(key) + + while not timestep.last(): + action = rotation_bin_pack_optimal_policy_select_action( + timestep.observation, solution + ) + reshaped_action_mask = timestep.observation.action_mask.reshape( + 6, timestep.observation.action_mask.shape[0], -1 + ) + assert reshaped_action_mask[tuple(action)] + state, timestep = step_fn(state, action) + # assert not timestep.extras["invalid_action"] + assert not jnp.any(timestep.extras["invalid_ems_from_env"]) + assert jnp.array_equal(state.items_placed, solution.items_placed) + assert round(timestep.extras["volume_utilization"]) == 1 + assert timestep.extras["ratio_packed_items"] == 1 + + +class TestExtendedBinPackRotationValue: + """ + Class Used to test the Extended Bin pack environment when Items are allowed to take all + possible orientations and have a value associated with them. + """ + + @pytest.fixture + def extended_bin_pack_random_select_action( + self, extended_bin_pack: ExtendedBinPack + ) -> SelectActionFn: + num_orientations, num_ems, num_items = np.asarray( + extended_bin_pack.action_spec().num_values + ) + + def select_action(key: chex.PRNGKey, observation: Observation) -> chex.Array: + """Randomly sample valid actions, as determined by `observation.action_mask`.""" + orientation_ems_item_id = jax.random.choice( + key=key, + a=num_orientations * num_ems * num_items, + p=observation.action_mask.flatten(), + ) + orientation_ems_id, item_id = jnp.divmod(orientation_ems_item_id, num_items) + orientation, ems_id = jnp.divmod(orientation_ems_id, num_ems) + action = jnp.array([orientation, ems_id, item_id], jnp.int32) + return action + + return jax.jit(select_action) # type: ignore + + @pytest.fixture # noqa: CCR001 + def extended_bin_pack_optimal_policy_select_action( # noqa: CCR001 + self, + request: FixtureRequest, + ) -> Callable[[Observation, State], chex.Array]: + """Optimal policy for the BinPack environment. + WARNING: Requires `normalize_dimensions` from the BinPack environment. + """ + normalize_dimensions = request.param + + def unnormalize_obs_ems(obs_ems: Space, solution: State) -> Space: + x_len, y_len, z_len = item_from_space(solution.container) + norm_space = Space( + x1=x_len, x2=x_len, y1=y_len, y2=y_len, z1=z_len, z2=z_len + ) + obs_ems: Space = jax.tree_util.tree_map( + lambda x, c: jnp.round(x * c).astype(jnp.int32), + obs_ems, + norm_space, + ) + return obs_ems + + def select_action( # noqa: CCR001 + observation: Observation, solution: State + ) -> chex.Array: + """Outputs the best action to fully pack the container.""" + reshaped_action_mask = observation.action_mask.reshape( + 6, observation.action_mask.shape[0], -1 + ) + for obs_ems_id, obs_ems_action_mask in enumerate(reshaped_action_mask[0]): + if not obs_ems_action_mask.any(): + continue + obs_ems = tree_utils.tree_slice(observation.ems, obs_ems_id) + if normalize_dimensions: + obs_ems = unnormalize_obs_ems(obs_ems, solution) + obs_ems_location = location_from_space(obs_ems) + for item_id, action_feasible in enumerate(obs_ems_action_mask): + if not action_feasible: + continue + item_location = tree_utils.tree_slice( + solution.items_location, item_id + ) + if item_location == obs_ems_location: + return jnp.array([0, obs_ems_id, item_id], jnp.int32) + raise LookupError("Could not find the optimal action.") + + return select_action + + def test__extended_bin_pack__reset( + self, extended_bin_pack: ExtendedBinPack + ) -> None: + """Validates the jitted reset of the environment.""" + chex.clear_trace_counter() + reset_fn = jax.jit(chex.assert_max_traces(extended_bin_pack.reset, n=1)) + + key = jax.random.PRNGKey(0) + _ = reset_fn(key) + # Call again to check it does not compile twice. + state, timestep = reset_fn(key) + assert isinstance(timestep, TimeStep) + assert isinstance(state, State) + # Check that the state is made of DeviceArrays, this is false for the non-jitted + # reset function since unpacking random.split returns numpy arrays and not device arrays. + assert_is_jax_array_tree(state) + assert_type_extended_bin_pack_state(state) + assert state.ems_mask.any() + assert jnp.any(state.action_mask) + assert state.items_mask.any() + + def test_extended_bin_pack_step__jit(self, extended_bin_pack: BinPack) -> None: + """Validates jitting the environment step function.""" + chex.clear_trace_counter() + step_fn = jax.jit(chex.assert_max_traces(extended_bin_pack.step, n=1)) + + key = jax.random.PRNGKey(0) + state, timestep = extended_bin_pack.reset(key) + + action = extended_bin_pack.action_spec().generate_value() + _ = step_fn(state, action) + # Call again to check it does not compile twice. + state, timestep = step_fn(state, action) + assert_type_extended_bin_pack_state(state) + + def test_bin_pack__render_does_not_smoke( + self, extended_bin_pack: ExtendedBinPack, dummy_rotation_state: State + ) -> None: + extended_bin_pack.render(dummy_rotation_state) + extended_bin_pack.close() + + def test_bin_pack__does_not_smoke( + self, + extended_bin_pack: ExtendedBinPack, + extended_bin_pack_random_select_action: SelectActionFn, + ) -> None: + """Test that we can run an episode without any errors.""" + check_env_does_not_smoke( + extended_bin_pack, extended_bin_pack_random_select_action + ) + + def test_bin_pack__pack_all_items_dummy_instance( + self, + extended_bin_pack: ExtendedBinPack, + extended_bin_pack_random_select_action: SelectActionFn, + ) -> None: + """Functional test to check that the dummy instance can be completed with a random agent.""" + step_fn = jax.jit(extended_bin_pack.step) + key = jax.random.PRNGKey(0) + state, timestep = extended_bin_pack.reset(key) + + while not timestep.last(): + action_key, key = jax.random.split(key) + action = extended_bin_pack_random_select_action( + action_key, timestep.observation + ) + state, timestep = step_fn(state, action) + + assert jnp.array_equal(jnp.sum(state.items_placed), state.nb_items) + + @pytest.mark.parametrize( + "extended_bin_pack_optimal_policy_select_action, normalize_dimensions", + [(False, False), (True, True)], + indirect=["extended_bin_pack_optimal_policy_select_action"], + ) + def test_bin_pack__optimal_policy_toy_instance( + self, + extended_bin_pack_optimal_policy_select_action: Callable[ + [Observation, State], chex.Array + ], + rotation_toy_generator: ExtendedToyGenerator, + normalize_dimensions: bool, + ) -> None: + """Functional test to check that the toy instance can be optimally packed with an optimal + policy. Checks for both options: normalizing dimensions and not normalizing. + """ + toy_bin_pack = ExtendedBinPack( + generator=rotation_toy_generator, + obs_num_ems=40, + normalize_dimensions=normalize_dimensions, + debug=True, + is_rotation_allowed=True, + is_value_based=False, + ) + key = jax.random.PRNGKey(0) + step_fn = jax.jit(toy_bin_pack.step) + state, timestep = toy_bin_pack.reset(key) + solution = toy_bin_pack.generator.generate_solution(key) + + while not timestep.last(): + action = extended_bin_pack_optimal_policy_select_action( + timestep.observation, solution + ) + state, timestep = step_fn(state, action) + assert isinstance(timestep.extras, dict) + # This is not true anymore since there are items that can't + # fit in all the possible orientations + # assert not timestep.extras["invalid_action"] + assert not jnp.any(timestep.extras["invalid_ems_from_env"]) + if timestep.extras is not None: + assert timestep.extras["volume_utilization"] == 1 + assert timestep.extras["ratio_packed_items"] == 1 + + @pytest.mark.parametrize( + "extended_bin_pack_optimal_policy_select_action, \ + normalize_dimensions, max_num_items, max_num_ems, obs_num_ems", + [ + (False, False, 5, 20, 10), + (True, True, 5, 20, 10), + (False, False, 20, 80, 50), + (True, True, 20, 80, 50), + ], + indirect=["extended_bin_pack_optimal_policy_select_action"], + ) + def test_bin_pack__optimal_policy_random_instance( + self, + extended_bin_pack_optimal_policy_select_action: Callable[ + [Observation, State], chex.Array + ], + normalize_dimensions: bool, + max_num_items: int, + max_num_ems: int, + obs_num_ems: int, + ) -> None: + """Functional test to check that random instances can be optimally packed with an optimal + policy. Checks for both options: normalizing dimensions and not normalizing, and checks for + two different sizes: 5 items and 20 items, with respectively 20 and 80 max number of EMSs. + """ + num_trial_episodes = 3 + random_bin_pack = ExtendedBinPack( + generator=ExtendedRandomGenerator( + max_num_items, + max_num_ems, + is_rotation_allowed=True, + is_value_based=True, + mean_item_value=1, + std_item_value=0.5, + ), + obs_num_ems=obs_num_ems, + normalize_dimensions=normalize_dimensions, + debug=True, + is_rotation_allowed=True, + is_value_based=False, + ) + reset_fn = jax.jit(random_bin_pack.reset) + generate_solution_fn = jax.jit(random_bin_pack.generator.generate_solution) + step_fn = jax.jit(random_bin_pack.step) + for key in jax.random.split(jax.random.PRNGKey(0), num_trial_episodes): + state, timestep = reset_fn(key) + solution = generate_solution_fn(key) + + while not timestep.last(): + action = extended_bin_pack_optimal_policy_select_action( + timestep.observation, solution + ) + reshaped_action_mask = timestep.observation.action_mask.reshape( + 6, timestep.observation.action_mask.shape[0], -1 + ) + assert reshaped_action_mask[tuple(action)] + state, timestep = step_fn(state, action) + # assert not timestep.extras["invalid_action"] + assert not jnp.any(timestep.extras["invalid_ems_from_env"]) + assert jnp.array_equal(state.items_placed, solution.items_placed) + assert round(timestep.extras["volume_utilization"]) == 1 diff --git a/jumanji/environments/packing/bin_pack/generator.py b/jumanji/environments/packing/bin_pack/generator.py index de32e635d..5f84484d4 100644 --- a/jumanji/environments/packing/bin_pack/generator.py +++ b/jumanji/environments/packing/bin_pack/generator.py @@ -16,8 +16,10 @@ import collections import csv import functools +import math import operator -from typing import List, Tuple +from random import randint +from typing import List, Optional, Tuple, cast import chex import jax @@ -29,9 +31,14 @@ Item, Location, State, + ValuedItem, empty_ems, item_from_space, + item_volume, location_from_space, + rotated_items_from_space, + space_from_item_and_location, + valued_item_from_space_and_max_value, ) from jumanji.tree_utils import tree_slice, tree_transpose @@ -41,6 +48,14 @@ TWENTY_FOOT_DIMS = (5870, 2330, 2200) CSV_COLUMNS = ["Item_Name", "Length", "Width", "Height", "Quantity"] +CSV_VALUE_PROBLEM_COLUMNS = [ + "Item_Name", + "Length", + "Width", + "Height", + "Quantity", + "Value", +] def make_container(container_dims: Tuple[int, int, int]) -> Container: @@ -364,7 +379,11 @@ def _generate_solved_instance(self, key: chex.PRNGKey) -> State: items_location=items_location, action_mask=None, sorted_ems_indexes=sorted_ems_indexes, + # For non value based optimisation set these to dummy values by default + instance_max_item_value_magnitude=0.0, + instance_total_value=0.0, key=jax.random.PRNGKey(0), + nb_items=20, ) return solution @@ -467,7 +486,11 @@ def _parse_csv_file( items_location=items_location, action_mask=None, sorted_ems_indexes=sorted_ems_indexes, + # For non value based optimisation set these to dummy values by default + instance_max_item_value_magnitude=0.0, + instance_total_value=0.0, key=jax.random.PRNGKey(0), + nb_items=num_items, ) return reset_state @@ -662,6 +685,7 @@ def _generate_solved_instance(self, key: chex.PRNGKey) -> State: items_spaces, items_mask = self._split_container_into_items_spaces( container, split_key ) + nb_items = jnp.sum(items_mask) items = item_from_space(items_spaces) sorted_ems_indexes = jnp.arange(0, self.max_num_ems, dtype=jnp.int32) @@ -676,23 +700,36 @@ def _generate_solved_instance(self, key: chex.PRNGKey) -> State: items_location=all_item_locations, action_mask=None, sorted_ems_indexes=sorted_ems_indexes, + # For non value based optimisation set these to dummy values by default + instance_max_item_value_magnitude=0.0, + instance_total_value=0.0, key=key, + nb_items=nb_items, ) return solution def _split_container_into_items_spaces( - self, container: Container, key: chex.PRNGKey + self, + container: Container, + key: chex.PRNGKey, + input_max_items_generated: Optional[int] = None, ) -> Tuple[Space, chex.Array]: """Split one space (the container) into several sub-spaces that will be identified as items. + + The output items_spaces and items_mask array will be self.max_num_items by default but can + be set to a custom value that is different from this (useful for + RandomValueProblemGenerator). """ + max_items_generated = cast(int, input_max_items_generated) or self.max_num_items + chex.assert_rank(list(container.__dict__.values()), 0) def cond_fun(val: Tuple[Space, chex.Array, chex.PRNGKey]) -> jnp.bool_: _, items_mask, _ = val num_placed_items = jnp.sum(items_mask) return ( - num_placed_items < self.max_num_items - self._split_num_same_items + 1 + num_placed_items < max_items_generated - self._split_num_same_items + 1 ) def body_fun( @@ -708,10 +745,10 @@ def body_fun( items_spaces = Space( **jax.tree_util.tree_map( - lambda x: x * jnp.ones(self.max_num_items, jnp.int32), container + lambda x: x * jnp.ones(max_items_generated, jnp.int32), container ).__dict__ ) - items_mask = jnp.zeros(self.max_num_items, bool).at[0].set(True) + items_mask = jnp.zeros(max_items_generated, bool).at[0].set(True) init_val = (items_spaces, items_mask, key) (items_spaces, items_mask, _) = jax.lax.while_loop(cond_fun, body_fun, init_val) @@ -875,3 +912,807 @@ def body_fn( ) return items_spaces, items_mask + + +class ValueProblemCSVGenerator(CSVGenerator): + """`Generator` that parses a CSV file to do active search on a single instance. It + always resets to the same instance defined by the CSV file. The generator can handle any + container dimensions but assumes a 20-ft container by default. + + The CSV file is expected to have the following columns: + - Item_Name + - Length + - Width + - Height + - Quantity + - Value + + Example with value: + Item_Name,Length,Width,Height,Quantity,Value + shape_1,1080,760,300,5,4.5 + shape_2,1100,430,250,3,3.4 + """ + + def __init__( + self, + csv_path: str, + max_num_ems: int, + container_dims: Tuple[int, int, int] = TWENTY_FOOT_DIMS, + ): + """Instantiate a `CSVGenerator` that generates the same instance (active search) + defined by a CSV file. + + Args: + csv_path: path to the CSV file defining the instance to reset to. + max_num_ems: maximum number of ems the environment will handle. This defines the shape + of the EMS buffer that is kept in the environment state. The good number heavily + depends on the number of items (given by the CSV file). + container_dims: (length, width, height) tuple of integers corresponding to the + dimensions of the container in millimeters. By default, assume a 20-ft container. + """ + super().__init__(csv_path, max_num_ems, container_dims) + + def _parse_csv_file( + self, csv_path: str, max_num_ems: int, container_dims: Tuple[int, int, int] + ) -> State: + """Create an instance by parsing a CSV file. + + Args: + csv_path: path to the CSV file to parse that defines the instance to reset to. + max_num_ems: maximum number of ems the environment will handle. This defines the shape + of the EMS buffer that is kept in the environment state. + container_dims: (length, width, height) tuple of integers corresponding to the + dimensions of the container in millimeters. + + Returns: + `BinPack` state that contains the instance defined in the CSV file. + """ + container = make_container(container_dims) + + # Initialize the EMSs + list_of_ems = [container] + (max_num_ems - 1) * [empty_ems()] + ems = tree_transpose(list_of_ems) + ems_mask = jnp.zeros(max_num_ems, bool).at[0].set(True) + + # Parse the CSV file to generate the items + rows = self._read_valued_csv(csv_path) + list_of_items = self._generate_list_of_valued_items(rows) + items = tree_transpose(list_of_items) + + # Initialize items mask and location + num_items = len(list_of_items) + items_mask = jnp.ones(num_items, bool) + items_placed = jnp.zeros(num_items, bool) + items_location = Location(*tuple(jnp.zeros((3, num_items), jnp.int32))) + + sorted_ems_indexes = jnp.arange(0, max_num_ems, dtype=jnp.int32) + + instance_total_value = jnp.sum(items.value * items_mask) + instance_max_item_value_magnitude = jnp.max(abs(items.value * items_mask)) + + reset_state = State( + container=container, + ems=ems, + ems_mask=ems_mask, + items=items, + nb_items=len(items.x_len), + items_mask=items_mask, + items_placed=items_placed, + items_location=items_location, + action_mask=None, + sorted_ems_indexes=sorted_ems_indexes, + instance_max_item_value_magnitude=instance_max_item_value_magnitude, + instance_total_value=instance_total_value, + key=jax.random.PRNGKey(0), + ) + + return reset_state + + def _read_valued_csv( + self, csv_path: str + ) -> List[Tuple[str, int, int, int, int, float]]: + rows = [] + with open(csv_path, newline="") as csvfile: + reader = csv.reader(csvfile) + for row_index, row in enumerate(reader): + if row_index == 0: + if len(row) != len(CSV_VALUE_PROBLEM_COLUMNS): + raise ValueError( + "Got wrong number of columns, expected: " + f"{', '.join(CSV_VALUE_PROBLEM_COLUMNS)}" + ) + elif row != CSV_VALUE_PROBLEM_COLUMNS: + raise ValueError("Columns in wrong order") + else: + # Column order: Item_Name, Length, Width, Height, Quantity, Value. + rows.append( + ( + row[0], + int(row[1]), + int(row[2]), + int(row[3]), + int(row[4]), + float(row[5]), + ) + ) + return rows + + def _generate_list_of_valued_items( + self, rows: List[Tuple[str, int, int, int, int, float]] + ) -> List[ValuedItem]: + """Generate the list of items from a Pandas DataFrame. + + Args: + rows: List[tuple] describing the items for the corresponding instance. + + Returns: + List of `ValuedItem` flattened so that identical items (quantity > 1) are copied + according to their quantity. + """ + list_of_items = [] + for (_, x_len, y_len, z_len, quantity, value) in rows: + identical_items = quantity * [ + ValuedItem( + x_len=jnp.array(x_len, jnp.int32), + y_len=jnp.array(y_len, jnp.int32), + z_len=jnp.array(z_len, jnp.int32), + value=jnp.array(value, jnp.float32), + ) + ] + list_of_items.extend(identical_items) + return list_of_items + + +class ExtendedToyGenerator(ToyGenerator): + def __init__(self) -> None: + super().__init__() + self.is_rotation_allowed = True + self.is_value_based = False + + def _generate_solved_instance(self, key: chex.PRNGKey) -> State: + solution = super()._generate_solved_instance(key) + x_len, y_len, z_len = ( + solution.items.x_len, + solution.items.y_len, + solution.items.z_len, + ) + solution.items = Item( + x_len=jnp.array([x_len, x_len, y_len, y_len, z_len, z_len]), + y_len=jnp.array([y_len, z_len, x_len, z_len, y_len, x_len]), + z_len=jnp.array([z_len, y_len, z_len, x_len, x_len, y_len]), + ) + + solution.items_mask = jnp.array( + [ + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + # Since only the items + # 2, 12, 17 and 18 can be placed with their length along any of the container axes, + # we mask these orientations of the other items. + [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0], + ] + ) + return solution + + def _unpack_items(self, state: State) -> State: + state = super()._unpack_items(state) + state.items_placed = jnp.zeros((6, self.max_num_items), bool) + return state + + +class ExtendedRandomGenerator(RandomGenerator): + def __init__( + self, + max_num_items: int, + max_num_ems: int, + is_rotation_allowed: bool, + is_value_based: bool, + split_eps: float = 0.3, + prob_split_one_item: float = 0.7, + split_num_same_items: int = 5, + container_dims: Tuple[int, int, int] = TWENTY_FOOT_DIMS, + mean_item_value: Optional[float] = None, + std_item_value: Optional[float] = None, + ): + """ + Args: + is_rotation_allowed: whether the generator has to generate instances where item + rotation is possible. Defaults to False. + is_value_based: whether the generator has to generator has to generate + instance with values. Defaults to False. + mean_item_value: The mean value of the normal distribution from which the item values + will be sampled. + std_item_value: The standard deviation of the normal distribution from which + the item values will be sampled. + Raises: + ValueError: When the generator has to generate value based instances but the mean and + std aren't provided. + """ + super().__init__( + max_num_items, + max_num_ems, + split_eps, + prob_split_one_item, + split_num_same_items, + container_dims, + ) + self.is_rotation_allowed = is_rotation_allowed + self.is_value_based = is_value_based + self.mean_item_value = mean_item_value + self.std_item_value = std_item_value + if self.is_value_based and (mean_item_value is None or std_item_value is None): + std_name = "std_item_value" + raise ValueError( + "Value Based generator not provided with" + + f"{'mean_item_value attribute' if mean_item_value is None else ''}" + + f"{f'{std_name} attribute' if std_item_value is None else ''}" + ) + + def _generate_value_based_solved_instance( # type: ignore + self, key: chex.PRNGKey + ) -> State: + """Generate the random instance with half of all items correctly packed (the higher value + half). The other half of the generated items remain unpacked. + + The first half of items are generated by splitting the container in the same way as the + inherited RandomGenerator class, but with less than half the max_num_items. Values are + randomly generated, and assigned to these items. These items are duplicated and assigned + values that are the value of the original items + the sum of the values of all the items of + the previous half . This means that the optimal solution is to fill one container perfectly + with the higher valued items. This method may lead to fewer than max_num_items being + generated in total, but the items tree will be of size max_num_items and the indexes that + don't correspond to the items of the instance are masked out with the state's items_mask. + """ + key, split_key = jax.random.split(key) + container = make_container(self.container_dims) + + list_of_ems = [container] + (self.max_num_ems - 1) * [empty_ems()] + ems = tree_transpose(list_of_ems) + ems_mask = jnp.zeros(self.max_num_ems, bool) + + # Create less than half of max_num_items item spaces by splitting up a container. This + # will lead to nb_items_in_one_container items that fit perfectly into a single container. + nb_items_in_one_container = math.floor(0.5 * self.max_num_items) + # This will generate items_spaces and item_mask of size nb_items_in_one_container. + items_spaces, items_mask = self._split_container_into_items_spaces( + container, split_key, nb_items_in_one_container + ) + # Randomly generate values that will then be increased by a value of + # total_value_of_generated_values to generate a "perfect instance" with a known optimal + # solution. + key, split_key = jax.random.split(key) + item_values = self.mean_item_value + ( + self.std_item_value + * jax.random.normal(split_key, (len(items_mask),), jnp.float32) + ) + total_value_of_generated_values = sum(item_values) + # Assign values to the items that are packed in the optimal solution. To ensure an optimal + # solution, the total value of the duplicated item values are added to the original + # generated item values. + optimal_items = valued_item_from_space_and_max_value( + items_spaces, item_values + total_value_of_generated_values + ) + # Duplicate the above items and assign values to them that are half their counterparts + # above. + extra_items = valued_item_from_space_and_max_value(items_spaces, item_values) + + # If self.max_num_items is an odd number, the concatenation of items and extra_items would + # result in a tree size of < self.max_num_items. In this case, we add padding. + padding_of_int_ones = jnp.ones( + self.max_num_items - 2 * len(items_mask), jnp.int32 + ) + padding_of_float_ones = jnp.ones( + self.max_num_items - 2 * len(items_mask), jnp.float32 + ) + padding_of_bool_zeros = jnp.zeros( + self.max_num_items - 2 * len(items_mask), bool + ) + padding_items = ValuedItem( + padding_of_int_ones * container.x2, + padding_of_int_ones * container.y2, + padding_of_int_ones * container.z2, + padding_of_float_ones, + ) + + # Create the solution state by creating trees of size self.max_num_items for items, + # items_placable_at_beginning_mask and items_placed_mask. + items = jax.tree_map( + lambda x, y, z: jnp.concatenate((x, y, z)), + optimal_items, + extra_items, + padding_items, + ) + items_placable_at_beginning_mask = jnp.concatenate( + (items_mask, items_mask, padding_of_bool_zeros) + ) + zeros_of_size_nb_extra_items = jnp.zeros(items_mask.shape, bool) + items_placed_mask = jnp.concatenate( + (items_mask, zeros_of_size_nb_extra_items, padding_of_bool_zeros) + ) + + sorted_ems_indexes = jnp.arange(0, self.max_num_ems, dtype=jnp.int32) + + # Create locations for placed, unplaced and padded items. + placed_items_locations = location_from_space(items_spaces) + + remaining_items_locations = Location( + x=jnp.zeros(self.max_num_items - len(items_mask), jnp.int32), + y=jnp.zeros(self.max_num_items - len(items_mask), jnp.int32), + z=jnp.zeros(self.max_num_items - len(items_mask), jnp.int32), + ) + all_item_locations = jax.tree_map( + lambda x, y: jnp.concatenate((x, y)), + placed_items_locations, + remaining_items_locations, + ) + best_value_volume_ratio_item = jnp.argmax(items.value / item_volume(items)) + normalization_value = jnp.min( + jnp.array( + jnp.max( + jnp.array( + ( + jnp.sum(items.value * jnp.array(items.value > 0)), + -jnp.sum(items.value * jnp.array(items.value < 0)), + ) + ) + ), + items.value[best_value_volume_ratio_item] + * ( + container.volume() + // item_volume(tree_slice(items, best_value_volume_ratio_item)) + + 1 + ), + ) + ) + instance_max_item_value_magnitude = jnp.max( + abs(items.value * items_placable_at_beginning_mask) + ) + + solution = State( + container=container, + ems=ems, + ems_mask=ems_mask, + items=items, + nb_items=len(items.x_len), + items_mask=items_placable_at_beginning_mask, + items_placed=items_placed_mask, + items_location=all_item_locations, + action_mask=None, + sorted_ems_indexes=sorted_ems_indexes, + instance_max_item_value_magnitude=instance_max_item_value_magnitude, + instance_total_value=normalization_value, + key=key, + ) + return solution + + def _generate_rotated_solved_instance(self, key: chex.PRNGKey) -> State: + solved_instance = super()._generate_solved_instance(key) + solved_instance.items = rotated_items_from_space( + space_from_item_and_location( + solved_instance.items, solved_instance.items_location + ) + ) + tmp = jnp.zeros((6, solved_instance.items_placed.shape[0]), bool) + solved_instance.items_placed = tmp.at[0].set(solved_instance.items_placed) + solved_instance.items_mask = jnp.broadcast_to( + solved_instance.items_mask, (6, solved_instance.items_mask.shape[0]) + ) + return solved_instance + + def _generate_solved_instance(self, key: chex.PRNGKey) -> State: + if self.is_rotation_allowed and self.is_value_based: + solved_instance = self._generate_value_based_solved_instance(key) + solved_instance.items = cast(ValuedItem, solved_instance.items) + solved_instance.items = rotated_items_from_space( + space=space_from_item_and_location( + solved_instance.items, solved_instance.items_location + ), + value=solved_instance.items.value, + ) + tmp = jnp.zeros((6, solved_instance.items_placed.shape[0]), bool) + solved_instance.items_placed = tmp.at[0].set(solved_instance.items_placed) + solved_instance.items_mask = jnp.broadcast_to( + solved_instance.items_mask, (6, solved_instance.items_mask.shape[0]) + ) + return solved_instance + else: + if self.is_rotation_allowed: + return self._generate_rotated_solved_instance(key) + elif self.is_value_based: + return self._generate_value_based_solved_instance(key) + else: + return super()._generate_solved_instance(key) + + def _unpack_items(self, state: State) -> State: + state = super()._unpack_items(state) + placed_item_arr_dims = ( + (6, self.max_num_items) if self.is_rotation_allowed else self.max_num_items + ) + state.items_placed = jnp.zeros(placed_item_arr_dims, bool) + return state + + +class ExtendedTrainingGenerator(ExtendedRandomGenerator): + def __init__( + self, + max_num_items: int, + max_num_ems: int, + mean_item_value: float, + std_item_value: float, + min_target_volume: int, + max_target_volume: int, + split_eps: float = 0.3, + prob_split_one_item: float = 0.7, + split_num_same_items: int = 5, + container_dims: Tuple[int, int, int] = TWENTY_FOOT_DIMS, + is_evaluation: Optional[bool] = False, + ): + """ + Args: + min_target_volume: minimal volume of the items in an instance in terms of containers. + min_target_volume: maximum volume of the items in an instance in terms of containers. + Raises: + ValueError: When the generator has to generate value based instances but the mean and + std aren't provided. + """ + super().__init__( + max_num_items, + max_num_ems, + True, + True, + split_eps, + prob_split_one_item, + split_num_same_items, + container_dims, + mean_item_value, + std_item_value, + ) + self.generated_instance_optimal_value = jnp.inf + self.is_evaluation = is_evaluation + self.min_target_volume = min_target_volume + self.max_target_volume = max_target_volume + + def _generate_value_based_solved_instance( # type: ignore + self, key: chex.PRNGKey, target_volume: int + ) -> State: + """Generate the random instance with 1/target_volume of all items correctly packed. The + other part of the generated items remain unpacked. + """ + key, split_key = jax.random.split(key) + container = make_container(self.container_dims) + + list_of_ems = [container] + (self.max_num_ems - 1) * [empty_ems()] + ems = tree_transpose(list_of_ems) + ems_mask = jnp.zeros(self.max_num_ems, bool) + + nb_items_in_one_container = math.floor(1 / target_volume * self.max_num_items) + first_items_spaces, first_items_mask = self._split_container_into_items_spaces( + container, split_key, nb_items_in_one_container + ) + len_first_items_mask = len(first_items_mask) + final_items_mask = first_items_mask + all_items_spaces = first_items_spaces + for _ in range(1, target_volume): + tmp_items_spaces, tmp_items_mask = self._split_container_into_items_spaces( + container, split_key, nb_items_in_one_container + ) + all_items_spaces = jax.tree_map( + lambda x, y: jnp.concatenate((x, y)), + all_items_spaces, + tmp_items_spaces, + ) + final_items_mask = jnp.concatenate((final_items_mask, tmp_items_mask)) + + # In case nb_items_in_one_container * target_volume < max_nb_items, add the difference as + # masked zero volume items. + nb_extra_items = self.max_num_items - len(final_items_mask) + extra_items_spaces = Space( + x1=jnp.zeros(nb_extra_items), + x2=jnp.zeros(nb_extra_items), + y1=jnp.zeros(nb_extra_items), + y2=jnp.zeros(nb_extra_items), + z1=jnp.zeros(nb_extra_items), + z2=jnp.zeros(nb_extra_items), + ) + extra_items_mask = jnp.zeros(nb_extra_items, bool) + all_items_spaces = jax.tree_map( + lambda x, y: jnp.concatenate((x, y)), + all_items_spaces, + extra_items_spaces, + ) + final_items_mask = jnp.concatenate((final_items_mask, extra_items_mask)) + item_values = self.mean_item_value + ( + self.std_item_value + * jax.random.normal(split_key, (self.max_num_items,), jnp.float32) + ) + if self.is_evaluation: + total_instance_value = jnp.sum(item_values) + item_values, _ = jax.lax.scan( + lambda item_values, item_mask_ind: ( + item_values.at[item_mask_ind].set( + (item_values[item_mask_ind] + total_instance_value) + * first_items_mask[item_mask_ind] + ), + (item_values[item_mask_ind] + total_instance_value) + * first_items_mask[item_mask_ind], + ), + item_values, + jnp.arange(len(first_items_mask)), + ) + optimal_value = jnp.sum(_) + self.generated_instance_optimal_value = optimal_value + items = valued_item_from_space_and_max_value(all_items_spaces, item_values) + + # Only the items of the first container are placed. + items_placed_mask = jnp.concatenate( + ( + first_items_mask, + jnp.zeros(self.max_num_items - len_first_items_mask, bool), + ) + ) + sorted_ems_indexes = jnp.arange(0, self.max_num_ems, dtype=jnp.int32) + + # Create locations for placed, unplaced and padded items. + placed_items_locations = location_from_space(first_items_spaces) + + remaining_items_locations = Location( + x=jnp.zeros(self.max_num_items - len_first_items_mask, jnp.int32), + y=jnp.zeros(self.max_num_items - len_first_items_mask, jnp.int32), + z=jnp.zeros(self.max_num_items - len_first_items_mask, jnp.int32), + ) + all_item_locations = jax.tree_map( + lambda x, y: jnp.concatenate((x, y)), + placed_items_locations, + remaining_items_locations, + ) + best_value_volume_ratio_item = jnp.argmax(items.value / item_volume(items)) + normalization_value = jnp.min( + jnp.array( + jnp.max( + jnp.array( + ( + jnp.sum(items.value * jnp.array(items.value > 0)), + -jnp.sum(items.value * jnp.array(items.value < 0)), + ) + ) + ), + items.value[best_value_volume_ratio_item] + * ( + container.volume() + // item_volume(tree_slice(items, best_value_volume_ratio_item)) + + 1 + ), + ) + ) + instance_max_item_value_magnitude = jnp.max(abs(items.value * final_items_mask)) + + solution = State( + container=container, + ems=ems, + ems_mask=ems_mask, + items=items, + nb_items=len(items.x_len), + items_mask=final_items_mask, + items_placed=items_placed_mask, + items_location=all_item_locations, + action_mask=None, + sorted_ems_indexes=sorted_ems_indexes, + instance_max_item_value_magnitude=instance_max_item_value_magnitude, + instance_total_value=normalization_value, + key=key, + ) + return solution + + def _generate_solved_instance(self, key: chex.PRNGKey) -> State: + target_volume = randint(self.min_target_volume, self.max_target_volume) + solved_instance = self._generate_value_based_solved_instance(key, target_volume) + solved_instance.items = cast(ValuedItem, solved_instance.items) + solved_instance.items = rotated_items_from_space( + space=space_from_item_and_location( + solved_instance.items, solved_instance.items_location + ), + value=solved_instance.items.value, + ) + tmp = jnp.zeros((6, solved_instance.items_placed.shape[0]), bool) + solved_instance.items_placed = tmp.at[0].set(solved_instance.items_placed) + solved_instance.items_mask = jnp.broadcast_to( + solved_instance.items_mask, (6, solved_instance.items_mask.shape[0]) + ) + return solved_instance + + +class ExtendedCSVGenerator(CSVGenerator): + """`Generator` that parses a CSV file to do active search on a single instance. It + always resets to the same instance defined by the CSV file. The generator can handle any + container dimensions but assumes a 20-ft container by default. + + The CSV file is expected to have the following columns: + - Item_Name + - Length + - Width + - Height + - Quantity + - Value + + Example with value: + Item_Name,Length,Width,Height,Quantity,Value + shape_1,1080,760,300,5,4.5 + shape_2,1100,430,250,3,3.4 + """ + + def __init__( + self, + csv_path: str, + max_num_ems: int, + is_rotation_allowed: bool, + is_value_based: bool, + container_dims: Tuple[int, int, int] = TWENTY_FOOT_DIMS, + ): + """Instantiate a `CSVGenerator` that generates the same instance (active search) + defined by a CSV file. + + Args: + csv_path: path to the CSV file defining the instance to reset to. + max_num_ems: maximum number of ems the environment will handle. This defines the shape + of the EMS buffer that is kept in the environment state. The good number heavily + depends on the number of items (given by the CSV file). + container_dims: (length, width, height) tuple of integers corresponding to the + dimensions of the container in millimeters. By default, assume a 20-ft container. + """ + super().__init__(csv_path, max_num_ems, container_dims) + self.is_value_based = is_value_based + self.is_rotation_allowed = is_rotation_allowed + + def _parse_csv_file( + self, csv_path: str, max_num_ems: int, container_dims: Tuple[int, int, int] + ) -> State: + """Create an instance by parsing a CSV file. + + Args: + csv_path: path to the CSV file to parse that defines the instance to reset to. + max_num_ems: maximum number of ems the environment will handle. This defines the shape + of the EMS buffer that is kept in the environment state. + container_dims: (length, width, height) tuple of integers corresponding to the + dimensions of the container in millimeters. + + Returns: + `BinPack` state that contains the instance defined in the CSV file. + """ + if not (self.is_rotation_allowed or self.is_value_based): + return super()._parse_csv_file(csv_path, max_num_ems, container_dims) + if not self.is_rotation_allowed: + return self._generate_value_based_instance_from_csv( + csv_path, max_num_ems, container_dims + ) + + if not self.is_value_based: + reset_state = super()._parse_csv_file(csv_path, max_num_ems, container_dims) + reset_state.items = rotated_items_from_space( + space_from_item_and_location( + reset_state.items, reset_state.items_location + ) + ) + tmp = jnp.zeros((6, reset_state.items_placed.shape[0]), bool) + reset_state.items_placed = tmp.at[0].set(reset_state.items_placed) + reset_state.items_mask = jnp.broadcast_to( + reset_state.items_mask, (6, reset_state.items_mask.shape[0]) + ) + return reset_state + + reset_state = self._generate_value_based_instance_from_csv( + csv_path, max_num_ems, container_dims + ) + reset_state.items = cast(ValuedItem, reset_state.items) + reset_state.items = rotated_items_from_space( + space=space_from_item_and_location( + reset_state.items, reset_state.items_location + ), + value=reset_state.items.value, + ) + tmp = jnp.zeros((6, reset_state.items_placed.shape[0]), bool) + reset_state.items_placed = tmp.at[0].set(reset_state.items_placed) + reset_state.items_mask = jnp.broadcast_to( + reset_state.items_mask, (6, reset_state.items_mask.shape[0]) + ) + return reset_state + + def _generate_value_based_instance_from_csv( + self, csv_path: str, max_num_ems: int, container_dims: Tuple[int, int, int] + ) -> State: + container = make_container(container_dims) + + # Initialize the EMSs + list_of_ems = [container] + (max_num_ems - 1) * [empty_ems()] + ems = tree_transpose(list_of_ems) + ems_mask = jnp.zeros(max_num_ems, bool).at[0].set(True) + + # Parse the CSV file to generate the items + rows = self._read_valued_csv(csv_path) + list_of_items = self._generate_list_of_valued_items(rows) + items = tree_transpose(list_of_items) + + # Initialize items mask and location + num_items = len(list_of_items) + items_mask = jnp.ones(num_items, bool) + items_placed = jnp.zeros(num_items, bool) + items_location = Location(*tuple(jnp.zeros((3, num_items), jnp.int32))) + + sorted_ems_indexes = jnp.arange(0, max_num_ems, dtype=jnp.int32) + + instance_total_value = jnp.sum(items.value * items_mask) + instance_max_item_value_magnitude = jnp.max(abs(items.value * items_mask)) + + reset_state = State( + container=container, + ems=ems, + ems_mask=ems_mask, + items=items, + nb_items=len(items.x_len), + items_mask=items_mask, + items_placed=items_placed, + items_location=items_location, + action_mask=None, + sorted_ems_indexes=sorted_ems_indexes, + instance_max_item_value_magnitude=instance_max_item_value_magnitude, + instance_total_value=instance_total_value, + key=jax.random.PRNGKey(0), + ) + + return reset_state + + def _read_valued_csv( + self, csv_path: str + ) -> List[Tuple[str, int, int, int, int, float]]: + rows = [] + with open(csv_path, newline="") as csvfile: + reader = csv.reader(csvfile) + for row_index, row in enumerate(reader): + if row_index == 0: + if len(row) != len(CSV_VALUE_PROBLEM_COLUMNS): + raise ValueError( + "Got wrong number of columns, expected: " + f"{', '.join(CSV_VALUE_PROBLEM_COLUMNS)}" + ) + elif row != CSV_VALUE_PROBLEM_COLUMNS: + raise ValueError("Columns in wrong order") + else: + # Column order: Item_Name, Length, Width, Height, Quantity, Value. + rows.append( + ( + row[0], + int(row[1]), + int(row[2]), + int(row[3]), + int(row[4]), + float(row[5]), + ) + ) + return rows + + def _generate_list_of_valued_items( + self, rows: List[Tuple[str, int, int, int, int, float]] + ) -> List[ValuedItem]: + """Generate the list of items from a Pandas DataFrame. + + Args: + rows: List[tuple] describing the items for the corresponding instance. + + Returns: + List of `ValuedItem` flattened so that identical items (quantity > 1) are copied + according to their quantity. + """ + list_of_items = [] + for (_, x_len, y_len, z_len, quantity, value) in rows: + identical_items = quantity * [ + ValuedItem( + x_len=jnp.array(x_len, jnp.int32), + y_len=jnp.array(y_len, jnp.int32), + z_len=jnp.array(z_len, jnp.int32), + value=jnp.array(value, jnp.float32), + ) + ] + list_of_items.extend(identical_items) + return list_of_items diff --git a/jumanji/environments/packing/bin_pack/generator_test.py b/jumanji/environments/packing/bin_pack/generator_test.py index e93a4aa51..25d3039ac 100644 --- a/jumanji/environments/packing/bin_pack/generator_test.py +++ b/jumanji/environments/packing/bin_pack/generator_test.py @@ -21,12 +21,18 @@ from jumanji.environments.packing.bin_pack.conftest import DummyGenerator from jumanji.environments.packing.bin_pack.generator import ( CSVGenerator, + ExtendedRandomGenerator, + ExtendedToyGenerator, RandomGenerator, ToyGenerator, save_instance_to_csv, ) -from jumanji.environments.packing.bin_pack.types import State, item_volume -from jumanji.testing.pytrees import assert_trees_are_different, assert_trees_are_equal +from jumanji.environments.packing.bin_pack.types import Item, State, item_volume +from jumanji.testing.pytrees import ( + assert_trees_are_close, + assert_trees_are_different, + assert_trees_are_equal, +) def test_save_instance_to_csv(dummy_state: State, tmpdir: py.path.local) -> None: @@ -203,3 +209,296 @@ def test_random_generator__generate_solution( solution_state2 = generate_solution(jax.random.PRNGKey(2)) assert_trees_are_different(solution_state1, solution_state2) + + +class TestRandomValueProblemGenerator: + @pytest.fixture + def random_generator( + self, max_num_items: int = 12, max_num_ems: int = 20 + ) -> ExtendedRandomGenerator: + return ExtendedRandomGenerator( + max_num_items, + max_num_ems, + is_value_based=True, + is_rotation_allowed=False, + mean_item_value=1, + std_item_value=0.5, + ) + + def test_random_generator__properties( + self, + random_generator: ExtendedRandomGenerator, + ) -> None: + """Validate that the random instance generator has the correct properties.""" + assert random_generator.max_num_items == 12 + assert random_generator.max_num_ems == 20 + + def test_random_generator__call( + self, random_generator: ExtendedRandomGenerator + ) -> None: + """Validate that the random instance generator's call function is jittable and compiles + only once. Also check that giving two different keys results in two different instances. + """ + chex.clear_trace_counter() + call_fn = jax.jit(chex.assert_max_traces(random_generator.__call__, n=1)) + state1 = call_fn(key=jax.random.PRNGKey(1)) + assert isinstance(state1, State) + + state2 = call_fn(key=jax.random.PRNGKey(2)) + assert_trees_are_different(state1, state2) + + def test_random_generator__generate_solution( + self, + random_generator: ExtendedRandomGenerator, + ) -> None: + """Validate that the random instance generator's generate_solution method behaves correctly. + Also check that it is jittable and compiles only once. + """ + + # This will produce a starting state for an environment (no items packed). + state1 = random_generator(jax.random.PRNGKey(1)) + + chex.clear_trace_counter() + generate_solution = jax.jit( + chex.assert_max_traces(random_generator.generate_solution, n=1) + ) + + # This will produce a solution to the environment (all possible items packed). + solution_state1 = generate_solution(jax.random.PRNGKey(1)) + assert isinstance(solution_state1, State) + assert_trees_are_equal(solution_state1.ems, state1.ems) + # Should be different because there is only 1 available ems in state1, whereas there should + # be none available at the end with a solution. + assert_trees_are_different(solution_state1.ems_mask, state1.ems_mask) + # The items do not change whether they are packed or not, it is just the items-placed mask + # that is different. + assert_trees_are_close(solution_state1.items, state1.items) + assert_trees_are_equal(solution_state1.items_mask, state1.items_mask) + # There should be no items placed in state1 whereas half of them will be placed in the + # perfect solution. + assert_trees_are_different(solution_state1.items_placed, state1.items_placed) + # They should be different because state1 has no locations and the solution should have + # half of them placed. + assert_trees_are_different( + solution_state1.items_location, state1.items_location + ) + # Checks that the perfect solution fills the container exactly + placed_items_volume = ( + item_volume(solution_state1.items) * solution_state1.items_placed + ).sum() + assert jnp.isclose(placed_items_volume, solution_state1.container.volume()) + + # Checks that the total volume of all items in an instance is greater than the volume of the + # container + items_volume = ( + item_volume(solution_state1.items) * solution_state1.items_mask + ).sum() + assert not jnp.isclose(items_volume, solution_state1.container.volume()) + assert items_volume > solution_state1.container.volume() + + # Generates a solution to a new instance (will be new because a different random key). + solution_state2 = generate_solution(jax.random.PRNGKey(2)) + assert_trees_are_different(solution_state1, solution_state2) + + +class TestExtendedToyGenerator: + @pytest.fixture + def toy_generator(self) -> ExtendedToyGenerator: + return ExtendedToyGenerator() + + def test_toy_generator__properties( + self, + toy_generator: ToyGenerator, + ) -> None: + """Validate that the toy instance generator has the correct properties.""" + assert toy_generator.max_num_items == 20 + assert toy_generator.max_num_ems > 0 + + def test_toy_generator__call( + self, + toy_generator: ToyGenerator, + ) -> None: + """Validate that the toy instance generator's call function behaves correctly, that it + returns the same state for different keys. Also check that it is jittable and compiles only + once. + """ + chex.clear_trace_counter() + call_fn = jax.jit(chex.assert_max_traces(toy_generator.__call__, n=1)) + state1 = call_fn(jax.random.PRNGKey(1)) + state2 = call_fn(jax.random.PRNGKey(2)) + assert_trees_are_equal(state1, state2) + + def test_toy_generator__generate_solution( + self, + toy_generator: ToyGenerator, + ) -> None: + """Validate that the toy instance generator's generate_solution method behaves correctly. + Also check that it is jittable and compiles only once.""" + state1 = toy_generator(jax.random.PRNGKey(1)) + + chex.clear_trace_counter() + generate_solution = jax.jit( + chex.assert_max_traces(toy_generator.generate_solution, n=1) + ) + + solution_state1 = generate_solution(jax.random.PRNGKey(1)) + assert isinstance(solution_state1, State) + assert_trees_are_equal(solution_state1.ems, state1.ems) + assert_trees_are_different(solution_state1.ems_mask, state1.ems_mask) + assert_trees_are_equal(solution_state1.items, state1.items) + assert_trees_are_equal(solution_state1.items_mask, state1.items_mask) + assert_trees_are_different(solution_state1.items_placed, state1.items_placed) + assert_trees_are_different( + solution_state1.items_location, state1.items_location + ) + + assert jnp.all(solution_state1.items_placed) + solution_state2 = generate_solution(jax.random.PRNGKey(2)) + assert_trees_are_equal(solution_state1, solution_state2) + + +class TestRotationRandomGenerator: + @pytest.fixture + def random_generator( + self, max_num_items: int = 6, max_num_ems: int = 10 + ) -> RandomGenerator: + return ExtendedRandomGenerator( + max_num_items, max_num_ems, is_rotation_allowed=True, is_value_based=False + ) + + def test_random_generator__properties( + self, + random_generator: RandomGenerator, + ) -> None: + """Validate that the random instance generator has the correct properties.""" + assert random_generator.max_num_items == 6 + assert random_generator.max_num_ems == 10 + + def test_random_generator__call(self, random_generator: RandomGenerator) -> None: + """Validate that the random instance generator's call function is jittable and compiles + only once. Also check that giving two different keys results in two different instances. + """ + chex.clear_trace_counter() + call_fn = jax.jit(chex.assert_max_traces(random_generator.__call__, n=1)) + state1 = call_fn(key=jax.random.PRNGKey(1)) + assert isinstance(state1, State) + + state2 = call_fn(key=jax.random.PRNGKey(2)) + assert_trees_are_different(state1, state2) + + def test_random_generator__generate_solution( + self, + random_generator: RandomGenerator, + ) -> None: + """Validate that the random instance generator's generate_solution method behaves correctly. + Also check that it is jittable and compiles only once. + """ + state1 = random_generator(jax.random.PRNGKey(1)) + + chex.clear_trace_counter() + generate_solution = jax.jit( + chex.assert_max_traces(random_generator.generate_solution, n=1) + ) + + solution_state1 = generate_solution(jax.random.PRNGKey(1)) + assert isinstance(solution_state1, State) + assert_trees_are_equal(solution_state1.ems, state1.ems) + assert_trees_are_different(solution_state1.ems_mask, state1.ems_mask) + assert_trees_are_equal(solution_state1.items, state1.items) + assert_trees_are_equal(solution_state1.items_mask, state1.items_mask) + assert_trees_are_different(solution_state1.items_placed, state1.items_placed) + assert_trees_are_different( + solution_state1.items_location, state1.items_location + ) + # In the optimal solution all the generated items are placed in their initial orientation so + # an item is either placed in its initial orientation in the bin or it doesn't even exist. + assert jnp.all(solution_state1.items_placed[0] | ~solution_state1.items_mask) + non_rotated_items = Item( + solution_state1.items.x_len[0], + solution_state1.items.y_len[0], + solution_state1.items.z_len[0], + ) + items_volume = ( + item_volume(non_rotated_items) * solution_state1.items_mask[0] + ).sum() + assert jnp.isclose(items_volume, solution_state1.container.volume()) + + solution_state2 = generate_solution(jax.random.PRNGKey(2)) + assert_trees_are_different(solution_state1, solution_state2) + + +class TestExtendedRandomGenerator: + @pytest.fixture + def random_generator( + self, max_num_items: int = 12, max_num_ems: int = 20 + ) -> RandomGenerator: + return ExtendedRandomGenerator( + max_num_items, + max_num_ems, + is_rotation_allowed=True, + is_value_based=True, + mean_item_value=1, + std_item_value=0.5, + ) + + def test_random_generator__properties( + self, + random_generator: RandomGenerator, + ) -> None: + """Validate that the random instance generator has the correct properties.""" + assert random_generator.max_num_items == 12 + assert random_generator.max_num_ems == 20 + + def test_random_generator__call(self, random_generator: RandomGenerator) -> None: + """Validate that the random instance generator's call function is jittable and compiles + only once. Also check that giving two different keys results in two different instances. + """ + chex.clear_trace_counter() + call_fn = jax.jit(chex.assert_max_traces(random_generator.__call__, n=1)) + state1 = call_fn(key=jax.random.PRNGKey(1)) + assert isinstance(state1, State) + + state2 = call_fn(key=jax.random.PRNGKey(2)) + assert_trees_are_different(state1, state2) + + def test_random_generator__generate_solution( + self, + random_generator: RandomGenerator, + ) -> None: + """Validate that the random instance generator's generate_solution method behaves correctly. + Also check that it is jittable and compiles only once. + """ + state1 = random_generator(jax.random.PRNGKey(1)) + + chex.clear_trace_counter() + generate_solution = jax.jit( + chex.assert_max_traces(random_generator.generate_solution, n=1) + ) + + solution_state1 = generate_solution(jax.random.PRNGKey(1)) + assert isinstance(solution_state1, State) + assert_trees_are_equal(solution_state1.ems, state1.ems) + assert_trees_are_different(solution_state1.ems_mask, state1.ems_mask) + assert_trees_are_close(solution_state1.items, state1.items) + assert_trees_are_equal(solution_state1.items_mask, state1.items_mask) + assert_trees_are_different(solution_state1.items_placed, state1.items_placed) + assert_trees_are_different( + solution_state1.items_location, state1.items_location + ) + + placed_items_volume = ( + item_volume(solution_state1.items) * solution_state1.items_placed + ).sum() + assert jnp.isclose(placed_items_volume, solution_state1.container.volume()) + + # Checks that the total volume of all items in an instance is greater than the volume of the + # container + items_volume = ( + item_volume(solution_state1.items) * solution_state1.items_mask + ).sum() + assert not jnp.isclose(items_volume, solution_state1.container.volume()) + assert items_volume > solution_state1.container.volume() + + # Generates a solution to a new instance (will be new because a different random key). + solution_state2 = generate_solution(jax.random.PRNGKey(2)) + assert_trees_are_different(solution_state1, solution_state2) diff --git a/jumanji/environments/packing/bin_pack/reward.py b/jumanji/environments/packing/bin_pack/reward.py index 0be90b51b..2a983ee40 100644 --- a/jumanji/environments/packing/bin_pack/reward.py +++ b/jumanji/environments/packing/bin_pack/reward.py @@ -13,12 +13,13 @@ # limitations under the License. import abc +from typing import cast import chex import jax import jax.numpy as jnp -from jumanji.environments.packing.bin_pack.types import State, item_volume +from jumanji.environments.packing.bin_pack.types import State, ValuedItem, item_volume from jumanji.tree_utils import tree_slice @@ -53,8 +54,17 @@ def __call__( is_done: bool, ) -> float: del next_state, is_done - _, item_id = action - chosen_item_volume = item_volume(tree_slice(state.items, item_id)) + # Check if the environment is BinPack or ExtendedBinPack + # by checking whether the action consists only of (ems,item) + # or of (ems, item, orientation). + if len(action) == 2: + _, item_id = action + chosen_item_volume = item_volume(tree_slice(state.items, item_id)) + elif len(action) == 3: + orientation, _, item_id = action + chosen_item_volume = item_volume( + tree_slice(state.items, (orientation, item_id)) + ) container_volume = state.container.volume() reward = chosen_item_volume / container_volume reward: float = jax.lax.select(is_valid, reward, jnp.array(0, float)) @@ -91,3 +101,65 @@ def sparse_reward(state: State) -> jnp.float_: next_state, ) return reward + + +class ValueBasedDenseReward(RewardFn): + """Computes a reward at each timestep, equal to the normalized value (relative to the total + value of items in the instance) of the item packed by taking the chosen action. The computed + reward is equivalent to the increase in value percentage packed within the container due to + packing the chosen item. + If the action is invalid, the reward is 0.0 instead. + """ + + def __call__( + self, + state: State, + action: chex.Array, + next_state: State, + is_valid: bool, + is_done: bool, + ) -> float: + del next_state, is_done + state.items = cast(ValuedItem, state.items) + if len(action) == 2: + _, item_id = action + chosen_item_value = tree_slice(state.items, item_id).value + elif len(action) == 3: + orientation, _, item_id = action + chosen_item_value = tree_slice(state.items, (orientation, item_id)).value + + reward = chosen_item_value / state.instance_total_value + reward: float = jax.lax.select(is_valid, reward, jnp.array(0, float)) + return reward + + +class ValueBasedSparseReward(RewardFn): + """Computes a sparse reward at the end of the episode. Returns the percentage value packed + (between 0.0 and 1.0). + If the action is invalid, the action is ignored and the reward is still returned as the current + percentage value packed. + """ + + def __call__( + self, + state: State, + action: chex.Array, + next_state: State, + is_valid: bool, + is_done: bool, + ) -> float: + del state, action, is_valid + + def sparse_reward(state: State) -> jnp.float_: + """Returns volume utilization between 0.0 and 1.0.""" + state.items = cast(ValuedItem, state.items) + items_value = jnp.sum(state.items.value * state.items_placed) + return items_value / state.instance_total_value + + reward: float = jax.lax.cond( + is_done, + sparse_reward, + lambda _: jnp.array(0, float), + next_state, + ) + return reward diff --git a/jumanji/environments/packing/bin_pack/reward_test.py b/jumanji/environments/packing/bin_pack/reward_test.py index 1630fc292..e838457da 100644 --- a/jumanji/environments/packing/bin_pack/reward_test.py +++ b/jumanji/environments/packing/bin_pack/reward_test.py @@ -17,8 +17,13 @@ import jumanji.tree_utils from jumanji.environments.packing.bin_pack.env import BinPack -from jumanji.environments.packing.bin_pack.reward import DenseReward, SparseReward -from jumanji.environments.packing.bin_pack.types import item_volume +from jumanji.environments.packing.bin_pack.reward import ( + DenseReward, + SparseReward, + ValueBasedDenseReward, + ValueBasedSparseReward, +) +from jumanji.environments.packing.bin_pack.types import item_value, item_volume def test__sparse_reward( @@ -62,9 +67,60 @@ def test__sparse_reward( assert jnp.isclose(reward, item_volume(item)) +def test__sparse_value_reward(bin_pack_sparse_value_reward: BinPack) -> None: + """This test is the same as the regular sparse reward test but with value instead of volume.""" + bin_pack_sparse_reward = bin_pack_sparse_value_reward + sparse_reward = ValueBasedSparseReward() + + reward_fn = jax.jit(sparse_reward) + step_fn = jax.jit(bin_pack_sparse_reward.step) + state, timestep = bin_pack_sparse_reward.reset(jax.random.PRNGKey(0)) + + # Check that the reward is correct for the next item. + for item_id, is_valid in enumerate(timestep.observation.items_mask): + action = jnp.array([0, item_id], jnp.int32) + next_state, next_timestep = step_fn(state, action) + reward = reward_fn( + state, action, next_state, is_valid, is_done=next_timestep.last() + ) + assert reward == next_timestep.reward == 0 + + # Check that all other invalid actions lead to the 0 reward, any ems_id > 0 is not valid at + # the beginning of the episode. + for ems_id in range(1, timestep.observation.action_mask.shape[0]): + for item_id in range(timestep.observation.action_mask.shape[1]): + action = jnp.array([ems_id, item_id], jnp.int32) + next_state, next_timestep = step_fn(state, action) + is_valid = timestep.observation.action_mask[tuple(action)] + is_done = next_timestep.last() + assert ~is_valid and is_done + reward = reward_fn(state, action, next_state, is_valid, is_done) + assert reward == 0 == next_timestep.reward + + # Check that taking an invalid action after packing one item returns the utilization + # of the first item. + action = jnp.array([0, 0], jnp.int32) + state, timestep = step_fn(state, action) + assert timestep.reward == 0 + assert timestep.mid() + next_state, timestep = step_fn(state, action) + reward = reward_fn(state, action, next_state, is_valid=False, is_done=True) + assert timestep.last() + item = jumanji.tree_utils.tree_slice(timestep.observation.items, action[1]) + instance_total_value = state.instance_total_value + instance_max_item_value_magnitude = state.instance_max_item_value_magnitude + # Multiply by instance_max_item_value_magnitude to undo the value normalisation of the item and + # divide by instance_total_value since this is what is used for reward normalisation. + assert jnp.isclose( + reward, + item_value(item) * instance_max_item_value_magnitude / instance_total_value, + ) + + def test_dense_reward( bin_pack_dense_reward: BinPack, dense_reward: DenseReward ) -> None: + """This test is the same as the regular dense reward test but with value instead of volume.""" reward_fn = jax.jit(dense_reward) step_fn = jax.jit(bin_pack_dense_reward.step) state, timestep = bin_pack_dense_reward.reset(jax.random.PRNGKey(0)) @@ -94,3 +150,46 @@ def test_dense_reward( assert ~is_valid and is_done reward = reward_fn(state, action, next_state, is_valid, is_done) assert reward == 0 == next_timestep.reward + + +def test_dense_value_reward(bin_pack_dense_value_reward: BinPack) -> None: + bin_pack_dense_reward = bin_pack_dense_value_reward + reward_fn = jax.jit(ValueBasedDenseReward()) + step_fn = jax.jit(bin_pack_dense_reward.step) + state, timestep = bin_pack_dense_reward.reset(jax.random.PRNGKey(0)) + + # Check that the reward is correct for the next item. + for item_id, is_valid in enumerate(timestep.observation.items_mask): + action = jnp.array([0, item_id], jnp.int32) + next_state, next_timestep = step_fn(state, action) + reward = reward_fn( + state, action, next_state, is_valid, is_done=next_timestep.last() + ) + assert reward == next_timestep.reward + if is_valid: + item = jumanji.tree_utils.tree_slice(timestep.observation.items, item_id) + instance_total_value = state.instance_total_value + instance_max_item_value_magnitude = state.instance_max_item_value_magnitude + # Multiply by instance_max_item_value_magnitude to undo the value normalisation of the + # item and divide by instance_total_value since this is what is used for reward + # normalisation. + assert jnp.isclose( + reward, + item_value(item) + * instance_max_item_value_magnitude + / instance_total_value, + ) + else: + assert reward == 0 + assert next_timestep.last() + + # Check that all other invalid actions lead to the 0 reward. + for ems_id in range(1, timestep.observation.action_mask.shape[0]): + for item_id in range(timestep.observation.action_mask.shape[1]): + action = jnp.array([ems_id, item_id], jnp.int32) + next_state, next_timestep = step_fn(state, action) + is_valid = timestep.observation.action_mask[tuple(action)] + is_done = next_timestep.last() + assert ~is_valid and is_done + reward = reward_fn(state, action, next_state, is_valid, is_done) + assert reward == 0 == next_timestep.reward diff --git a/jumanji/environments/packing/bin_pack/space.py b/jumanji/environments/packing/bin_pack/space.py index 89f38dc45..50aeb327b 100644 --- a/jumanji/environments/packing/bin_pack/space.py +++ b/jumanji/environments/packing/bin_pack/space.py @@ -13,7 +13,7 @@ # limitations under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Optional import chex import jax.numpy as jnp @@ -122,7 +122,9 @@ def is_included(self, space: Space) -> chex.Numeric: & (self.z2 <= space.z2) ) - def hyperplane(self, axis: str, direction: str) -> Space: + def hyperplane( + self, axis: str, direction: str, full_support: Optional[bool] = False + ) -> "Space": """Returns the hyperplane (e.g. lower hyperplane on the x axis) for EMS creation when packing an item. @@ -135,6 +137,13 @@ def hyperplane(self, axis: str, direction: str) -> Space: """ inf_ = jnp.inf axis_direction = f"{axis}_{direction}" + # Can set supporting item to self because this method is only ever called by item spaces. + if full_support: + supporting_item_space = self + else: + supporting_item_space = Space( + x1=-inf_, x2=inf_, y1=-inf_, y2=inf_, z1=-inf_, z2=inf_ + ) if axis_direction == "x_lower": return Space(x1=-inf_, x2=self.x1, y1=-inf_, y2=inf_, z1=-inf_, z2=inf_) elif axis_direction == "x_upper": @@ -146,7 +155,14 @@ def hyperplane(self, axis: str, direction: str) -> Space: elif axis_direction == "z_lower": return Space(x1=-inf_, x2=inf_, y1=-inf_, y2=inf_, z1=-inf_, z2=self.z1) elif axis_direction == "z_upper": - return Space(x1=-inf_, x2=inf_, y1=-inf_, y2=inf_, z1=self.z2, z2=inf_) + return Space( + x1=supporting_item_space.x1, + x2=supporting_item_space.x2, + y1=supporting_item_space.y1, + y2=supporting_item_space.y2, + z1=self.z2, + z2=inf_, + ) else: raise ValueError( f"arguments not valid, got axis: {axis} and direction: {direction}." diff --git a/jumanji/environments/packing/bin_pack/types.py b/jumanji/environments/packing/bin_pack/types.py index 5f6845a84..a365b8dc5 100644 --- a/jumanji/environments/packing/bin_pack/types.py +++ b/jumanji/environments/packing/bin_pack/types.py @@ -12,19 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, NamedTuple, Optional +from typing import TYPE_CHECKING, NamedTuple, Optional, Union +import chex +import jax.numpy as jnp from typing_extensions import TypeAlias +from jumanji.environments.packing.bin_pack.space import Space + if TYPE_CHECKING: from dataclasses import dataclass else: from chex import dataclass -import chex -import jax.numpy as jnp - -from jumanji.environments.packing.bin_pack.space import Space Container: TypeAlias = Space EMS: TypeAlias = Space @@ -50,7 +50,79 @@ def item_from_space(space: Space) -> Item: ) -def item_fits_in_item(item: Item, other_item: Item) -> chex.Array: +class ValuedItem(NamedTuple): + x_len: chex.Numeric + y_len: chex.Numeric + z_len: chex.Numeric + value: chex.Numeric + + +def valued_item_from_space_and_max_value( + space: Space, value: chex.Numeric +) -> ValuedItem: + return ValuedItem( + x_len=space.x2 - space.x1, + y_len=space.y2 - space.y1, + z_len=space.z2 - space.z1, + value=value, + ) + + +ItemType: TypeAlias = Union[Item, ValuedItem] + + +def rotated_items_from_space( + space: Space, value: Optional[chex.Numeric] = None +) -> ItemType: + x_len = jnp.asarray( + [ + # x along X, y along Y, z along Z (orientation A of DeepPack) + space.x2 - space.x1, + # x along X, z along Y, y along Z (Orientation B of DeepPack) + space.x2 - space.x1, + # z along X, y along Y, x along Z (Orientation C of DeepPack) + space.z2 - space.z1, + # y along X, x along Y, z along Z (Orientation D of DeepPack) + space.y2 - space.y1, + # z along X, x along Y, y along Z (Orientation E of DeepPack) + space.z2 - space.z1, + # y along X, z along Y, x along Z (Orientation F of deepPack) + space.y2 - space.y1, + ], + ) + y_len = jnp.asarray( + [ + space.y2 - space.y1, + space.z2 - space.z1, + space.y2 - space.y1, + space.x2 - space.x1, + space.x2 - space.x1, + space.z2 - space.z1, + ] + ) + z_len = jnp.asarray( + [ + space.z2 - space.z1, + space.y2 - space.y1, + space.x2 - space.x1, + space.z2 - space.z1, + space.y2 - space.y1, + space.x2 - space.x1, + ] + ) + if value is None: + return Item( + x_len=x_len, + y_len=y_len, + z_len=z_len, + ) + else: + return ValuedItem( + x_len=x_len, y_len=y_len, z_len=z_len, value=jnp.asarray(6 * [value]) + ) + + +def item_fits_in_item(item: ItemType, other_item: ItemType) -> chex.Array: """Check if an item is smaller than another one.""" return ( (item.x_len <= other_item.x_len) @@ -59,7 +131,7 @@ def item_fits_in_item(item: Item, other_item: Item) -> chex.Array: ) -def item_volume(item: Item) -> chex.Array: +def item_volume(item: ItemType) -> chex.Array: """Returns the volume as a float to prevent from overflow with 32 bits.""" x_len = jnp.asarray(item.x_len, float) y_len = jnp.asarray(item.y_len, float) @@ -67,6 +139,12 @@ def item_volume(item: Item) -> chex.Array: return x_len * y_len * z_len +def item_value(item: ItemType) -> chex.Array: + if not isinstance(item, ValuedItem): + raise ValueError(f"Trying to obtain the value of an item of type {type(item)}") + return jnp.asarray(item.value, float) + + class Location(NamedTuple): x: chex.Numeric y: chex.Numeric @@ -90,7 +168,7 @@ def location_from_space(space: Space) -> Location: ) -def space_from_item_and_location(item: Item, location: Location) -> Space: +def space_from_item_and_location(item: ItemType, location: Location) -> Space: """Returns a space from an item at a particular location. The bottom left corner is given by the location while the top right is the location plus the item dimensions. """ @@ -123,13 +201,20 @@ class State: container: Container # leaves of shape () ems: EMS # leaves of shape (max_num_ems,) ems_mask: chex.Array # (max_num_ems,) - items: Item # leaves of shape (max_num_items,) - items_mask: chex.Array # (max_num_items,) - items_placed: chex.Array # (max_num_items,) + # Since the items are allowed to take one of 6 orientations the items, items_mask , items_placed + # and action_mask tensors all have an extra dimension of size 6 representing the orientation + # that the items takes. The agent however sees the several orientations of each items as 6 + # different items among which it can only pack one. + items: ItemType # leaves of shape (6,max_num_items,) + items_mask: chex.Array # (6,max_num_items) + items_placed: chex.Array # (6,max_num_items) items_location: Location # leaves of shape (max_num_items,) - action_mask: Optional[chex.Array] # (obs_num_ems, max_num_items) + action_mask: Optional[chex.Array] # (6, obs_num_ems, max_num_items) sorted_ems_indexes: chex.Array # (max_num_ems,) + instance_max_item_value_magnitude: chex.Numeric # () - only for value based optimisation + instance_total_value: chex.Numeric # leaves of shape () - only for value based optimisation key: chex.PRNGKey # (2,) + nb_items: int # leaves of shape () class Observation(NamedTuple): @@ -145,7 +230,7 @@ class Observation(NamedTuple): ems: EMS # leaves of shape (obs_num_ems,) ems_mask: chex.Array # (obs_num_ems,) - items: Item # leaves of shape (max_num_items,) - items_mask: chex.Array # (max_num_items,) - items_placed: chex.Array # (max_num_items,) - action_mask: chex.Array # (obs_num_ems, max_num_items) + items: ItemType # leaves of shape (max_num_items,) + items_mask: chex.Array # (6*max_num_items,) + items_placed: chex.Array # (6*max_num_items,) + action_mask: chex.Array # (obs_num_ems, 6*max_num_items) diff --git a/jumanji/environments/packing/bin_pack/viewer.py b/jumanji/environments/packing/bin_pack/viewer.py index 223cc3813..56046cfbf 100644 --- a/jumanji/environments/packing/bin_pack/viewer.py +++ b/jumanji/environments/packing/bin_pack/viewer.py @@ -274,3 +274,103 @@ def _get_used_volume(self, state: State) -> float: if placed ) return used_volume + + +class ExtendedBinPackViewer(BinPackViewer): + def __init__( + self, name: str, is_rotation_allowed: bool, render_mode: str = "human" + ) -> None: + """ + This class defines a viewer for the ExtendedBinPack environment. + It inherits from the BinPackViewer class and redefines the methods + _add_overlay, _create_entities and _get_used_volume. This overriding + is necessary because the shape of the items array, items_placed array and item_mask + array have changed in this environment. + """ + super().__init__(name, render_mode) + self.is_rotation_allowed = is_rotation_allowed + + def _add_overlay(self, fig: plt.Figure, ax: plt.Axes, state: State) -> None: + """Sets the bounds of the scene and displays text about the scene. + + Args: + state: `State` of the environment + """ + eps = 0.05 + container = item_from_space(state.container) + ax.set( + xlim=(-container.x_len * eps, container.x_len * (1 + eps)), + ylim=(-container.y_len * eps, container.y_len * (1 + eps)), + zlim=(-container.z_len * eps, container.z_len * (1 + eps)), + ) + ax.set_xlabel("x", font=self.FONT_STYLE) + ax.set_ylabel("y", font=self.FONT_STYLE) + ax.set_zlabel("z", font=self.FONT_STYLE) + if self.is_rotation_allowed: + n_items = state.items_mask.shape[1] + else: + n_items = state.items_mask.shape[0] + placed_items: np.ndarray = np.sum(state.items_placed) + container_volume = ( + float(container.x_len) * float(container.y_len) * float(container.z_len) + ) + used_volume = self._get_used_volume(state) + + metrics = [ + ("Placed", f"{placed_items:{len(str(n_items))}}/{n_items}"), + ("Used Volume", f"{used_volume / container_volume:6.1%}"), + ] + title = " | ".join(key + ": " + value for key, value in metrics) + fig.suptitle(title, font=self.FONT_STYLE) + + def _create_entities( # noqa: CCR001 + self, state: State + ) -> List[mpl_toolkits.mplot3d.art3d.Poly3DCollection]: + entities = [] + if self.is_rotation_allowed: + n_items = state.items_mask.shape[1] + cmap = plt.cm.get_cmap("hsv", n_items) + for i in range(6): + for j in range(n_items): + if state.items_placed[i, j]: + box = self._create_box( + ( + state.items_location.x[j], + state.items_location.y[j], + state.items_location.z[j], + ), + ( + state.items.x_len[i, j], + state.items.y_len[i, j], + state.items.z_len[i, j], + ), + cmap(j), + 0.3, + ) + entities.append(box) + + container = item_from_space(state.container) + box = self._create_box( + (0.0, 0.0, 0.0), + (container.x_len, container.y_len, container.z_len), + "cyan", + 0.05, + ) + entities.append(box) + return entities + else: + return super()._create_entities(state) + + def _get_used_volume(self, state: State) -> float: + if self.is_rotation_allowed: + used_volume = sum( + float(state.items.x_len[i, j]) + * float(state.items.y_len[i, j]) + * float(state.items.z_len[i, j]) + for i in range(6) + for j, placed in enumerate(state.items_placed[i]) + if placed + ) + return used_volume + else: + return super()._get_used_volume(state) diff --git a/jumanji/testing/pytrees.py b/jumanji/testing/pytrees.py index a60dee52e..636d15e86 100644 --- a/jumanji/testing/pytrees.py +++ b/jumanji/testing/pytrees.py @@ -46,6 +46,23 @@ def is_equal_pytree(tree1: MixedTypeTree, tree2: MixedTypeTree) -> bool: return bool(is_equal) +def is_close_pytree(tree1: MixedTypeTree, tree2: MixedTypeTree) -> bool: + """Returns true if all leaves in `tree1` and `tree2` are equal. Requires that `tree1` and + `tree2` share the same structure, and that `np.asarray(leaf)` is valid for all leaves of the + trees. + + Note that this function will block gradients between the input and output, and is + created for use in the context of testing rather than for direct use inside RL algorithms.""" + is_equal_func = lambda leaf1, leaf2: np.isclose( + np.asarray(leaf1), np.asarray(leaf2) + ) + is_equal_leaves = tree_lib.flatten( + tree_lib.map_structure(is_equal_func, tree1, tree2) + ) + is_equal = np.all(is_equal_leaves) + return bool(is_equal) + + def assert_trees_are_different(tree1: MixedTypeTree, tree2: MixedTypeTree) -> None: """Checks whether `tree1` and `tree2` have at least one leaf where they differ. Requires that `tree1` and `tree2` share the same structure, and that `np.asarray(leaf)` is valid for all @@ -70,6 +87,15 @@ def assert_trees_are_equal(tree1: MixedTypeTree, tree2: MixedTypeTree) -> None: ), "The trees differ in at least one leaf's value(s)." +def assert_trees_are_close(tree1: MixedTypeTree, tree2: MixedTypeTree) -> None: + """Checks if all leaves in a `tree1` and `tree2` are close (equal if ints, and equal to within + a tolerance if not). Requires that `tree1` and `tree2` share the same structure, and that + `np.asarray(leaf)` is valid for all leaves of the trees.""" + assert is_close_pytree( + tree1, tree2 + ), "The trees differ in at least one leaf's value(s)." + + def is_tree_with_leaves_of_type(input_tree: Any, *leaf_type: Type) -> bool: """Returns true if all leaves in the `input_tree` are of the specified `leaf_type`.""" leaf_is_type_func = lambda leaf: isinstance(leaf, leaf_type) diff --git a/jumanji/training/networks/bin_pack/actor_critic.py b/jumanji/training/networks/bin_pack/actor_critic.py index c3de93f8f..ea0df420b 100644 --- a/jumanji/training/networks/bin_pack/actor_critic.py +++ b/jumanji/training/networks/bin_pack/actor_critic.py @@ -21,7 +21,7 @@ import numpy as np from jumanji.environments.packing.bin_pack import BinPack, Observation -from jumanji.environments.packing.bin_pack.types import EMS, Item +from jumanji.environments.packing.bin_pack.types import EMS, ItemType from jumanji.training.networks.actor_critic import ( ActorCriticNetworks, FeedForwardNetwork, @@ -146,7 +146,7 @@ def embed_ems(self, ems: EMS) -> chex.Array: embeddings = hk.Linear(self.model_size, name="ems_projection")(ems_leaves) return embeddings - def embed_items(self, items: Item) -> chex.Array: + def embed_items(self, items: ItemType) -> chex.Array: # Stack the 3 items attributes into a single vector [x_len, y_len, z_len]. items_leaves = jnp.stack(jax.tree_util.tree_leaves(items), axis=-1) # Projection of the EMSs. diff --git a/jumanji/training/setup_train.py b/jumanji/training/setup_train.py index ef30367ac..c6ab480e9 100644 --- a/jumanji/training/setup_train.py +++ b/jumanji/training/setup_train.py @@ -30,6 +30,7 @@ Cleaner, Connector, FlatPack, + ExtendedBinPack, Game2048, GraphColoring, JobShop, @@ -91,8 +92,36 @@ def setup_logger(cfg: DictConfig) -> Logger: return logger -def _make_raw_env(cfg: DictConfig) -> Environment: - return jumanji.make(cfg.env.registered_version) +def _make_raw_env(cfg: DictConfig, eval: bool = False) -> Environment: + try: + env = jumanji.make(cfg.env.registered_version) + except ValueError as error: + if ( + "Unregistered environment" in str(error) + and cfg.env.name != "extended_bin_pack" + ): + raise ValueError( + "Unregistered environment setup not possible for any other argument" + f"other than bin_pack, env requested is {cfg.env.name}." + ) + env_settings_dict = getattr(cfg.env, "env_settings", {}) + reward_string = cfg.env.env_settings.reward_fn + reward_fn = getattr(jumanji.environments.packing.bin_pack.reward, reward_string) + generator_string = cfg.env.env_settings.generator + generator_settings = cfg.env.generator_settings + if eval: + generator_settings = dict(generator_settings) + generator_settings["is_evaluation"] = True + generator = getattr( + jumanji.environments.packing.bin_pack.generator, generator_string + ) + env_settings_dict = { + **cfg.env.env_settings, + "generator": generator(**generator_settings), + "reward_fn": reward_fn(), + } + env = ExtendedBinPack(**env_settings_dict) + return env def setup_env(cfg: DictConfig) -> Environment: @@ -138,7 +167,7 @@ def _setup_random_policy( # noqa: CCR001 cfg: DictConfig, env: Environment ) -> RandomPolicy: assert cfg.agent == "random" - if cfg.env.name == "bin_pack": + if cfg.env.name == "bin_pack" or cfg.env.name == "extended_bin_pack": assert isinstance(env.unwrapped, BinPack) random_policy = networks.make_random_policy_bin_pack(bin_pack=env.unwrapped) elif cfg.env.name == "snake": @@ -219,7 +248,7 @@ def _setup_actor_critic_neworks( # noqa: CCR001 cfg: DictConfig, env: Environment ) -> ActorCriticNetworks: assert cfg.agent == "a2c" - if cfg.env.name == "bin_pack": + if cfg.env.name == "bin_pack" or cfg.env.name == "extended_bin_pack": assert isinstance(env.unwrapped, BinPack) actor_critic_networks = networks.make_actor_critic_networks_bin_pack( bin_pack=env.unwrapped, @@ -424,7 +453,7 @@ def _setup_actor_critic_neworks( # noqa: CCR001 def setup_evaluators(cfg: DictConfig, agent: Agent) -> Tuple[Evaluator, Evaluator]: - env = _make_raw_env(cfg) + env = _make_raw_env(cfg, eval=True) stochastic_eval = Evaluator( eval_env=env, agent=agent, diff --git a/jumanji/training/train.py b/jumanji/training/train.py index 4d01d7785..7ff6aeadd 100644 --- a/jumanji/training/train.py +++ b/jumanji/training/train.py @@ -84,6 +84,7 @@ def epoch_fn(training_state: TrainingState) -> Tuple[TrainingState, Dict]: training_state.params_state, stochastic_eval_key ) jax.block_until_ready(metrics) + print(utils.first_from_device(metrics)) logger.write( data=utils.first_from_device(metrics), label="eval_stochastic", diff --git a/train_test.py b/train_test.py new file mode 100644 index 000000000..3f5be9ed3 --- /dev/null +++ b/train_test.py @@ -0,0 +1,35 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings + +from hydra import compose, initialize + +from jumanji.training.train import train + +warnings.filterwarnings("ignore") + +env = "extended_bin_pack" +agent = "a2c" +with initialize(version_base=None, config_path="configs"): + cfg = compose( + config_name="config.yaml", + overrides=[ + f"env={env}", + f"agent={agent}", + "logger.type=terminal", + "logger.save_checkpoint=Trues", + ], + ) +train(cfg) diff --git a/value_based_training_script.py b/value_based_training_script.py new file mode 100644 index 000000000..27b0aabeb --- /dev/null +++ b/value_based_training_script.py @@ -0,0 +1,46 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import subprocess +import warnings + +from hydra import compose, initialize + +from jumanji.training.train import train + +warnings.filterwarnings("ignore") + +# Based on: +# stackoverflow.com/questions/67504079/how-to-check-if-an-nvidia-gpu-is-available-on-my-system +try: + subprocess.check_output("nvidia-smi") + print("a GPU is connected.") +except Exception: + # TPU or CPU + if "COLAB_TPU_ADDR" in os.environ and os.environ["COLAB_TPU_ADDR"]: + import jax.tools.colab_tpu + + jax.tools.colab_tpu.setup_tpu() + print("A TPU is connected.") + else: + print("Only CPU accelerator is connected.") + +config = "configs/config.yaml" +env_config = "configs/env/bin_pack.yaml" + +with initialize(version_base=None, config_path="configs"): + cfg = compose(config_name="config.yaml", overrides=["logger.save_checkpoint=true"]) + +train(cfg)