rlgraph / rlgraph

RLgraph: Modular computation graphs for deep reinforcement learning

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Inconsistence of the record key name in pytorch mode

yangysc opened this issue · comments

In the ring_buffer.py, sometimes (in the init stage) the record is stored in the form of '/terminal'( with a /), sometimes (in the run stage) the key name of the stored data is flattened by the define_by_run_unflatten function, so there is no '/', which caused the error, because self.terminal_key in the command below (in the line 146 of ring_buffer.py) always assumes there is a '/'.
num_records = get_batch_size(records[self.terminal_key])

Why do we need to flatten it?

Hey,
could you specify 'sometimes' in more detail? When running the unit-test, I do not see anything ('test_ring_buffer').

The reason we flatten things is to be able to iterate across all the keys of the flat dict in case the record has a deep-nested structure (nested states/actions). We want to flatten so we can iterate within one single dict across state/state1, state/state2 etc. When we return the record, we want to unflatten this again so the final returned record has the original nested structure.

Sorry for not mentioning that clearly. Yes, In fact, I also tested the 'test_ring_buffer.py', and there was no problem. But when you run the ppo_cartpole.py in the example folder with the pytorch as the backend, the error would occur.

/home/noone/anaconda3/envs/lab/bin/python /home/noone/Documents/New_torch/rlgraph-master/examples/ppo_cartpole.py
pygame 1.9.6
Hello from the pygame community. https://www.pygame.org/contribute.html
19-07-16 10:21:07:INFO:Parsed state space definition: Floatbox((4,) <class 'numpy.float32'> )
19-07-16 10:21:07:INFO:Parsed action space definition: Intbox(() <class 'numpy.int32'> )
19-07-16 10:21:07:INFO:No preprocessing required.
Expose API sync from synchronizable to policy code:
@rlgraph_api(component=self, must_be_complete=False, ok_to_overwrite=False)
def sync(self, values_):
	return getattr(self.sub_components['synchronizable'], 'sync')(values_)

Expose API sync from synchronizable to value-function code:
@rlgraph_api(component=self, must_be_complete=False, ok_to_overwrite=False)
def sync(self, values_):
	return getattr(self.sub_components['synchronizable'], 'sync')(values_)

19-07-16 10:21:07:INFO:Execution spec is: {'seed': 15, 'mode': 'single', 'distributed_spec': None, 'disable_monitoring': False, 'gpu_spec': {'gpus_enabled': False, 'max_usable_gpus': 0, 'cuda_devices': None}, 'device_strategy': 'default', 'default_device': None, 'device_map': {}, 'torch_num_threads': 1, 'OMP_NUM_THREADS': 1}
19-07-16 10:21:07:INFO:Meta-graph build completed in 0.14640023600077257 s.
19-07-16 10:21:07:INFO:Meta-graph op-records generated: 297
19-07-16 10:21:11:INFO:Define-by-run computation-graph build completed in 3.9417111479851883 s (38 iterations).
19-07-16 10:21:11:INFO:Initialized single-threaded executor with 1 environments 'OpenAIGym(<TimeLimit<CartPoleEnv<CartPole-v0>>>)' and Agent 'PPOAgent()'
Starting workload, this will take some time for the agents to build.
Traceback (most recent call last):
  File "/home/noone/Documents/New_torch/rlgraph-master/examples/ppo_cartpole.py", line 88, in <module>
    main(sys.argv)
  File "/home/noone/Documents/New_torch/rlgraph-master/examples/ppo_cartpole.py", line 80, in main
    worker.execute_timesteps(10000, use_exploration=True)
  File "/home/noone/Documents/New_torch/rlgraph-master/rlgraph/execution/single_threaded_worker.py", line 112, in execute_timesteps
    reset=reset
  File "/home/noone/Documents/New_torch/rlgraph-master/rlgraph/execution/single_threaded_worker.py", line 329, in _execute
    episode_terminals[i]
  File "/home/noone/Documents/New_torch/rlgraph-master/rlgraph/execution/single_threaded_worker.py", line 405, in _observe
    terminals=terminals, env_id=env_ids
  File "/home/noone/Documents/New_torch/rlgraph-master/rlgraph/agents/agent.py", line 486, in observe
    terminals=np.asarray(self.terminals_buffer[env_id])
  File "/home/noone/Documents/New_torch/rlgraph-master/rlgraph/agents/ppo_agent.py", line 566, in _observe_graph
    self.graph_executor.execute(("insert_records", [preprocessed_states, actions, rewards, terminals]))
  File "/home/noone/Documents/New_torch/rlgraph-master/rlgraph/graphs/pytorch_executor.py", line 96, in execute
    api_ret = self.graph_builder.execute_define_by_run_op(api_method, tensor_params)
  File "/home/noone/Documents/New_torch/rlgraph-master/rlgraph/graphs/graph_builder.py", line 951, in execute_define_by_run_op
    return self.root_component.api_fn_by_name[api_method](self.root_component, *params)
  File "/home/noone/Documents/New_torch/rlgraph-master/rlgraph/utils/decorators.py", line 105, in api_method_wrapper
    output = wrapped_func(self, *args, **kwargs)
  File "/home/noone/Documents/New_torch/rlgraph-master/rlgraph/agents/ppo_agent.py", line 252, in insert_records
    return agent.memory.insert_records(records)
  File "/home/noone/Documents/New_torch/rlgraph-master/rlgraph/utils/decorators.py", line 105, in api_method_wrapper
    output = wrapped_func(self, *args, **kwargs)
  File "/home/noone/Documents/New_torch/rlgraph-master/rlgraph/components/memories/ring_buffer.py", line 146, in _graph_fn_insert_records
    num_records = get_batch_size(records[self.terminal_key])
KeyError: '/terminals'

Process finished with exit code 1

The content of the ppo_cartpole file is:

# Copyright 2018/2019 The RLgraph authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""
Example script for training a Proximal policy optimization agent on an OpenAI gym environment.

Usage:

python ppo_cartpole.py [--config configs/ppo_cartpole.json] [--env CartPole-v0]


"""

import json
import os
import sys

import numpy as np
from absl import flags

from rlgraph.agents import Agent
from rlgraph.environments import OpenAIGymEnv
from rlgraph.execution import SingleThreadedWorker

FLAGS = flags.FLAGS

flags.DEFINE_string('config', './configs/ppo_cartpole.json', 'Agent config file.')
flags.DEFINE_string('env', 'CartPole-v0', 'gym environment ID.')


def main(argv):
    try:
        FLAGS(argv)
    except flags.Error as e:
        print('%s\\nUsage: %s ARGS\\n%s' % (e, sys.argv[0], FLAGS))

    agent_config_path = os.path.join(os.getcwd(), FLAGS.config)
    with open(agent_config_path, 'rt') as fp:
        agent_config = json.load(fp)

    env = OpenAIGymEnv.from_spec({
        "type": "openai",
        "gym_env": FLAGS.env
    })

    agent = Agent.from_spec(
        agent_config,
        state_space=env.state_space,
        action_space=env.action_space
    )
    episode_returns = []

    def episode_finished_callback(episode_return, duration, timesteps, **kwargs):
        episode_returns.append(episode_return)
        if len(episode_returns) % 10 == 0:
            print("Episode {} finished: reward={:.2f}, average reward={:.2f}.".format(
                len(episode_returns), episode_return, np.mean(episode_returns[-10:])
            ))

    worker = SingleThreadedWorker(env_spec=lambda: env, agent=agent, render=False, worker_executes_preprocessing=False,
                                  episode_finish_callback=episode_finished_callback)
    print("Starting workload, this will take some time for the agents to build.")

    # Use exploration is true for training, false for evaluation.
    worker.execute_timesteps(10000, use_exploration=True)

    print("Mean reward: {:.2f} / over the last 10 episodes: {:.2f}".format(
        np.mean(episode_returns), np.mean(episode_returns[-10:])
    ))


if __name__ == '__main__':
    main(sys.argv)

Then I debugged the ppo_cartpole.py. I found that when running this command
agent = Agent.from_spec( agent_config, state_space=env.state_space, action_space=env.action_space )

The constructed ring_buffer has a key name with a '/'. And in the function_graph_fn_insert_records(self, records) of the ring_buffer.py file, num_records = get_batch_size(records[self.terminal_key]) wouldn't error because that records has '/' and self.terminal_key also has '/'.

But when you runn worker.execute_timesteps(10000, use_exploration=True), in the function_graph_fn_insert_records(self, records) of the ring_buffer.py file, num_records = get_batch_size(records[self.terminal_key]) would error, because that records has no '/' but self.terminal_key still has '/'.

I think maybe there is inconsistence about flattening the records or not.

Hope this helps.

Hey, this may be due to some recent change regarding flattening in the agent, I will try to follow up but cannot before the weekend..

Hey, this should be fixed now - the issue was that due to the way the function is called (a graph function wrapped with an API decorator), the automatic flattening/unflattening was not happening.

I will close this now but let us know if the issue persists