Smaller World Models for Reinforcement Learning

Model-based reinforcement learning algorithms try to learn an agent by training a model that simulates the environment. However, the size of such models tends to be quite large which could be a burden as well. In this paper, we address the question, how we could design a model with fewer parameters than previous model-based approaches while achieving the same performance in the 100 K-interactions regime. For this purpose, we create a world model that combines a vector quantized-variational autoencoder to encode observations and a convolutional long short-term memory to model the dynamics. This is connected to a model-free proximal policy optimization agent to train purely on simulated experience from this world model. Detailed experiments on the Atari environments show that it is possible to reach comparable performance to the SimPLe method with a significantly smaller world model. A series of ablation studies justify our design choices and give additional insights.


Introduction
Reinforcement learning environments can be formalized with Markov decision processes (MDPs).An MDP is defined by a set of states S, a set of actions A, the starting distribution p(s 1 ), and the dynamics distribution p(s t+1 , r t | s t , a t ) for any time step t.The dynamics distribution describes the probability of the next state s t+1 ∈ S and reward r t ∈ R, given the current state s t ∈ S and action a t ∈ A. This one-step dynamics distribution fully describes the environment's dynamics, since the environment is assumed to satisfy the Markov property, i.e., states retain all relevant information about the past.For practicality, we add another variable d t ∈ {0, 1} that indicates whether the state is non-terminal or terminal, respectively.
The dynamics distribution then becomes p(s t+1 , r t , d t | s t , a t ).This still fits into the MDP definition, as d t can be seen as part of the state.
An reinforcement learning agent interacts with an environment and tries to maximize the sum of rewards by taking the right actions.The behavior of an agent is defined by a policy distribution π θ (a t | s t ) that maps states onto actions, where θ denotes the parameters of some learned model.Reinforcement learning methods provide means to optimize the policy distribution in the sense that the actions that maximize the sum of rewards have high probability.
Model-free reinforcement learning algorithms try to optimize the policy π θ (a t | s t ) based on real experience from the environment, without any knowledge of the underlying dynamics.They have shown great success in a wide range of environments, yet they often require a large amount of training data, i.e., many interactions with the environment.This low sample efficiency still renders them inappropriate for real-world applications with expensive data collection.Many efforts have been made to alleviate this problem, e.g., experience replay, where data is stored and reused multiple times instead of being thrown away immediately.
Model-based algorithms approximate the true dynamics with a model of the environment p φ (s t+1 , r t , d t | s t , a t ) ≈ p(s t+1 , r t , d t | s t , a t ) with parameters φ.To differentiate it from other models, we will call p φ (s t+1 , r t , d t | s t , a t ) a world model.The process of improving a policy with a world model is called planning and can be carried out in two ways [23]: (i) Decision-time planning dynamically improves the action selection process during runtime by looking ahead into the future using the world model.This can be illustrated with the game of chess, where it is beneficial to think about the consequences of possible moves.This slows down the agent at decision-time, but it may be favorable for environments in which predictions about the future are crucial.(ii) Background planning generates novel training data by sampling from the world model.The agent can be trained with this simulated experience, possibly in combination with real experience.This increases training time, but is suited for environments in which fast decisions are required and collecting experience is expensive, e.g., self-driving cars or robots, because after training, the world model is not needed anymore.

Contributions
In this work we consider a model-based approach with background planning as described above.The ability to generate new experience without acting in the real environment increases the sample efficiency, because less real data is necessary.An overview of our approach can be seen in Fig. 1.Our main contributions and insights can be summarized as follows: • We implement a world model based on a VQ-VAE which requires fewer parameters than existing approaches (see Table 4).• Our world model uses a two-dimensional discrete latent representation combined with a dynamics network built from convolutional LSTMs (see Fig. 5) and we demonstrate that this choice is favorable for model-based reinforcement learning (see Table 6).• We extensively evaluate our world model in the Arcade Learning Environment and restrict the training to 100K interactions instead of the usually used 10 M or 50 M interactions.We show that learning a latent space world model and training an agent is possible in this restricted regime (see Table 2).

Preliminaries
Many real world environments only expose observations o t that might not contain all information of the underlying states s t .The Markov property might not hold for these observations.For instance, the observation o t could be a camera image while the state s t could describe the actual position and velocity of objects.This partial observability can be seen as a special case of function approximation (e.g., the parameters can be chosen in such a way that the approximation does not depend on the unobservable parts of the state), and all algorithms that consider function approximation can be applied to partial observability [23].Therefore, we will continue to denote the observations as s t for simplicity.

Actor-Critic Algorithms
A central concept of reinforcement learning is the value, which is the expected discounted sum of rewards when starting at some position and following a specific policy.The state-value function v π θ (s t ) computes the value when starting at state s t and following policy π θ .The action-value function q π θ (s t , a t ) computes the value when starting at state s t and action a t .Formally, they are defined as where R t , S t , and A t are random variables for the reward, state, and action at time step t, respectively, and γ ∈ [0, 1) is the discount factor that determines the weighting of distant rewards.The exact definitions of the random variables and the expectations can be derived from the definition of the MDP and the dynamics distribution.Furthermore, the advantage of a particular action over the average action in terms of value can be computed with the advantage function a π θ (s t , a t ) := q π θ (s t , a t ) − v π θ (s t ) for a given policy π θ .Model-free algorithms can be divided into two classes: value-based methods, which estimate a value function and implicitly define the policy such that the value is maximized, and policy gradient methods, which directly learn a policy distribution.In this work we will focus on policy gradient methods.The policy gradient objective L PG (θ ) provides a learning signal to optimize the parameters θ of the policy distribution.There are several ways to formulate this objective, but we will use the following definition: where a ψ (s t , a t ) ≈ a π θ (s t , a t ) is an estimator of the true advantage function with parameters ψ [23].If this estimator involves an estimation of a value function, it is called an actor-critic method.A common estimator of the advantage function is the expected multi-step temporal difference error where v ψ (s t ) ≈ v π θ (s t ) is an estimator of the state-value function with parameters ψ.The horizon T determines the number of time steps before "bootstrapping" the remaining value with the learned state-value function and it controls the trade-off between bias and variance.Typically, data is sampled from N environment instances for T time steps to obtain a minibatch of shape N × T , in order to approximate the expectations in Eq. 3 and Eq. 4 and update the parameters θ .At the same time, the learned state-value function in Eq. 4 can be optimized by minimizing the following mean squared error where sg(•) is the stop-gradient operator that prevents gradients from flowing through the target values when applying stochastic gradient descent.
To improve the exploration behaviour of the policy, [16] suggest an entropy regularization term that encourages the policy to stay slightly uncertain, denotes the conditional entropy of A t given S t for the distribution π θ .We can define a combined maximization objective for both the policy and the value function where c 1 , c 2 ≥ 0 are coefficients that control the contribution of the entropy regularization and the state-value function, respectively.The "vanilla" policy gradient discussed so far can be further improved, for instance, via trust region methods, which bound large policy updates in terms of changes to the policy distribution.This can be achieved by imposing a constraint on the divergence between the old policy and the new policy after a parameter update and either requires expensive secondorder optimization or KL divergence approximations.Proximal policy optimization (PPO, [19]) instead proposes a heuristically determined first-order objective aimed at bounding large policy updates (7) where π θ old is the policy used to collect the data, r θ (a t , s t ) := π θ (a t | s t ) π θ old (a t | s t ) is the likelihood ratio, a ψ (s t , a t ) ≈ a π θ old (s t , a t ) is an estimator of the advantage function of the old policy, and clip(x, a, b) := max{a, min{b, x}}.In order to reduce variance, the likelihood ratio is clipped between 1 − and 1 + for a small , thus keeping it close to 1.The bound is pessimistic, as the objective is only clipped if it would improve too heavily, while decreasing the objective is not prevented (because of the minimum).Clipping takes effect in two cases: (i) the advantage estimate a ψ (s t , a t ) is positive and the policy gradient wants to increase the probabilities of the actions excessively, and (ii) the advantage estimate is negative and the policy wants to decrease the probabilities of the actions too heavily.The likelihood ratios allow for multiple parameter updates on a single minibatch, so more information can be squeezed out of the data before throwing it away.The PPO objective can be combined with the state-value function objective and the entropy regularization term, resulting in the final objective we will use in our work where the coefficients c 1 , c 2 ≥ 0 control the contribution of the objectives.

Arcade Learning Environment
The Arcade Learning Environment (ALE) by [3] is a commonly used framework to evaluate reinforcement learning algorithms.It provides access to many games of the Atari 2600 console.
The actions are discrete and correspond to the inputs that can be made using the Atari 2600 joystick, with a total number of 18.The reward is defined as the change in the score of a game.The observations are raw video frames, which are usually preprocessed before being passed to the agent.Typically, the RGB frames of shape 210 × 160 × 3 are converted to grayscale and scaled down to a square image, in our case to 96 × 96.We follow the common preprocessing steps from [15]: (i) Convert frames to grayscale and scale them down.(ii) Frame skipping, i.e., only consider every nth frame and repeat the same action in-between, which reduces the effective frame rate and enables the agent to reach further advanced observations in less time steps.(iii) Frame stacking, i.e., concatenate the last k frames (after frame skipping), which allows the agent to observe short-time effects, e.g., the velocity of objects.(iv) Max pooling, i.e., take the component-wise maximum of the two most recent frames, since hardware limitations of the Atari 2600 can cause objects to be displayed only every other frame.
In the following, an interaction with an environment denotes a single action that is sent to the game.We will use the common frame stacking value of 4 (i.e., we only consider every fourth frame).Therefore, an interaction corresponds to four frames.There are common values for the number of frames that are collected during training of the reinforcement learning agent, among which are 10, 40, 100, and 200 M (or in terms of interactions 2.5, 10, 25, and 50 M, respectively).In terms of real time these are roughly 46, 185, 463, and 962 hours, respectively, which are orders of magnitude longer than the time a human would usually need to be comparatively good at an Atari 2600 game.In this work we limit the agent to about 100K interactions, i.e., 400 K frames, as proposed by [11], which (approximately) equals 1 hour and 51 minutes of real time gameplay.

