Commit 2709f085 authored by nimishsantosh107's avatar nimishsantosh107
Browse files

evaluator updated to work with statemachine

parent 3e0b3e87
Pipeline #8517 failed with stages
in 6 minutes and 4 seconds
...@@ -12,7 +12,7 @@ import numpy as np ...@@ -12,7 +12,7 @@ import numpy as np
import redis import redis
import flatland import flatland
from flatland.envs.malfunction_generators import malfunction_from_file from flatland.envs.malfunction_generators import FileMalfunctionGen
from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import rail_from_file from flatland.envs.rail_generators import rail_from_file
from flatland.envs.line_generators import line_from_file from flatland.envs.line_generators import line_from_file
...@@ -267,7 +267,7 @@ class FlatlandRemoteClient(object): ...@@ -267,7 +267,7 @@ class FlatlandRemoteClient(object):
self.current_env_path = test_env_file_path self.current_env_path = test_env_file_path
self.env = RailEnv(width=1, height=1, rail_generator=rail_from_file(test_env_file_path), self.env = RailEnv(width=1, height=1, rail_generator=rail_from_file(test_env_file_path),
line_generator=line_from_file(test_env_file_path), line_generator=line_from_file(test_env_file_path),
malfunction_generator_and_process_data=malfunction_from_file(test_env_file_path), malfunction_generator=FileMalfunctionGen(test_env_file_path),
obs_builder_object=obs_builder_object) obs_builder_object=obs_builder_object)
time_start = time.time() time_start = time.time()
...@@ -276,7 +276,6 @@ class FlatlandRemoteClient(object): ...@@ -276,7 +276,6 @@ class FlatlandRemoteClient(object):
local_observation, info = self.env.reset( local_observation, info = self.env.reset(
regenerate_rail=True, regenerate_rail=True,
regenerate_schedule=True, regenerate_schedule=True,
activate_agents=False,
random_seed=random_seed random_seed=random_seed
) )
time_diff = time.time() - time_start time_diff = time.time() - time_start
......
...@@ -21,16 +21,10 @@ import redis ...@@ -21,16 +21,10 @@ import redis
import timeout_decorator import timeout_decorator
import flatland import flatland
from flatland.core.env_observation_builder import DummyObservationBuilder
from flatland.envs.step_utils.states import TrainState from flatland.envs.step_utils.states import TrainState
from flatland.envs.malfunction_generators import malfunction_from_file
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import rail_from_file
from flatland.envs.line_generators import line_from_file
from flatland.evaluators import aicrowd_helpers from flatland.evaluators import aicrowd_helpers
from flatland.evaluators import messages from flatland.evaluators import messages
from flatland.utils.rendertools import RenderTool from flatland.utils.rendertools import RenderTool
from flatland.envs.rail_env_utils import load_flatland_environment_from_file
from flatland.envs.persistence import RailEnvPersister from flatland.envs.persistence import RailEnvPersister
use_signals_in_timeout = True use_signals_in_timeout = True
...@@ -755,7 +749,6 @@ class FlatlandRemoteEvaluationService: ...@@ -755,7 +749,6 @@ class FlatlandRemoteEvaluationService:
_observation, _info = self.env.reset( _observation, _info = self.env.reset(
regenerate_rail=True, regenerate_rail=True,
regenerate_schedule=True, regenerate_schedule=True,
activate_agents=False,
random_seed=RANDOM_SEED random_seed=RANDOM_SEED
) )
...@@ -1131,7 +1124,7 @@ class FlatlandRemoteEvaluationService: ...@@ -1131,7 +1124,7 @@ class FlatlandRemoteEvaluationService:
# and then we compute the mean across each of the test_id groups # and then we compute the mean across each of the test_id groups
################################################################################# #################################################################################
################################################################################# #################################################################################
source_df = self.evaluation_metadata_df.dropna() source_df = self.evaluation_metadata_df
# grouped_df = source_df.groupby(['test_id']).mean() # grouped_df = source_df.groupby(['test_id']).mean()
mean_reward = source_df["reward"].mean() mean_reward = source_df["reward"].mean()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment