diff --git a/flatland/evaluators/client.py b/flatland/evaluators/client.py index 9dc1e587699c7e4e18523734bc831f068d504b9d..d96e8c5e42aa4c5808bf52809df9686c0fd8c7be 100644 --- a/flatland/evaluators/client.py +++ b/flatland/evaluators/client.py @@ -225,13 +225,15 @@ class FlatlandRemoteClient(object): obs_builder_object=obs_builder_object ) + time_start = time.time() local_observation, info = self.env.reset( regenerate_rail=False, regenerate_schedule=False, activate_agents=False, random_seed=random_seed ) - + time_diff = time.time() - time_start + self.update_running_mean_stats("internal_env_reset_time", time_diff) # Use the local observation # as the remote server uses a dummy observation builder return local_observation, info @@ -296,8 +298,7 @@ if __name__ == "__main__": _action[_idx] = np.random.randint(0, 5) return _action - my_observation_builder = TreeObsForRailEnv(max_depth=3, - predictor=ShortestPathPredictorForRailEnv()) + my_observation_builder = DummyObservationBuilder() episode = 0 obs = True @@ -318,7 +319,10 @@ if __name__ == "__main__": while True: action = my_controller(obs, remote_client.env) + time_start = time.time() observation, all_rewards, done, info = remote_client.env_step(action) + time_diff = time.time() - time_start + print("Step Time : ", time_diff) if done['__all__']: print("Current Episode : ", episode) print("Episode Done")