Variational Autoencoders
A variational autoencoder (VAE, [13]) is a powerful generative model that can produce new samples of a complex distribution p(x) by decoding samples of a simpler distribution p(z), e.g., samples from p(x) could be images and samples from p(z) low-dimensional real vectors.Conversely, a VAE can also map high-dimensional samples x onto low-dimensional representations z.
VAEs can be derived from latent variable models, which are generative models that sample a latent variable z and decode it into the actual data x via a learned decoder p (x | z) with parameters , instead of directly sampling x.This has the advantage that high-level or global information can be decided upon before generating the high-dimensional output, e.g., the digit, its position, its shape, etc., before generating the pixels.However, these models suffer from the problem that maximizing the likelihood p (x) = p (x | z) p 0 (z)dz, requires a lot of samples from the prior p 0 (z) in order to approximate the integral.Even if z is lowdimensional, the number of possible samples usually is really high, p (x | z) will be close to zero for most z, and capturing the entire latent space via random sampling is not tractable.VAEs approach this problem by additionally learning an encoder p (z | x) that can produce samples of z conditioned on training data x, so that only samples of z that are relevant for maximizing the likelihood are considered.Using variational inference, a lower bound on the log-likelihood can be derived [13] The expected value in Eq. 9 can be interpreted as a "reconstruction" term and is often approximated using a single sample from p (z | x).The KL divergence acts as a regularizer on the encoder since it keeps its outputs closer to the prior distribution.
VAEs with a normally distributed latent space often have the issue that the outputs of the decoder are "blurry" (e.g., blurry images).One reason is that the training samples still only sparsely fill the latent space and sampling from a standard normal prior can generate latent variables with low probability in p (x | z).Vector quantized-variational autoencoders (VQ-VAEs, [25]) provide two improvements: (i) The latent space is discrete, which makes it more compact, and (ii) Samples from the prior are generated more carefully with an additionally trained autoregressive neural network.
A discrete latent space for image inputs is achieved as follows.An encoder convolutional neural network f (x) outputs tensors of shape H × W × C, with spatial dimensions H , W , and number of hidden channels C.This output is interpreted as a two-dimensional structure of C-dimensional vectors that get discretized via vector quantization (VQ).That is, a list of K embedding vectors is maintained and the output vectors are replaced by those embedding vectors that have the smallest Euclidean distance.The latent representation z then is a matrix of integers, where z i, j is the index of the closest embedding vector for the output vector f (x) i, j at row i and column j.The embedding vectors are updated by either moving them slightly towards the output vectors in every parameter update or by keeping an exponential moving average of the output vectors.During backpropagation, the gradients are estimated by passing them straight through the vector quantization step, as if the quantization did not happen.
The probabilistic interpretation of the vector quantization is that the encoder p (z | x) outputs a two-dimensional grid of shape H × W of independent categorical distributions, where the number of categories corresponds to the number of embedding vectors.For each entry z i, j of the latent representation, the index of the embedding vector closest to the output vector has a probability of one; all other categories have zero probability.The prior p 0 (z) is a two-dimensional grid of shape H × W of uniform categorical distributions, i.e., each p 0 (z i, j ) = Cat z i, j ;K , 1 K • • • 1 K is a categorical distribution with equal probabilities.This choice of distributions leads to a constant KL divergence term in the lower bound from Eq. 9, which is another beneficial property of VQ-VAEs and makes them more robust across different training sets, since the magnitude of the KL divergence stays the same, independent of the training data x [25].For the standard (Gaussian) VAEs, this is usually addressed by inserting a coefficient for the KL term in the lower bound in Eq. 9 [9], which needs to be manually fine-tuned for each dataset.
To generate latent variables autoregressively, a PixelCNN [24] is trained to predict the indices of the embedding vectors, after finishing the training of the encoder, decoder, and embedding vectors.The indices can be looked up in the list of embedding vectors, which then are decoded.This allows VQ-VAEs to generate stable, non-blurry samples.

Method
This section provides a summary of the methods that we use to build our world model, starting with a high-level view and gradually becoming more concrete.First, in Sect.3.1 we introduce the concepts and our notation for latent space world models and policies.Second, in Sect.3.2 we show the neural network architectures that we employ.Lastly, in Sect.3.3 we explain our training procedure and how the models work together.

Latent Space World Models
Simulating experience in latent space instead of the raw state or observation space can be beneficial.It is computationally more efficient and the potentially less complex latent space can alleviate the prediction tasks.Therefore, we estimate the dynamics p(s t+1 , r t , d t | s t , a t ) with latent variables.Note that the latent variables are no direct approximations of the true MDP states, as their sole purpose is to facilitate the generation of experience.
We will continue with the following notation for conditional expectations: The sum rule of probability allows us to introduce latent variables z t , z t+1 into the dynamics We assume that the graphical models in Fig. 2 describe the dependencies between the random variables.This allows us to estimate the true dynamics using the world model p φ (s t+1 , r t , d t | s t , a t ) ≈ p(s t+1 , r t , d t | s t , a t ) with parameters φ in a specific way, as described in the following.The conditional independencies induced by the graphical model in Fig. 2a lead to the following approximation of the dynamics 123 Fig. 2 Graphical models of the non-recurrent world model and the policy.From the perspective of the world model, the observation s t and the action a t are noise variables.From the perspective of the policy, s t is a noise variable.The black dot is an intermediate variable, which reduces the number of arrows In practice we approximate each expectation with a single Monte Carlo sample.This graphical model is just one of many possibilities to build a world model and now we will explain our motivation for choosing this one.We want an observation encoding model, p φ (z t | s t ), that does not depend on the action, so that we can apply a (vector quantized-)variational autoencoder to the observations analogous to the "vision" component of [6].We also want to predict the next latent variable based on the previous latent variable and action, i.e., p φ (z t+1 | z t , a t ), independent of the observation s t , so that no decoding into observations is necessary during the simulation of experience.Furthermore, the reward and next latent variable should not depend on each other, which allows us to use two prediction heads and to compute them in a single neural network pass.In contrast to [11], our policy is conditioned on the latent variables, so that no decoding into high-dimensional observations is necessary when choosing actions, where θ are the parameters of the policy.We use the conditional independence a t ⊥ ⊥ s t | z t induced by the graphical model in Fig. 2b.So far, the world model p φ (s t+1 , r t , d t | s t , a t ) has to predict s t+1 , r t , and d t based on the current observation s t and action a t .As discussed in Sect. 1, the observation might only contain partial information of the true state and the Markov property might not hold.Therefore, we introduce recurrent variables h t−1 , h t that can capture information over multiple time steps, analogous to the "memory" component of [6], and we obtain a recurrent world model The graphical models in Fig. 3 describe the dependencies between the random variables for our recurrent world model.This allows us to derive the recurrent dynamics Once more, we use the independencies induced by the graphical model.Following the graphical model in Fig. 3b, the policy can be conditioned on the recurrent variable, i.e., We will implement a recurrent world model in this work.We believe that this type of world model is appropriate for a wide range of environments and can also be applied to domains other than Atari, since we made no domain-specific assumptions (except the independence between reward and next latent variable, which could be unfavorable for some environments).From Eq. 13 it becomes clear that several models have to be learned, which are summarized and labeled in Table 1, in addition to the policy.The parameters of the world model are denoted by φ and the parameters of the policy by θ .In the next section we describe our architectures that implement the models from Table 1.

