Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 3 additions & 12 deletions src/open-r1-multimodal/src/open_r1/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@
if is_wandb_available():
import wandb

from open_r1.utils.image import resize_image_min_size
from open_r1.vlm_modules.vlm_module import VLMBaseModule

# What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
# rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
Expand Down Expand Up @@ -535,21 +537,10 @@ def _generate_and_score_completions(self, inputs: dict[str, Union[torch.Tensor,

for img in imgs:
try:
# Ensure minimum dimensions of 28 pixels
w, h = img.size
if w < 28 or h < 28:
# Calculate new dimensions maintaining aspect ratio
if w < h:
new_w = 28
new_h = int(h * (28/w))
else:
new_h = 28
new_w = int(w * (28/h))
img = img.resize((new_w, new_h), PIL.Image.Resampling.LANCZOS)
img = resize_image_min_size(img, min_size=28)
except:
pass
images.append(img)


prompt_inputs, additional_output = self.vlm_module.prepare_model_inputs(
self.processing_class,
Expand Down
17 changes: 17 additions & 0 deletions src/open-r1-multimodal/src/open_r1/utils/image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import PIL.Image


def resize_image_min_size(image: PIL.Image.Image, min_size: int = 28) -> PIL.Image.Image:
"""Resize an image to ensure its shortest side is at least ``min_size``."""
width, height = image.size
if width >= min_size and height >= min_size:
return image

if width < height:
new_width = min_size
new_height = int(height * (min_size / width))
else:
new_height = min_size
new_width = int(width * (min_size / height))

return image.resize((new_width, new_height), PIL.Image.Resampling.LANCZOS)
43 changes: 43 additions & 0 deletions src/open-r1-multimodal/tests/test_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import PIL.Image

from open_r1.utils.image import resize_image_min_size


def test_resize_image_min_size_keeps_large_images():
image = PIL.Image.new("RGB", (224, 224))

resized = resize_image_min_size(image, min_size=28)

assert resized.size == (224, 224)


def test_resize_image_min_size_keeps_boundary_images():
image = PIL.Image.new("RGB", (28, 28))

resized = resize_image_min_size(image, min_size=28)

assert resized.size == (28, 28)


def test_resize_image_min_size_expands_small_width():
image = PIL.Image.new("RGB", (20, 50))

resized = resize_image_min_size(image, min_size=28)

assert resized.size == (28, 70)


def test_resize_image_min_size_expands_small_height():
image = PIL.Image.new("RGB", (50, 20))

resized = resize_image_min_size(image, min_size=28)

assert resized.size == (70, 28)


def test_resize_image_min_size_expands_equal_small_sides():
image = PIL.Image.new("RGB", (20, 20))

resized = resize_image_min_size(image, min_size=28)

assert resized.size == (28, 28)