From 7d6bb308e26ffe9566f4d7cdaf072809bad0f449 Mon Sep 17 00:00:00 2001 From: Travor <3488616445@qq.com> Date: Thu, 7 May 2026 02:22:01 +0800 Subject: [PATCH] Fix image resize minimum-size handling --- .../src/open_r1/trainer/grpo_trainer.py | 15 ++----- .../src/open_r1/utils/image.py | 17 ++++++++ src/open-r1-multimodal/tests/test_image.py | 43 +++++++++++++++++++ 3 files changed, 63 insertions(+), 12 deletions(-) create mode 100644 src/open-r1-multimodal/src/open_r1/utils/image.py create mode 100644 src/open-r1-multimodal/tests/test_image.py diff --git a/src/open-r1-multimodal/src/open_r1/trainer/grpo_trainer.py b/src/open-r1-multimodal/src/open_r1/trainer/grpo_trainer.py index d7dbe1ab4..08fbd0390 100755 --- a/src/open-r1-multimodal/src/open_r1/trainer/grpo_trainer.py +++ b/src/open-r1-multimodal/src/open_r1/trainer/grpo_trainer.py @@ -60,7 +60,9 @@ if is_wandb_available(): import wandb +from open_r1.utils.image import resize_image_min_size from open_r1.vlm_modules.vlm_module import VLMBaseModule + # What we call a reward function is a callable that takes a list of prompts and completions and returns a list of # rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model. RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]] @@ -535,21 +537,10 @@ def _generate_and_score_completions(self, inputs: dict[str, Union[torch.Tensor, for img in imgs: try: - # Ensure minimum dimensions of 28 pixels - w, h = img.size - if w < 28 or h < 28: - # Calculate new dimensions maintaining aspect ratio - if w < h: - new_w = 28 - new_h = int(h * (28/w)) - else: - new_h = 28 - new_w = int(w * (28/h)) - img = img.resize((new_w, new_h), PIL.Image.Resampling.LANCZOS) + img = resize_image_min_size(img, min_size=28) except: pass images.append(img) - prompt_inputs, additional_output = self.vlm_module.prepare_model_inputs( self.processing_class, diff --git a/src/open-r1-multimodal/src/open_r1/utils/image.py b/src/open-r1-multimodal/src/open_r1/utils/image.py new file mode 100644 index 000000000..92b76fa3b --- /dev/null +++ b/src/open-r1-multimodal/src/open_r1/utils/image.py @@ -0,0 +1,17 @@ +import PIL.Image + + +def resize_image_min_size(image: PIL.Image.Image, min_size: int = 28) -> PIL.Image.Image: + """Resize an image to ensure its shortest side is at least ``min_size``.""" + width, height = image.size + if width >= min_size and height >= min_size: + return image + + if width < height: + new_width = min_size + new_height = int(height * (min_size / width)) + else: + new_height = min_size + new_width = int(width * (min_size / height)) + + return image.resize((new_width, new_height), PIL.Image.Resampling.LANCZOS) diff --git a/src/open-r1-multimodal/tests/test_image.py b/src/open-r1-multimodal/tests/test_image.py new file mode 100644 index 000000000..840ee8c53 --- /dev/null +++ b/src/open-r1-multimodal/tests/test_image.py @@ -0,0 +1,43 @@ +import PIL.Image + +from open_r1.utils.image import resize_image_min_size + + +def test_resize_image_min_size_keeps_large_images(): + image = PIL.Image.new("RGB", (224, 224)) + + resized = resize_image_min_size(image, min_size=28) + + assert resized.size == (224, 224) + + +def test_resize_image_min_size_keeps_boundary_images(): + image = PIL.Image.new("RGB", (28, 28)) + + resized = resize_image_min_size(image, min_size=28) + + assert resized.size == (28, 28) + + +def test_resize_image_min_size_expands_small_width(): + image = PIL.Image.new("RGB", (20, 50)) + + resized = resize_image_min_size(image, min_size=28) + + assert resized.size == (28, 70) + + +def test_resize_image_min_size_expands_small_height(): + image = PIL.Image.new("RGB", (50, 20)) + + resized = resize_image_min_size(image, min_size=28) + + assert resized.size == (70, 28) + + +def test_resize_image_min_size_expands_equal_small_sides(): + image = PIL.Image.new("RGB", (20, 20)) + + resized = resize_image_min_size(image, min_size=28) + + assert resized.size == (28, 28)