RobertTLange / gymnax

RL Environments in JAX 🌍

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Four Rooms (Sutton et al. 1999) environment

RobertTLange opened this issue · comments

Implement the classic four rooms environment. Start with old numpy implementation from HRL MSc thesis:

import numpy as np
import copy

# Action definitions
RIGHT = 0
UP    = 1
LEFT  = 2
DOWN  = 3


class RoomWorld():
    """The environment for Sutton's semi-MDP HRL.
    """
    def __init__(self, goal_position=[7,9], env_noise=0.1):
        """Map of the rooms. -1 indicates wall, 0 indicates hallway,
           positive numbers indicate numbered rooms
        """
        self.numbered_map = np.array([
        [-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1],
        [-1, 1, 1, 1, 1, 1,-1, 2, 2, 2, 2, 2,-1],
        [-1, 1, 1, 1, 1, 1,-1, 2, 2, 2, 2, 2,-1],
        [-1, 1, 1, 1, 1, 1, 0, 2, 2, 2, 2, 2,-1],
        [-1, 1, 1, 1, 1, 1,-1, 2, 2, 2, 2, 2,-1],
        [-1, 1, 1, 1, 1, 1,-1, 2, 2, 2, 2, 2,-1],
        [-1,-1, 0,-1,-1,-1,-1, 2, 2, 2, 2, 2,-1],
        [-1, 3, 3, 3, 3, 3,-1,-1,-1, 0,-1,-1,-1],
        [-1, 3, 3, 3, 3, 3,-1, 4, 4, 4, 4, 4,-1],
        [-1, 3, 3, 3, 3, 3,-1, 4, 4, 4, 4, 4,-1],
        [-1, 3, 3, 3, 3, 3, 0, 4, 4, 4, 4, 4,-1],
        [-1, 3, 3, 3, 3, 3,-1, 4, 4, 4, 4, 4,-1],
        [-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1]])
        self.walkability_map = np.array([
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0],
        [0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0],
        [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
        [0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0],
        [0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0],
        [0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0],
        [0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0],
        [0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0],
        [0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0],
        [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
        [0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
        self.state_space   = np.argwhere(self.walkability_map)
        self.action_space  = np.arange(4)
        self.goal_position = np.array(goal_position)
        self.action_success_rate = 1 - env_noise
        self.agents = [] # agents affect each other's observations, so should be included
        # Rewards
        self.step_reward      = 0.0 #-0.1 (Sutton used 0 and depended on discounting effect of gamma to push toward more efficient policies)
        self.collision_reward = 0.0 # was -0.1 at first, but spending a
                                    # timestep without moving is a penalty
        self.goal_reward      = 1.#10.
        self.invalid_plan_reward = 0.0#-10.


    def add_agent(self,agent):
        """Adds an agent to the environment after giving it an identifier
        """
        agent.sebango = len(self.agents) + 2
        self.agents.append(agent)

    def move_agent(self,direction,sebango=2):
        """Attempts moving an agent in a specified direction.
           If the move would put the agent in a wall, the agent remains
           where he is and is given a negative reward value.
        """
        agent  = self.agents[sebango-2]
        new_pos = agent.move(direction)
        if self.walkability_map[tuple(new_pos)].all():
            agent.set_position(new_pos)
            collision = False
        else:
            collision = True
        return collision

    def evaluate_reward(self,sebango=2,collision=False):
        """Calculates the reward to be given for the current timestep after an
           action has been taken.
        """
        agent  = self.agents[sebango-2]
        reward = self.step_reward
        done   = False
        if collision:
            reward += self.collision_reward
        if (agent.get_position() == self.goal_position).all():
            reward += self.goal_reward
            done = True
        return reward, done

    def get_observation_map(self):
        """Returns the observation of the current state as a walkability map
           with agents (sebango) and goal position (-1) labeled
        """
        obs = copy.copy(self.walkability_map)
        for ag in self.agents:
            obs[tuple(ag.get_position())] = ag.sebango
        obs[tuple(self.goal_position)] = -1
        return obs

    def get_observation_pos(self,sebango):
        """Returns the observation of the current state as the position of the
           agent indicated by sebango.
           Assumes single agent and static goal location so only need agent pos
        """
        return self.agents[sebango-2].get_position()

    def step(self,direction,sebango=2):
        """Takes one timestep with a specific direction.
           Only deals with primitive actions.
           Determines the actual direction of motion stochastically
           Determines the reward and returns reward and observation.
           Observation is the walkability map + other info:
             - the agent indicated by its sebango (a number 2 or greater)
             - The goal is indicated as -1 in the observation map.
        """
        roll   = np.random.random()
        sr = self.action_success_rate
        fr = 1.0 - sr
        if roll <= sr:
            coll = self.move_agent(direction,sebango)
        elif roll <= sr+fr/3.:
            coll = self.move_agent((direction+1)%4,sebango)
        elif roll <= sr+fr*2./3.:
            coll = self.move_agent((direction+2)%4,sebango)
        else:
            coll = self.move_agent((direction+3)%4,sebango)
        obs = self.get_observation_pos(2)
        reward, done = self.evaluate_reward(sebango, collision=coll)
        return obs, reward, done

    def reset(self, random_placement=False):
        """Resets the state of the world, putting all registered  agents back
           to their initial positions (positions set at instantiation),
           unless random_placement = True
        """
        if random_placement:
            random_index     = np.random.randint(low=0,
                    high=self.state_space.shape[0],size=len(self.agents))
            for i,ag in enumerate(self.agents):
                ag.set_position(self.state_space[random_index[i]])
        else:
            for ag in self.agents:
                ag.set_position(ag.initial_position)
        obs = self.get_observation_pos(2)    # CURRENTLY ASSUMING ONE AGENT!
        return obs