From 4d5624ef47b4377d686b63a2fa8b6802f0b9e00e Mon Sep 17 00:00:00 2001 From: Giacomo Spigler <spiglerg@gmail.com> Date: Wed, 19 Jun 2019 18:50:47 +0200 Subject: [PATCH] added speeds in generators --- flatland/envs/agent_utils.py | 6 ++++-- flatland/envs/generators.py | 10 +++++----- flatland/envs/rail_env.py | 4 ++-- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py index 2d07eee..aa46aec 100644 --- a/flatland/envs/agent_utils.py +++ b/flatland/envs/agent_utils.py @@ -47,12 +47,14 @@ class EnvAgentStatic(object): self.speed_data = speed_data @classmethod - def from_lists(cls, positions, directions, targets): + def from_lists(cls, positions, directions, targets, speeds=None): """ Create a list of EnvAgentStatics from lists of positions, directions and targets """ speed_datas = [] for i in range(len(positions)): - speed_datas.append({'position_fraction': 0.0, 'speed': 1.0, 'transition_action_on_cellexit': 0}) + speed_datas.append({'position_fraction': 0.0, + 'speed': speeds[i] if speeds is not None else 1.0, + 'transition_action_on_cellexit': 0}) return list(starmap(EnvAgentStatic, zip(positions, directions, targets, [False] * len(positions), speed_datas))) def to_list(self): diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py index f644bc1..085d6fd 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/generators.py @@ -18,7 +18,7 @@ def empty_rail_generator(): rail_array = grid_map.grid rail_array.fill(0) - return grid_map, [], [], [] + return grid_map, [], [], [], [] return generator @@ -139,7 +139,7 @@ def complex_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist= agents_target = [sg[1] for sg in start_goal[:num_agents]] agents_direction = start_dir[:num_agents] - return grid_map, agents_position, agents_direction, agents_target + return grid_map, agents_position, agents_direction, agents_target, [1.0]*len(agents_position) return generator @@ -183,7 +183,7 @@ def rail_from_manual_specifications_generator(rail_spec): rail, num_agents) - return rail, agents_position, agents_direction, agents_target + return rail, agents_position, agents_direction, agents_target, [1.0]*len(agents_position) return generator @@ -209,7 +209,7 @@ def rail_from_GridTransitionMap_generator(rail_map): rail_map, num_agents) - return rail_map, agents_position, agents_direction, agents_target + return rail_map, agents_position, agents_direction, agents_target, [1.0]*len(agents_position) return generator @@ -482,6 +482,6 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11): return_rail, num_agents) - return return_rail, agents_position, agents_direction, agents_target + return return_rail, agents_position, agents_direction, agents_target, [1.0]*len(agents_position) return generator diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 2621308..58df3a1 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -151,7 +151,7 @@ class RailEnv(Environment): self.rail = tRailAgents[0] if replace_agents: - self.agents_static = EnvAgentStatic.from_lists(*tRailAgents[1:4]) + self.agents_static = EnvAgentStatic.from_lists(*tRailAgents[1:5]) self.restart_agents() @@ -191,7 +191,7 @@ class RailEnv(Environment): # for i in range(len(self.agents_handles)): for iAgent in range(self.get_num_agents()): agent = self.agents[iAgent] - agent.speed_data['speed']=0.5 + print(agent.speed_data['speed']) if self.dones[iAgent]: # this agent has already completed... continue -- GitLab