Skip to content
Draft
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -276,3 +276,7 @@ sandbox/

# Other
flower
antenna-flatbug/
flat_bug_M.pt
processing_services/example/flat_bug_M.pt
processing_services/example/huggingface_cache/
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
exclude: "^docs/|/migrations/|ui/"
default_stages: [commit]
default_stages: [pre-commit]

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
Expand Down
82 changes: 82 additions & 0 deletions processing_services/example/FLAT_BUG_IMPLEMENTATION.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Flat-Bug Integration Implementation

## Summary

I've successfully implemented the `FlatBugObjectDetector` class to use the actual flat-bug library instead of relying on a Hugging Face checkpoint. Here's what was changed and how it works:

## Key Changes Made

### 1. Updated `compile()` method
- **Before**: Used `transformers.pipeline` with a placeholder checkpoint
- **After**: Uses `flat_bug.predictor.Predictor` with the default model `'flat_bug_M.pt'`
- The model will be automatically downloaded on first use
- Added configurable hyperparameters (score threshold, IoU threshold, etc.)

### 2. Updated `run()` method
- **Before**: Called `self.model(image, candidate_labels=...)`
- **After**: Uses `self.model.pyramid_predictions(image)` which is the flat-bug API
- Handles the `TensorPredictions` response format from flat-bug
- Converts tensors to numpy arrays and extracts bounding boxes and scores

### 3. Updated description
- Now accurately reflects that it uses the actual flat-bug library
- Mentions specialization for terrestrial arthropod detection

## How It Works

1. **Installation**: Flat-bug needs to be installed from source:
```bash
git clone https://github.com/darsa-group/flat-bug.git
cd flat-bug
pip install -e .
```

2. **Model Loading**: The `flat_bug_M.pt` model is downloaded automatically on first use

3. **Inference**: Uses flat-bug's pyramid tiling approach for detection on arbitrarily large images

4. **Output**: Converts flat-bug's `TensorPredictions` format to your existing `Detection` objects

## Installation Requirements

```bash
# Install flat-bug
pip install git+https://github.com/darsa-group/flat-bug.git

# Ensure PyTorch is installed
pip install torch>=2.3
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this an issue for us?

```

## Testing

I've created `test_flat_bug_implementation.py` which you can run to:
- Verify the flat-bug installation
- Inspect the actual format of `TensorPredictions` objects
- Confirm the attribute names and data structures

## Potential Adjustments Needed

The implementation makes some assumptions about the flat-bug API that should be verified:

1. **Attribute names**: I assumed `predictions.boxes` and `predictions.scores` exist, but these might be named differently
2. **Box format**: I assumed boxes are in `[x1, y1, x2, y2]` format, but this should be confirmed
3. **Tensor handling**: The conversion from tensors to numpy arrays might need adjustment

## Running the Test

```bash
cd /Users/markfisher/Desktop/antenna/processing_services/example/
python test_flat_bug_implementation.py
```

This will show you the exact structure of the `TensorPredictions` object and help identify any needed adjustments.

## Benefits of This Approach

1. **No checkpoint URL needed**: Uses the official flat-bug library with built-in model management
2. **Specialized for arthropods**: Flat-bug is specifically trained for terrestrial arthropod detection
3. **High performance**: Uses pyramid tiling for efficient processing of large images
4. **Automatic model download**: No need to manually manage model files
5. **Configurable**: Can adjust detection thresholds and other hyperparameters

The implementation should work as-is, but running the test script will help identify any format discrepancies that need minor adjustments.
147 changes: 145 additions & 2 deletions processing_services/example/api/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import math
import random

import numpy as np
import torch

