keras-rl / keras-rl

Deep Reinforcement Learning for Keras.

Home Page:http://keras-rl.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Numpy data wrangling in rl/callbacks.py

jan-gebauer opened this issue · comments

commented

Hi, I've ran into an issue while trying to use the fit() function on a SARSAAgent class. The issue appeared in rl/callbacks.py in the function on_step_begin of class TrainIntervalLogger. Specifically, the assert for metrics.shape was not working due to the datatype (comparison of tuples (x,) and (y,z)) and then for np.isnan(metrics).all() as np couldn't handle metrics safely due to the wrong data type.

I've fixed it for myself but I was wondering if I should make an issue and do a pull request. It's probably not well done but I'm sure some Python guru can refactor it in a flash.

I'm using Python 3.8.7 and Numpy 1.19.5. Keras-rl should be the most recent one, I've pulled it the other day (02.02.2021?). The following script reproduces the issue.

import numpy as np

ml = [[], [1,2,3,4], [1,2]] #mimics the self.metrics variable

na = np.array(ml)

foo = 3
try:
assert na.shape == (na.shape[0], foo)
except:
    print ("assert did not pass")

try:
    print (np.isnan(na).all())
except:
    print ("isnan did not pass")

My suggestion for a quick fix

metrics = np.array(self.metrics)
assert metrics.shape[0] == self.interval
formatted_metrics = ''
nan_check = False
for element in metrics:
    if np.isnan(element).all():
        nan_check = True
if not nan_check:  # not all values are means

The nan_check would then repeat for the next code block as well, as you have "np.isnan(infos).all()" in there, which also doesn't work.

I could probably improve the for loop, if this was required.

commented

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.