NVlabs / trajdata

A unified interface to many trajectory forecasting datasets.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Issues with SceneBatch.to_agent

bmacadam-sfu opened this issue · comments

Hi there,

I caught the following errors in the to_agent method for SceneBatch while trying to visualize some scenes:

  1. The index_neighbors function does not preserve StateArrays/StateTensors. This can be fixed by checking if a statetensor is passed and using its formatting.
        def index_neighbors(x: Tensor | StateTensor) -> Tensor | StateTensor:
            index_neighbors = x[others_mask].reshape([batch_size, num_agents-1]+list(x.shape[2:]))
            if isinstance(x, StateTensor):
                index_neighbors = StateTensor.from_array(index_neighbors, x._format)
            return index_neighbors
  1. The index_agent function doesn't play well with map names, this can be fixed by wrapping each map_name into a singleton list:
            map_names=index_agent_list([[m] for m in self.map_names]),

Of course, for the index_neighbors problem, you would presumably want to fix the underlying issues with reshape and StateArrays/StateTensors.