Skip to content
Snippets Groups Projects
Commit 4d5624ef authored by spiglerg's avatar spiglerg
Browse files

added speeds in generators

parent 5bf451eb
No related branches found
No related tags found
No related merge requests found
...@@ -47,12 +47,14 @@ class EnvAgentStatic(object): ...@@ -47,12 +47,14 @@ class EnvAgentStatic(object):
self.speed_data = speed_data self.speed_data = speed_data
@classmethod @classmethod
def from_lists(cls, positions, directions, targets): def from_lists(cls, positions, directions, targets, speeds=None):
""" Create a list of EnvAgentStatics from lists of positions, directions and targets """ Create a list of EnvAgentStatics from lists of positions, directions and targets
""" """
speed_datas = [] speed_datas = []
for i in range(len(positions)): for i in range(len(positions)):
speed_datas.append({'position_fraction': 0.0, 'speed': 1.0, 'transition_action_on_cellexit': 0}) speed_datas.append({'position_fraction': 0.0,
'speed': speeds[i] if speeds is not None else 1.0,
'transition_action_on_cellexit': 0})
return list(starmap(EnvAgentStatic, zip(positions, directions, targets, [False] * len(positions), speed_datas))) return list(starmap(EnvAgentStatic, zip(positions, directions, targets, [False] * len(positions), speed_datas)))
def to_list(self): def to_list(self):
......
...@@ -18,7 +18,7 @@ def empty_rail_generator(): ...@@ -18,7 +18,7 @@ def empty_rail_generator():
rail_array = grid_map.grid rail_array = grid_map.grid
rail_array.fill(0) rail_array.fill(0)
return grid_map, [], [], [] return grid_map, [], [], [], []
return generator return generator
...@@ -139,7 +139,7 @@ def complex_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist= ...@@ -139,7 +139,7 @@ def complex_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist=
agents_target = [sg[1] for sg in start_goal[:num_agents]] agents_target = [sg[1] for sg in start_goal[:num_agents]]
agents_direction = start_dir[:num_agents] agents_direction = start_dir[:num_agents]
return grid_map, agents_position, agents_direction, agents_target return grid_map, agents_position, agents_direction, agents_target, [1.0]*len(agents_position)
return generator return generator
...@@ -183,7 +183,7 @@ def rail_from_manual_specifications_generator(rail_spec): ...@@ -183,7 +183,7 @@ def rail_from_manual_specifications_generator(rail_spec):
rail, rail,
num_agents) num_agents)
return rail, agents_position, agents_direction, agents_target return rail, agents_position, agents_direction, agents_target, [1.0]*len(agents_position)
return generator return generator
...@@ -209,7 +209,7 @@ def rail_from_GridTransitionMap_generator(rail_map): ...@@ -209,7 +209,7 @@ def rail_from_GridTransitionMap_generator(rail_map):
rail_map, rail_map,
num_agents) num_agents)
return rail_map, agents_position, agents_direction, agents_target return rail_map, agents_position, agents_direction, agents_target, [1.0]*len(agents_position)
return generator return generator
...@@ -482,6 +482,6 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11): ...@@ -482,6 +482,6 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11):
return_rail, return_rail,
num_agents) num_agents)
return return_rail, agents_position, agents_direction, agents_target return return_rail, agents_position, agents_direction, agents_target, [1.0]*len(agents_position)
return generator return generator
...@@ -151,7 +151,7 @@ class RailEnv(Environment): ...@@ -151,7 +151,7 @@ class RailEnv(Environment):
self.rail = tRailAgents[0] self.rail = tRailAgents[0]
if replace_agents: if replace_agents:
self.agents_static = EnvAgentStatic.from_lists(*tRailAgents[1:4]) self.agents_static = EnvAgentStatic.from_lists(*tRailAgents[1:5])
self.restart_agents() self.restart_agents()
...@@ -191,7 +191,7 @@ class RailEnv(Environment): ...@@ -191,7 +191,7 @@ class RailEnv(Environment):
# for i in range(len(self.agents_handles)): # for i in range(len(self.agents_handles)):
for iAgent in range(self.get_num_agents()): for iAgent in range(self.get_num_agents()):
agent = self.agents[iAgent] agent = self.agents[iAgent]
agent.speed_data['speed']=0.5 print(agent.speed_data['speed'])
if self.dones[iAgent]: # this agent has already completed... if self.dones[iAgent]: # this agent has already completed...
continue continue
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment