Skip to content
Snippets Groups Projects
Commit 38b49cf5 authored by Erik Nygren's avatar Erik Nygren :bullettrain_front: Committed by spmohanty
Browse files

fixed first tests in malfunction test

parent 81819368
No related branches found
No related tags found
No related merge requests found
...@@ -37,7 +37,11 @@ env = RailEnv(width=100, ...@@ -37,7 +37,11 @@ env = RailEnv(width=100,
seed=14, # Random seed seed=14, # Random seed
grid_mode=False, grid_mode=False,
max_rails_between_cities=2, max_rails_between_cities=2,
<<<<<<< HEAD
max_rails_in_city=13, max_rails_in_city=13,
=======
max_rails_in_city=8,
>>>>>>> fixed first tests in malfunction test
), ),
schedule_generator=sparse_schedule_generator(speed_ration_map), schedule_generator=sparse_schedule_generator(speed_ration_map),
number_of_agents=100, number_of_agents=100,
......
...@@ -55,6 +55,7 @@ class DistanceMap: ...@@ -55,6 +55,7 @@ class DistanceMap:
self.env_width = rail.width self.env_width = rail.width
def _compute(self, agents: List[EnvAgent], rail: GridTransitionMap): def _compute(self, agents: List[EnvAgent], rail: GridTransitionMap):
print("computing distance map")
self.agents_previous_computation = self.agents self.agents_previous_computation = self.agents
self.distance_map = np.inf * np.ones(shape=(len(agents), self.distance_map = np.inf * np.ones(shape=(len(agents),
self.env_height, self.env_height,
......
...@@ -82,7 +82,11 @@ def test_malfunction_process(): ...@@ -82,7 +82,11 @@ def test_malfunction_process():
obs_builder_object=SingleAgentNavigationObs() obs_builder_object=SingleAgentNavigationObs()
) )
# reset to initialize agents_static # reset to initialize agents_static
<<<<<<< HEAD
obs, info = env.reset(False, False, True, random_seed=10) obs, info = env.reset(False, False, True, random_seed=10)
=======
obs, info = env.reset(False, False, True, random_seed=0)
>>>>>>> fixed first tests in malfunction test
print(env.agents[0].malfunction_data) print(env.agents[0].malfunction_data)
# Check that a initial duration for malfunction was assigned # Check that a initial duration for malfunction was assigned
assert env.agents[0].malfunction_data['next_malfunction'] > 0 assert env.agents[0].malfunction_data['next_malfunction'] > 0
...@@ -151,7 +155,11 @@ def test_malfunction_process_statistically(): ...@@ -151,7 +155,11 @@ def test_malfunction_process_statistically():
obs_builder_object=SingleAgentNavigationObs() obs_builder_object=SingleAgentNavigationObs()
) )
# reset to initialize agents_static # reset to initialize agents_static
<<<<<<< HEAD
env.reset(True, True, False, random_seed=10) env.reset(True, True, False, random_seed=10)
=======
env.reset(False, False, False, random_seed=0)
>>>>>>> fixed first tests in malfunction test
env.agents[0].target = (0, 0) env.agents[0].target = (0, 0)
nb_malfunction = 0 nb_malfunction = 0
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment