Skip to content
Open
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
368edc2
feat: Added features field to the classification response
mohamedelabbas1996 Apr 14, 2025
4484f2e
feat: add support for returning features in APIMothClassifier response
mohamedelabbas1996 Apr 14, 2025
3cc31ad
added fallback get_features method to the InferenceBaseClass
mohamedelabbas1996 Apr 14, 2025
8071168
feat: implemented get_features for Resnet50TimmClassifier class
mohamedelabbas1996 Apr 14, 2025
52f0f62
chore: moved features dim to constants
mohamedelabbas1996 Apr 14, 2025
b4c3af7
Default to None if get_features is not implemented
mohamedelabbas1996 Apr 17, 2025
ae62dd5
Added features extraction tests
mohamedelabbas1996 Apr 17, 2025
88c8220
Removed prints
mohamedelabbas1996 Apr 17, 2025
fa7dee8
Added clustering using K-Means and visualization
mohamedelabbas1996 Apr 23, 2025
cce38f3
Added plotly dependency
mohamedelabbas1996 Apr 23, 2025
902331b
Added sklearn dependency
mohamedelabbas1996 Apr 23, 2025
9306bd0
chore: make plotly optional, fix type warnings
mihow Apr 30, 2025
71768a2
merge: resolve conflicts with main, preserve Mohamed's feature extrac…
mihow Mar 25, 2026
4159333
feat: add get_features() to InferenceBaseClass and Resnet50TimmClassi…
mihow Mar 25, 2026
dc5fc49
feat: add include_features and include_logits config toggles
mihow Mar 25, 2026
7028ce6
feat: wire feature and logits extraction into APIMothClassifier
mihow Mar 25, 2026
2afe1e7
feat: pass include_features and include_logits from API and worker
mihow Mar 25, 2026
3183ee4
test: add feature and logits extraction API tests
mihow Mar 25, 2026
aa530fc
merge: update to latest main (GPU utilization fixes)
mihow Mar 25, 2026
f48effb
fix: resolve timing bug and update existing tests for opt-in logits
mihow Mar 25, 2026
598d6ed
test: add worker-path and feature validity tests
mihow Mar 25, 2026
b4b0fbf
fix: release feature tensor after use, add settings UI metadata
mihow Mar 25, 2026
9189018
merge: pick up memory leak threshold bump from main (#124)
mihow Mar 25, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions trapdata/antenna/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch.multiprocessing as mp
import torchvision

from trapdata.antenna.client import get_full_service_name, get_jobs, post_batch_results
from trapdata.antenna.client import get_full_service_name, get_jobs
from trapdata.antenna.datasets import get_rest_dataloader
from trapdata.antenna.result_posting import ResultPoster
from trapdata.antenna.schemas import AntennaTaskResult, AntennaTaskResultError
Expand Down Expand Up @@ -254,7 +254,12 @@ def _process_job(

# Defer instantiation of poster, detector and classifiers until we have data
if not classifier:
classifier = classifier_class(source_images=[], detections=[])
classifier = classifier_class(
source_images=[],
detections=[],
include_features=settings.include_features,
include_logits=settings.include_logits,
)
detector = APIMothDetector([])
result_poster = ResultPoster(max_pending=MAX_PENDING_POSTS)

Expand Down Expand Up @@ -330,13 +335,14 @@ def _process_job(

if use_binary_filter:
assert binary_filter is not None, "Binary filter not initialized"
detections_for_terminal_classifier, detections_to_return = (
_apply_binary_classification(
binary_filter,
detector_results,
image_tensors,
image_detections,
)
(
detections_for_terminal_classifier,
detections_to_return,
) = _apply_binary_classification(
binary_filter,
detector_results,
image_tensors,
image_detections,
)
else:
# No binary filtering, send all detections to terminal classifier
Expand Down Expand Up @@ -458,7 +464,9 @@ def _process_job(
)
_, t = log_time() # reset time to measure batch load time
logger.info(
f"Finished batch {i + 1}. Total items: {items}, Classification time: {cls_time:.2f}s, Detection time: {det_time:.2f}s, Load time: {load_time:.2f}s"
f"Finished batch {i + 1}. Total items: {items}, "
f"Classification time: {cls_time:.2f}s, Detection time: {det_time:.2f}s, "
f"Load time: {load_time:.2f}s"
)

if result_poster:
Expand Down
2 changes: 2 additions & 0 deletions trapdata/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,8 @@ async def process(data: PipelineRequest) -> PipelineResponse:
# single=True if len(filtered_detections) == 1 else False,
single=True, # @TODO solve issues with reading images in multiprocessing
example_config_param=data.config.example_config_param,
include_features=data.config.include_features,
include_logits=data.config.include_logits,
terminal=True,
# critera=data.config.criteria, # @TODO another approach to intermediate filter models
)
Expand Down
32 changes: 24 additions & 8 deletions trapdata/api/models/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,16 @@ def __init__(
source_images: typing.Iterable[SourceImage],
detections: typing.Iterable[DetectionResponse],
terminal: bool = True,
include_features: bool = False,
include_logits: bool = False,
*args,
**kwargs,
):
self.source_images = source_images
self.detections = list(detections)
self.terminal = terminal
self.include_features = include_features
self.include_logits = include_logits
self.results: list[DetectionResponse] = []
super().__init__(*args, **kwargs)
logger.info(
Expand All @@ -66,28 +70,39 @@ def get_dataset(self):
batch_size=self.batch_size,
)

def post_process_batch(self, logits: torch.Tensor):
"""
Return the labels, softmax/calibrated scores, and the original logits for
each image in the batch.
def predict_batch(self, batch):
batch_input = batch.to(self.device, non_blocking=True)
logits = self.model(batch_input)
features = None
if self.include_features:
features = self.get_features(batch_input)
return logits, features
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated

Almost like the base class method, but we need to return the logits as well.
def post_process_batch(self, batch_output):
"""
Return ClassifierResult objects with labels, scores, and
optional logits and feature vectors for each image in the batch.
"""
logits, features = batch_output
predictions = torch.nn.functional.softmax(logits, dim=1)
predictions = predictions.cpu().numpy()
logits = logits.cpu()
logits_cpu = logits.cpu()
if features is not None:
features = features.cpu()

batch_results = []

for i, pred in enumerate(predictions):
class_indices = np.arange(len(pred))
labels = [self.category_map[i] for i in class_indices]
logit = logits[i].tolist()
labels = [self.category_map[idx] for idx in class_indices]
logit = logits_cpu[i].tolist() if self.include_logits else None
feature_vec = features[i].tolist() if features is not None else None

result = ClassifierResult(
labels=labels,
logit=logit,
scores=pred.tolist(),
features=feature_vec,
)

batch_results.append(result)
Expand Down Expand Up @@ -164,6 +179,7 @@ def update_detection_classification(
classification=self.get_best_label(predictions),
scores=predictions.scores,
logits=predictions.logit,
features=predictions.features,
inference_time=seconds_per_item,
algorithm=AlgorithmReference(name=self.name, key=self.get_key()),
timestamp=datetime.datetime.now(),
Expand Down
35 changes: 30 additions & 5 deletions trapdata/api/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,22 @@ class ClassificationResponse(pydantic.BaseModel):
),
repr=False, # Too long to display in the repr
)
logits: list[float] = pydantic.Field(
default_factory=list,
logits: list[float] | None = pydantic.Field(
default=None,
description=(
"The raw logits output by the model, before any calibration or "
"normalization."
"Raw logits (unnormalized model outputs) for each class. "
"Only included when include_logits=true in the pipeline config."
),
repr=False, # Too long to display in the repr
repr=False,
)
features: list[float] | None = pydantic.Field(
default=None,
description=(
"Feature vector (embedding) extracted from the model backbone before "
"the classification head. Only included when include_features=true in "
"the pipeline config."
),
repr=False,
)
inference_time: float | None = None
algorithm: AlgorithmReference
Expand Down Expand Up @@ -239,6 +248,22 @@ class PipelineConfigRequest(pydantic.BaseModel):
description="Example of a configuration parameter for a pipeline.",
examples=[3],
)
include_features: bool = pydantic.Field(
default=False,
description=(
"Whether to include feature vectors (embeddings) in classification "
"responses. Feature vectors are 2048-dim floats extracted from the "
"model backbone. Disabled by default to reduce response size."
),
)
include_logits: bool = pydantic.Field(
default=False,
description=(
"Whether to include raw logits in classification responses. "
"Logits are the unnormalized model outputs before softmax. "
"Disabled by default to reduce response size."
),
)


class PipelineRequest(pydantic.BaseModel):
Expand Down
124 changes: 124 additions & 0 deletions trapdata/api/tests/test_features_extraction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import pathlib
from unittest import TestCase

from fastapi.testclient import TestClient

from trapdata.api.api import PipelineChoice, PipelineRequest, PipelineResponse, app
from trapdata.api.schemas import PipelineConfigRequest, SourceImageRequest
from trapdata.api.tests.image_server import StaticFileTestServer
from trapdata.tests import TEST_IMAGES_BASE_PATH


class TestFeatureAndLogitsExtractionAPI(TestCase):
@classmethod
def setUpClass(cls):
cls.test_images_dir = pathlib.Path(TEST_IMAGES_BASE_PATH)
cls.file_server = StaticFileTestServer(cls.test_images_dir)
cls.client = TestClient(app)

@classmethod
def tearDownClass(cls):
cls.file_server.stop()

def get_local_test_images(self, num=1):
image_paths = [
"panama/01-20231110214539-snapshot.jpg",
"panama/01-20231111032659-snapshot.jpg",
"panama/01-20231111015309-snapshot.jpg",
]
return [
SourceImageRequest(id=str(i), url=self.file_server.get_url(path))
for i, path in enumerate(image_paths[:num])
]

def _run_pipeline(
self,
include_features: bool = False,
include_logits: bool = False,
num_images: int = 1,
):
test_images = self.get_local_test_images(num=num_images)
config = PipelineConfigRequest(
include_features=include_features,
include_logits=include_logits,
)
pipeline_request = PipelineRequest(
pipeline=PipelineChoice["global_moths_2024"],
source_images=test_images,
config=config,
)
with self.file_server:
response = self.client.post("/process", json=pipeline_request.model_dump())
assert response.status_code == 200
return PipelineResponse(**response.json())
Comment thread
coderabbitai[bot] marked this conversation as resolved.

def test_features_included_when_enabled(self):
"""Features are present and valid when include_features=True."""
result = self._run_pipeline(include_features=True)
self.assertTrue(result.detections, "No detections returned")
for detection in result.detections:
for classification in detection.classifications:
if classification.terminal:
self.assertIsNotNone(
classification.features,
"Features should not be None when enabled",
)
self.assertIsInstance(classification.features, list)
self.assertTrue(
all(isinstance(x, float) for x in classification.features)
)
self.assertEqual(len(classification.features), 2048)

def test_features_absent_when_disabled(self):
"""Features are None when include_features=False (default)."""
result = self._run_pipeline(include_features=False)
self.assertTrue(result.detections, "No detections returned")
for detection in result.detections:
for classification in detection.classifications:
self.assertIsNone(
classification.features,
"Features should be None when disabled",
)

def test_logits_included_when_enabled(self):
"""Logits are present when include_logits=True."""
result = self._run_pipeline(include_logits=True)
self.assertTrue(result.detections, "No detections returned")
for detection in result.detections:
for classification in detection.classifications:
if classification.terminal:
self.assertIsNotNone(
classification.logits,
"Logits should not be None when enabled",
)
self.assertIsInstance(classification.logits, list)
self.assertTrue(
all(isinstance(x, float) for x in classification.logits)
)

def test_logits_absent_when_disabled(self):
"""Logits are None when include_logits=False (default)."""
result = self._run_pipeline(include_logits=False)
self.assertTrue(result.detections, "No detections returned")
for detection in result.detections:
for classification in detection.classifications:
self.assertIsNone(
classification.logits,
"Logits should be None when disabled",
)

def test_both_features_and_logits(self):
"""Both features and logits present when both flags enabled."""
result = self._run_pipeline(include_features=True, include_logits=True)
self.assertTrue(result.detections, "No detections returned")
for detection in result.detections:
for classification in detection.classifications:
if classification.terminal:
self.assertIsNotNone(classification.features)
self.assertIsNotNone(classification.logits)

def test_default_config_has_nothing_extra(self):
"""Default PipelineConfigRequest disables both features and logits."""
config = PipelineConfigRequest()
self.assertFalse(config.include_features)
self.assertFalse(config.include_logits)
1 change: 0 additions & 1 deletion trapdata/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
NEGATIVE_BINARY_LABEL = "nonmoth"
NULL_DETECTION_LABELS = [NEGATIVE_BINARY_LABEL]
TRACKING_COST_THRESHOLD = 1.0

POSITIVE_COLOR = [0, 100 / 255, 1, 1] # Blue
# POSITIVE_COLOR = [1, 0, 162 / 255, 1] # Pink
# NEUTRAL_COLOR = [1, 1, 1, 0.5] # White
Expand Down
9 changes: 9 additions & 0 deletions trapdata/ml/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,14 @@ def get_model(self) -> torch.nn.Module:
"""
raise NotImplementedError

def get_features(self, batch_input: torch.Tensor) -> torch.Tensor | None:
"""Extract feature vectors from the model backbone.

Override in subclasses that support feature extraction.
Returns None by default for models that don't implement it.
"""
return None

def get_transforms(self) -> torchvision.transforms.Compose:
"""
This method must be implemented by a subclass.
Expand Down Expand Up @@ -342,3 +350,4 @@ class ClassifierResult:
labels: list[str] | None
logit: list[float] | None
scores: list[float]
features: list[float] | None = None
12 changes: 12 additions & 0 deletions trapdata/ml/models/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,18 @@ def get_model(self):
model.eval()
return model

@torch.no_grad()
def get_features(self, batch_input: torch.Tensor) -> torch.Tensor:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work on this method of extracting features! It seems more flexible than our current feature extractor. Perhaps we should add a comment in both feature extractors that the other one exists. And eventually update the old one to use this code.

"""Extract 2048-dim feature vectors from the ResNet50 backbone.

Uses timm's forward_features() which returns (B, 2048, H, W) spatial
feature maps for ResNet50. Pooled to (B, 2048) via adaptive avg pool.
"""
features = self.model.forward_features(batch_input)
features = torch.nn.functional.adaptive_avg_pool2d(features, (1, 1))
features = features.view(features.size(0), -1)
return features


class BinaryClassifier(Resnet50ClassifierLowRes):
stage = 2
Expand Down
4 changes: 4 additions & 0 deletions trapdata/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ class Settings(BaseSettings):
antenna_service_name: str = "AMI Data Companion"
antenna_api_batch_size: int = 16

# Feature and logits extraction settings
include_features: bool = False
include_logits: bool = False

@pydantic.field_validator("image_base_path", "user_data_path")
def validate_path(cls, v):
"""
Expand Down
Loading