DIAMOND
๐Ÿ’Ž
Diffusion for World Modeling:
Visual Details Matter in Atari

1University of Geneva, 2University of Edinburgh, 3Microsoft Research
*Equal Contribution  โ€ Equal Supervision

NeurIPS 2024 Spotlight

Diamond agent playing in Atari Diffusion World Model.
DIAMOND ๐Ÿ’Ž (DIffusion As a Model Of eNvironment Dreams) is a reinforcement learning agent trained entirely in a diffusion world model. The agent playing in the diffusion model is shown above.

DIAMOND's diffusion world model can also be trained to simulate 3D environments, such as CounterStrike: Global Offensive (CSGO).

Abstract

World models constitute a promising approach for training reinforcement learning agents in a safe and sample-efficient manner. Recent world models predominantly operate on sequences of discrete latent variables to model environment dynamics. However, this compression into a compact discrete representation may ignore visual details that are important for reinforcement learning. Concurrently, diffusion models have become a dominant approach for image generation, challenging well-established methods modeling discrete latents. Motivated by this paradigm shift, we introduce DIAMOND (DIffusion As a Model Of eNvironment Dreams), a reinforcement learning agent trained in a diffusion world model. We analyze the key design choices that are required to make diffusion suitable for world modeling, and demonstrate how improved visual details can lead to improved agent performance. DIAMOND achieves a mean human normalized score of 1.46 on the competitive Atari 100k benchmark; a new best for agents trained entirely within a world model. To foster future research on diffusion for world modeling, we release our code, agents and playable world models at https://github.com/eloialonso/diamond.

CSGO DIAMOND ๐Ÿ’Ž Diffusion World Model Demonstrations

All videos generated by a human playing with keyboard and mouse inside
DIAMOND's diffusion world model, trained on CSGO.

Try it for yourself

Try out our playable CSGO and Atari world models for yourself: Installation Instructions


git clone git@github.com:eloialonso/diamond.git
cd diamond
conda create -n diamond python=3.10
conda activate diamond
pip install -r requirements.txt

To play our Atari world models: python src/play.py --pretrained

For our CSGO world model: git checkout csgo
python src/play.py

How does it work?

We train a diffusion model to predict the next frame of the game. The diffusion model takes into account the agentโ€™s action and the previous frames to simulate the environment response.

The diffusion world model takes into account the agent's action and previous frames to generate the next frame.
The diffusion world model takes into account the agent's action and previous frames to generate the next frame.


The agent repeatedly provides new actions, and the diffusion model updates the game.

The diffusion model acts as a world model in which the agent can learn to play.

Autoregressive generation with diffusion world model.
Autoregressive generation enables the diffusion model to act as a world model in which the agent can learn to play.

To make the world model fast, we need to reduce the number of denoising steps. We found DDPM (Ho et al. 2020) to become unstable with low numbers of denoising steps. In contrast, we found EDM (Karras et al., 2022) to produce stable trajectories even for 1 denoising step.

DDPM vs EDM based diffusion world models. The DDPM-based model becomes unstable for low numbers of denoising steps, while the EDM-based model remains stable.
The DDPM-based model is unstable for low numbers of denoising steps due to accumulating autoregressive error, while the EDM-based model remains stable. Lower denoising steps enables a faster world model.

But in Boxing, 1-step denoising interpolates between possible outcomes and results in blurry predictions for the unpredictable black player.

In contrast, using more denoising steps enables better selection of a particular mode, improving consistency over time.

Diffusion world model trajectories for the Atari game Boxing for varying numbers of denoising steps.
Larger numbers of denoising steps n enable better mode selection for transitions with multiple modes. We therefore use n=3 for Diamond's diffusion world model.

Interestingly, the white player's movements are predicted correctly regardless of the number of denoising steps. This is because it is controlled by the policy, so its actions are given to the world model. This removes any ambiguity that can cause blurry predictions.

We find that diffusion-based DIAMOND provides better modeling of important visual details than the discrete token-based IRIS.

Visualisation of IRIS and DIAMOND world's models on Asterix, Breakout and RoadRunner.
DIAMOND's world model is able to better capture important visual details than the discrete token-based IRIS.

Training an agent with reinforcement learning on this diffusion world model, DIAMOND achieves a mean human-normalized score of 1.46 on Atari 100k (46% better than human); a new best for agents trained in a world model on 100k frames.

Check out our paper for more details!

BibTeX

@inproceedings{alonso2024diffusionworldmodelingvisual,
      title={Diffusion for World Modeling: Visual Details Matter in Atari},
      author={Eloi Alonso and Adam Jelley and Vincent Micheli and Anssi Kanervisto and Amos Storkey and Tim Pearce and Franรงois Fleuret},
      booktitle={Thirty-eighth Conference on Neural Information Processing Systems}}
      year={2024},
      url={https://arxiv.org/abs/2405.12399},
}