DCBTrainer: expected Tensor as element 0 in argument 0, but got str
tobiasmorville opened this issue · comments
Running the StatLog shuttle example
from genrl.bandit import StatlogDataBandit
bandit = StatlogDataBandit(download=True)
context = bandit.reset()
from genrl.bandit import NeuralLinearPosteriorAgent
agent = NeuralLinearPosteriorAgent(bandit)
context = bandit.reset()
action = agent.select_action(context)
new_context, reward = bandit.step(action)
from genrl.bandit import DCBTrainer
trainer = DCBTrainer(agent, bandit)
trainer.train(timesteps=5000, batch_size=32)
yields an error in the training:
TypeError: expected Tensor as element 0 in argument 0, but got str
Printing the output at the time of transition self.db[contexts]
is a list of strings:
Started at 28-08-20 12:07:36
Training NeuralLinearPosteriorAgent on StatlogDataBandit for 5000 timesteps
timestep regret/regret reward/reward regret/cumulative_regret reward/cumulative_reward regret/regret_moving_avg reward/reward_moving_avg
100 0 1 76 24 0.76 0.24
200 1 0 162 38 0.81 0.19
300 1 0 251 49 0.852 0.148
400 1 0 336 64 0.872 0.128
500 0 1 421 79 0.864 0.136
self.db[contexts]: ['x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x']
Encounterred exception during training!
expected Tensor as element 0 in argument 0, but got str
Training completed in 1 seconds
Final Regret Moving Average: 0.86 | Final Reward Moving Average: 0.14
Full error trace:
Traceback (most recent call last):
File "/usr/local/anaconda3/envs/rl/lib/python3.6/site-packages/genrl/bandit/trainer.py", line 194, in train
action, kwargs.get("batch_size", 64), train_epochs
File "/usr/local/anaconda3/envs/rl/lib/python3.6/site-packages/genrl/bandit/agents/cb_agents/neural_linpos.py", line 188, in update_params
z, y = self.latent_db.get_data_for_action(action, batch_size)
File "/usr/local/anaconda3/envs/rl/lib/python3.6/site-packages/genrl/bandit/agents/cb_agents/common/transition.py", line 110, in get_data_for_action
torch.stack([self.db["contexts"][i] for i in idx])
TypeError: expected Tensor as element 0 in argument 0, but got str
EDIT:
Printing idx
at the time of error yields:
idx = [200, 208, 343, 209, 318, 8, 33, 280, 513, 396, 184, 324, 287, 365, 145, 136, 338, 268, 231, 519, 116, 113, 340, 341, 409, 238, 153, 262, 446, 244, 135, 205]
and fetching correspondant contexts from the agent [agent.db.db['contexts'][x] for x in idx]
yields seemingly valid input:
[tensor([ 37., 0., 98., 1., 30., -16., 61., 67., 6.]),
tensor([40., -2., 88., 0., 38., 0., 48., 49., 2.]),
tensor([41., 0., 81., 0., 42., 4., 40., 40., 0.]),
tensor([ 37., 0., 77., 0., -22., -14., 41., 101., 60.]),
tensor([83., -1., 86., -2., -4., 0., 3., 92., 88.]),
tensor([37., -5., 79., 0., 36., 0., 42., 43., 0.]),
tensor([56., 0., 97., -1., 50., 31., 41., 48., 6.]),
tensor([39., 0., 86., 2., 38., 0., 47., 47., 0.]),
tensor([55., -3., 84., 0., 54., 0., 28., 30., 2.]),
tensor([ 44., 0., 81., 6., 42., -16., 38., 40., 2.]),
tensor([41., 5., 82., 0., 42., 13., 41., 41., 0.]),
tensor([48., 0., 86., 0., 46., -8., 38., 40., 2.]),
tensor([ 44., 0., 109., 0., 42., -15., 66., 68., 2.]),
tensor([ 46., 0., 88., -4., 44., -22., 43., 44., 2.]),
tensor([ 82., 1., 86., 0., -22., 6., 4., 109., 106.]),
tensor([37., 0., 74., 0., 28., 6., 38., 46., 8.]),
tensor([ 55., 0., 82., -2., -20., 12., 26., 103., 76.]),
tensor([43., 0., 86., 7., 42., 0., 43., 45., 2.]),
tensor([54., -6., 90., -3., 54., -4., 36., 36., 0.]),
tensor([47., 2., 76., -1., 46., 0., 29., 29., 0.]),
tensor([45., 0., 83., 0., 46., 29., 37., 36., 0.]),
tensor([56., 0., 81., 0., -2., 26., 25., 83., 58.]),
tensor([ 43., 0., 83., 0., 42., -17., 39., 41., 2.]),
tensor([37., 0., 76., 0., 26., 15., 39., 50., 12.]),
tensor([37., 0., 80., 0., 10., -5., 43., 70., 26.]),
tensor([ 49., 0., 95., 0., 46., -15., 47., 49., 2.]),
tensor([55., 0., 95., -1., 52., -1., 40., 44., 4.]),
tensor([51., -5., 88., 0., 52., 0., 36., 36., 0.]),
tensor([56., 5., 78., 2., 44., 0., 22., 34., 12.]),
tensor([44., -5., 77., 0., 44., 0., 33., 33., 0.]),
tensor([ 80., 0., 84., 0., -36., 5., 4., 120., 116.]),
tensor([45., 0., 86., 0., 46., 21., 40., 39., 0.])]
however, printing self.db[contexts]
yields all strings:
['x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x']
@tomonodes Could you update your installation (from source) and try this:
from genrl.utils import StatlogDataBandit
bandit = StatlogDataBandit(download=True)
context = bandit.reset()
from genrl.agents import NeuralLinearPosteriorAgent
agent = NeuralLinearPosteriorAgent(bandit)
context = bandit.reset()
action = agent.select_action(context)
new_context, reward = bandit.step(action)
from genrl.trainers import DCBTrainer
trainer = DCBTrainer(agent, bandit)
trainer.train(timesteps=5000, batch_size=32)
This works for me.
Feel free to close the issue(if this works).
Installing with pip install genrl
installs genrl==0.0.2
which reproduces the bug on a clean conda env with python 3.6.
Building from source installs genrl==0.0.1
in which the above problem is solved 👍.