Observation Encoder and Decoder
We use a vector quantized-variational autoencoder [25] for the observation encoder and decoder, so that each latent variable z t is a 6 × 6 matrix filled with discrete embedding indices.We set the size of the embedding vectors to 32 and there are 128 embedding vectors in total.We apply certain preprocessing steps to the raw frames of the Atari game (see Sect. 2.2), so that the observations s t are a stack of four grayscale frames, which results in the shape 96 × 96 × 4. Frame stacking allows the observation encoder to incorporate short-time information into the stationary latent representations.
The encoder consists of multiple convolutions each followed by batch normalization [10] and the decoder of multiple deconvolutions.See Fig. 4 for a visualization.After each batch normalization and deconvolution we add a leaky ReLU nonlinearity [14].The encoder p φ (z | s) and the prior p 0 (z) follow the default VQ-VAE distributions, i.e., one-hot categorical distributions and uniform categorical distributions, respectively.The last deconvolution com- Our motivation for a two-dimensional latent representation is that it is able to maintain spatial correlations.In combination with convolutional operations, these correlations can be respected when predicting the next time step.

Dynamics
The recurrent dynamics model is implemented by a two-cell convolutional LSTM [22] with layer normalization [1] after each cell.The input consists of a 6 × 6 × 50 tensor, where the first 32 channels are made up of the embedding vectors, which are looked up in the VQ-VAE using the indices of the latent representation; the last 18 channels contain one-hot encodings of the action, repeated along the spatial dimensions, in order to condition the dynamics on the action.In fact, for environments with less than 18 actions the number of channels is reduced accordingly.The action encodings are also concatenated to the output of each convolutional LSTM cell, because the action information might get lost during the forward pass.Since the output h t of the last convolutional LSTM cell is deterministic, there is no stochasticity in p φ (h t | z t , a t , h t−1 ).In addition to a convolutional LSTM we have tested the follow-up architectures spatio-temporal LSTM [26] and causal LSTM [27], but they have increased the number of parameters and the training time significantly, while the performance has stayed the same.
After that there are two prediction heads: one for the latent predictor, consisting of one convolutional layer, and one for the reward predictor, consisting of a convolutional layer with layer normalization and two fully-connected layers, each followed by a leaky ReLU nonlinearity.The entire model is depicted in Fig. 5.
The latent predictor head computes the unnormalized scores U ∈ R 6×6×128 for the embedding indices.These scores get normalized via the softmax function, in order to obtain an independent categorical distribution for each component of z t+1 , i.e., Fig. 5 Dynamics network.After the second convolutional LSTM cell, the network splits into the reward predictor head at the top, and the latent predictor head at the bottom.The recurrent states h t , h t−1 are omitted for clarity where (z t+1 ) j,k is the latent matrix entry at row j and column k, and U j,k is the output vector at row j and column k.We suppose that the discretization of the latent space stabilizes the latent predictor, since it has to predict scores for a predefined set of categories instead of real values.This is especially important when we take into account that the targets are moving, i.e., that the latent representations, which are produced by the VQ-VAE, change during training.
The rewards are discretized into three categories {−1, 0, 1} by clipping them into the interval [−1, 1] and rounding to the nearest integer.The reward predictor head computes the unnormalized scores V ∈ R 3 for the corresponding reward categories.These scores are normalized via the softmax function to obtain a categorical distribution, i.e., The support of this distribution is r ∈ {1, 2, 3}, so we have to map the rewards accordingly (r = r orig + 2) when we compute the likelihood.Since Atari games are episodic, the world model has to terminate episodes by predicting the binary terminal variable d t .This prediction has to be reliable, as an incorrect prediction of d t = 1 can have a severe impact on the simulated experience, and thus on the policy.We follow [11] and employ a naive but effective solution, which is to end all episodes after a fixed number of time steps T sim , i.e., we assign the following fixed distribution, where T sim := 50 for all of our experiments.On the downside, this prevents the policy to learn from effects that are longer than T sim [11].

Policy
Our policy π θ (a t | z t ) is only conditioned on the latent representation, and not on the recurrent variable h t−1 , as presented in Eq. 12.The policy network is visualized in Fig. 6.Fig. 6 Policy network.After the first fully-connected layer that computes a "hidden" vector, the network splits up into the value prediction head at the top, and the action distribution head at the bottom The embedding vectors of the VQ-VAE are sent through two convolutional layers with layer normalization and a fully connected layer; all layers are followed by leaky ReLU nonlinearities.This "hidden" output is passed to two separate fully connected layers: the estimated state-value function v θ (z t ), and the action logits W ∈ R M , which are the unnormalized scores for the M possible actions.The action distribution is a categorical distribution, which we get by normalizing the scores with the softmax function,

Training Procedure
We follow the Simulated Policy Learning (SimPLe) training procedure that was presented by [11].The idea is to iteratively repeat three steps, namely, (i) Collecting real experience and storing it in a data buffer, (ii) Training the world model with experience from the data buffer, and (iii) Training the agent by simulating experience with the world model.This loop can be illustrated with the following pseudocode: In the following we provide more detailed descriptions of the individual steps.

Collecting Experience
In the first step of each iteration experience is collected from the real environment.In the first iteration, we apply a random policy that samples actions uniformly.In subsequent iterations, we use the trained policy.The observation encoder processes the raw observations to produce latent representations, since the policy is conditioned on them.This process is visualized in Fig. 7a.There is no recurrency involved at this point.We adopt the evaluation method from [11] and limit the total number of interactions (i.e., the number of actions taken) throughout the entire training process to roughly 100K .We follow [11] and perform 15 iterations with 6400 interactions per iteration.They perform additional 6400 interactions prior to the first iteration, therefore, we perform 12800 interactions in the first iteration, resulting in the same number of total interactions, i.e., 102400.We sample from a single environment instance and do not reset the environment between iterations.
As [11] have mentioned, collecting useful experience is crucial for the performance of the world model, especially in the low data regime.To improve exploration, we increase the randomness of the policy by inserting a temperature parameter τ > 0 into the Softmax calculation of the policy, i.e., where f θ (z t ) ∈ R M are the unnormalized scores computed by the policy network, and a i is the ith action for i ∈ {1, . . ., M}. Increasing the temperature τ causes the softmax normalization to be "softer", i.e., it makes the distribution more uniform, while decreasing τ moves the distribution closer to a maximum.

