NVlabs / trajdata

A unified interface to many trajectory forecasting datasets.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Plot specific scene from dataset

Leon0402 opened this issue · comments

For debugging purposes I wanted to plot the different scenes with their name, so I actually know which plot refers to which data in the original raw dataset.

I tried using the example here: https://github.com/nvr-avg/trajdata/blob/main/examples/scene_batch_example.py

But I saw no way how I get scene information (like the name) from the SceneBatch. So in the end I did a custom implementation like this:

  scene_name = "json_data_2_aug_0"
  scene_ts = 1

  for scene in dataset.scenes(): 
      if scene.name != scene_name: 
          continue

      scene_cache = dataset.cache_class(dataset.cache_path, scene, 0, dataset.augmentations)

      present_agents = scene.agent_presence[scene_ts]
      agent_histories, _, _ = scene_cache.get_agents_history(scene_ts, present_agents, history_sec=(None, None))
      agent_states = [scene_cache.get_state(agent.name, scene_ts) for agent in present_agents]
      agent_futures, _, _ = scene_cache.get_agents_future(scene_ts, present_agents, future_sec=(None, None))

      plot_scene_batch(agent_histories, agent_states, agent_futures, ouput_path / f"{scene.name}_ts_{scene_ts}.png")

I think I probably shouldn't use the cache class directly as these are implementation details? Is there any other way how I can get scene information when using the classical DataLoader Mechanism?

@BorisIvanovic Friendly ping :)

Hi @Leon0402! Have you tried out the latest updated version? It includes the scene IDs in the SceneBatch object: https://github.com/NVlabs/trajdata/blob/main/src/trajdata/data_structures/batch.py#L141

If you're looking to get specific Scenes to work with, you can use https://github.com/NVlabs/trajdata/blob/main/src/trajdata/dataset.py#L628 or call __getitem__ with specific indices (which you could obtain by looping through the data index (https://github.com/NVlabs/trajdata/blob/main/src/trajdata/dataset.py#L645), although that's technically an internal object 🤔