NVlabs / trajdata

A unified interface to many trajectory forecasting datasets.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Problem loading cache in `UnifiedDataset` `load_or_create_cache()` function.

aallaire91 opened this issue · comments

Hello! I noticed an issue with loading a cache using the load_or_create_cache() function in UnifiedDataset. In the following code snippet, you can see the keep_ids is not defined when cache_path exists, as is the case when a cache has already been created and you are just trying to load it. After the if statement, it is expected that keep_ids exists in order to remove the undesired entries from the data index. See the line self.remove_elements(keep_ids=keep_ids). It seems like keep_mask, which is the one of the outputs of dill.load(f,encoding="latin1"), should be renamed to keep_ids in order to fix this issue.

def load_or_create_cache(
        self, cache_path: str, num_workers=0, filter_fn=None
    ) -> None:
        if isfile(cache_path):
            print(f"Loading cache from {cache_path} ...", end="")
            t = time.time()
            with open(cache_path, "rb") as f:
                self._cached_batch_elements, keep_mask = dill.load(f, encoding="latin1")
            print(f" done in {time.time() - t:.1f}s.")

        else:
            # Build cache
            cached_batch_elements = []
            keep_ids = []

            if num_workers <= 0:
                cache_data_iterator = self
            else:
                # Use DataLoader as a generic multiprocessing framework.
                # We set batchsize=1 and a custom collate function.
                # In effect this will just call self.__getitem__ in parallel.
                cache_data_iterator = DataLoader(
                    self,
                    batch_size=1,
                    num_workers=num_workers,
                    shuffle=False,
                    collate_fn=lambda xlist: xlist[0],
                )

            for element in tqdm(
                cache_data_iterator,
                desc=f"Caching batch elements ({num_workers} CPUs): ",
                disable=False,
            ):
                if filter_fn is None or filter_fn(element):
                    cached_batch_elements.append(element)
                    keep_ids.append(element.data_index)

            # Just deletes the variable cache_data_iterator,
            # not self (in case it is set to that)!
            del cache_data_iterator

            print(f"Saving cache to {cache_path} ....", end="")
            t = time.time()
            with open(cache_path, "wb") as f:
                dill.dump((cached_batch_elements, keep_ids), f)
            print(f" done in {time.time() - t:.1f}s.")

            self._cached_batch_elements = cached_batch_elements

        # Remove unwanted elements
        self.remove_elements(keep_ids=keep_ids)

        # Verify
        if len(self._cached_batch_elements) != self._data_len:
            raise ValueError("Current data and cached data lengths do not match!")