Campatibility with TF-Agents
fede72bari opened this issue · comments
All Gym/Gymnasium standard environments are compatible with TwnsorFlow RL agents, but when I tried to use TF-Agents with anytrading I get errors because some required methods and attributes seem to be missing. For instance
this code that is working for CartoPole and other Gym environments
env = gym.make('stocks-v0') #, df=df_features, frame_bound=(5,100), window_size=5)
def my_process_data(env):
start = env.frame_bound[0] - env.window_size
end = env.frame_bound[1]
prices = env.df.loc[:, 'Close'].to_numpy()[start:end]
signal_features = env.df.loc[:, ['volume', 'volume_delta', 'day_of_the_week', 'US_market', 'EU_market', 'AS_market', 'POC_volume', 'DPOC_buy_volume', 'DPOC_sell_volume', 'short_CVD', 'impulsive_CVD', 'avg_short_CVD', 'middle_CVD', 'middle_impulsive_CVD', 'avg_middle_CVD', 'close_delta', 'open_delta', 'high_delta', 'low_delta', 'close_%_change', 'open_%_change', 'high_%_change', 'low_%_change', 'H_shadow', 'L_shadow', 'BBANDS_short_3devdown', 'BBANDS_short_dev_delta', 'BBANDS_short_dev_delta_rolling_mean_7', 'BBANDS_short_dev_delta_rolling_mean_14', 'BBANDS_middle_dev_delta', 'BBANDS_middle_dev_delta_rolling_mean_7', 'BBANDS_middle_dev_delta_rolling_mean_14', 'RSI_14', 'RSI_7', 'ATR', 'HT_sine', 'HT_leadsine', 'HT_sine_over_leadsine', 'mfdelta', 'mfsign', 'SAR_UP_DWN', 'SAR_UP_DWN_STRENGHT', 'short_CVD_over_CVD_mean', 'middle_CVD_over_CVD_mean', 'VWAP_over_POC_price', 'VWAP_over_DPOC_buy_price', 'VWAP_over_DPOC_sell_price', 'POC_price_over_DPOC_buy_price', 'POC_price_over_DPOC_sell_price', 'DPOC_buy_price_over_DPOC_sell_price', 'short_CVD_over_CVD_mean-1', 'middle_CVD_over_CVD_mean-1', 'VWAP_over_POC_price-1', 'VWAP_over_DPOC_buy_price-1', 'VWAP_over_DPOC_sell_price-1', 'POC_price_over_DPOC_buy_price-1', 'POC_price_over_DPOC_sell_price-1', 'DPOC_buy_price_over_DPOC_sell_price-1', 'short_CVD_over_CVD_mean-2', 'middle_CVD_over_CVD_mean-2', 'VWAP_over_POC_price-2', 'VWAP_over_DPOC_buy_price-2', 'VWAP_over_DPOC_sell_price-2', 'POC_price_over_DPOC_buy_price-2', 'POC_price_over_DPOC_sell_price-2', 'DPOC_buy_price_over_DPOC_sell_price-2', 'short_CVD_over_CVD_mean-3', 'middle_CVD_over_CVD_mean-3', 'VWAP_over_POC_price-3', 'VWAP_over_DPOC_buy_price-3', 'VWAP_over_DPOC_sell_price-3', 'POC_price_over_DPOC_buy_price-3', 'POC_price_over_DPOC_sell_price-3', 'DPOC_buy_price_over_DPOC_sell_price-3', 'short_CVD_over_CVD_mean-4', 'middle_CVD_over_CVD_mean-4', 'VWAP_over_POC_price-4', 'VWAP_over_DPOC_buy_price-4', 'VWAP_over_DPOC_sell_price-4', 'POC_price_over_DPOC_buy_price-4', 'POC_price_over_DPOC_sell_price-4', 'DPOC_buy_price_over_DPOC_sell_price-4', 'high_%_change-1', 'low_%_change-1', 'open_%_change-1', 'close_%_change-1', 'high_%_change-2', 'low_%_change-2', 'open_%_change-2', 'close_%_change-2', 'high_%_change-3', 'low_%_change-3', 'open_%_change-3', 'close_%_change-3', 'high_%_change-4', 'low_%_change-4', 'open_%_change-4', 'close_%_change-4', 'high_%_change-5', 'low_%_change-5', 'open_%_change-5', 'close_%_change-5', 'impulsive_CVD_derivate', 'short_CVD_derivate', 'avg_short_CVD_derivate', 'middle_CVD_derivate', 'avg_middle_CVD_derivate', 'VWAP_derivate', 'volume_derivate', 'RSI_14_derivate', 'RSI_7_derivate', 'mama_derivate', 'impulsive_CVD_increasing', 'short_CVD_increasing', 'middle_CVD_increasing', 'avg_short_CVD_increasing', 'avg_middle_CVD_increasing', 'VWAP_increasing', 'RSI_14_increasing', 'RSI_7_increasing', 'doji', 'Engulfing', 'Marubozu', 'short_CVD_min_delayed', 'short_CVD_max_delayed', 'middle_CVD_min', 'middle_CVD_min_delayed', 'middle_CVD_max', 'middle_CVD_max_delayed', 'impulsive_CVD_min_delayed', 'impulsive_CVD_max_delayed', 'impulsive_CVD_derivate_min_delayed', 'impulsive_CVD_derivate_max_delayed', 'avg_short_CVD_min_delayed', 'avg_short_CVD_max_delayed', 'BBANDS_short_dev_delta_min_delayed', 'BBANDS_short_dev_delta_max_delayed', 'BBANDS_middle_dev_delta_min_delayed', 'BBANDS_middle_dev_delta_max_delayed', 'RSI_14_min_delayed', 'RSI_14_max_delayed', 'RSI_7_min_delayed', 'RSI_7_max_delayed', 'zero_crossing_short_CVD', 'zero_crossing_impulsive_CVD', 'zero_crossing_impulsive_CVD_derivate', 'zero_crossing_avg_short_CVD', 'H_shadow_mean_4', 'H_shadow_mean_7', 'H_shadow_mean_21', 'L_shadow_mean_4', 'L_shadow_mean_7', 'L_shadow_mean_21', 'high_delta_mean_4', 'high_shadow_mean_7', 'high_shadow_mean_21', 'low_shadow_mean_4', 'low_shadow_mean_7', 'low_shadow_mean_21', 'VWAP_mean_4_p_derivate', 'VWAP_mean_7_p_derivate', 'VWAP_mean_21_p_derivate', 'VWAP_mean_63_p_derivate', 'VWAP_mean_189_p_derivate', 'VWAP_mean_4_vs_7_delta', 'VWAP_mean_4_vs_21_delta', 'VWAP_mean_4_vs_63_delta', 'VWAP_mean_4_vs_189_delta', 'VWAP_mean_7_vs_21_delta', 'VWAP_mean_7_vs_63_delta', 'VWAP_mean_7_vs_189_delta', 'VWAP_mean_21_vs_63_delta', 'VWAP_mean_21_vs_189_delta', 'VWAP_mean_63_vs_189_delta', 'close_vs_VWAP_4', 'close_vs_VWAP_7', 'close_vs_VWAP_21', 'close_vs_VWAP_63', 'close_vs_VWAP_189', 'open_vs_VWAP_4', 'open_vs_VWAP_7', 'open_vs_VWAP_21', 'open_vs_VWAP_63', 'open_vs_VWAP_189', 'low_vs_VWAP_4', 'low_vs_VWAP_7', 'low_vs_VWAP_21', 'low_vs_VWAP_63', 'low_vs_VWAP_189', 'high_vs_VWAP_4', 'high_vs_VWAP_7', 'high_vs_VWAP_21', 'high_vs_VWAP_63', 'high_vs_VWAP_189', 'zigzag_trigger', 'CO/HL', 'CO_sign', 'HL_delta', 'CO_delta', 'CO_delta_ratio_-1', 'CO_delta_ratio_-2', 'CO_delta_ratio_-3', 'Bar_relative_POC_price', 'Bar_relative_VWAP_price', 'US_central_time_seconds', 'cumulative_volume_mean_delta', 'HL_rolling_fw_mean']].to_numpy()[start:end]
return prices, signal_features
class MyForexEnv(StocksEnv):
_process_data = my_process_data
env = MyForexEnv(df=df_features, window_size=512, frame_bound=(512, len(df_features)))
[....]
from tf_agents.environments import utils
utils.validate_py_environment(env, episodes=5)
env = tf_py_environment.TFPyEnvironment(env)
returns this error
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[154], line 3
1 features_n = df_features.shape[1]-4
2 from tf_agents.environments import utils
----> 3 utils.validate_py_environment(env, episodes=5)
5 env = tf_py_environment.TFPyEnvironment(env)
7 q_net = q_network.QNetwork(env.observation_spec(),
8 env.action_spec(),
9 fc_layer_params = (features_n, features_n*10, features_n/8),
10 activation_fn = activation_function,
11 dropout_layer_params = None)
File ~\AppData\Roaming\Python\Python310\site-packages\tf_agents\environments\utils.py:58, in validate_py_environment(environment, episodes, observation_and_action_constraint_splitter)
52 def validate_py_environment(
53 environment: py_environment.PyEnvironment,
54 episodes: int = 5,
55 observation_and_action_constraint_splitter: Optional[
56 types.Splitter] = None):
57 """Validates the environment follows the defined specs."""
---> 58 time_step_spec = environment.time_step_spec()
59 action_spec = environment.action_spec()
61 random_policy = random_py_policy.RandomPyPolicy(
62 time_step_spec=time_step_spec,
63 action_spec=action_spec,
64 observation_and_action_constraint_splitter=(
65 observation_and_action_constraint_splitter))
AttributeError: 'MyForexEnv' object has no attribute 'time_step_spec
Is it possible to use anytrading environment with TF-Agents?
Hi @fede72bari,
Try this please:
gym_env = gym.make('stocks-v0')
env = tf_agents.environments.gym_wrapper.GymWrapper(gym_env)