Training the World Model
In the second step of each iteration the world model is trained in a supervised manner using the collected real experience.We sample minibatches of N obs observations from the data buffer, s (i) ∼ p buffer (s) for i ∈ {1, . . ., N obs }, where p buffer (s) is a uniform distribution over all collected real observations in the data buffer.With these minibatches we minimize the VQ-VAE loss using the Adam optimizer [12], in order to optimize the observation encoder and decoder models.
To perform a parameter update of the dynamics models, we first sample a minibatch of N seq sequences of length T seq of observations, actions, and rewards from the data buffer.Then, we initialize the recurrent states with all zeros and iterate over the sequences step by step.At each step the observation encoder converts the observations to latent variables.The dynamics models take the latent variables, the actions, and the last recurrent states to predict the next latent variables and rewards, while also computing the next recurrent states.After all T seq time steps are processed, the sum over the sequence of the negative log-likelihood of the predicted latent and reward distributions is minimized with the Adam optimizer.To make training more robust, the "true" encoded latent variables are used as input to the dynamics models in the first T context time steps (also called the context size), and after that the latent variables from the latent predictor are used.This can mitigate the issue that bad predictions at the beginning of a sequence degenerate the remaining ones.Gradients coming from the reward predictor head (see Fig. 5) can have a relatively high magnitude, which can have a degrading effect on the performance of the other models.Therefore, we scale down the negative log-likelihood of the rewards, but increase the learning rate of the reward predictor head to compensate for the scaled down gradients.
The VQ-VAE and the dynamics models are trained in an alternating fashion, i.e., we perform a single parameter update of the VQ-VAE and then of the dynamics models.When the entire data buffer was used for training the dynamics models, one epoch is finished.Multiple epochs are executed per iteration, and the number of epochs decreases in every iteration since the data buffer grows; the exact numbers can be found in Table 5.After collecting the first batch of data we train the VQ-VAE separately for 50 epochs with a higher learning rate (in the "warm-up" phase).We want to give the dynamics models a better starting point by feeding it with latent representations that already contain useful information.This cannot be done in later training stages, however, as the representations would change and the dynamics models would not be able to keep up easily.Indeed, in the entire course of training, the rate at which the observation encoder model changes proved to be important.We have to adjust the training of the VQ-VAE in such a way, that it can learn from newly obtained Fig. 7 The interactions between world model, agent, and environment, at inference and at training time.In both cases the agent receives the experience in simulation space and does not notice a difference observations, without changing the latent representations too heavily.After this initial phase we even slow down the training of the observation encoder and decoder models by performing only every second or third parameter update, and effectively fixing the parameters in between (see update intervals in Table 5).
As a side note, [11] state that the scale of the KL divergence loss of a variational autoencoder is game dependent, which makes VAEs impractical to apply to all Atari games without fine-tuning.The VQ-VAE does not suffer from this problem, because the KL term is constant and only depends on the dimensions of the latent space and the number of embedding vectors, as explained in Sect.2.3.

Simulating Experience
The third step in each iteration is to improve the policy by simulating experience from the world model.To simulate a minibatch of N sim episodes, we initialize the latent variables by encoding N sim randomly selected observations from the data buffer.This enables the policy to learn from experience from any stage of the environment, although the number of time steps is limited to T sim (see Sect. 3.2.2).The rest of the experience is generated iteratively.At any time step t, the policy selects actions based on the current latent states and the world model computes the next latent state, reward, and hidden state based on the current latent states, actions, and recurrent states.See Fig. 7b for a visualization of this process.

Training the Policy
The last part of the third step is to train the policy in a model-free manner.We apply proximal policy optimization [19] to the simulated experience with the action distribution and statevalue being computed by the policy network (see Fig. 6).We set the temperature parameter from Eq. 19 back to τ := 1, since we do not want the extra randomness in the context of simulation.We approximate the PPO objective from Eq. 8 using minibatches of simulated experience with batch size N sim and horizon T horizon .We align the horizon with the maximum number of simulated time steps, i.e., we set T horizon := T sim .We estimate the advantage function using generalized advantage estimation (GAE, [20]).In most iterations 1000 minibatches are simulated and trained upon, but in some iterations we gnerate 2000 or 3000 minibatches.These numbers were determined empirically by [11] and can be found in Table 5.

Evaluation
In this section we evaluate our world model on Atari games, provide analyses on the results, and perform further ablation studies on a smaller scale.

Performance Comparison
We restrict our agent to 100 K interactions per game.We evaluate two approaches: using a single frame and using four stacked frames.Both approaches have slight advantages in some of the games.Further discussion on this can be found in Sect.4.3.
We average the results over five training runs with different seeds.For every run we evaluate the latest policy in each iteration by rolling out 32 episodes in the real environment and computing the mean of the (cumulative) episode rewards.In Table 2 we report the mean final episode rewards (i.e., the mean episode reward after the final iteration; averaged over five runs) for 36 Atari games.In Table 5 we provide a detailed list of all hyperparameters.In Fig. 9 we show the course of the mean episode reward over the 15 iterations for all games.
We compare our method with the default variant of SimPLe (stochastic discrete, 50 steps, γ = 0.99), which comes closest to our model in terms of hyperparameters (discount rate, batch size etc.) and number of parameter updates.This implies that the results of [11] that we show are not necessarily the best results across all of their variants, but the best for a fairer comparison.Our method achieves a higher value than this SimPLe variant in 20 out of 36 tested games.

Comparison with Human Performance and 10M
In Table 3 we compare our method (the best score of the default frame stacking and no frame stacking) with a random agent, a human baseline [15], and PPO trained on 10 M interactions [19].The human score is the episode reward averaged over 20 episodes, after around two hours of practice in each game, meaning that it is comparable to the 100 K interactions.We also state the human normalized score [15] of our method, which is calculated as (our score − random score)/(human score − random score), such that 0% corresponds to the random score and 100% corresponds to the human score.Values lower than 0% mean that the performance is worse than random and values higher than 100% indicate "superhuman" performance.We achieve superhuman performance in three games.In Fig. 8 we compare our human normalized scores with SimPLe.

