google-research / robotics_transformer

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

The inference results with trained checkpoint (RT1main) look so weired.

ka2hyeon opened this issue · comments

After downloading rt1 dataset here, I did inference the dataset with your checkpoint (RT1 main). However, the inference result is far from that in the dataset. When plotted, the shape of the trajectories looks similar between the dataset and inference, but scale shift and drift are observed in the inference. Even more, the training loss reduces greatly when I additionally train the checkpoint with the dataset. It seems that your checkpoint is trained with the different dataset from what you shared, or maybe I am doing wrong inference.

I appreciate your efforts of large-scale robot learning, and I want to reproduce your result. Could you give me any comments about the mismatch between the inference result and GT in your dataset?

Here is psuedo-code of what I did:

# Data source loading
import tensorflow_datasets as tfds
builder = tfds.builder_from_directory(builder_dir=[RT1 data dir])
ds = builder.as_data_source(split=split, decoders=tfds.decode.SkipDecoding())
example = tf.train.Example.FromString(ds.data_source[episode_id])
features = example.features.feature

images = tf.io.decode_image(features["steps/observations/images"].bytes_list.value)
rotation_delta = tf.constant(features["steps/acitons/rotation_delta"].float_list.value, dtype=np.float32)
...
observations = {'images': images, ...}
actions_gt = {'actions': rotation_delta , ...}

# Inference
from tf_agents.trajectories.time_step import TimeStep
time_step = TimeStep(
        observation=observations,
        ...
    )
policy = tf.saved_model.load(pb_path_of_rt1main)
policy_state = policy.get_initial_state(batch_size=1)
action_step = policy.action(time_step, policy_state)

# !! action_step is different_from actions_gt

Sorry to bother you, but I wonder how do you successfully load the data. I also downloaded the data, but the file names in the ds.data_source are not included inside the dataset, eg. 'RT_1_paper_release-train.array_record-00000-of-01024'. Do you know how to fix it? Thanks!

Sorry to bother you, but I wonder how do you successfully load the data. I also downloaded the data, but the file names in the ds.data_source are not included inside the dataset, eg. 'RT_1_paper_release-train.array_record-00000-of-01024'. Do you know how to fix it? Thanks!

Refer to ipynb for loading dataset.