Commit a165ac00 authored by Erik Nygren's avatar Erik Nygren 🚅
Browse files

fixed first tests in malfunction test

parent ff18baa1
Pipeline #2397 failed with stages
in 14 minutes and 26 seconds
......@@ -37,7 +37,7 @@ env = RailEnv(width=100,
seed=14, # Random seed
grid_mode=False,
max_rails_between_cities=2,
max_rails_in_city=6,
max_rails_in_city=8,
),
schedule_generator=sparse_schedule_generator(speed_ration_map),
number_of_agents=100,
......
......@@ -55,6 +55,7 @@ class DistanceMap:
self.env_width = rail.width
def _compute(self, agents: List[EnvAgent], rail: GridTransitionMap):
print("computing distance map")
self.agents_previous_computation = self.agents
self.distance_map = np.inf * np.ones(shape=(len(agents),
self.env_height,
......
......@@ -9,7 +9,6 @@ from typing import List, NamedTuple, Optional, Dict
import msgpack
import msgpack_numpy as m
import numpy as np
from gym.utils import seeding
from flatland.core.env import Environment
......@@ -187,7 +186,7 @@ class RailEnv(Environment):
self.distance_map = DistanceMap(self.agents, self.height, self.width)
self.action_space = [1]
# Stochastic train malfunctioning parameters
if stochastic_data is not None:
prop_malfunction = stochastic_data['prop_malfunction']
......@@ -466,7 +465,7 @@ class RailEnv(Environment):
return
# Is the agent at the beginning of the cell? Then, it can take an action.
# As long as the agent is malfunctioning or stopped at the beginning of the cell,
# As long as the agent is malfunctioning or stopped at the beginning of the cell,
# different actions may be taken!
if agent.speed_data['position_fraction'] == 0.0:
# No action has been supplied for this agent -> set DO_NOTHING as default
......
......@@ -70,14 +70,6 @@ def test_malfunction_process():
'malfunction_rate': 1000,
'min_duration': 3,
'max_duration': 3}
# random.seed(0)
# np.random.seed(0)
stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents
'malfunction_rate': 70, # Rate of malfunction occurence
'min_duration': 2, # Minimal duration of malfunction
'max_duration': 5 # Max duration of malfunction
}
rail, rail_map = make_simple_rail2()
......@@ -90,8 +82,8 @@ def test_malfunction_process():
obs_builder_object=SingleAgentNavigationObs()
)
# reset to initialize agents_static
obs = env.reset(False, False, True, random_seed=0)
obs, info = env.reset(False, False, True, random_seed=0)
print(env.agents[0].malfunction_data)
# Check that a initial duration for malfunction was assigned
assert env.agents[0].malfunction_data['next_malfunction'] > 0
for agent in env.agents:
......@@ -100,6 +92,9 @@ def test_malfunction_process():
agent_halts = 0
total_down_time = 0
agent_old_position = env.agents[0].position
# Move target to unreachable position in order to not interfere with test
env.agents[0].target = (0, 0)
for step in range(100):
actions = {}
......@@ -157,6 +152,7 @@ def test_malfunction_process_statistically():
)
# reset to initialize agents_static
env.reset(False, False, False, random_seed=0)
env.agents[0].target = (0, 0)
nb_malfunction = 0
for step in range(20):
......@@ -166,7 +162,6 @@ def test_malfunction_process_statistically():
action_dict[agent.handle] = RailEnvActions(np.random.randint(4))
env.step(action_dict)
# check that generation of malfunctions works as expected
assert env.agents[0].malfunction_data["nr_malfunctions"] == 4
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment