[Bug Report] API test failing for nested observations
dm-ackerman opened this issue · comments
Describe the bug
(Reported by someone else on discord)
API test is failing on an env with nested observation (dict of dict)
example env to reproduce is below
>>> env = example.env()
>>> api_test(env, num_cycles=1000, verbose_progress=False)
...
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/opt/home/code/PettingZoo/pettingzoo/test/api_test.py", line 592, in api_test
play_test(env, observation_0, num_cycles)
File "/opt/home/code/PettingZoo/pettingzoo/test/api_test.py", line 479, in play_test
== prev_observe["observation"].dtype
AttributeError: 'collections.OrderedDict' object has no attribute 'dtype'
Code example
import functools
import numpy as np
from gymnasium import spaces
from pettingzoo import AECEnv
from pettingzoo.utils import agent_selector
class env(AECEnv):
def __init__(self, render_mode=None):
self.possible_agents = [f"player_{i}" for i in range(2)]
@functools.lru_cache(maxsize=None)
def observation_space(self, agent):
return spaces.Dict(
{
"observation": spaces.Dict(
{
"nested_item1": spaces.Box(
low=0,
high=1,
shape=(2,),
dtype=bool,
),
"nested_item2": spaces.MultiDiscrete([3, 5], dtype=np.int8),
}
),
"action_mask": spaces.MultiBinary((25,)),
}
)
@functools.lru_cache(maxsize=None)
def action_space(self, agent):
return spaces.Discrete(25)
def render(self):
pass
def observe(self, agent):
return self.observation_space(agent).sample()
def close(self):
pass
def reset(self, seed=None, options=None):
self.agents = self.possible_agents[:]
self.rewards = {agent: 0 for agent in self.agents}
self._cumulative_rewards = {agent: 0 for agent in self.agents}
self.terminations = {agent: False for agent in self.agents}
self.truncations = {agent: False for agent in self.agents}
self.infos = {agent: {} for agent in self.agents}
self._agent_selector = agent_selector(self.agents)
self.agent_selection = self._agent_selector.next()
def step(self, action):
return
System info
>>> import sys; sys.version
'3.9.12 (main, Apr 5 2022, 06:56:58) \n[GCC 7.5.0]'
>>> pettingzoo.__version__
'1.24.3'
Additional context
This may be partially related to #1111 which seemed to have nested dicts - but the code isn't available any more so not sure. Regardless, I think this is distinct as numpy isn't the problem.
Checklist
- I have checked that there is no similar issue in the repo
I have a simple fix for this that skips that assert when there are nested dicts, but I want to see if I have time to fix this correctly by recursing the dicts.
Either way, I'll submit a PR to fix it today or tomorrow. Wanted to add an issue so code to reproduce is available.
I pushed a PR for this to recursively check compatibility in nested dicts. It will pass with the example above.
To trigger a failure of the test, use the above example code but replace the observe()
function with the following:
def observe(self, agent):
rval = self.observation_space(agent).sample()
rval['observation']["nested_item2"]=np.ones(2)
return rval
that gives:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/opt/home/code/PettingZoo/pettingzoo/test/api_test.py", line 614, in api_test
play_test(env, observation_0, num_cycles)
File "/opt/home/code/PettingZoo/pettingzoo/test/api_test.py", line 499, in play_test
_test_observation_space_compatibility(
File "/opt/home/code/PettingZoo/pettingzoo/test/api_test.py", line 407, in _test_observation_space_compatibility
_test_observation_space_compatibility(
File "/opt/home/code/PettingZoo/pettingzoo/test/api_test.py", line 407, in _test_observation_space_compatibility
_test_observation_space_compatibility(
File "/opt/home/code/PettingZoo/pettingzoo/test/api_test.py", line 411, in _test_observation_space_compatibility
assert (
AssertionError: dtype for observation at [observation][nested_item2] is float64, but observation space specifies int8.
Also,
The revised test flagged a problem with an example env. I pushed a separate PR to fix that. That needs to be pushed in order for the fix for this problem to pass the tests.