Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/mlfoundations/open_clip/llms.txt

Use this file to discover all available pages before exploring further.

Overview

The CoCaLoss class extends ClipLoss to support training CoCa models, which combine contrastive learning with autoregressive caption generation. It computes a weighted combination of:
  1. Contrastive loss - CLIP-style image-text matching (inherited from ClipLoss)
  2. Caption loss - Cross-entropy loss for next-token prediction in caption generation
This dual-objective training enables models to both match images with text descriptions and generate natural language captions.

Class Definition

from open_clip import CoCaLoss

Initialization Parameters

caption_loss_weight
float
required
Weight applied to the caption generation loss. Controls the relative importance of caption quality vs. contrastive matching.
clip_loss_weight
float
required
Weight applied to the contrastive loss. Set to 0 to train only the caption decoder.
pad_id
int
default:"0"
Padding token ID. Positions with this token are ignored when computing caption loss.
local_loss
bool
default:"False"
If True, computes contrastive loss only between local and gathered features. See ClipLoss documentation for details.
gather_with_grad
bool
default:"False"
If True, gathers features with gradient flow enabled for contrastive loss.
cache_labels
bool
default:"False"
If True, caches ground truth labels for contrastive loss.
rank
int
default:"0"
Current process rank in distributed training.
world_size
int
default:"1"
Total number of processes in distributed training.
use_horovod
bool
default:"False"
If True, uses Horovod for distributed operations instead of torch.distributed.

Attributes

Inherits all attributes from ClipLoss, plus:
  • clip_loss_weight: Weight for contrastive loss component
  • caption_loss_weight: Weight for caption generation loss component
  • caption_loss: CrossEntropyLoss module with ignore_index=pad_id

Key Methods

forward

def forward(
    self,
    image_features: torch.Tensor,
    text_features: torch.Tensor,
    logits: torch.Tensor,
    labels: torch.Tensor,
    logit_scale: torch.Tensor,
    output_dict: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Dict[str, torch.Tensor]]:
Computes the combined CoCa loss. Parameters:
  • image_features: Normalized contrastive image features of shape (batch_size, embed_dim)
  • text_features: Normalized contrastive text features of shape (batch_size, embed_dim)
  • logits: Caption generation logits of shape (batch_size, seq_len, vocab_size)
  • labels: Target token IDs for caption generation of shape (batch_size, seq_len)
  • logit_scale: Temperature parameter for contrastive loss (typically model.logit_scale.exp())
  • output_dict: If True, returns dict with named losses, else returns tuple
Returns:
  • If output_dict=False: Tuple of (clip_loss, caption_loss) - both weighted by their respective coefficients
  • If output_dict=True: Dictionary with keys "contrastive_loss" and "caption_loss"

Usage Example

import torch
from open_clip import create_model, CoCaLoss

# Create CoCa model
model = create_model('coca_ViT-B-32', pretrained=False)

# Create loss function
loss_fn = CoCaLoss(
    caption_loss_weight=2.0,  # Caption loss is weighted 2x
    clip_loss_weight=1.0,     # Standard contrastive loss weight
    pad_id=0
)

# Training loop
images = torch.randn(16, 3, 224, 224)
captions = torch.randint(0, 49408, (16, 77))

# Forward pass through model
output = model(images, captions)

# Extract required tensors
image_features = output['image_features']  # Contrastive features
text_features = output['text_features']    # Contrastive features
logits = output['logits']                  # Caption logits
labels = output['labels']                  # Shifted caption targets
logit_scale = output['logit_scale']

# Compute loss
clip_loss, caption_loss = loss_fn(
    image_features,
    text_features,
    logits,
    labels,
    logit_scale
)

total_loss = clip_loss + caption_loss
total_loss.backward()

Distributed Training Example

import torch.distributed as dist
from open_clip import create_model, CoCaLoss

# Initialize distributed
dist.init_process_group(backend='nccl')
rank = dist.get_rank()
world_size = dist.get_world_size()

# Create model and loss
model = create_model('coca_ViT-L-14').to(rank)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank])

loss_fn = CoCaLoss(
    caption_loss_weight=2.0,
    clip_loss_weight=1.0,
    pad_id=0,
    rank=rank,
    world_size=world_size,
    cache_labels=True
)

for images, captions in dataloader:
    images = images.to(rank)
    captions = captions.to(rank)
    
    output = model(images, captions)
    
    clip_loss, caption_loss = loss_fn(
        output['image_features'],
        output['text_features'],
        output['logits'],
        output['labels'],
        output['logit_scale']
    )
    
    total_loss = clip_loss + caption_loss
    total_loss.backward()
    optimizer.step()
    optimizer.zero_grad()

