- Tokenizers come in two types: continuous and discrete
- Continuous tokenizers encode visual data into continuous latent embeddings, as in latent diffusion models
- These embeddings are suitable for models that generatedata by sampling from continuous distributions.
- Discrete tokenizers encode visual data into discrete latent codes, mapping them into quantized indices, as seen in autoregressive transformers such as VideoPoet
- Continuous tokenizers encode visual data into continuous latent embeddings, as in latent diffusion models
Cosmos tokenizer
Architecture

- Cosmos Tokenizer is designed as an encoder-decoder architecture.
- Given an input video , with , , being the height, width, and number of frames, the encoder () tokenizes the inputs into a token video , with a spatial compression factor of and a temporal compression factor of .
- The decoder () then reconstructs the input video from these tokens, resulting in the reconstructed video , mathematically given by:
- The architecture employs a temporally causal design, ensuring that each stage processes only current and past frames, independent of future frames.
3D Haar wavelet transform
- Unlike common approaches, the tokenizer operates in the wavelet space, where inputs are first processed by a 2-level wavelet transform (3D Haar Wavelet Transform).
- Specifically, the wavelet transform maps the input video in a group-wise manner to downsample the inputs by a factor of four along , , and . The groups are formed as: .
- The wavelet transform allows to operate on a more compact video representation that eliminates redundancies in pixel information, allowing the remaining layers to focus on more semantic compression.
- Subsequent encoder stages process the frames in a temporally causal manner as .
Block design
-
They make extensive use of a custom
CausalConv3d, which is a small wrapper aroundnn.Conv3dthat turns an ordinary 3D conv into a strictly time-causal conv (no look-ahead in time), while still letting you use stride / dilation / 3D kernels. -
The encoder stages (post wavelet transform) are implemented using a series of residual blocks interleaved with downsampling blocks. In each block, they employ
- a spatio-temporal factorized 3D convolution, where you first apply a 2D convolution with a kernel size of to capture spatial information, followed by a temporal convolution with a kernel size of to capture temporal dynamics. This is the residual block
- They use left padding of to ensure causality.
- To capture long-range dependencies, they utilize a spatio-temporal factorized causal self-attention with a global support region—for instance, for the last encoder block.
- Note that the QKV are obtained using causal convolutions too.
- The downsampling blocks downsample by 2× in time and 2× in space with a causal 3D conv.
stride=2,time_stride=2→CausalConv3dusesstride=(2,2,2).
- The decoder mirrors the encoder, replacing the downsampling blocks with an upsampling block.
- a spatio-temporal factorized 3D convolution, where you first apply a 2D convolution with a kernel size of to capture spatial information, followed by a temporal convolution with a kernel size of to capture temporal dynamics. This is the residual block
Discrete and continous tokens

- They employ the vanilla autoencoder (AE) formulation to model the continuous tokenizer’s latent space.
- The latent dimension for the continuous tokenizers is 16
- For discrete tokenizers, they adopt the Finite scalar quantization (FSQ) - VQ-VAE Made Simple as the latent space quantizer.
- For the discrete tokenizers, the latent dimension is 6,which represents the number of the FSQ levels, which are (8, 8, 8, 5, 5, 5).
- This configuration corresponds to a vocabulary size of 64,000
Compression rate
-
They train the image tokenizers (denoted as CI and DI) at two compression rates: 8 × 8 and 16 × 16.
-
They train the video tokenizers (denoted as CV and DV) at three compression rates: 4 × 8 × 8, 8 × 8 × 8, and 8 × 16 × 16.
-
Here, the compression rates are expressed as 𝐻 × 𝑊 for images and 𝑇 × 𝐻 × 𝑊 for videos
-
In the code, compression is controlled by the variable `channels_mult: list[int]
- where
num_dowsample_stage = len(channels_mult) - and
channels_mult[i]defines the multiplier of the original channel dimension (normally it’s 4, because 2-level Haar transform) at the i’th stage. i.e.out_channels = channels * channels_mult[i]
- where
Wan VAE
- 127M parameter
- very similar design to Cosmos, they use 3d Causal Convs, and self-attention