from .schemas import (
Expand All @@ -28,7 +29,7 @@ def get_best_device() -> str:
MPS is not supported by the current algoritms.
"""
if torch.cuda.is_available():
return f"cuda:{torch.cuda.current_device()}"
return f"cuda: {torch.cuda.current_device()}"
else:
return "cpu"

Expand Down Expand Up @@ -178,6 +179,148 @@ def get_algorithm_config_response(self) -> AlgorithmConfigResponse:
)


class FlatBugObjectDetector(Algorithm):
"""
Flat-bug Object Detection model.
Uses the flat-bug library for terrestrial arthropod detection and segmentation.
Produces both a bounding box and a classification for each detection.
"""

candidate_labels: list[str] = ["insect"]

def compile(self, device: str | None = None):
saved_models_key = "flat_bug_object_detector" # generate a key for each uniquely compiled algorithm

if saved_models_key not in SAVED_MODELS:
from flat_bug.predictor import Predictor

device_choice = device or get_best_device()
# device_index = int(device_choice.split(":")[-1]) if ":" in device_choice else -1
logger.info(f"Compiling {self.algorithm_config_response.name} on device {device_choice}...")

# Initialize flat-bug predictor with default model
self.model = Predictor(model="flat_bug_M.pt", device=device_choice) # Default flat-bug model

# Set some reasonable hyperparameters
# TIME=False is critical to avoid CUDA event errors when running on CPU
self.model.set_hyperparameters(
SCORE_THRESHOLD=0.5, IOU_THRESHOLD=0.5, TIME=False # Must be False for CPU compatibility
)

SAVED_MODELS[saved_models_key] = self.model
else:
logger.info(f"Using saved model for {self.algorithm_config_response.name}...")
self.model = SAVED_MODELS[saved_models_key]

def run(self, source_images: list[SourceImage], intermediate=False) -> list[Detection]:
detector_responses: list[Detection] = []
for source_image in source_images:
if source_image.width and source_image.height and source_image._pil:
start_time = datetime.datetime.now()
logger.info("Predicting with flat-bug...")

# Convert PIL image to tensor (flat-bug expects tensor, not PIL Image)
# Convert PIL to numpy then to tensor in CHW format
image_np = np.array(source_image._pil)
image_tensor = torch.from_numpy(image_np).permute(2, 0, 1).float()

# Use flat-bug's pyramid_predictions method with tensor input
predictions = self.model.pyramid_predictions(image_tensor)

end_time = datetime.datetime.now()
elapsed_time = (end_time - start_time).total_seconds()

# Extract bounding boxes from flat-bug predictions
# flat-bug returns TensorPredictions with boxes, confs (not scores), and classes
# Based on test results: boxes=int64, confs=float32, classes=float32
if hasattr(predictions, "boxes") and predictions.boxes is not None:
boxes = predictions.boxes.cpu().numpy() # Convert to numpy (int64)
scores = (
predictions.confs.cpu().numpy() # Use 'confs' not 'scores' (float32)
if hasattr(predictions, "confs") and predictions.confs is not None
else None
)

for i, box in enumerate(boxes):
# box format from flat-bug is xyxy: [x1, y1, x2, y2] (verified via test)
x1, y1, x2, y2 = box

bbox = BoundingBox(
x1=float(x1),
x2=float(x2),
y1=float(y1),
y2=float(y2),
)

cropped_image_pil = source_image._pil.crop((bbox.x1, bbox.y1, bbox.x2, bbox.y2))

# Get confidence score if available
confidence_score = float(scores[i]) if scores is not None and i < len(scores) else 0.5

detection = Detection(
id=f"{source_image.id}-crop-{bbox.x1}-{bbox.y1}-{bbox.x2}-{bbox.y2}",
url=source_image.url, # @TODO: ideally, should save cropped image at separate url
width=cropped_image_pil.width,
height=cropped_image_pil.height,
timestamp=datetime.datetime.now(),
source_image=source_image,
bbox=bbox,
inference_time=elapsed_time,
algorithm=AlgorithmReference(
name=self.algorithm_config_response.name,
key=self.algorithm_config_response.key,
),
classifications=[
ClassificationResponse(
classification=self.candidate_labels[
0
], # flat-bug detects arthropods, use first label
labels=self.candidate_labels,
scores=[confidence_score],
logits=[confidence_score],
inference_time=elapsed_time,
timestamp=datetime.datetime.now(),
algorithm=AlgorithmReference(
name=self.algorithm_config_response.name,
key=self.algorithm_config_response.key,
),
terminal=not intermediate,
)
],
)
detection._pil = cropped_image_pil
detector_responses.append(detection)
else:
logger.info("No detections found in image")
else:
raise ValueError(f"Source image {source_image.id} does not have width and height attributes.")

return detector_responses

def get_category_map(self) -> AlgorithmCategoryMapResponse:
return AlgorithmCategoryMapResponse(
data=[{"index": i, "label": label} for i, label in enumerate(self.candidate_labels)],
labels=self.candidate_labels,
version="v1", # TODO confirm version
description="Candidate labels used for flat-bug object detection.",
uri=None,
)

def get_algorithm_config_response(self) -> AlgorithmConfigResponse:
return AlgorithmConfigResponse(
name="Flat Bug Object Detector",
key="flat-bug-object-detector",
task_type="detection",
description=(
"Flat Bug Object Detection model."
"Produces both a bounding box and a candidate label classification for each detection."
),
version=1,
version_name="v1", # TODO confirm version
category_map=self.get_category_map(),
)


class HFImageClassifier(Algorithm):
"""
A local classifier that uses the Hugging Face pipeline to classify images.
Expand Down Expand Up @@ -277,7 +420,7 @@ def get_category_map(self) -> AlgorithmCategoryMapResponse:
labels=labels,
version="ImageNet-1k",
description=description_text,
uri=f"https://huggingface.co/{self.model_name}",
uri=f"https://huggingface.co/{self.model_name}", # noqa: E231
)

