SforAiDl / genrl

A PyTorch reinforcement learning library for generalizable and reproducible algorithm implementations with an aim to improve accessibility in RL

Home Page:https://genrl.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Errors in documentation for Adding a new Data Bandit

TMorville opened this issue · comments

Running the example code found here yields a shape error

from typing import Tuple

import pandas as pd
import torch

from genrl.utils.data_bandits.base import DataBasedBandit
from genrl.utils.data_bandits.utils import download_data


URL = "http://archive.ics.uci.edu/ml/machine-learning-databases/wine/wine.data"

class WineDataBandit(DataBasedBandit):
    def __init__(self, **kwargs):
        super(WineDataBandit, self).__init__(**kwargs)

        path = kwargs.get("path", "./data/Wine/")
        download = kwargs.get("download", None)
        force_download = kwargs.get("force_download", None)
        url = kwargs.get("url", URL)

        if download:
            path = download_data(path, url, force_download)

        self._df = pd.read_csv(path, header=None)
        self.n_actions = len(self._df[0].unique())
        self.context_dim = self._df.shape[1] - 1
        self.len = len(self._df)
        
        print(self.n_actions, self.context_dim, self.len)

    def reset(self) -> torch.Tensor:
        self._reset()
        self.df = self._df.sample(frac=1).reset_index(drop=True)
        return self._get_context()

    def _compute_reward(self, action: int) -> Tuple[int, int]:
        label = self._df.iloc[self.idx, 0]
        r = int(label == (action + 1))
        return r, 1

    def _get_context(self) -> torch.Tensor:
        return torch.tensor(
            self._df.iloc[self.idx, 0].values,
            device=self.device,
            dtype=torch.float,
        )

due to self._df.iloc[self.idx, 0].values returning an integer. Correcting this line to self._df.iloc[self.idx, 1:].values such that all context, except for target output, is returned makes the example work.

Thanks for pointing this out!

Yep, this is a typo, it should be self._df.iloc[self.idx, 1:].values.

Would you be up for opening a PR for the same?