Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions perceptionmetrics/models/torch_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,10 +299,14 @@ def __init__(
# Load confidence and NMS thresholds from config
self.confidence_threshold = self.model_cfg.get("confidence_threshold", 0.5)
self.nms_threshold = self.model_cfg.get("nms_threshold", 0.3)
self.max_detections_per_image = self.model_cfg.get(
"max_detections_per_image", 100
Comment thread
RihaanBH-1810 marked this conversation as resolved.
Outdated
)

self.postprocess_args = [self.confidence_threshold]
if self.model_format == "yolo":
self.postprocess_args.append(self.nms_threshold)
self.postprocess_args.append(self.max_detections_per_image)

# Add reverse mapping for idx to class_name
self.idx_to_class_name = {v["idx"]: k for k, v in self.ontology.items()}
Expand Down
17 changes: 16 additions & 1 deletion perceptionmetrics/models/utils/torchvision.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
def postprocess_detection(output: dict, confidence_threshold: float = 0.5):
def postprocess_detection(
output: dict, confidence_threshold: float = 0.5, max_detections: int = 100
Comment thread
RihaanBH-1810 marked this conversation as resolved.
Outdated
):
"""Post-process torchvision model output.

:param output: Dictionary with keys 'boxes', 'labels', and 'scores'.
:type output: dict
:param confidence_threshold: Confidence threshold to filter boxes.
:type confidence_threshold: float
:param max_detections: Maximum number of best detections to keep per image after filtering.
:type max_detections: int
:return: Dictionary with keys 'boxes', 'labels', and 'scores'.
:rtype: dict
"""
Expand All @@ -15,4 +19,15 @@ def postprocess_detection(output: dict, confidence_threshold: float = 0.5):
"labels": output["labels"][keep_mask],
"scores": output["scores"][keep_mask],
}

if max_detections > 0:
limit = min(max_detections, output["scores"].shape[0])
if limit > 0:
Comment thread
RihaanBH-1810 marked this conversation as resolved.
Outdated
limited_idx = output["scores"].argsort(descending=True)[:limit]
output = {
"boxes": output["boxes"][limited_idx],
"labels": output["labels"][limited_idx],
"scores": output["scores"][limited_idx],
}

return output
12 changes: 11 additions & 1 deletion perceptionmetrics/models/utils/yolo.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import torch
from torchvision.ops import nms


CLASS_NMS_OFFSET = 7680 # offset to apply to boxes for class-wise NMS


def postprocess_detection(
output: torch.Tensor,
confidence_threshold: float = 0.25,
nms_threshold: float = 0.45,
max_detections: int = 100,
):
"""Post-process YOLO model output.

Expand All @@ -18,6 +18,8 @@ def postprocess_detection(
:type confidence_threshold: float
:param nms_threshold: IoU threshold for Non-Maximum Suppression (NMS). Some models may not perform NMS (e.g. YOLOv26).
:type nms_threshold: float
:param max_detections: Maximum number of best detections to keep per image after filtering.
:type max_detections: int
:return: Dictionary with keys 'boxes', 'labels', and 'scores'.
:rtype: dict
"""
Expand Down Expand Up @@ -57,4 +59,12 @@ def postprocess_detection(
scores = scores[keep_idx]
labels = labels[keep_idx]

if max_detections > 0:
limit = min(max_detections, scores.shape[0])
if limit > 0:
limited_idx = scores.argsort(descending=True)[:limit]
boxes_xyxy = boxes_xyxy[limited_idx]
scores = scores[limited_idx]
labels = labels[limited_idx]

return {"boxes": boxes_xyxy, "labels": labels, "scores": scores}