import asyncio
import logging
from concurrent.futures import ThreadPoolExecutor
from typing import Optional, Any
import aiortc
import av
import cv2
import numpy as np
from vision_agents.core.processors.base_processor import VideoProcessorPublisher
from vision_agents.core.utils.video_forwarder import VideoForwarder
from vision_agents.core.utils.video_track import QueuedVideoTrack
from vision_agents.core.warmup import Warmable
logger = logging.getLogger(__name__)
class ObjectDetectionProcessor(VideoProcessorPublisher, Warmable[Optional[Any]]):
"""
Detects objects in video frames and annotates them with bounding boxes.
"""
name = "object_detection"
def __init__(
self,
model_path: str = "yolo11n.pt",
conf_threshold: float = 0.5,
fps: int = 10,
max_workers: int = 10,
):
self.model_path = model_path
self.conf_threshold = conf_threshold
self.fps = fps
# Video output
self._video_track = QueuedVideoTrack()
self._video_forwarder: Optional[VideoForwarder] = None
# Model (loaded in warmup)
self._model: Optional[Any] = None
# Thread pool for CPU work
self.executor = ThreadPoolExecutor(
max_workers=max_workers,
thread_name_prefix="object_detection"
)
self._shutdown = False
logger.info("Object Detection Processor initialized")
async def on_warmup(self) -> Optional[Any]:
"""Load YOLO model during warmup."""
try:
from ultralytics import YOLO
loop = asyncio.get_event_loop()
model = await loop.run_in_executor(
self.executor,
lambda: YOLO(self.model_path)
)
logger.info(f"Model loaded: {self.model_path}")
return model
except Exception as e:
logger.warning(f"Model load failed: {e}")
return None
def on_warmed_up(self, resource: Optional[Any]) -> None:
"""Receive the loaded model."""
self._model = resource
async def process_video(
self,
track: aiortc.VideoStreamTrack,
participant_id: Optional[str],
shared_forwarder: Optional[VideoForwarder] = None,
) -> None:
"""Start processing incoming video."""
logger.info(f"Starting video processing at {self.fps} FPS")
self._video_forwarder = shared_forwarder or VideoForwarder(
track,
max_buffer=self.fps,
fps=self.fps,
name="detection_forwarder",
)
self._video_forwarder.add_frame_handler(
self._process_and_publish,
fps=float(self.fps),
name="detection",
)
async def _process_and_publish(self, frame: av.VideoFrame):
"""Process a single frame and publish result."""
if self._shutdown or not self._model:
await self._video_track.add_frame(frame)
return
try:
# Convert to numpy
frame_bgr = frame.to_ndarray(format="bgr24")
# Detect objects in thread pool
loop = asyncio.get_event_loop()
detections = await loop.run_in_executor(
self.executor,
self._detect_objects,
frame_bgr
)
# Annotate frame
annotated = self._draw_boxes(frame_bgr, detections)
# Convert back and publish
output = av.VideoFrame.from_ndarray(annotated, format="bgr24")
await self._video_track.add_frame(output)
except Exception:
logger.exception("Frame processing failed")
await self._video_track.add_frame(frame)
def _detect_objects(self, frame_bgr: np.ndarray) -> list:
"""Run detection (sync, runs in thread pool)."""
if not self._model:
return []
results = self._model(
frame_bgr,
verbose=False,
conf=self.conf_threshold,
)
detections = []
if results and results[0].boxes:
for box in results[0].boxes:
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
detections.append({
"bbox": (int(x1), int(y1), int(x2), int(y2)),
"conf": float(box.conf[0]),
"label": results[0].names[int(box.cls[0])],
})
return detections
def _draw_boxes(self, frame: np.ndarray, detections: list) -> np.ndarray:
"""Draw bounding boxes on frame."""
result = frame.copy()
for det in detections:
x1, y1, x2, y2 = det["bbox"]
cv2.rectangle(result, (x1, y1), (x2, y2), (0, 255, 0), 2)
text = f"{det['label']} {det['conf']:.2f}"
cv2.putText(result, text, (x1, y1 - 5),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
return result
def publish_video_track(self) -> aiortc.VideoStreamTrack:
"""Return video track to publish."""
return self._video_track
async def stop_processing(self) -> None:
"""Stop video processing."""
if self._video_forwarder:
await self._video_forwarder.remove_frame_handler(
self._process_and_publish
)
logger.info("Video processing stopped")
async def close(self) -> None:
"""Clean up resources."""
self._shutdown = True
await self.stop_processing()
self.executor.shutdown(wait=False)