Skip to main content

Overview

The SegmentationModel class provides trash detection and tracking functionality using a pre-trained YOLO model. It handles device management, model loading, and inference operations for real-time trash segmentation.
The model automatically selects the optimal device (GPU/CPU) and loads the pre-trained trash detection model during initialization.

Interface Definition

SegmentationModelInterface

Abstract base class defining the contract for segmentation models.
from abc import ABC, abstractmethod
import numpy as np
from typing import Tuple, Dict
from torch import device
from ultralytics.engine.results import Results

class SegmentationModelInterface(ABC):
    @abstractmethod
    def inference(self, image: np.ndarray) -> Tuple[Results, Dict[int, str], device]:
        pass

Class Definition

class SegmentationModel(SegmentationModelInterface):
    def __init__(self):
        self.device = DeviceManager.get_device()
        self.trash_segmentation_model = ModelLoader(self.device).get_model()

Constructor

__init__()

Initializes the segmentation model with automatic device selection and model loading. Attributes Initialized:
device
torch.device
PyTorch device object (CPU or CUDA GPU) automatically selected based on availability
trash_segmentation_model
YOLO
Loaded YOLO segmentation model optimized for trash detection
Initialization Process:
  1. Detects available hardware (GPU/CPU) using DeviceManager
  2. Loads the pre-trained trash segmentation model using ModelLoader
  3. Prepares model for inference with optimal device configuration
Example:
from trash_classificator.segmentation.main import SegmentationModel

# Initialize the model (automatic device and model setup)
model = SegmentationModel()

print(f"Model loaded on device: {model.device}")
Model initialization may take several seconds on first load as it downloads and caches the model weights.

Methods

inference()

Performs trash detection and tracking on an input image.
def inference(self, image: np.ndarray) -> tuple[list[Results], Dict[int, str], device]

Parameters

image
np.ndarray
required
Input image as a NumPy array in BGR format (OpenCV standard). Should be a 3-channel color image with shape (height, width, 3).

Returns

results
list[Results]
List of YOLO Results objects containing:
  • masks: Segmentation masks for detected objects
  • boxes: Bounding boxes with tracking IDs and confidence scores
  • cls: Class indices for detected trash types
Each Results object represents detections from one frame in the tracking stream.
trash_classes
Dict[int, str]
Dictionary mapping class indices to trash type names:
  • 0: Specific trash category (e.g., “plastic”)
  • 1: Another category (e.g., “paper”)
  • 2: Additional category (e.g., “metal”)
The exact mappings are defined in trash_classificator.segmentation.models.trash_model.
device
torch.device
The PyTorch device used for inference (same as self.device)

Inference Configuration

The method uses the following YOLO tracking parameters:
conf
float
default:"0.55"
Confidence threshold for detections (0.0 to 1.0). Only detections with confidence ≥ 0.55 are returned.
persist
bool
default:"true"
Enables persistent tracking across frames, maintaining consistent object IDs
imgsz
int
default:"640"
Input image size for the model. Images are resized to 640x640 for inference.
stream
bool
default:"true"
Enables streaming mode for efficient video processing with generator-based results
verbose
bool
default:"false"
Suppresses YOLO logging output for cleaner console output
Source Code Reference:
segmentation/main.py
def inference(self, image: np.ndarray) -> tuple[list[Results], Dict[int, str], device]:
    results = self.trash_segmentation_model.track(image, conf=0.55, verbose=False, persist=True, imgsz=640,
                                                  stream=True)
    return results, trash_classes, self.device

Usage Examples

Basic Inference

import cv2
import numpy as np
from trash_classificator.segmentation.main import SegmentationModel

# Initialize model
model = SegmentationModel()

# Load image
image = cv2.imread('frame.jpg')

# Run inference
results, trash_classes, device = model.inference(image)

# Process results
for result in results:
    if result.boxes.id is not None:
        print(f"Detected {len(result.boxes)} trash objects")
        print(f"Classes: {trash_classes}")
        print(f"Device: {device}")

Video Stream Processing

import cv2
from trash_classificator.segmentation.main import SegmentationModel

model = SegmentationModel()
cap = cv2.VideoCapture('trash_video.mp4')

while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        break
    
    # Run inference with tracking
    results, trash_classes, device = model.inference(frame)
    
    # Iterate through detected objects
    for result in results:
        if result.boxes.id is not None:
            boxes = result.boxes.xyxy.cpu().numpy()
            track_ids = result.boxes.id.int().cpu().tolist()
            classes = result.boxes.cls.cpu().tolist()
            
            # Process each detection
            for box, track_id, cls in zip(boxes, track_ids, classes):
                x1, y1, x2, y2 = box
                label = trash_classes[int(cls)]
                print(f"Track ID {track_id}: {label} at [{x1:.0f}, {y1:.0f}, {x2:.0f}, {y2:.0f}]")

cap.release()

Accessing Detection Data

import cv2
from trash_classificator.segmentation.main import SegmentationModel

model = SegmentationModel()
image = cv2.imread('scene.jpg')

results, trash_classes, device = model.inference(image)

for result in results:
    if result.boxes.id is None:
        print("No trash detected")
        continue
    
    # Access segmentation masks
    masks = result.masks.xy  # List of polygon coordinates
    print(f"Number of masks: {len(masks)}")
    
    # Access bounding boxes
    boxes = result.boxes.xyxy.cpu()  # [x1, y1, x2, y2] format
    print(f"Bounding boxes shape: {boxes.shape}")
    
    # Access tracking IDs
    track_ids = result.boxes.id.int().cpu().tolist()
    print(f"Track IDs: {track_ids}")
    
    # Access class labels
    class_indices = result.boxes.cls.cpu().tolist()
    class_names = [trash_classes[int(cls)] for cls in class_indices]
    print(f"Detected classes: {class_names}")
    
    # Access confidence scores
    confidences = result.boxes.conf.cpu().tolist()
    print(f"Confidence scores: {confidences}")

Technical Details


Performance Considerations

  • Confidence Threshold: The 0.55 threshold balances precision and recall for trash detection
  • Image Size: 640x640 provides a good trade-off between speed and accuracy
  • Streaming Mode: Reduces memory usage for video processing by yielding results incrementally
  • Device Selection: GPU inference is typically 10-50x faster than CPU

  • TrashClassificator - Main orchestrator that uses this model
  • Drawing - Visualizes the results from this model
  • DeviceManager - Manages hardware device selection (processor.py:8)
  • ModelLoader - Loads YOLO model weights (processor.py:9)

Build docs developers (and LLMs) love