• Tokenizers come in two types: continuous and discrete
    • Continuous tokenizers encode visual data into continuous latent embeddings, as in latent diffusion models
      • These embeddings are suitable for models that generatedata by sampling from continuous distributions.
    • Discrete tokenizers encode visual data into discrete latent codes, mapping them into quantized indices, as seen in autoregressive transformers such as VideoPoet

Cosmos tokenizer

Architecture

  • Cosmos Tokenizer is designed as an encoder-decoder architecture.
  • Given an input video , with , , being the height, width, and number of frames, the encoder () tokenizes the inputs into a token video , with a spatial compression factor of and a temporal compression factor of .
  • The decoder () then reconstructs the input video from these tokens, resulting in the reconstructed video , mathematically given by:

  • The architecture employs a temporally causal design, ensuring that each stage processes only current and past frames, independent of future frames.

3D Haar wavelet transform

  • Unlike common approaches, the tokenizer operates in the wavelet space, where inputs are first processed by a 2-level wavelet transform (3D Haar Wavelet Transform).
    • Specifically, the wavelet transform maps the input video in a group-wise manner to downsample the inputs by a factor of four along , , and . The groups are formed as: .
    • The wavelet transform allows to operate on a more compact video representation that eliminates redundancies in pixel information, allowing the remaining layers to focus on more semantic compression.
  • Subsequent encoder stages process the frames in a temporally causal manner as .

Block design

  • They make extensive use of a custom CausalConv3d, which is a small wrapper around nn.Conv3d that turns an ordinary 3D conv into a strictly time-causal conv (no look-ahead in time), while still letting you use stride / dilation / 3D kernels.

  • The encoder stages (post wavelet transform) are implemented using a series of residual blocks interleaved with downsampling blocks. In each block, they employ

    • a spatio-temporal factorized 3D convolution, where you first apply a 2D convolution with a kernel size of to capture spatial information, followed by a temporal convolution with a kernel size of to capture temporal dynamics. This is the residual block
      • They use left padding of to ensure causality.
    • To capture long-range dependencies, they utilize a spatio-temporal factorized causal self-attention with a global support region—for instance, for the last encoder block.
    • Note that the QKV are obtained using causal convolutions too.
    • The downsampling blocks downsample by 2× in time and 2× in space with a causal 3D conv.
      • stride=2, time_stride=2CausalConv3d uses stride=(2,2,2).
    • The decoder mirrors the encoder, replacing the downsampling blocks with an upsampling block.

Discrete and continous tokens

  • They employ the vanilla autoencoder (AE) formulation to model the continuous tokenizer’s latent space.
    • The latent dimension for the continuous tokenizers is 16
  • For discrete tokenizers, they adopt the Finite scalar quantization (FSQ) - VQ-VAE Made Simple as the latent space quantizer.
    • For the discrete tokenizers, the latent dimension is 6,which represents the number of the FSQ levels, which are (8, 8, 8, 5, 5, 5).
    • This configuration corresponds to a vocabulary size of 64,000

Compression rate

  • They train the image tokenizers (denoted as CI and DI) at two compression rates: 8 × 8 and 16 × 16.

  • They train the video tokenizers (denoted as CV and DV) at three compression rates: 4 × 8 × 8, 8 × 8 × 8, and 8 × 16 × 16.

  • Here, the compression rates are expressed as 𝐻 × 𝑊 for images and 𝑇 × 𝐻 × 𝑊 for videos

  • In the code, compression is controlled by the variable `channels_mult: list[int]

    • where num_dowsample_stage = len(channels_mult)
    • and channels_mult[i] defines the multiplier of the original channel dimension (normally it’s 4, because 2-level Haar transform) at the i’th stage. i.e. out_channels = channels * channels_mult[i]

Wan VAE

  • 127M parameter
  • very similar design to Cosmos, they use 3d Causal Convs, and self-attention