rlgraph / rlgraph

RLgraph: Modular computation graphs for deep reinforcement learning

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Building policy with continuous action space throws error

janislavjankov opened this issue · comments

Here's a test to demonstrate this:

    def test_policy_for_continuous_action_space(self):
        # state_space (NN is a simple single fc-layer relu network (2 units), random biases, random weights).
        state_space = FloatBox(shape=(4,), add_batch_rank=True)

        # action_space (5 possible actions).
        action_space = FloatBox(low=-1.0, high=1.0, add_batch_rank=True)

        policy = Policy(network_spec=config_from_path("configs/test_simple_nn.json"), action_space=action_space)
        test = ComponentTest(
            component=policy,
            input_spaces=dict(
                nn_input=state_space,
                actions=action_space,
                logits=FloatBox(shape=(2, ), add_batch_rank=True),
                probabilities=FloatBox(add_batch_rank=True)
            ),
            action_space=action_space
        )

        test.read_variable_values(policy.variables)

This test fails with:

self = <rlgraph.components.policies.policy.Policy object at 0x12ebb08d0>
key = '_T0_'
probabilities = <tf.Tensor 'policy/action-adapter-0/Squeeze:0' shape=(?,) dtype=float32>

    @graph_fn(flatten_ops=True, split_ops=True, add_auto_key_as_first_param=True)
    def _graph_fn_get_distribution_entropies(self, key, probabilities):
        """
        Pushes the given `probabilities` through all our distributions' `entropy` API-methods and returns a
        DataOpDict with the keys corresponding to our `action_space`.
    
        Args:
            probabilities (DataOp): The parameters to define a distribution.
    
        Returns:
            FlattenedDataOp: A DataOpDict with the different distributions' `entropy` outputs. Keys always correspond to
                structure of `self.action_space`.
        """
>       return self.distributions[key].entropy(probabilities)
E       KeyError: '_T0_'

I see that the ActionAdapter in _graph_fn_get_probabilities_log_probs returns tuples instead of single values (lines 213, 214) - the parameters and the log parameters, but not the probabilities (and assumes it is a Gaussian distribution).
Btw, I would expect this to be a call to the underlying distribution, not to the action adapter.

Thanks for raising this! I will try looking into it tonight.

Ok so we need to ensure the flattening/splitting does not go wrong when we get nested tuples (mean,std) instead of flat probabilities/logits. Hope we can do this this week. The low-tech solution would be to manually check the distributions everywhere and read out different args, but I hope we can avoid that.

Also the method name (and function) is a bit misleading. It doesn't really generalize between discrete and continuous variables.

Yes, the policy class is a bit messy and definitely needs overhaul, relict from ad hoc functions built with paper deadline pressure and not cleaning them up since then.

Ok, we have fixed the continuous action problems and added some API methods to the Policy class, mainly due to the renaming of "probabilities" into "parameters", thereby generalizing for all kinds of different distributions, not just categorical ones. You can still use the old API methods and will just get a warning to change the names. We will deprecate the old ones in a few months or so.
An example for continuous actions is the Pendulum-v0 test case on PPO here:
tests/agent_learning/short_tasks/test_ppo_agent_short_task_learning.py::test_ppo_on_continuous_action_environment, which remains to be tuned for actual learning.

The parameterization of Normal and Beta distributions always happens within the last axis of the NN output tensor. So for example for the Normal distribution and an action space: FloatBox(shape=(2,)) (2 actions), a single item (of a batch) NN output will be [1.0, 2.0, 0.5, 0.01], where the first two floats are the mean values of the 2 actions and the last two floats are the log-stddev values of the 2 actions.

I'm closing this issue now.

Thanks