joschu / modular_rl

Implementation of TRPO and related algorithms

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Error when using saved weights to continue learning

abhinavrai44 opened this issue · comments

I am getting the following warning when I try to save the weights. Here I am loading the weights from a previously trained model.

{'warnflag': 1, 'task': 'STOP: TOTAL NO. of ITERATIONS EXCEEDS LIMIT', 'nit': 26, 'funcalls': 30}
got zero gradient. not updating

This is the code that I am using

 if __name__ == "__main__":
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    update_argument_parser(parser, GENERAL_OPTIONS)
    parser.add_argument("--agent",required=True)
    parser.add_argument("--plot",action="store_true")
    parser.add_argument('--visualize', dest='visualize', action='store_true', default=False)
    args,_ = parser.parse_known_args([arg for arg in sys.argv[1:] if arg not in ('-h', '--help')])
    env = StandEnv(args.visualize)
    hdf = h5py.File('a.h5','r')
    snapnames = hdf['agent_snapshots'].keys()
    snapname = snapnames[-1]
    agent = cPickle.loads(hdf['agent_snapshots'][snapname].value)
    agent.stochastic=False
    env_spec = env.spec

    agent_ctor = get_agent_cls(args.agent)
    update_argument_parser(parser, agent_ctor.options)
    args = parser.parse_args()
    
    args.timestep_limit = 200
    cfg = args.__dict__
    np.random.seed(args.seed)
    if args.use_hdf:
        hdf, diagnostics = prepare_h5_file(args)
    gym.logger.setLevel(logging.WARN)

    COUNTER = 0
    def callback(stats):
        global COUNTER
        COUNTER += 1
        # Print stats
        print "*********** Iteration %i ****************" % COUNTER
        print tabulate(filter(lambda (k,v) : np.asarray(v).size==1, stats.items())) #pylint: disable=W0110
        # Store to hdf5
        if args.use_hdf:
            for (stat,val) in stats.items():
                if np.asarray(val).ndim==0:
                    diagnostics[stat].append(val)
                else:
                    assert val.ndim == 1
                    diagnostics[stat].extend(val)
            if args.snapshot_every and ((COUNTER % args.snapshot_every==0) or (COUNTER==args.n_iter)):
                hdf['/agent_snapshots/%0.4i'%COUNTER] = np.array(cPickle.dumps(agent,-1))
        # Plot
        if args.plot:
            animate_rollout(env, agent, min(500, args.timestep_limit))

    run_policy_gradient_algorithm(env, agent, callback=callback, usercfg = cfg)

    if args.use_hdf:
        hdf['env_id'] = env_spec.id
        try: hdf['env'] = np.array(cPickle.dumps(env, -1))
        except Exception: print "failed to pickle env" #pylint: disable=W0703
    env.close()