Model Size
In Table 4 we show the number of parameters of our world model compared with the SimPLe world model by [11], which uses about seven times as many parameters.It also shows the number of parameters of all of our models in detail.At training time all models are used, but at inference time only the observation encoder and the policy network are necessary.

Training Times
A single run takes approximately 12 hours on an Nvidia A100 GPU.Due to the small model size and small batch sizes, the memory usage is quite low.On an Nvidia P100 GPU, our

Table 2
Mean final episode reward of our method (with and without frame stacking) in comparison with SimPLe [11] and model-free PPO [19] trained with 100 K interactions.
The numbers are taken from [11] Game method takes roughly 25 hours, while SimPLe [11] requires 500 hours on the same hardware [21], which is 20 times slower.

Analysis
In our experiments we observe the same phenomenon as [11], namely, that the results can vary drastically for different runs with the same hyperparameters but with different random seeds.A possible explanation for this is that the world model cannot infer dynamics of regions of the environment's state space that it has never seen before and that the algorithm is sensitive to the exploration-exploitation trade-off, as the number of interactions is relatively low.

Latent Representations
In Fig. 10 we show a visualization of the latent variables and their reconstructions.We assign random colors to the embedding indices and draw a colored square for each entry of the latent variable matrix.Note that a single square does not correspond to the same area in the Fig. 9 Human normalized scores of our method (4 frames), averaged over five runs ± standard deviation.The x-axis shows the number of interactions with the real environment, but does not reflect the number of parameter updates that were performed in between.The dashed lines show the best of the five runs and the horizontal lines show the final human normalized score of SimPLe [11], since we only know their final scores 3).The VQ-VAE is able to encode most information into the latent representation.We have picked these two games to show that changes in the scene (Fig. 10a) or even switching scenes (Fig. 10b) can be represented, although this can cause some loss of details (e.g., see the left reconstruction in Fig. 10b).

Generated Sequences
A straightforward method to evaluate the dynamics models is to observe the generated sequences.A video of some generated sequences can be found at https://github.com/jrobine/smaller-world-models.In this section we analyze the sequences in more detail.First, Fig. 11a shows frames of Freeway at a 40 frame interval.The world model has learned to predict the car movement in the correct direction (to the left at the top, and to the right at the bottom) and with varying speeds (but constant speeds per car; the red [top] and yellow [bottom] highlighted cars move faster than the blue [middle] one).Second, Fig. 11b shows that the world model is able to predict the ball movement in Breakout.In the first frame the ball moves down to the left; in the second frame the ball has correctly bounced off the paddle and the wall, and moves up to the right; in the third frame the ball has removed a block and has bounced off, and the score has been increased.However, the life counter has been increased incorrectly in the second frame.
In Fig. 12a we see a consecutive sequence of frames of FishingDerby in which some fish appear and disappear (highlighted in red).We observe this phenomenon in several games that are complex and involve multiple objects.Moreover, Fig. 12b shows consecutive frames of Gopher in which a second farmer appears out of nowhere, which should be impossible, and the tunnels at the bottom change arbitrarily in the third frame.
These examples illustrate that the world model tries to predict everything that is visible, e.g., even scores or other elements of the user-interface, no matter how helpful the information is for the agent.Furthermore, although the world model is not capable of modeling the dynamics of all games, we have observed that the predictions are relatively stable in the sense that the frames rarely degenerate completely (for example, fixed parts of the frames like the background are relatively stable), even way beyond the number of time steps that are used to train the dynamics models.

Latent Prediction Accuracy
On the one hand, acting on latent representations is computationally more efficient, since simulation can be performed purely in latent space without reconstructing observations.On the other hand, the representations can be less stable since they change during training and do not need to converge, and after each world model training iteration the agent gets confronted with new representations.In Fig. 13a we show the accuracy of the latent predictor p φ (z t+1 | z t , a t , h t ) over the course of training for three Atari games.There is a drop in accuracy every time a new iteration starts, since new unseen data is fed into the network, especially in early iterations.As already mentioned, it has proved to be better to train the observation encoder and decoder and the dynamics models simultaneously, because otherwise we would see more drops similar to Fig. 13a.Notice that the accuracy stays approximately constant, so we might ask whether the system is learning anything.The reason is that Fig. 13a reports the accuracy on the data seen so far, so the accuracy on the right has a different meaning than the accuracy on the left.

Mean Episode Reward
We consider three learning curves from Fig. 9 that are typical for our world model.First, in Freeway the agent's performance successfully increases over the course of training.Second, in Gravitar the agent's performance decreases over the course of training.This can have various reasons, e.g., when the agent reaches a new area of the state space and the environment's dynamics change drastically.Finally, in Hero the agent has comparably high performance from the beginning, which is most likely due to the inductive bias of the model.

Ablation Studies
To better understand which components of our model contribute to the overall performance, we conducted a series of ablation studies to answer the following questions (short answers in parentheses).We describe the results of the ablation studies in detail in the next sections.
(1) Can frame stacking improve the performance?(In a lot of environments feeding a single frame into the VQ-VAE model is sufficient.)(2) Which role plays the discretization of the latent space and how does it interact with the dimensionality of the latent space?(A two-dimensional discrete latent space yields the best results.)(3) Is conditioning the policy on reconstructed frames more suitable than acting in the latent space?(The latent representation is sufficient in most environments.)(4) Can the performance be improved by autoregressively sampling the latent variable?(No, and it furthermore slows down the inference time.)(5) Does a smaller receptive field improve the performance, since it simplifies the task of the dynamics model?(No, but it allows for more interpretable visualizations.)

Q1: Can Frame Stacking Improve the Performance?
As described earlier, our default VQ-VAE works with four stacked frames as input and as output.On the downside, this introduces complexity for the dynamics models, since the same frame can get a different representation, depending on the three other frames in the stack.In this experiment we set the frame stacking value to one, i.e., we disable frame stacking.
The results indicate that this can lead to significant improvements, which is why we have expanded this experiment to all environments and have included it in the main results.In Fig. 13b we show the increase in accuracy of the latent predictor in Gopher when using no frame stacking.However, this increase is not necessarily present in all environments, even if the overall performance is improved, since there are other important factors, e.g., how useful the latent representations are for the policy.The training times stay roughly the same, since the only change is that the depth of the filters in the first convolutional layer increases.

