Skip to content
Snippets Groups Projects
Commit 603d67f6 authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

fixed initial malfunction and tests

parent 6a3f528b
No related branches found
No related tags found
No related merge requests found
# In Flatland you can use custom observation builders and predicitors
# Observation builders generate the observation needed by the controller
# Preditctors can be used to do short time prediction which can help in avoiding conflicts in the network
import time
from flatland.envs.observations import GlobalObsForRailEnv
# First of all we import the Flatland rail environment
from flatland.envs.rail_env import RailEnv
......@@ -26,8 +28,8 @@ from flatland.utils.rendertools import RenderTool, AgentRenderVariant
# Here we use the sparse_rail_generator with the following parameters
width = 100 # With of map
height = 100 # Height of ap
nr_trains = 10 # Number of trains that have an assigned task in the env
height = 100 # Height of map
nr_trains = 50 # Number of trains that have an assigned task in the env
cities_in_map = 20 # Number of cities where agents can start or end
seed = 14 # Random seed
grid_distribution_of_cities = False # Type of city distribution, if False cities are randomly placed
......@@ -151,14 +153,14 @@ for agent_idx, agent in enumerate(env.agents):
# If multiple agents want to enter the same cell at the same time the lower index agent will enter first.
# Let's check if there are any agents with the same start location
agents_with_same_start = []
agents_with_same_start = set()
print("\n The following agents have the same initial position:")
print("=====================================================")
for agent_idx, agent in enumerate(env.agents):
for agent_2_idx, agent2 in enumerate(env.agents):
if agent_idx != agent_2_idx and agent.initial_position == agent2.initial_position:
print("Agent {} as the same initial position as agent {}".format(agent_idx, agent_2_idx))
agents_with_same_start.append(agent_idx)
agents_with_same_start.add(agent_idx)
# Lets try to enter with all of these agents at the same time
action_dict = dict()
......@@ -246,8 +248,11 @@ for step in range(500):
# Environment step which returns the observations for all agents, their corresponding
# reward and whether their are done
start_time = time.time()
next_obs, all_rewards, done, _ = env.step(action_dict)
env_renderer.render_env(show=True, show_observations=False, show_predictions=False)
end_time = time.time()
print(end_time - start_time)
# env_renderer.render_env(show=True, show_observations=False, show_predictions=False)
frame_step += 1
# Update replay buffer and train agent
for a in range(env.get_num_agents()):
......
......@@ -308,7 +308,9 @@ class RailEnv(Environment):
# A proportion of agent in the environment will receive a positive malfunction rate
if self.np_random.rand() < self.proportion_malfunctioning_trains:
agent.malfunction_data['malfunction_rate'] = self.mean_malfunction_rate
next_breakdown = int(
self._exp_distirbution_synced(rate=agent.malfunction_data['malfunction_rate']))
agent.malfunction_data['next_malfunction'] = next_breakdown
agent.malfunction_data['malfunction'] = 0
initial_malfunction = self._agent_malfunction(i_agent)
......@@ -346,7 +348,7 @@ class RailEnv(Environment):
"""
agent = self.agents[i_agent]
# Decrease counter for next event only if agent is currently not broken
# Decrease counter for next event only if agent is currently not broken and agent has a malfunction rate
if agent.malfunction_data['malfunction_rate'] >= 1 and agent.malfunction_data['next_malfunction'] > 0 and \
agent.malfunction_data['malfunction'] < 1:
agent.malfunction_data['next_malfunction'] -= 1
......
......@@ -126,7 +126,7 @@ def test_malfunction_process():
env.agents[0].malfunction_data['nr_malfunctions'])
# Check that 20 stops where performed
assert agent_halts == 20
assert agent_halts == 21
# Check that malfunctioning data was standing around
assert total_down_time > 0
......@@ -155,16 +155,16 @@ def test_malfunction_process_statistically():
env.agents[0].target = (0, 0)
nb_malfunction = 0
agent_malfunction_list = [[6, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0],
[6, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1],
[6, 6, 6, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3],
[6, 6, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 0],
[6, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 6],
[6, 6, 5, 4, 3, 2, 1, 0, 0, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 6],
[6, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2],
[6, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5],
[6, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 0],
[6, 6, 6, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3]]
agent_malfunction_list = [[0, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0],
[0, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 6, 5],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 0, 6, 5, 4],
[0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 6, 5, 4],
[6, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0],
[6, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0, 0, 6, 5, 4, 3],
[0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5],
[0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 0, 6, 5, 4, 3, 2, 1, 0],
[0, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1],
[6, 6, 6, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0]]
for step in range(20):
action_dict: Dict[int, RailEnvActions] = {}
......@@ -175,6 +175,7 @@ def test_malfunction_process_statistically():
# agent_malfunction_list[agent_idx].append(env.agents[agent_idx].malfunction_data['malfunction'])
assert env.agents[agent_idx].malfunction_data['malfunction'] == agent_malfunction_list[agent_idx][step]
env.step(action_dict)
# print(agent_malfunction_list)
def test_malfunction_before_entry():
......@@ -230,14 +231,13 @@ def test_malfunction_before_entry():
assert env.agents[8].malfunction_data['malfunction'] == 2
assert env.agents[9].malfunction_data['malfunction'] == 2
#for a in range(env.get_num_agents()):
# for a in range(env.get_num_agents()):
# print("assert env.agents[{}].malfunction_data['malfunction'] == {}".format(a,
# env.agents[a].malfunction_data[
# 'malfunction']))
def test_initial_malfunction():
stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents
'malfunction_rate': 100, # Rate of malfunction occurence
'min_duration': 2, # Minimal duration of malfunction
......@@ -410,7 +410,6 @@ def test_initial_malfunction_do_nothing():
rail, rail_map = make_simple_rail2()
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment