Skip to main content
This API documentation is provisional. WorldStereo code and model weights are not yet publicly released. The API described here represents the expected interface based on the research framework.

Overview

WorldStereo employs two dedicated geometric memory modules that enable multi-view-consistent video generation:
  1. Global Geometric Memory: Provides coarse structural priors through incrementally updated point clouds
  2. Spatial-Stereo Memory: Constrains attention receptive fields with 3D correspondence for fine-grained details
These modules work together to inject 3D geometric understanding into the video generation process, enabling precise camera control and high-quality 3D reconstruction.
The memory modules operate as control branches that integrate with the VDM backbone without requiring joint training, ensuring efficiency and modularity.

Global Geometric Memory

The global geometric memory module maintains a dynamic 3D representation of the scene through incrementally updated point clouds. It provides coarse structural priors that guide the video generation process.

Class: GlobalGeometricMemory

Constructor

GlobalGeometricMemory(
    point_cloud_dim: int = 256,
    feature_dim: int = 768,
    num_neighbors: int = 32,
    update_frequency: str = "per_frame",
    spatial_index: str = "kd_tree",
    max_points: int = 100000
)
point_cloud_dim
int
default:"256"
Dimension of point cloud feature embeddings.
feature_dim
int
default:"768"
Dimension of output features to inject into the VDM.
num_neighbors
int
default:"32"
Number of nearest neighbors to consider for each query point.
update_frequency
str
default:"per_frame"
How often to update the point cloud: “per_frame”, “per_step”, or “manual”.
spatial_index
str
default:"kd_tree"
Spatial indexing structure for efficient nearest neighbor search: “kd_tree”, “ball_tree”, or “cuda_knn”.
max_points
int
default:"100000"
Maximum number of points to maintain in memory (oldest points are removed).

Methods

initialize
Initialize or reset the memory with a point cloud.
initialize(
    point_cloud: torch.Tensor,
    features: Optional[torch.Tensor] = None,
    colors: Optional[torch.Tensor] = None
) -> None
point_cloud
torch.Tensor
required
Initial point cloud coordinates, shape (N, 3).
features
torch.Tensor
Optional per-point features, shape (N, D). If not provided, features are computed from coordinates.
colors
torch.Tensor
Optional RGB colors for each point, shape (N, 3).
update
Incrementally update the memory with new observations.
update(
    new_points: torch.Tensor,
    camera_pose: torch.Tensor,
    new_features: Optional[torch.Tensor] = None,
    merge_threshold: float = 0.01
) -> Dict[str, Any]
new_points
torch.Tensor
required
New point cloud observations, shape (M, 3).
camera_pose
torch.Tensor
required
Camera pose for the observations, shape (4, 4) transformation matrix.
new_features
torch.Tensor
Optional features for new points, shape (M, D).
merge_threshold
float
default:"0.01"
Distance threshold for merging nearby points (in scene units).
num_added
int
Number of new points added to memory.
num_merged
int
Number of points merged with existing points.
total_points
int
Total number of points in memory after update.
query
Query the memory for features at specific 3D locations.
query(
    query_points: torch.Tensor,
    camera_params: CameraParameters,
    aggregation: str = "mean"
) -> torch.Tensor
query_points
torch.Tensor
required
3D query locations, shape (B, Q, 3) where Q is the number of query points.
camera_params
CameraParameters
required
Current camera parameters for view-dependent feature computation.
aggregation
str
default:"mean"
How to aggregate neighbor features: “mean”, “max”, “weighted”, or “attention”.
features
torch.Tensor
Aggregated features for query points, shape (B, Q, feature_dim).
inject_to_latent
Inject geometric memory features into the VDM latent representation.
inject_to_latent(
    latent: torch.Tensor,
    camera_params: CameraParameters,
    injection_method: str = "cross_attention"
) -> torch.Tensor
latent
torch.Tensor
required
VDM latent representation, shape (B, T, C, H, W).
camera_params
CameraParameters
required
Camera parameters for current viewpoint.
injection_method
str
default:"cross_attention"
Method for feature injection: “cross_attention”, “addition”, or “concatenation”.
latent_with_memory
torch.Tensor
Latent representation with injected geometric features, shape (B, T, C, H, W).

Properties

@property
point_cloud() -> torch.Tensor
    # Returns current point cloud, shape (N, 3)

@property
point_features() -> torch.Tensor
    # Returns point features, shape (N, feature_dim)

@property
num_points() -> int
    # Returns current number of points in memory

Example Usage

import torch
from worldstereo import GlobalGeometricMemory, CameraParameters

# Create memory module
global_memory = GlobalGeometricMemory(
    point_cloud_dim=256,
    feature_dim=768,
    num_neighbors=32
)

# Initialize with point cloud
initial_points = torch.randn(5000, 3).cuda()
global_memory.initialize(initial_points)

# Query features
query_points = torch.randn(1, 100, 3).cuda()
camera_params = CameraParameters(...)
features = global_memory.query(
    query_points=query_points,
    camera_params=camera_params,
    aggregation="weighted"
)

Spatial-Stereo Memory

The spatial-stereo memory module maintains a memory bank of fine-grained visual features with 3D correspondence information. It constrains the model’s attention receptive fields to focus on geometrically consistent regions.

Class: SpatialStereoMemory

Constructor

SpatialStereoMemory(
    memory_size: int = 1024,
    feature_dim: int = 768,
    num_correspondences: int = 64,
    attention_window: int = 7,
    correspondence_threshold: float = 0.5,
    update_strategy: str = "fifo"
)
memory_size
int
default:"1024"
Maximum number of feature vectors in the memory bank.
feature_dim
int
default:"768"
Dimension of feature vectors.
num_correspondences
int
default:"64"
Number of 3D correspondences to maintain per memory entry.
attention_window
int
default:"7"
Spatial window size for constrained attention (in pixels).
correspondence_threshold
float
default:"0.5"
Confidence threshold for accepting 3D correspondences.
update_strategy
str
default:"fifo"
Memory replacement strategy: “fifo”, “lru”, or “importance_based”.

Methods

add_to_memory
Add new features and correspondences to the memory bank.
add_to_memory(
    features: torch.Tensor,
    correspondences: torch.Tensor,
    confidence: torch.Tensor,
    metadata: Optional[Dict[str, Any]] = None
) -> List[int]
features
torch.Tensor
required
Feature vectors to add, shape (N, feature_dim).
correspondences
torch.Tensor
required
3D correspondence locations, shape (N, num_correspondences, 3).
confidence
torch.Tensor
required
Confidence scores for correspondences, shape (N, num_correspondences).
metadata
Dict[str, Any]
Optional metadata (camera pose, frame index, etc.).
indices
List[int]
Memory indices where features were stored.
constrained_attention
Compute attention with spatial-stereo constraints.
constrained_attention(
    query: torch.Tensor,
    query_points: torch.Tensor,
    camera_params: CameraParameters,
    use_correspondence_mask: bool = True
) -> Tuple[torch.Tensor, torch.Tensor]
query
torch.Tensor
required
Query features, shape (B, N, feature_dim).
query_points
torch.Tensor
required
3D locations of query points, shape (B, N, 3).
camera_params
CameraParameters
required
Current camera parameters.
use_correspondence_mask
bool
default:"True"
Whether to apply correspondence-based attention masking.
output
torch.Tensor
Attention output, shape (B, N, feature_dim).
attention_weights
torch.Tensor
Attention weights for visualization, shape (B, N, memory_size).
retrieve
Retrieve relevant memory entries based on 3D proximity.
retrieve(
    query_points: torch.Tensor,
    camera_params: CameraParameters,
    top_k: int = 32
) -> Tuple[torch.Tensor, torch.Tensor]
query_points
torch.Tensor
required
3D query locations, shape (B, N, 3).
camera_params
CameraParameters
required
Camera parameters for view-dependent retrieval.
top_k
int
default:"32"
Number of nearest memory entries to retrieve.
features
torch.Tensor
Retrieved feature vectors, shape (B, N, top_k, feature_dim).
indices
torch.Tensor
Indices of retrieved memory entries, shape (B, N, top_k).
clear
Clear the memory bank.
clear() -> None

Properties

@property
memory_bank() -> torch.Tensor
    # Returns current memory features, shape (memory_size, feature_dim)

@property
correspondence_map() -> torch.Tensor
    # Returns 3D correspondences, shape (memory_size, num_correspondences, 3)

@property
occupancy() -> float
    # Returns fraction of memory bank currently in use (0.0 to 1.0)

Example Usage

from worldstereo import SpatialStereoMemory

# Create memory module
spatial_memory = SpatialStereoMemory(
    memory_size=1024,
    feature_dim=768,
    num_correspondences=64,
    attention_window=7
)

# Extract features from initial frame
features = model.extract_features(image)  # (N, 768)
correspondences = compute_correspondences(image)  # (N, 64, 3)
confidence = compute_confidence(correspondences)  # (N, 64)

# Add to memory
indices = spatial_memory.add_to_memory(
    features=features,
    correspondences=correspondences,
    confidence=confidence
)

Memory Integration

The two memory modules work together during video generation:
1

Global Geometric Memory

Provides coarse 3D structure from the point cloud, guiding overall scene geometry and camera control.
2

Spatial-Stereo Memory

Refines generation with fine-grained details by constraining attention to geometrically consistent regions.
3

Incremental Updates

Both memories are updated as new frames are generated, improving consistency across multi-view sequences.
The memory modules operate independently but can be queried in parallel for optimal performance. The global memory focuses on structure, while spatial memory handles texture and detail.

Performance Considerations

  • Global memory uses spatial indexing (KD-tree) for efficient nearest neighbor search
  • Spatial memory implements FIFO or LRU caching to maintain bounded size
  • Point cloud merging reduces redundancy and memory footprint
  • GPU-accelerated operations for real-time performance
  • Global memory can handle 100K+ points efficiently
  • Spatial memory bank size is configurable based on available GPU memory
  • Batch processing for multiple queries reduces overhead
  • Incremental updates avoid full recomputation
  • Increase num_neighbors for better geometric conditioning (slower)
  • Increase memory_size for more detailed spatial memory (more GPU memory)
  • Larger attention_window improves detail but reduces speed
  • More num_correspondences increases accuracy but computation cost

WorldStereo Model

Main model class that integrates memory modules

Inference API

High-level generation interface

Global Geometric Memory

Conceptual overview

Spatial-Stereo Memory

Conceptual overview

Build docs developers (and LLMs) love