Q2: Which Role Plays the Discretization of the Latent Space and How Does it Interact with the Dimensionality of the Latent Space?
In this experiment we have compared different models to evaluate a discrete vs. a continuous latent space and a two-dimensional vs. a one-dimensional latent space.This leads to four model configurations: (1) Discrete, two-dimensional latent space This is the default configuration, so no changes have been made.(2) Discrete, one-dimensional latent space We have changed the size of the latent representation from 6 × 6 to 36 × 1 by stacking the rows of the output of the observation encoder.This way we can still use the convolutional LSTM that takes embedding vectors as input.
The architecture of the policy network only needs to be adjusted to the new input size.(3) Continuous, two-dimensional latent space We have changed the posterior distribution of the observation encoder to a product of independent normal distributions (instead of categorical distributions), and the latent space prior to a standard multivariate normal distribution.This corresponds to the usual setup for VAEs, but with a two-dimensional latent space.The output of the encoder is a 6×6×2 tensor with means in the first channel and standard deviations (actually, the logarithm of the variance) in the second channel.The latent representations are sampled from the posterior and no vector quantization is applied.Most parts of the dynamics model stays the same, but the input of the convolutional LSTM is a 6 × 6 × 1 tensor instead of the 6 × 6 tensor of embedding vectors, and the next latent head computes the parameters of independent normal distributions, analogous to the observation encoder.The policy network is adjusted to the one channel input.(4) Continuous, one-dimensional latent space This is a combination of the previous two configurations, i.e., the observation encoder computes two vectors of length 36, one for the mean and one for the standard deviations of 36 independent normal distributions.The latent space prior is a standard multivariate normal distribution.In the dynamics model we have used an LSTM instead of a convolutional LSTM, since the latent representations are real-valued vectors.For the same reason, the policy network uses linear layers instead of convolutional layers.
For the evaluation we have selected five environments in which our method performed best and no frame stacking has been used.The results are shown in Table 6.The main observations are that a discrete latent space tends to lead to a better performance (compare the first two columns with the last two columns), as well as a two-dimensional one (compare column one with column three, and column two with column four).

Q3: Is Conditioning the Policy on Reconstructed Frames More Suitable Than Acting in the Latent Space?
In this experiment we condition the policy on the reconstructed frames similar to the SimPLe world model by [11], i.e., the outputs of the VQ-VAE decoder.To improve the quality of the reconstructions, we apply no frame stacking.We adjust the policy network to match the 96 × 96 × 1 input tensors, and add downsampling convolutional layers.We evaluate this setup in a randomly selected subset of the environments and achieve better results in three out of eight environments compared with the policy conditioned on latent representations, while the results in the remaining five environments are relatively good as well, as can be seen in Table 7.This means that in some environments the frames are a more suitable representation for the agent than the latent variables.The training time is increased quite significantly, due to the additional costs of the observation decoder and the policy network, from roughly 12 hours to 21 hours per run on an Nvidia A100 GPU.

Q4: Can the Performance be Improved by Autoregressively Sampling the Latent Variable?
Currently, the latent predictor samples the indices of the next latent variable independently.This might be disadvantageous because conditional dependencies between the indices, which correspond to certain areas of the video frames, are ignored.In this experiment the world model predicts the next latent variable autoregressively.For this purpose, we change the architecture of the dynamics network to use a conditional PixelCNN with three gated convolutional layers [24] that is conditioned on the output of the convolutional LSTM.We reduce the convolutional LSTM to one cell, and also apply no frame stacking.We only have preliminary results for this experiment, shown in Table 8, since the training time of a single run increases from 12 hours to 70 hours on an Nvidia A100 GPU.Most of the time is spent on simulating experience, because this is where the (sequential) autoregressive sampling is required.We did not adjust the hyperparameters, so they are not fine-tuned for this architecture.We see a minor improvement in two of the four environments, however the results are not convincing considering the additional training time.In this experiment we narrow the receptive field of the latent representations by changing the kernel sizes, strides, and paddings of the convolutional layers in the observation encoder and decoder.As a consequence, each entry z i, j of the latent representation matrix corresponds to a 16×16 image patch in the 96×96 frame without overlap.The results that we get on a small scale indicate that this does not improve the performance, but it allows for a more interpretable visualization of the latent space.In Fig. 14 we can observe that the individual objects like the balls or the paddles are assigned a specific index in the latent representation.However, depending on its location inside of an image patch, the same object may get assigned varying indices.When an object is located between two patches, both corresponding indices are changed.

Summary
Our ablation study shows that the two-dimensional discrete latent representation is essential for the good performance of our approach.We have also seen that the latent representation is more efficient than reconstructed frames.

World Models [6]
Modeling environments with complex visual observations is a hard task.The complexity can be alleviated by encoding observations into low-dimensional representations and acting completely in this latent space.[6] train a variational autoencoder, which they call the "vision" component.It extracts information of the observation at the current time step.An LSTM combined with a mixture density network [4] predicts the latent variable at the next time step stochastically.They call it the "memory" component, which can accumulate information over multiple time steps.The policy is conditioned on latent variables instead of observations, which enables them to simulate experience in latent space, without decoding back into pixel space.This is more efficient and can reduce the effect of accumulating errors [6].In Sect.3.1 we describe in more detail how to integrate latent variables and recurrent states into the world model.
They successfully evaluate their architectures on two environments, but it involves some manual fine-tuning of the policy.The policy is optimized with an evolution strategy, which is not suitable for bigger networks.Additionally, their training procedure is non-iterative, i.e., they randomly collect real experience only once and then train the world model and the policy.This implies that no new experience can be obtained with the improved policies and that a random policy has to ensure sufficient exploration, which makes the approach inappropriate for more complex environments.