def get_algorithm_config_response(self) -> AlgorithmConfigResponse:
Expand Down
2 changes: 2 additions & 0 deletions processing_services/example/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import fastapi

from .pipelines import (
FlatBugDetectorPipeline,
Pipeline,
ZeroShotHFClassifierPipeline,
ZeroShotObjectDetectorPipeline,
Expand Down Expand Up @@ -37,6 +38,7 @@


pipelines: list[type[Pipeline]] = [
FlatBugDetectorPipeline,
ZeroShotHFClassifierPipeline,
ZeroShotObjectDetectorPipeline,
ZeroShotObjectDetectorWithConstantClassifierPipeline,
Expand Down
32 changes: 32 additions & 0 deletions processing_services/example/api/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .algorithms import (
Algorithm,
ConstantClassifier,
FlatBugObjectDetector,
HFImageClassifier,
RandomSpeciesClassifier,
ZeroShotObjectDetector,
Expand Down Expand Up @@ -346,3 +347,34 @@ def run(self) -> PipelineResultsResponse:
logger.info(f"Successfully processed {len(detections_with_classifications)} detections.")

return pipeline_response


class FlatBugDetectorPipeline(Pipeline):
"""
A pipeline that uses the flat-bug object detector for arthropod detection.
Produces both a bounding box and a classification for each detection.
"""

batch_sizes = [1]
config = PipelineConfigResponse(
name="Flat Bug Object Detector Pipeline",
slug="flat-bug-object-detector-pipeline",
description="Flat Bug object detector for terrestrial arthropods.",
version=1,
algorithms=[FlatBugObjectDetector().algorithm_config_response],
)

def get_stages(self) -> list[Algorithm]:
flat_bug_detector = FlatBugObjectDetector()
self.config.algorithms = [flat_bug_detector.algorithm_config_response]
return [flat_bug_detector]

def run(self) -> PipelineResultsResponse:
start_time = datetime.datetime.now()
logger.info("[1/1] Running the flat-bug object detector...")
detections = self._get_detections(self.stages[0], self.source_images, self.batch_sizes[0])
end_time = datetime.datetime.now()
elapsed_time = (end_time - start_time).total_seconds()
pipeline_response: PipelineResultsResponse = self._get_pipeline_response(detections, elapsed_time)
logger.info(f"Successfully processed {len(detections)} detections.")
return pipeline_response
Loading