How to obtain a single prediction from trained model?
Karlheinzniebuhr opened this issue · comments
I'm trying to obtain a single prediction from a trained model for production.
# Load the vecnormalize object
vec_normalize = VecNormalize.load("vecnormalize.pkl", envs)
# Load the agent
loaded_model = A2C.load("model_multi_vec.zip")
# Pass vecnormalize object to the loaded agent
loaded_model.set_env(vec_normalize)
# # Get the window of data to be used as an observation
single_obs = gymdf.iloc[-10:].values
# # Normalize the observation using the vecnormalize object
single_obs = envs.normalize_obs(single_obs)
# Perform the prediction
action, _states = loaded_model.predict(single_obs)
action, _states
But I'm getting this error
vec_normalize.py:212, in VecNormalize.normalize_obs(self, obs)
210 obs_[key] = self._normalize_obs(obs[key], self.obs_rms[key]).astype(np.float32)
211 else:
--> 212 obs_ = self._normalize_obs(obs, self.obs_rms).astype(np.float32)
213 return obs_
vec_normalize.py:188, in VecNormalize._normalize_obs(self, obs, obs_rms)
181 def _normalize_obs(self, obs: np.ndarray, obs_rms: RunningMeanStd) -> np.ndarray:
182 """
183 Helper to normalize observation.
184 :param obs:
185 :param obs_rms: associated statistics
186 :return: normalized observation
187 """
--> 188 return np.clip((obs - obs_rms.mean) / np.sqrt(obs_rms.var + self.epsilon), -self.clip_obs, self.clip_obs)
ValueError: operands could not be broadcast together with shapes (10,6) (10,2)
So reading the source code I noticed that it doesn't actually use OHLCV, only Close by default. So needed to implement the other ones like this to get it working. Now it makes sense the shape values differed
def normalize_data(env):
start = env.frame_bound[0] - env.window_size
end = env.frame_bound[1]
prices = env.df.loc[:, 'Close'].to_numpy()[start:end]
signal_features = env.df.loc[:, ['Open', 'High', 'Low', 'Close']].to_numpy()[start:end]
diff = np.insert(np.diff(prices), 0, 0)
signal_features = np.column_stack((signal_features, diff))
return prices, signal_features
class MyForexEnv(ForexEnv):
_process_data = normalize_data