Issues with SceneBatch.to_agent
bmacadam-sfu opened this issue · comments
bmacadam-sfu commented
Hi there,
I caught the following errors in the to_agent method for SceneBatch while trying to visualize some scenes:
- The
index_neighbors
function does not preserve StateArrays/StateTensors. This can be fixed by checking if a statetensor is passed and using its formatting.
def index_neighbors(x: Tensor | StateTensor) -> Tensor | StateTensor:
index_neighbors = x[others_mask].reshape([batch_size, num_agents-1]+list(x.shape[2:]))
if isinstance(x, StateTensor):
index_neighbors = StateTensor.from_array(index_neighbors, x._format)
return index_neighbors
- The
index_agent
function doesn't play well with map names, this can be fixed by wrapping each map_name into a singleton list:
map_names=index_agent_list([[m] for m in self.map_names]),
bmacadam-sfu commented
Of course, for the index_neighbors
problem, you would presumably want to fix the underlying issues with reshape and StateArrays/StateTensors.