tingnapianhai / RL-FlappyBird

Using reinforcement learning to train FlappyBird.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

RL Flappy Bird


This project is a basic application of Reinforcement Learning.

It integrates Deep Java Library (DJL) to uses DQN to train agent. The pretrained model are trained with 3M steps on a single GPU.

Build the project and run

This project supports building with Maven, you can use the following command to build:

mvn compile  

The following command will start to train without graphics:

mvn exec:java -Dexec.mainClass="com.kingyu.rlbird.ai.TrainBird"

The above command will train from scratch. You can also try to train with the pretrained weight:

mvn exec:java -Dexec.mainClass="com.kingyu.rlbird.ai.TrainBird" -Dexec.args="-p"

To test with the model directly, you can do the followings

mvn exec:java -Dexec.mainClass="com.kingyu.rlbird.ai.TrainBird" -Dexec.args="-p -t"  
Argument Comments
-g Training with graphics.
-b Batch size to use for training.
-p Use pre-trained weights.
-t Test the trained model.

Deep Q-Network Algorithm

The pseudo-code for the Deep Q Learning algorithm, as given in Human-level Control through Deep Reinforcement Learning. Nature, can be found below:

Initialize replay memory D to size N
Initialize action-value function Q with random weights
for episode = 1, M do
    Initialize state s_1
    for t = 1, T do
        With probability ϵ select random action a_t
        otherwise select a_t=max_a  Q(s_t,a; θ_i)
        Execute action a_t in emulator and observe r_t and s_(t+1)
        Store transition (s_t,a_t,r_t,s_(t+1)) in D
        Sample a minibatch of transitions (s_j,a_j,r_j,s_(j+1)) from D
        Set y_j:=
            r_j for terminal s_(j+1)
            r_j+γ*max_(a^' )  Q(s_(j+1),a'; θ_i) for non-terminal s_(j+1)
        Perform a gradient step on (y_j-Q(s_j,a_j; θ_i))^2 with respect to θ
    end for
end for


Trained Model

  • It may take 10+ hours to train a bird to a perfect state. You can find the model trained with three million steps in project resource folder: src/main/resources/model/dqn-trained-0000-params


This work is based on the following repos:


MIT © Kingyu Luk


Using reinforcement learning to train FlappyBird.

License:MIT License


Language:Java 100.0%