Simulated Policy Learning [11]
[11] succesfully train a world model on a subset of games of the Arcade Learning Environment, while restricting each game to 400 K frames.First, they introduce an iterative training procedure (SimPLe) that alternates between collecting real experience, training the world model, and improving the policy with simulated experience from the world model.Second, a video prediction model similar to SV2P [2] predicts the next frame conditioned on the current frame, while the input action is incorporated in the decoding process.The latent variables are discretized into a bit vector and an LSTM predicts it autoregressively during inference time (similar to the PixelCNN of a VQ-VAE).Although a latent space is used, simulating new experience is relatively expensive, as decoding into pixel space is necessary at each time step.They train the policy with the model-free PPO algorithm [19], conditioned on the real or predicted frames, and get excellent results in a lot of Atari games considering the low number of interactions.

DreamerV1 [7] and DreamerV2 [8]
[7, 8] train a world model with meaningful latent representations, so that the agent can operate directly on the latent variables.Similar to [6], a variational autoencoder encodes observations into latent space and a recurrent neural network predicts the next time step.The authors cleverly split the state into a stochastic and a deterministic part.One of the main improvements of DreamerV2 over the world model of DreamerV1 is the discretization of the latent space, as it uses a vector of categorical variables instead of Gaussian variables.The probabilities of the categorical distributions are computed using the Softmax function.
In contrast to this, we base our world model on a VQ-VAE [25] and thus discretize the latent space using vector quantization.This means that we use a dictionary of embedding vectors to discretize the feature representations via a nearest neighbor look-up.These embedding vectors are also learned during training and can contain richer information than class scores.Furthermore, with our approach we can easily increase the number of discrete categories without increasing the complexity of the encoder and the decoder, since it only affects the number of embedding vectors.Another difference of our approach to DreamerV2 is that the latent space is a 2-dimensional matrix instead of a 1-dimensional vector, since we do not flatten the output of the encoder.Therefore, we employ a convolutional LSTM instead of a gated recurrent unit (GRU; 5) to compute the recurrent states.
They show that their agent beats model-free algorithms in many Atari games after 50 M interactions.We attempt to learn as much as possible from only 100 K interactions.Unfortunately, DreamerV2 was developed concurrently with our work.

MuZero [18]
One of the first successful applications of planning at decision-time to visually complex environments like Atari 2600 games has been accomplished by the MuZero algorithm.Without prior knowledge of the environment's dynamics, a trained world model looks ahead via Monte-Carlo tree search.Their results are remarkable, but at the cost of long training times and large models.

Conclusion
We have successfully shown that it is possible to build a world model with much fewer parameters than previous model-based reinforcement learning approaches, while performing comparably.To achieve this, we employ a VQ-VAE with a two-dimensional discrete latent space, a convolutional LSTM as the dynamics network, and a convolutional neural network for parameterizing the policy.This setup is able to effectively encode the complex visual observations of the Atari environments into representations, on which a policy is learned.Furthermore, acting entirely in latent space speeds up training, since we avoid the computational costs of decoding high-dimensional frames.Detailed experiments including several ablation studies confirm our design choices.
In the future, we would like to predict the end of episodes, instead of terminating episodes after a fixed number of time steps (since the latter prevents the agent from learning beyond this time horizon).Similar to DreamerV2 [8], one could learn a discount factor that downweights the rewards instead of actually ending the episodes abruptly.Moreover, we train the encoder and the dynamics model separately, since this leads to more control over the individual components and makes hyperparameter tuning easier.However, the prediction of rewards could potentially be improved by incorporating it into the training of the encoder.
Another important question is how to improve exploration in order to get higher sample efficiency.Currently, we increase the randomness of our policy in Eq. 19 to introduce exploration.Other ideas could incorporate knowledge of the world model, e.g., by measuring the novelty of observations.[17] train a dynamics model and use the prediction error as an intrinsic reward.Such an approach would be a good match to our world model since it is inherently able to compute the likelihood of the next observations.However, we will leave that to future work.

Fig. 1
Fig. 1 Overview of our approach.A VQ-VAE encodes game frames into discrete latent representations, which are fed into the policy to determine the next actions.The dynamics model is conditioned on the latent representation and the action, and predicts the reward and the next latent representation using a convolutional LSTM with hidden states.The architecture of the individual components is described in detail in Sect.3.2

Fig. 3
Fig. 3 Graphical models of the recurrent world model and the policy

Fig. 4
Fig. 4 Observation encoder network (top) and observation decoder network (bottom) putes the logits of 96 × 96 × 4 continuous Bernoulli distributions for the decoder p φ (s | z), with independence among the stacked frames and pixels.Our motivation for a two-dimensional latent representation is that it is able to maintain spatial correlations.In combination with convolutional operations, these correlations can be respected when predicting the next time step.

Fig. 12 Fig. 13
Fig. 12 Snapshots of bad sequences produced by our dynamics models

Table 1
Summary of the models Observation encoder p φ (z t | s t ) φ (h t | z t , a t , h t−1 ) φ (z t+1 | z t , h t , a t ) Reward predictor p φ (r t | z t , h t , a t ) Terminal predictor p φ (d t | z t , h t , a t ) t | z t ) or π θ (a t | z t , h t−1 ) p φ (s t+1 , r t , d t , h t | s t , a t , h t−1 ) ≈ p(s t+1 , r t , d t | s t , a t ) that potentially is better at dealing with partial observability.

Table 3
[11] final episode reward and human normalized score of our method compared with a random agent, the human baseline, and PPO at 10M Human normalized score (HNS) of our models and SimPLe[11]on a symmetrical logarithmic scale

Table 5
Summary of the hyperparameters

Table 6
Mean final episode reward of the model configurations

Table 7
Mean final episode reward of our models when the policy is conditioned on the reconstructed frames (without frame stacking) compared with the default models, averaged over five runs

Table 8
Mean final episode reward of our autoregressive world model (without frame stacking), averaged over three training runs, compared with the default models which independently samples the indices of the next latent variable