ai4co / rl4co

A PyTorch library for all things Reinforcement Learning (RL) for Combinatorial Optimization (CO)

Home Page:https://rl4.co

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[BUG] Feature name in StateAugmentation

aziabatz opened this issue · comments

Describe the bug

The code responsible for data augmentation assumes that the environment provides a locs key in the given TensorDict. This assumption leads to a KeyError when the environment does not contain a locs key within the dict.

To Reproduce

To reproduce the behavior, just train with a model that uses StateAugmentation (ie. POMO) over an environment that does not offer the key mentioned above. Below is the stack trace encountered when this condition is met:

Traceback (most recent call last):
  [...]
  File "c:\Users\...\hpcenv\Lib\site-packages\lightning\pytorch\trainer\call.py", line 309, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "c:\Users\...\hpcenv\Lib\site-packages\lightning\pytorch\strategies\strategy.py", line 412, in validation_step
    return self.lightning_module.validation_step(*args, **kwargs)
  File "c:\Users\...\hpcenv\Lib\site-packages\rl4co\models\rl\common\base.py", line 257, in validation_step
    return self.shared_step(
  File "c:\Users\...\hpcenv\Lib\site-packages\rl4co\models\zoo\pomo\model.py", line 90, in shared_step
    td = self.augment(td)
  File "c:\Users\...\hpcenv\Lib\site-packages\rl4co\data\transforms.py", line 131, in __call__
    aug_feat = self.augmentation(td_aug[feat], self.num_augment)
  File "c:\Users\...\hpcenv\Lib\site-packages\tensordict\base.py", line 239, in __getitem__
    result = self._get_tuple(idx_unravel, NO_DEFAULT)
  File "c:\Users\...\hpcenv\Lib\site-packages\tensordict\_td.py", line 1577, in _get_tuple
    first = self._get_str(key[0], default)
  File "c:\Users\...\hpcenv\Lib\site-packages\tensordict\base.py", line 2234, in _default_get
    raise KeyError('key "locs" not found in TensorDict with keys [...]')

System info

  • Installation method: pip (virtualenv)
  • RL4CO: 0.3.1
  • PyTorch: 2.2.0+cpu
  • PyTorch Lightning: 2.2.0.post0
  • TorchRL: 0.3.0
  • TensorDict: 0.3.0
  • Numpy: 1.26.4
  • Python: 3.11.8 (tags/v3.11.8:db85d51, Feb 6 2024, 22:03:32) [MSC v.1937 64 bit (AMD64)]
  • Platform: win32

Reason and Possible fixes

The actual issue originates from the StateAugmentation initialization. When the feats argument is None, the constructor calls env_aug_feats to determine the appropriate feature names based on the environment name (presumably, because it receives the env_name argument). However, env_aug_feats currently ignores the env_name parameter and unconditionally returns ["locs"], assuming this dict key is present in all environments.

A solution could be modifying env_aug_feats to select the feature names from a dictionary based on the environment name, as done in some parts of the project. Here you have a diff with the proposed changes:

--- a/rl4co/data/transforms.py
+++ b/rl4co/data/transforms.py
@@ -1,4 +1,14 @@
+env_features= {
+    "tsp": ["locs"],
+    "whatever": ["feature"],
+    # ....
+}
+
 def env_aug_feats(env_name: str = None):
-    return ["locs"]
+    return env_features.get(env_name, ["locs"])

Another suggestion is to modify the constructors of those models, such as POMO, to allow passing the list of features (feats) directly to StateAugmentation, without relying solely on a dictionary. Here is how it should look:

--- a/rl4co/models/zoo/pomo/model.py
+++ b/rl4co/models/zoo/pomo/model.py
@@ -10,6 +10,7 @@
         policy: Union[nn.Module, POMOPolicy] = None,
         policy_kwargs={},
         baseline: str = "shared",
+        feats=None,
         num_augment: int = 8,
         use_dihedral_8: bool = True,
         num_starts: int = None,
@@ -26,7 +27,7 @@
         self.num_augment = num_augment
         if self.num_augment > 1:
             self.augment = StateAugmentation(
-                self.env.name, num_augment=self.num_augment, use_dihedral_8=use_dihedral_8
+                self.env.name, num_augment=self.num_augment, use_dihedral_8=use_dihedral_8, feats=feats
             )
         else:
             self.augment = None

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have provided a minimal working example to reproduce the bug (required)

Hi @aziabatz, thanks for your suggestion about the improvement of the augmentation function. We refactored the StateAugmentation() class and cleaned up some redundant code.

Now the class has a clear augmentation function and features selection. Users could use the symmetric or dihedral8 augmentation function, or their own augmentation functions. And customize the feature to augment.