Farama-Foundation / PettingZoo

An API standard for multi-agent reinforcement learning environments, with popular reference environments and related utilities

Home Page:https://pettingzoo.farama.org

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[Bug Report] API test failing for nested observations

dm-ackerman opened this issue · comments

Describe the bug

(Reported by someone else on discord)

API test is failing on an env with nested observation (dict of dict)
example env to reproduce is below

>>> env = example.env()
>>> api_test(env, num_cycles=1000, verbose_progress=False)
...
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/opt/home/code/PettingZoo/pettingzoo/test/api_test.py", line 592, in api_test
    play_test(env, observation_0, num_cycles)
  File "/opt/home/code/PettingZoo/pettingzoo/test/api_test.py", line 479, in play_test
    == prev_observe["observation"].dtype
AttributeError: 'collections.OrderedDict' object has no attribute 'dtype'

Code example

import functools

import numpy as np
from gymnasium import spaces

from pettingzoo import AECEnv
from pettingzoo.utils import agent_selector


class env(AECEnv):
    def __init__(self, render_mode=None):
        self.possible_agents = [f"player_{i}" for i in range(2)]

    @functools.lru_cache(maxsize=None)
    def observation_space(self, agent):
        return spaces.Dict(
            {
                "observation": spaces.Dict(
                    {
                        "nested_item1": spaces.Box(
                            low=0,
                            high=1,
                            shape=(2,),
                            dtype=bool,
                        ),
                        "nested_item2": spaces.MultiDiscrete([3, 5], dtype=np.int8),
                    }
                ),
                "action_mask": spaces.MultiBinary((25,)),
            }
        )

    @functools.lru_cache(maxsize=None)
    def action_space(self, agent):
        return spaces.Discrete(25)

    def render(self):
        pass

    def observe(self, agent):
        return self.observation_space(agent).sample()

    def close(self):
        pass

    def reset(self, seed=None, options=None):
        self.agents = self.possible_agents[:]
        self.rewards = {agent: 0 for agent in self.agents}
        self._cumulative_rewards = {agent: 0 for agent in self.agents}
        self.terminations = {agent: False for agent in self.agents}
        self.truncations = {agent: False for agent in self.agents}
        self.infos = {agent: {} for agent in self.agents}
        self._agent_selector = agent_selector(self.agents)
        self.agent_selection = self._agent_selector.next()

    def step(self, action):
        return

System info

>>> import sys; sys.version
'3.9.12 (main, Apr  5 2022, 06:56:58) \n[GCC 7.5.0]'
>>> pettingzoo.__version__
'1.24.3'

Additional context

This may be partially related to #1111 which seemed to have nested dicts - but the code isn't available any more so not sure. Regardless, I think this is distinct as numpy isn't the problem.

Checklist

  • I have checked that there is no similar issue in the repo

I have a simple fix for this that skips that assert when there are nested dicts, but I want to see if I have time to fix this correctly by recursing the dicts.
Either way, I'll submit a PR to fix it today or tomorrow. Wanted to add an issue so code to reproduce is available.

I pushed a PR for this to recursively check compatibility in nested dicts. It will pass with the example above.

To trigger a failure of the test, use the above example code but replace the observe() function with the following:

    def observe(self, agent):
        rval =  self.observation_space(agent).sample()
        rval['observation']["nested_item2"]=np.ones(2)        
        return rval

that gives:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/opt/home/code/PettingZoo/pettingzoo/test/api_test.py", line 614, in api_test
    play_test(env, observation_0, num_cycles)
  File "/opt/home/code/PettingZoo/pettingzoo/test/api_test.py", line 499, in play_test
    _test_observation_space_compatibility(
  File "/opt/home/code/PettingZoo/pettingzoo/test/api_test.py", line 407, in _test_observation_space_compatibility
    _test_observation_space_compatibility(
  File "/opt/home/code/PettingZoo/pettingzoo/test/api_test.py", line 407, in _test_observation_space_compatibility
    _test_observation_space_compatibility(
  File "/opt/home/code/PettingZoo/pettingzoo/test/api_test.py", line 411, in _test_observation_space_compatibility
    assert (
AssertionError: dtype for observation at [observation][nested_item2] is float64, but observation space specifies int8.

Also,
The revised test flagged a problem with an example env. I pushed a separate PR to fix that. That needs to be pushed in order for the fix for this problem to pass the tests.