[BUG] Feature name in StateAugmentation
aziabatz opened this issue · comments
Describe the bug
The code responsible for data augmentation assumes that the environment provides a locs
key in the given TensorDict. This assumption leads to a KeyError when the environment does not contain a locs
key within the dict.
To Reproduce
To reproduce the behavior, just train with a model that uses StateAugmentation (ie. POMO) over an environment that does not offer the key mentioned above. Below is the stack trace encountered when this condition is met:
Traceback (most recent call last):
[...]
File "c:\Users\...\hpcenv\Lib\site-packages\lightning\pytorch\trainer\call.py", line 309, in _call_strategy_hook
output = fn(*args, **kwargs)
File "c:\Users\...\hpcenv\Lib\site-packages\lightning\pytorch\strategies\strategy.py", line 412, in validation_step
return self.lightning_module.validation_step(*args, **kwargs)
File "c:\Users\...\hpcenv\Lib\site-packages\rl4co\models\rl\common\base.py", line 257, in validation_step
return self.shared_step(
File "c:\Users\...\hpcenv\Lib\site-packages\rl4co\models\zoo\pomo\model.py", line 90, in shared_step
td = self.augment(td)
File "c:\Users\...\hpcenv\Lib\site-packages\rl4co\data\transforms.py", line 131, in __call__
aug_feat = self.augmentation(td_aug[feat], self.num_augment)
File "c:\Users\...\hpcenv\Lib\site-packages\tensordict\base.py", line 239, in __getitem__
result = self._get_tuple(idx_unravel, NO_DEFAULT)
File "c:\Users\...\hpcenv\Lib\site-packages\tensordict\_td.py", line 1577, in _get_tuple
first = self._get_str(key[0], default)
File "c:\Users\...\hpcenv\Lib\site-packages\tensordict\base.py", line 2234, in _default_get
raise KeyError('key "locs" not found in TensorDict with keys [...]')
System info
- Installation method: pip (virtualenv)
- RL4CO: 0.3.1
- PyTorch: 2.2.0+cpu
- PyTorch Lightning: 2.2.0.post0
- TorchRL: 0.3.0
- TensorDict: 0.3.0
- Numpy: 1.26.4
- Python: 3.11.8 (tags/v3.11.8:db85d51, Feb 6 2024, 22:03:32) [MSC v.1937 64 bit (AMD64)]
- Platform: win32
Reason and Possible fixes
The actual issue originates from the StateAugmentation initialization. When the feats argument is None
, the constructor calls env_aug_feats
to determine the appropriate feature names based on the environment name (presumably, because it receives the env_name
argument). However, env_aug_feats
currently ignores the env_name
parameter and unconditionally returns ["locs"]
, assuming this dict key is present in all environments.
A solution could be modifying env_aug_feats
to select the feature names from a dictionary based on the environment name, as done in some parts of the project. Here you have a diff with the proposed changes:
--- a/rl4co/data/transforms.py
+++ b/rl4co/data/transforms.py
@@ -1,4 +1,14 @@
+env_features= {
+ "tsp": ["locs"],
+ "whatever": ["feature"],
+ # ....
+}
+
def env_aug_feats(env_name: str = None):
- return ["locs"]
+ return env_features.get(env_name, ["locs"])
Another suggestion is to modify the constructors of those models, such as POMO, to allow passing the list of features (feats
) directly to StateAugmentation, without relying solely on a dictionary. Here is how it should look:
--- a/rl4co/models/zoo/pomo/model.py
+++ b/rl4co/models/zoo/pomo/model.py
@@ -10,6 +10,7 @@
policy: Union[nn.Module, POMOPolicy] = None,
policy_kwargs={},
baseline: str = "shared",
+ feats=None,
num_augment: int = 8,
use_dihedral_8: bool = True,
num_starts: int = None,
@@ -26,7 +27,7 @@
self.num_augment = num_augment
if self.num_augment > 1:
self.augment = StateAugmentation(
- self.env.name, num_augment=self.num_augment, use_dihedral_8=use_dihedral_8
+ self.env.name, num_augment=self.num_augment, use_dihedral_8=use_dihedral_8, feats=feats
)
else:
self.augment = None
Checklist
- I have checked that there is no similar issue in the repo (required)
- I have provided a minimal working example to reproduce the bug (required)
Hi @aziabatz, thanks for your suggestion about the improvement of the augmentation function. We refactored the StateAugmentation()
class and cleaned up some redundant code.
Now the class has a clear augmentation function and features selection. Users could use the symmetric or dihedral8 augmentation function, or their own augmentation functions. And customize the feature to augment.