SforAiDl / genrl

A PyTorch reinforcement learning library for generalizable and reproducible algorithm implementations with an aim to improve accessibility in RL

Home Page:https://genrl.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

DCBTrainer: expected Tensor as element 0 in argument 0, but got str

tobiasmorville opened this issue · comments

Running the StatLog shuttle example

from genrl.bandit import StatlogDataBandit

bandit = StatlogDataBandit(download=True)
context = bandit.reset()

from genrl.bandit import NeuralLinearPosteriorAgent

agent = NeuralLinearPosteriorAgent(bandit)
context = bandit.reset()

action = agent.select_action(context)
new_context, reward = bandit.step(action)

from genrl.bandit import DCBTrainer

trainer = DCBTrainer(agent, bandit)
trainer.train(timesteps=5000, batch_size=32)

yields an error in the training:

TypeError: expected Tensor as element 0 in argument 0, but got str

Printing the output at the time of transition self.db[contexts] is a list of strings:

Started at 28-08-20 12:07:36
Training NeuralLinearPosteriorAgent on StatlogDataBandit for 5000 timesteps
timestep                  regret/regret             reward/reward             regret/cumulative_regret  reward/cumulative_reward  regret/regret_moving_avg  reward/reward_moving_avg  
100                       0                         1                         76                        24                        0.76                      0.24                      
200                       1                         0                         162                       38                        0.81                      0.19                      
300                       1                         0                         251                       49                        0.852                     0.148                     
400                       1                         0                         336                       64                        0.872                     0.128                     
500                       0                         1                         421                       79                        0.864                     0.136
                     
self.db[contexts]: ['x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x']

Encounterred exception during training!
expected Tensor as element 0 in argument 0, but got str

Training completed in 1 seconds
Final Regret Moving Average: 0.86 | Final Reward Moving Average: 0.14

Full error trace:

Traceback (most recent call last):
  File "/usr/local/anaconda3/envs/rl/lib/python3.6/site-packages/genrl/bandit/trainer.py", line 194, in train
    action, kwargs.get("batch_size", 64), train_epochs
  File "/usr/local/anaconda3/envs/rl/lib/python3.6/site-packages/genrl/bandit/agents/cb_agents/neural_linpos.py", line 188, in update_params
    z, y = self.latent_db.get_data_for_action(action, batch_size)
  File "/usr/local/anaconda3/envs/rl/lib/python3.6/site-packages/genrl/bandit/agents/cb_agents/common/transition.py", line 110, in get_data_for_action
    torch.stack([self.db["contexts"][i] for i in idx])
TypeError: expected Tensor as element 0 in argument 0, but got str

EDIT:

Printing idx at the time of error yields:

idx = [200, 208, 343, 209, 318, 8, 33, 280, 513, 396, 184, 324, 287, 365, 145, 136, 338, 268, 231, 519, 116, 113, 340, 341, 409, 238, 153, 262, 446, 244, 135, 205]

and fetching correspondant contexts from the agent [agent.db.db['contexts'][x] for x in idx] yields seemingly valid input:

[tensor([ 37.,   0.,  98.,   1.,  30., -16.,  61.,  67.,   6.]),
 tensor([40., -2., 88.,  0., 38.,  0., 48., 49.,  2.]),
 tensor([41.,  0., 81.,  0., 42.,  4., 40., 40.,  0.]),
 tensor([ 37.,   0.,  77.,   0., -22., -14.,  41., 101.,  60.]),
 tensor([83., -1., 86., -2., -4.,  0.,  3., 92., 88.]),
 tensor([37., -5., 79.,  0., 36.,  0., 42., 43.,  0.]),
 tensor([56.,  0., 97., -1., 50., 31., 41., 48.,  6.]),
 tensor([39.,  0., 86.,  2., 38.,  0., 47., 47.,  0.]),
 tensor([55., -3., 84.,  0., 54.,  0., 28., 30.,  2.]),
 tensor([ 44.,   0.,  81.,   6.,  42., -16.,  38.,  40.,   2.]),
 tensor([41.,  5., 82.,  0., 42., 13., 41., 41.,  0.]),
 tensor([48.,  0., 86.,  0., 46., -8., 38., 40.,  2.]),
 tensor([ 44.,   0., 109.,   0.,  42., -15.,  66.,  68.,   2.]),
 tensor([ 46.,   0.,  88.,  -4.,  44., -22.,  43.,  44.,   2.]),
 tensor([ 82.,   1.,  86.,   0., -22.,   6.,   4., 109., 106.]),
 tensor([37.,  0., 74.,  0., 28.,  6., 38., 46.,  8.]),
 tensor([ 55.,   0.,  82.,  -2., -20.,  12.,  26., 103.,  76.]),
 tensor([43.,  0., 86.,  7., 42.,  0., 43., 45.,  2.]),
 tensor([54., -6., 90., -3., 54., -4., 36., 36.,  0.]),
 tensor([47.,  2., 76., -1., 46.,  0., 29., 29.,  0.]),
 tensor([45.,  0., 83.,  0., 46., 29., 37., 36.,  0.]),
 tensor([56.,  0., 81.,  0., -2., 26., 25., 83., 58.]),
 tensor([ 43.,   0.,  83.,   0.,  42., -17.,  39.,  41.,   2.]),
 tensor([37.,  0., 76.,  0., 26., 15., 39., 50., 12.]),
 tensor([37.,  0., 80.,  0., 10., -5., 43., 70., 26.]),
 tensor([ 49.,   0.,  95.,   0.,  46., -15.,  47.,  49.,   2.]),
 tensor([55.,  0., 95., -1., 52., -1., 40., 44.,  4.]),
 tensor([51., -5., 88.,  0., 52.,  0., 36., 36.,  0.]),
 tensor([56.,  5., 78.,  2., 44.,  0., 22., 34., 12.]),
 tensor([44., -5., 77.,  0., 44.,  0., 33., 33.,  0.]),
 tensor([ 80.,   0.,  84.,   0., -36.,   5.,   4., 120., 116.]),
 tensor([45.,  0., 86.,  0., 46., 21., 40., 39.,  0.])]

however, printing self.db[contexts] yields all strings:

['x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x']

@tomonodes Could you update your installation (from source) and try this:

from genrl.utils import StatlogDataBandit

bandit = StatlogDataBandit(download=True)
context = bandit.reset()

from genrl.agents import NeuralLinearPosteriorAgent

agent = NeuralLinearPosteriorAgent(bandit)
context = bandit.reset()

action = agent.select_action(context)
new_context, reward = bandit.step(action)

from genrl.trainers import DCBTrainer

trainer = DCBTrainer(agent, bandit)
trainer.train(timesteps=5000, batch_size=32)

This works for me.

Feel free to close the issue(if this works).

Yep, this looks like an older build. This specific issue was fixed in #260 . The docs still seem to be outdated though. This should be fixed with #281

Installing with pip install genrl installs genrl==0.0.2 which reproduces the bug on a clean conda env with python 3.6.

Building from source installs genrl==0.0.1 in which the above problem is solved 👍.