AZ much worse than generic solution for simple game
70Gage70 opened this issue · comments
I'm trying to train AZ on single-player 21. You have a shuffled deck of cards and at each step you either "take" a card (and add its value to your total, such that Ace = 1, 2 = 2 ... face cards = 10) or "stop" and receive your current total. The obvious strategy would be to take if the expected value of a draw would leave your total <=21, and stop otherwise. This gives an average reward of roughly 14. I defined the game and used the exact training parameters from gridworld.jl and this is the result:
I don't understand why (i) the rewards are much less than 14 and (ii) why AZ is worse than the network.
Game
using AlphaZero
using CommonRLInterface
const RL = CommonRLInterface
import Random as RNG
# using StaticArrays
# using Crayons
const NONFACE_CARDS = [i for j = 1:4 for i = 1:10]
const FACE_CARDS = [10 for j = 1:4 for i = 1:3]
const STANDARD_DECK = map(UInt8, vcat(NONFACE_CARDS, FACE_CARDS))
### MANDATORY INTERFACE
# state = "what the player should look at"
mutable struct Env21 <: AbstractEnv
deck::Vector{UInt8}
state::UInt8 # points
reward::UInt8
terminated::Bool
end
function RL.reset!(env::Env21)
env.deck = RNG.shuffle(STANDARD_DECK)
env.state = 0
env.reward = 0
env.terminated = false
return nothing
end
function Env21()
deck = RNG.shuffle(STANDARD_DECK)
state = 0
reward = 0
terminated = false
return Env21(deck, state, reward, terminated)
end
RL.actions(env::Env21) = [:take, :stop]
RL.observe(env::Env21) = env.state
RL.terminated(env::Env21) = env.terminated
function RL.act!(env::Env21, action)
if action == :take
draw = popfirst!(env.deck)
env.state += draw
if env.state >= 22
env.reward = 0
env.state = 0 ######################### okay?
env.terminated = true
end
elseif action == :stop
env.reward = env.state
env.terminated = true
else
error("Invalid action $action")
end
return env.reward
end
### TESTING
# env = Env21()
# reset!(env)
# rsum = 0.0
# while !terminated(env)
# global rsum += act!(env, rand(actions(env)))
# end
# @show rsum
### MULTIPLAYER INTERFACE
RL.players(env::Env21) = [1]
RL.player(env::Env21) = 1
### Optional Interface
RL.observations(env::Env21) = map(UInt8, collect(0:21))
RL.clone(env::Env21) = Env21(copy(env.deck), copy(env.state), copy(env.reward), copy(env.terminated))
RL.state(env::Env21) = env.state
RL.setstate!(env::Env21, new_state) = (env.state = new_state)
RL.valid_action_mask(env::Env21) = BitVector([1, 1])
### AlphaZero Interface
function GI.render(env::Env21)
println(env.deck)
println(env.state)
println(env.reward)
println(env.terminated)
return nothing
end
function GI.vectorize_state(env::Env21, state)
v = zeros(Float32, 22)
v[state + 1] = 1
return v
end
const action_names = ["take", "stop"]
function GI.action_string(env::Env21, a)
idx = findfirst(==(a), RL.actions(env))
return isnothing(idx) ? "?" : action_names[idx]
end
function GI.parse_action(env::Env21, s)
idx = findfirst(==(s), action_names)
return isnothing(idx) ? nothing : RL.actions(env)[idx]
end
function GI.read_state(env::Env21)
return env.state
end
GI.heuristic_value(::Env21) = 0.
GameSpec() = CommonRLInterfaceWrapper.Spec(Env21())
Canonical strategy
import Random as RNG
const NONFACE_CARDS = [i for j = 1:4 for i = 1:10]
const FACE_CARDS = [10 for j = 1:4 for i = 1:3]
const STANDARD_DECK = map(UInt8, vcat(NONFACE_CARDS, FACE_CARDS))
function mc_run()
deck = RNG.shuffle(STANDARD_DECK)
score = 0
while true
expected_score = score + sum(STANDARD_DECK)/length(deck)
if expected_score >= 22
return score
else
score = score + popfirst!(deck)
if score >= 22
return 0
end
end
end
end
function mc(n_trials)
score = 0
for i = 1:n_trials
score = score + mc_run()
end
return score/n_trials
end
mc(10000)
I don't have time to look too deeply but here are a few remarks:
- AZ not learning a good policy with default hyperparameters is not necessarily a red flag in itself, even for simple games. AZ can not be used as a black box in general and tuning is important.
- The MCTS policy being worse than the network policy is more surprising. Admittedly, I've not tested AZ.jl on a lot of stochastic environments but there may be subtlelties and rough edges here.
- In particular, there are many ways to handle stochasticity in MCTS with differentt tradeoffs. The current MCTS implementation is an open-loop MCTS implementation, which if I remember correctly deals ok with light stochasticity but can struggle with highly stochastic environments.
My advice:
- Try and benchmark pure MCTS (with rollouts). If it does terrible even with a lot of search, then there may be a bug in the MCTS implementation or AZ's MCTS implementation may not be suited to your game.
- Do not hesitate to make the environment smaller in your tests, for example by having very small decks.
Thanks for the tips, appreciate it. I'll certainly take a look at the MCTS