Skip to content
Snippets Groups Projects
Commit 9ce6b221 authored by Erik Nygren's avatar Erik Nygren
Browse files

using new level generator for training and inference

parent b8e41dc1
No related branches found
No related tags found
No related merge requests found
......@@ -9,47 +9,59 @@ from predictors.predictions import ShortestPathPredictorForRailEnv
import torch_training.Nets
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import rail_from_file
from flatland.envs.schedule_generators import schedule_from_file
from flatland.envs.rail_generators import rail_from_file, sparse_rail_generator
from flatland.envs.schedule_generators import schedule_from_file, sparse_schedule_generator
from flatland.utils.rendertools import RenderTool
from torch_training.dueling_double_dqn import Agent
from utils.observation_utils import normalize_observation
random.seed(3)
np.random.seed(2)
tree_depth = 3
observation_helper = TreeObsForRailEnv(max_depth=tree_depth, predictor=ShortestPathPredictorForRailEnv(10))
file_name = "./railway/simple_avoid.pkl"
env = RailEnv(width=10,
height=20,
rail_generator=rail_from_file(file_name),
schedule_generator=schedule_from_file(file_name),
obs_builder_object=observation_helper)
x_dim = env.width
y_dim = env.height
"""
x_dim = 10 # np.random.randint(8, 20)
y_dim = 10 # np.random.randint(8, 20)
n_agents = 5 # np.random.randint(3, 8)
n_goals = n_agents + np.random.randint(0, 3)
min_dist = int(0.75 * min(x_dim, y_dim))
# Parameters for the Environment
x_dim = 20
y_dim = 20
n_agents = 5
tree_depth = 2
# Use a the malfunction generator to break agents from time to time
stochastic_data = {'prop_malfunction': 0.1, # Percentage of defective agents
'malfunction_rate': 30, # Rate of malfunction occurence
'min_duration': 3, # Minimal duration of malfunction
'max_duration': 20 # Max duration of malfunction
}
# Custom observation builder
predictor = ShortestPathPredictorForRailEnv()
observation_helper = TreeObsForRailEnv(max_depth=tree_depth, predictor=predictor)
# Different agent types (trains) with different speeds.
speed_ration_map = {1.: 0.25, # Fast passenger train
1. / 2.: 0.25, # Fast freight train
1. / 3.: 0.25, # Slow commuter train
1. / 4.: 0.25} # Slow freight train
env = RailEnv(width=x_dim,
height=y_dim,
rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=2, min_dist=min_dist,
max_dist=99999,
seed=0),
schedule_generator=complex_schedule_generator(),
obs_builder_object=observation_helper,
number_of_agents=n_agents)
rail_generator=sparse_rail_generator(num_cities=5,
# Number of cities in map (where train stations are)
num_intersections=4,
# Number of intersections (no start / target)
num_trainstations=10, # Number of possible start/targets on map
min_node_dist=3, # Minimal distance of nodes
node_radius=2, # Proximity of stations to city center
num_neighb=3,
# Number of connections to other cities/intersections
seed=15, # Random seed
grid_mode=True,
enhance_intersection=False
),
schedule_generator=sparse_schedule_generator(speed_ration_map),
number_of_agents=n_agents,
stochastic_data=stochastic_data, # Malfunction data generator
obs_builder_object=observation_helper)
env.reset(True, True)
"""
env_renderer = RenderTool(env, gl="PILSVG", )
handle = env.get_agent_handles()
num_features_per_node = env.obs_builder.observation_dim
......
......@@ -14,9 +14,9 @@ import torch_training.Nets
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import complex_rail_generator
from flatland.envs.rail_generators import sparse_rail_generator
# Import Flatland/ Observations and Predictors
from flatland.envs.schedule_generators import complex_schedule_generator
from flatland.envs.schedule_generators import sparse_schedule_generator
from torch_training.dueling_double_dqn import Agent
from utils.observation_utils import normalize_observation
......@@ -36,30 +36,55 @@ def main(argv):
np.random.seed(1)
# Initialize a random map with a random number of agents
x_dim = np.random.randint(8, 15)
y_dim = np.random.randint(8, 15)
n_agents = np.random.randint(3, 8)
n_goals = n_agents + np.random.randint(0, 3)
min_dist = int(0.75 * min(x_dim, y_dim))
tree_depth = 3
print("main2")
"""
Get an observation builder and predictor:
The predictor will always predict the shortest path from the current location of the agent.
This is used to warn for potential conflicts --> Should be enhanced to get better performance!
"""
# Parameters for the Environment
x_dim = 20
y_dim = 20
n_agents = 5
tree_depth = 2
# Use a the malfunction generator to break agents from time to time
stochastic_data = {'prop_malfunction': 0.1, # Percentage of defective agents
'malfunction_rate': 30, # Rate of malfunction occurence
'min_duration': 3, # Minimal duration of malfunction
'max_duration': 20 # Max duration of malfunction
}
# Custom observation builder
predictor = ShortestPathPredictorForRailEnv()
observation_helper = TreeObsForRailEnv(max_depth=tree_depth, predictor=predictor)
# Different agent types (trains) with different speeds.
speed_ration_map = {1.: 0.25, # Fast passenger train
1. / 2.: 0.25, # Fast freight train
1. / 3.: 0.25, # Slow commuter train
1. / 4.: 0.25} # Slow freight train
env = RailEnv(width=x_dim,
height=y_dim,
rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=2, min_dist=min_dist,
max_dist=99999,
seed=0),
schedule_generator=complex_schedule_generator(),
obs_builder_object=observation_helper,
number_of_agents=n_agents)
rail_generator=sparse_rail_generator(num_cities=5,
# Number of cities in map (where train stations are)
num_intersections=4,
# Number of intersections (no start / target)
num_trainstations=10, # Number of possible start/targets on map
min_node_dist=3, # Minimal distance of nodes
node_radius=2, # Proximity of stations to city center
num_neighb=3,
# Number of connections to other cities/intersections
seed=15, # Random seed
grid_mode=True,
enhance_intersection=False
),
schedule_generator=sparse_schedule_generator(speed_ration_map),
number_of_agents=n_agents,
stochastic_data=stochastic_data, # Malfunction data generator
obs_builder_object=observation_helper)
env.reset(True, True)
handle = env.get_agent_handles()
......@@ -105,19 +130,26 @@ def main(argv):
and the size of the levels every 50 episodes.
"""
if episodes % 50 == 1:
x_dim = np.random.randint(8, 15)
y_dim = np.random.randint(8, 15)
n_agents = np.random.randint(3, 8)
n_goals = n_agents + np.random.randint(0, 3)
min_dist = int(0.75 * min(x_dim, y_dim))
env = RailEnv(width=x_dim,
height=y_dim,
rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=2, min_dist=min_dist,
max_dist=99999,
seed=0),
schedule_generator=complex_schedule_generator(),
obs_builder_object=observation_helper,
number_of_agents=n_agents)
rail_generator=sparse_rail_generator(num_cities=5,
# Number of cities in map (where train stations are)
num_intersections=4,
# Number of intersections (no start / target)
num_trainstations=10,
# Number of possible start/targets on map
min_node_dist=3, # Minimal distance of nodes
node_radius=2, # Proximity of stations to city center
num_neighb=3,
# Number of connections to other cities/intersections
seed=15, # Random seed
grid_mode=True,
enhance_intersection=False
),
schedule_generator=sparse_schedule_generator(speed_ration_map),
number_of_agents=n_agents,
stochastic_data=stochastic_data, # Malfunction data generator
obs_builder_object=observation_helper)
# Adjust the parameters according to the new env.
max_steps = int((env.height + env.width))
......
......@@ -8,6 +8,76 @@ from flatland.core.env_prediction_builder import PredictionBuilder
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.rail_env import RailEnvActions
class DummyPredictorForRailEnv(PredictionBuilder):
"""
DummyPredictorForRailEnv object.
This object returns predictions for agents in the RailEnv environment.
The prediction acts as if no other agent is in the environment and always takes the forward action.
"""
def get(self, custom_args=None, handle=None):
"""
Called whenever get_many in the observation build is called.
Parameters
-------
custom_args: dict
Not used in this dummy implementation.
handle : int (optional)
Handle of the agent for which to compute the observation vector.
Returns
-------
np.array
Returns a dictionary indexed by the agent handle and for each agent a vector of (max_depth + 1)x5 elements:
- time_offset
- position axis 0
- position axis 1
- direction
- action taken to come here
The prediction at 0 is the current position, direction etc.
"""
agents = self.env.agents
if handle:
agents = [self.env.agents[handle]]
prediction_dict = {}
for agent in agents:
action_priorities = [RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT]
_agent_initial_position = agent.position
_agent_initial_direction = agent.direction
prediction = np.zeros(shape=(self.max_depth + 1, 5))
prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0]
for index in range(1, self.max_depth + 1):
action_done = False
# if we're at the target, stop moving...
if agent.position == agent.target:
prediction[index] = [index, *agent.target, agent.direction, RailEnvActions.STOP_MOVING]
continue
for action in action_priorities:
cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \
self.env._check_action_on_agent(action, agent)
if all([new_cell_isValid, transition_isValid]):
# move and change direction to face the new_direction that was
# performed
agent.position = new_position
agent.direction = new_direction
prediction[index] = [index, *new_position, new_direction, action]
action_done = True
break
if not action_done:
raise Exception("Cannot move further. Something is wrong")
prediction_dict[agent.handle] = prediction
agent.position = _agent_initial_position
agent.direction = _agent_initial_direction
return prediction_dict
class ShortestPathPredictorForRailEnv(PredictionBuilder):
"""
ShortestPathPredictorForRailEnv object.
......@@ -16,7 +86,8 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
The prediction acts as if no other agent is in the environment and always takes the forward action.
"""
def __init__(self, max_depth):
def __init__(self, max_depth=20):
# Initialize with depth 20
self.max_depth = max_depth
def get(self, custom_args=None, handle=None):
......@@ -53,10 +124,13 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
for agent in agents:
_agent_initial_position = agent.position
_agent_initial_direction = agent.direction
agent_speed = agent.speed_data["speed"]
times_per_cell = int(np.reciprocal(agent_speed))
prediction = np.zeros(shape=(self.max_depth + 1, 5))
prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0]
new_direction = _agent_initial_direction
new_position = _agent_initial_position
visited = set()
for index in range(1, self.max_depth + 1):
# if we're at the target, stop moving...
if agent.position == agent.target:
......@@ -70,12 +144,10 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
# Take shortest possible path
cell_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
new_position = None
new_direction = None
if np.sum(cell_transitions) == 1:
if np.sum(cell_transitions) == 1 and index % times_per_cell == 0:
new_direction = np.argmax(cell_transitions)
new_position = get_new_position(agent.position, new_direction)
elif np.sum(cell_transitions) > 1:
elif np.sum(cell_transitions) > 1 and index % times_per_cell == 0:
min_dist = np.inf
no_dist_found = True
for direction in range(4):
......@@ -87,7 +159,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
new_direction = direction
no_dist_found = False
new_position = get_new_position(agent.position, new_direction)
else:
elif index % times_per_cell == 0:
raise Exception("No transition possible {}".format(cell_transitions))
# update the agent's position and direction
......
......@@ -36,9 +36,6 @@ def main(argv):
n_goals = 5
min_dist = 5
# We are training an Agent using the Tree Observation with depth 2
observation_builder = TreeObsForRailEnv(max_depth=2)
# Use a the malfunction generator to break agents from time to time
stochastic_data = {'prop_malfunction': 0.1, # Percentage of defective agents
'malfunction_rate': 30, # Rate of malfunction occurence
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment