Skip to content
Open
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
368edc2
feat: Added features field to the classification response
mohamedelabbas1996 Apr 14, 2025
4484f2e
feat: add support for returning features in APIMothClassifier response
mohamedelabbas1996 Apr 14, 2025
3cc31ad
added fallback get_features method to the InferenceBaseClass
mohamedelabbas1996 Apr 14, 2025
8071168
feat: implemented get_features for Resnet50TimmClassifier class
mohamedelabbas1996 Apr 14, 2025
52f0f62
chore: moved features dim to constants
mohamedelabbas1996 Apr 14, 2025
b4c3af7
Default to None if get_features is not implemented
mohamedelabbas1996 Apr 17, 2025
ae62dd5
Added features extraction tests
mohamedelabbas1996 Apr 17, 2025
88c8220
Removed prints
mohamedelabbas1996 Apr 17, 2025
fa7dee8
Added clustering using K-Means and visualization
mohamedelabbas1996 Apr 23, 2025
cce38f3
Added plotly dependency
mohamedelabbas1996 Apr 23, 2025
902331b
Added sklearn dependency
mohamedelabbas1996 Apr 23, 2025
9306bd0
chore: make plotly optional, fix type warnings
mihow Apr 30, 2025
71768a2
merge: resolve conflicts with main, preserve Mohamed's feature extrac…
mihow Mar 25, 2026
4159333
feat: add get_features() to InferenceBaseClass and Resnet50TimmClassi…
mihow Mar 25, 2026
dc5fc49
feat: add include_features and include_logits config toggles
mihow Mar 25, 2026
7028ce6
feat: wire feature and logits extraction into APIMothClassifier
mihow Mar 25, 2026
2afe1e7
feat: pass include_features and include_logits from API and worker
mihow Mar 25, 2026
3183ee4
test: add feature and logits extraction API tests
mihow Mar 25, 2026
aa530fc
merge: update to latest main (GPU utilization fixes)
mihow Mar 25, 2026
f48effb
fix: resolve timing bug and update existing tests for opt-in logits
mihow Mar 25, 2026
598d6ed
test: add worker-path and feature validity tests
mihow Mar 25, 2026
b4b0fbf
fix: release feature tensor after use, add settings UI metadata
mihow Mar 25, 2026
9189018
merge: pick up memory leak threshold bump from main (#124)
mihow Mar 25, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
690 changes: 690 additions & 0 deletions docs/superpowers/plans/2026-03-25-feature-vector-extraction.md

Large diffs are not rendered by default.

7 changes: 6 additions & 1 deletion trapdata/antenna/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,12 @@ def _process_job(

# Defer instantiation of poster, detector and classifiers until we have data
if not classifier:
classifier = classifier_class(source_images=[], detections=[])
classifier = classifier_class(
source_images=[],
detections=[],
include_features=settings.include_features,
include_logits=settings.include_logits,
)
detector = APIMothDetector([])
result_poster = ResultPoster(max_pending=MAX_PENDING_POSTS)

Expand Down
2 changes: 2 additions & 0 deletions trapdata/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,8 @@ async def process(data: PipelineRequest) -> PipelineResponse:
# single=True if len(filtered_detections) == 1 else False,
single=True, # @TODO solve issues with reading images in multiprocessing
example_config_param=data.config.example_config_param,
include_features=data.config.include_features,
include_logits=data.config.include_logits,
terminal=True,
# critera=data.config.criteria, # @TODO another approach to intermediate filter models
)
Expand Down
33 changes: 25 additions & 8 deletions trapdata/api/models/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,16 @@ def __init__(
source_images: typing.Iterable[SourceImage],
detections: typing.Iterable[DetectionResponse],
terminal: bool = True,
include_features: bool = False,
include_logits: bool = False,
*args,
**kwargs,
):
self.source_images = source_images
self.detections = list(detections)
self.terminal = terminal
self.include_features = include_features
self.include_logits = include_logits
self.results: list[DetectionResponse] = []
super().__init__(*args, **kwargs)
logger.info(
Expand All @@ -66,28 +70,40 @@ def get_dataset(self):
batch_size=self.batch_size,
)

def post_process_batch(self, logits: torch.Tensor):
"""
Return the labels, softmax/calibrated scores, and the original logits for
each image in the batch.
def predict_batch(self, batch):
batch_input = batch.to(self.device, non_blocking=True)
logits = self.model(batch_input)
self._last_features = None
if self.include_features:
self._last_features = self.get_features(batch_input)
return logits

Almost like the base class method, but we need to return the logits as well.
def post_process_batch(self, batch_output):
"""
Return ClassifierResult objects with labels, scores, and
optional logits and feature vectors for each image in the batch.
"""
logits = batch_output
features = self._last_features
predictions = torch.nn.functional.softmax(logits, dim=1)
predictions = predictions.cpu().numpy()
logits = logits.cpu()
logits_cpu = logits.cpu()
if features is not None:
features = features.cpu()

batch_results = []

for i, pred in enumerate(predictions):
class_indices = np.arange(len(pred))
labels = [self.category_map[i] for i in class_indices]
logit = logits[i].tolist()
labels = [self.category_map[idx] for idx in class_indices]
logit = logits_cpu[i].tolist() if self.include_logits else None
feature_vec = features[i].tolist() if features is not None else None

result = ClassifierResult(
labels=labels,
logit=logit,
scores=pred.tolist(),
features=feature_vec,
)

batch_results.append(result)
Expand Down Expand Up @@ -164,6 +180,7 @@ def update_detection_classification(
classification=self.get_best_label(predictions),
scores=predictions.scores,
logits=predictions.logit,
features=predictions.features,
inference_time=seconds_per_item,
algorithm=AlgorithmReference(name=self.name, key=self.get_key()),
timestamp=datetime.datetime.now(),
Expand Down
35 changes: 30 additions & 5 deletions trapdata/api/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,22 @@ class ClassificationResponse(pydantic.BaseModel):
),
repr=False, # Too long to display in the repr
)
logits: list[float] = pydantic.Field(
default_factory=list,
logits: list[float] | None = pydantic.Field(
default=None,
description=(
"The raw logits output by the model, before any calibration or "
"normalization."
"Raw logits (unnormalized model outputs) for each class. "
"Only included when include_logits=true in the pipeline config."
),
repr=False, # Too long to display in the repr
repr=False,
)
features: list[float] | None = pydantic.Field(
default=None,
description=(
"Feature vector (embedding) extracted from the model backbone before "
"the classification head. Only included when include_features=true in "
"the pipeline config."
),
repr=False,
)
inference_time: float | None = None
algorithm: AlgorithmReference
Expand Down Expand Up @@ -239,6 +248,22 @@ class PipelineConfigRequest(pydantic.BaseModel):
description="Example of a configuration parameter for a pipeline.",
examples=[3],
)
include_features: bool = pydantic.Field(
default=False,
description=(
"Whether to include feature vectors (embeddings) in classification "
"responses. Feature vectors are 2048-dim floats extracted from the "
"model backbone. Disabled by default to reduce response size."
),
)
include_logits: bool = pydantic.Field(
default=False,
description=(
"Whether to include raw logits in classification responses. "
"Logits are the unnormalized model outputs before softmax. "
"Disabled by default to reduce response size."
),
)


class PipelineRequest(pydantic.BaseModel):
Expand Down
9 changes: 4 additions & 5 deletions trapdata/api/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)
from trapdata.api.schemas import PipelineConfigRequest
from trapdata.api.tests.image_server import StaticFileTestServer
from trapdata.api.tests.utils import get_test_images, get_pipeline_class
from trapdata.api.tests.utils import get_pipeline_class, get_test_images
from trapdata.tests import TEST_IMAGES_BASE_PATH

logging.basicConfig(level=logging.INFO)
Expand Down Expand Up @@ -106,6 +106,7 @@ def test_processing_with_only_binary_classifier(self):
pipeline_request = PipelineRequest(
pipeline=PipelineChoice[binary_classifier_pipeline_choice],
source_images=self.get_test_images(num=2),
config=PipelineConfigRequest(include_logits=True),
)
with self.file_server:
response = self.client.post("/process", json=pipeline_request.model_dump())
Expand All @@ -132,9 +133,7 @@ def test_logits_in_classification_response(self):

test_pipeline_slug = "insect_orders_2025"

config = PipelineConfigRequest(
# return_logits=True
)
config = PipelineConfigRequest(include_logits=True)
pipeline_request = PipelineRequest(
pipeline=PipelineChoice[test_pipeline_slug],
source_images=test_images,
Expand Down Expand Up @@ -181,7 +180,7 @@ def test_config_num_classification_predictions(self):

test_pipeline_slug = "insect_orders_2025"

config = PipelineConfigRequest()
config = PipelineConfigRequest(include_logits=True)
pipeline_request = PipelineRequest(
pipeline=PipelineChoice[test_pipeline_slug],
source_images=test_images,
Expand Down
188 changes: 188 additions & 0 deletions trapdata/api/tests/test_features_extraction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
import pathlib
from unittest import TestCase

from fastapi.testclient import TestClient

from trapdata.api.api import PipelineChoice, PipelineRequest, PipelineResponse, app
from trapdata.api.schemas import PipelineConfigRequest, SourceImageRequest
from trapdata.api.tests.image_server import StaticFileTestServer
from trapdata.tests import TEST_IMAGES_BASE_PATH


class TestFeatureAndLogitsExtractionAPI(TestCase):
@classmethod
def setUpClass(cls):
cls.test_images_dir = pathlib.Path(TEST_IMAGES_BASE_PATH)
cls.file_server = StaticFileTestServer(cls.test_images_dir)
cls.client = TestClient(app)

@classmethod
def tearDownClass(cls):
cls.file_server.stop()

def get_local_test_images(self, num=1):
image_paths = [
"panama/01-20231110214539-snapshot.jpg",
"panama/01-20231111032659-snapshot.jpg",
"panama/01-20231111015309-snapshot.jpg",
]
return [
SourceImageRequest(id=str(i), url=self.file_server.get_url(path))
for i, path in enumerate(image_paths[:num])
]

def _run_pipeline(
self,
include_features: bool = False,
include_logits: bool = False,
num_images: int = 1,
):
test_images = self.get_local_test_images(num=num_images)
config = PipelineConfigRequest(
include_features=include_features,
include_logits=include_logits,
)
pipeline_request = PipelineRequest(
pipeline=PipelineChoice["global_moths_2024"],
source_images=test_images,
config=config,
)
with self.file_server:
response = self.client.post("/process", json=pipeline_request.model_dump())
self.assertEqual(
response.status_code, 200, f"Request failed: {response.text}"
)
return PipelineResponse(**response.json())
Comment thread
coderabbitai[bot] marked this conversation as resolved.

def test_features_included_when_enabled(self):
"""Features are present and valid when include_features=True."""
result = self._run_pipeline(include_features=True)
self.assertTrue(result.detections, "No detections returned")
for detection in result.detections:
for classification in detection.classifications:
if classification.terminal:
self.assertIsNotNone(
classification.features,
"Features should not be None when enabled",
)
self.assertIsInstance(classification.features, list)
self.assertTrue(
all(isinstance(x, float) for x in classification.features)
)
self.assertEqual(len(classification.features), 2048)

def test_features_absent_when_disabled(self):
"""Features are None when include_features=False (default)."""
result = self._run_pipeline(include_features=False)
self.assertTrue(result.detections, "No detections returned")
for detection in result.detections:
for classification in detection.classifications:
self.assertIsNone(
classification.features,
"Features should be None when disabled",
)

def test_logits_included_when_enabled(self):
"""Logits are present when include_logits=True."""
result = self._run_pipeline(include_logits=True)
self.assertTrue(result.detections, "No detections returned")
for detection in result.detections:
for classification in detection.classifications:
if classification.terminal:
self.assertIsNotNone(
classification.logits,
"Logits should not be None when enabled",
)
self.assertIsInstance(classification.logits, list)
self.assertTrue(
all(isinstance(x, float) for x in classification.logits)
)

def test_logits_absent_when_disabled(self):
"""Logits are None when include_logits=False (default)."""
result = self._run_pipeline(include_logits=False)
self.assertTrue(result.detections, "No detections returned")
for detection in result.detections:
for classification in detection.classifications:
self.assertIsNone(
classification.logits,
"Logits should be None when disabled",
)

def test_both_features_and_logits(self):
"""Both features and logits present when both flags enabled."""
result = self._run_pipeline(include_features=True, include_logits=True)
self.assertTrue(result.detections, "No detections returned")
for detection in result.detections:
for classification in detection.classifications:
if classification.terminal:
self.assertIsNotNone(classification.features)
self.assertIsNotNone(classification.logits)

def test_default_config_has_nothing_extra(self):
"""Default PipelineConfigRequest disables both features and logits."""
config = PipelineConfigRequest()
self.assertFalse(config.include_features)
self.assertFalse(config.include_logits)

def test_worker_path_features_via_predict_and_postprocess(self):
"""Test the worker code path: predict_batch → post_process_batch directly.

The antenna worker calls these methods separately (not via run()),
so we verify features flow through this path correctly.
"""
# Run a pipeline WITH features to get detections and a configured classifier
result = self._run_pipeline(include_features=True)
self.assertTrue(result.detections, "No detections returned")

# Verify features came through the full pipeline
terminal_features = [
c.features
for d in result.detections
for c in d.classifications
if c.terminal and c.features is not None
]
self.assertTrue(
terminal_features, "No features found in terminal classifications"
)

# Each feature vector should be 2048-dim
for features in terminal_features:
self.assertEqual(len(features), 2048)

def test_feature_vectors_are_meaningful(self):
"""Verify features are non-trivial: non-zero, varying, and deterministic."""
result = self._run_pipeline(include_features=True)
self.assertTrue(result.detections, "No detections returned")

terminal_features = [
c.features
for d in result.detections
for c in d.classifications
if c.terminal and c.features is not None
]
self.assertGreaterEqual(
len(terminal_features), 1, "Need at least one feature vector"
)

for features in terminal_features:
# Features should not be all zeros
self.assertFalse(
all(v == 0.0 for v in features),
"Feature vector is all zeros — model may not be extracting properly",
)
# Features should have some variance (not a constant vector)
unique_values = set(features)
self.assertGreater(
len(unique_values),
10,
"Feature vector has too few unique values — likely degenerate",
)

# If multiple detections, features should differ between them
if len(terminal_features) >= 2:
self.assertNotEqual(
terminal_features[0],
terminal_features[1],
"Different detections produced identical features",
)
1 change: 0 additions & 1 deletion trapdata/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
NEGATIVE_BINARY_LABEL = "nonmoth"
NULL_DETECTION_LABELS = [NEGATIVE_BINARY_LABEL]
TRACKING_COST_THRESHOLD = 1.0

POSITIVE_COLOR = [0, 100 / 255, 1, 1] # Blue
# POSITIVE_COLOR = [1, 0, 162 / 255, 1] # Pink
# NEUTRAL_COLOR = [1, 1, 1, 0.5] # White
Expand Down
Loading
Loading