NVlabs / trajdata

A unified interface to many trajectory forecasting datasets.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Get map data from Trajdata

Leon0402 opened this issue · comments

Hi,

I have some uses cases where I need the map data beforehand for a specific scene. Here are two code examples:

def load_map(dataset: UnifiedDataset, scene: Scene) -> GeometricMap:
    scene_cache = dataset.cache_class(dataset.cache_path, scene, 0, dataset.augmentations)

    maps_path: Path = DataFrameCache.get_maps_path(scene_cache.path, scene_cache.scene.env_name)
    metadata_file: Path = maps_path / f"{scene_cache.scene.location}_metadata.dill"
    with open(metadata_file, "rb") as f:
        map_info: MapMetadata = dill.load(f)

    map_file: Path = maps_path / f"{map_info.name}.zarr"
    disk_data = zarr.open_array(map_file, mode="r")

    # Map data is [layers, height, width], trajectron expects it as [layers, width, height]
   # Note here 
    map_data = (disk_data[:] * 255.0).astype(np.uint8)
    map_data = np.swapaxes(map_data, 1, 2)

    return GeometricMap(data=map_data, homography=map_info.map_from_world, description=map_info.layers) 
def load_map(scene_cache):
    maps_path: Path = DataFrameCache.get_maps_path(scene_cache.path, scene_cache.scene.env_name)
    metadata_file: Path = maps_path / f"{scene_cache.scene.location}_metadata.dill"
    with open(metadata_file, "rb") as f:
        map_info: MapMetadata = dill.load(f)

    map_file: Path = maps_path / f"{map_info.name}.zarr"
    disk_data = zarr.open_array(map_file, mode="r")

    map_from_world = map_info.map_from_world
    min_x, min_y, _ = np.rint(
        map_from_world @ (scene_cache.scene_data_df['x'].min() - 50, scene_cache.scene_data_df['y'].min() - 50, 1)).astype(int)
    max_x, max_y, _ = np.rint(
        map_from_world @ (scene_cache.scene_data_df['x'].max() + 50, scene_cache.scene_data_df['y'].max() + 50, 1)).astype(int)

    # Map data is [layers, height, width], trajectron expects it as [layers, width, height]
    map_data = (disk_data[..., min_y:max_y, min_x:max_x] * 255.0).astype(np.uint8)
    map_data = np.swapaxes(map_data, 1, 2)

    map_from_world = (np.array([[1.0, 0.0, -min_x], [0.0, 1.0, -min_y], [0.0, 0.0, 1.0]]) @ map_from_world)

    rgb_groups = map_info.layer_rgb_groups
    map_data_plot = np.stack([
        np.amax(map_data[rgb_groups[0]], axis=0),
        np.amax(map_data[rgb_groups[1]], axis=0),
        np.amax(map_data[rgb_groups[2]], axis=0),
    ])

    return {
        "PEDESTRIAN": GeometricMap(data=map_data, homography=map_from_world, description=map_info.layers),
        "VISUALIZATION": GeometricMap(data=map_data_plot, homography=map_from_world)
    }

I find it a little bit difficult to read the map data and would love to have a nicer interface for this! Can something be added to make this simpler?

Edit: I know there a methods like load_map_patch, but they were difficult to use in the above methods. Mainly because you have to specify a center and it also does padding.

Hi @Leon0402, I highly recommend you take a look through the new version of trajdata. I've overhauled the map interface so now there's a unified vectorized representation which should be easier to work with.

Hi @BorisIvanovic I already had a look at new changes, but I don't see how they help here. I found no methods, which would simply above code.

But maybe I overlooked something. Can you give an example?

Ahh sorry, I thought you might have been looking for more detailed information in the map. If you're looking to get the entire rasterized maps, what you're doing above makes sense 👍

I don't think there's much I can do to simplify the first function, however I would think that load_map_patch covers a good portion of the second function, since you are already getting the patch bounds as .min() - 50 and .max() + 50, the average of which you could pass as the world center and the differences of which you can pass as the patch size.

As for offset_xy, if you don't want any offset, you could always pass in (0, 0) (this will center the patch at your spcified world center).

@BorisIvanovic Do you mean there is not much we can do to simply the first function with what Trajdata offers currently or in general?

In general I could think of a few ways to improve the first and second method:

  • Remove need for accessing internals. This would probably make necessary to decide what is internal in the first place. The complete scene cache seems to be internal right now. So there are a few options here:
    • Make the interface of UnifiedDataset more functional by adding methods to access map data (and other data)
    • Add some new interface like UnifiedScene, which can be constructed through UnifiedDataset and add some methods to it
    • Consider SceneCache as the public interface. In this case there should be at least a method to construct the cache in a more friendly way. Like dataset.access_scene_cache(scene). Additionally the SceneCache should then offer more methods
  • A Method to load map metadata directly without constructing paths
  • A Method to load map data. Perhaps load_map_patch can be refactored a little bit to provide default values and such stuff for easier usage. And perhaps some overload with patch bounds rather than centers or to access the whole map. Perhaps also make padding optional.

@Leon0402 Your interpretation is correct: "there is not much we can do to simply the first function with what Trajdata offers currently"

As for your later points, I agree. We're actually internally building something like a MapAPI object whose goal is to provide a nice interface to access map data without the requirement to access internal class members like now (sorry for the state of it now...). Keep an eye out for a future update when I bring it to this repo! 👌