Dictionary Output for Logging

loss_fn = CoCaLoss(
    caption_loss_weight=2.0,
    clip_loss_weight=1.0
)

output = model(images, captions)

loss_dict = loss_fn(
    output['image_features'],
    output['text_features'],
    output['logits'],
    output['labels'],
    output['logit_scale'],
    output_dict=True
)

print(loss_dict)
# {'contrastive_loss': tensor(4.2), 'caption_loss': tensor(3.1)}

# Easy logging to wandb/tensorboard
for key, value in loss_dict.items():
    logger.log({key: value.item()})

Caption-Only Training

# Train only the caption decoder (freeze contrastive learning)
loss_fn = CoCaLoss(
    caption_loss_weight=1.0,
    clip_loss_weight=0.0,  # Disable contrastive loss
    pad_id=0
)

# Clip loss will be zero tensor (no backward pass)
clip_loss, caption_loss = loss_fn(...)
print(clip_loss)  # tensor(0.)

Custom Loss Weighting Strategy

class AdaptiveCoCaLoss(CoCaLoss):
    """Dynamically adjust loss weights during training."""
    
    def __init__(self, *args, initial_caption_weight=2.0, **kwargs):
        super().__init__(
            caption_loss_weight=initial_caption_weight,
            *args,
            **kwargs
        )
        self.step = 0
    
    def forward(self, *args, **kwargs):
        # Gradually increase caption loss weight
        self.caption_loss_weight = 2.0 + (self.step / 10000)
        self.step += 1
        return super().forward(*args, **kwargs)

loss_fn = AdaptiveCoCaLoss(
    clip_loss_weight=1.0,
    pad_id=0
)

Monitoring Loss Components

import torch
from collections import defaultdict

loss_fn = CoCaLoss(
    caption_loss_weight=2.0,
    clip_loss_weight=1.0
)

metrics = defaultdict(list)

for epoch in range(num_epochs):
    for images, captions in dataloader:
        output = model(images, captions)
        
        clip_loss, caption_loss = loss_fn(
            output['image_features'],
            output['text_features'],
            output['logits'],
            output['labels'],
            output['logit_scale']
        )
        
        # Track both components
        metrics['clip_loss'].append(clip_loss.item())
        metrics['caption_loss'].append(caption_loss.item())
        metrics['total_loss'].append((clip_loss + caption_loss).item())
        
        total_loss = clip_loss + caption_loss
        total_loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    
    # Epoch summary
    print(f"Epoch {epoch}:")
    print(f"  Avg CLIP loss: {sum(metrics['clip_loss'])/len(metrics['clip_loss']):.4f}")
    print(f"  Avg Caption loss: {sum(metrics['caption_loss'])/len(metrics['caption_loss']):.4f}")

Mathematical Formulation

The total CoCa loss is: Ltotal=w1Lcontrastive+w2Lcaption\mathcal{L}_{\text{total}} = w_1 \cdot \mathcal{L}_{\text{contrastive}} + w_2 \cdot \mathcal{L}_{\text{caption}} Where:
  1. Contrastive Loss (inherited from ClipLoss): Lcontrastive=12[CE(τIT,y)+CE(τTI,y)]\mathcal{L}_{\text{contrastive}} = \frac{1}{2}\left[\text{CE}(\tau \cdot I T^\top, y) + \text{CE}(\tau \cdot T I^\top, y)\right]
  2. Caption Loss (cross-entropy with teacher forcing): Lcaption=1Ni=1Nt=1T1[yitpad]logP(yityi<t,xi)\mathcal{L}_{\text{caption}} = -\frac{1}{N} \sum_{i=1}^{N} \sum_{t=1}^{T} \mathbb{1}[y_i^t \neq \text{pad}] \cdot \log P(y_i^t | y_i^{<t}, x_i) Where:
    • NN = batch size
    • TT = sequence length
    • yity_i^t = target token at position tt
    • xix_i = image features
    • Padding tokens are ignored via ignore_index

Hyperparameter Tuning

Recommended loss weight ratios:
Dataset TypeCaption WeightCLIP WeightNotes
Large web data (LAION)1.0-2.01.0Balanced training
Caption-focused (COCO)2.0-3.01.0Prioritize generation quality
Retrieval-focused0.5-1.01.0Prioritize matching
Fine-tuning3.0-5.00.1-0.5Adapt caption style
Guidelines:
  • Start with caption_loss_weight=2.0, clip_loss_weight=1.0
  • If captions are low quality, increase caption_loss_weight
  • If retrieval performance is poor, increase clip_loss_weight
  • Monitor both loss components separately

Build docs developers (and LLMs) love