Plot specific scene from dataset
Leon0402 opened this issue · comments
For debugging purposes I wanted to plot the different scenes with their name, so I actually know which plot refers to which data in the original raw dataset.
I tried using the example here: https://github.com/nvr-avg/trajdata/blob/main/examples/scene_batch_example.py
But I saw no way how I get scene information (like the name) from the SceneBatch. So in the end I did a custom implementation like this:
scene_name = "json_data_2_aug_0"
scene_ts = 1
for scene in dataset.scenes():
if scene.name != scene_name:
continue
scene_cache = dataset.cache_class(dataset.cache_path, scene, 0, dataset.augmentations)
present_agents = scene.agent_presence[scene_ts]
agent_histories, _, _ = scene_cache.get_agents_history(scene_ts, present_agents, history_sec=(None, None))
agent_states = [scene_cache.get_state(agent.name, scene_ts) for agent in present_agents]
agent_futures, _, _ = scene_cache.get_agents_future(scene_ts, present_agents, future_sec=(None, None))
plot_scene_batch(agent_histories, agent_states, agent_futures, ouput_path / f"{scene.name}_ts_{scene_ts}.png")
I think I probably shouldn't use the cache class directly as these are implementation details? Is there any other way how I can get scene information when using the classical DataLoader Mechanism?
@BorisIvanovic Friendly ping :)
Hi @Leon0402! Have you tried out the latest updated version? It includes the scene IDs in the SceneBatch
object: https://github.com/NVlabs/trajdata/blob/main/src/trajdata/data_structures/batch.py#L141
If you're looking to get specific Scene
s to work with, you can use https://github.com/NVlabs/trajdata/blob/main/src/trajdata/dataset.py#L628 or call __getitem__
with specific indices (which you could obtain by looping through the data index (https://github.com/NVlabs/trajdata/blob/main/src/trajdata/dataset.py#L645), although that's technically an internal object 🤔