Skip to content
Snippets Groups Projects
Commit 4d5624ef authored by spiglerg's avatar spiglerg
Browse files

added speeds in generators

parent 5bf451eb
No related branches found
No related tags found
No related merge requests found
......@@ -47,12 +47,14 @@ class EnvAgentStatic(object):
self.speed_data = speed_data
@classmethod
def from_lists(cls, positions, directions, targets):
def from_lists(cls, positions, directions, targets, speeds=None):
""" Create a list of EnvAgentStatics from lists of positions, directions and targets
"""
speed_datas = []
for i in range(len(positions)):
speed_datas.append({'position_fraction': 0.0, 'speed': 1.0, 'transition_action_on_cellexit': 0})
speed_datas.append({'position_fraction': 0.0,
'speed': speeds[i] if speeds is not None else 1.0,
'transition_action_on_cellexit': 0})
return list(starmap(EnvAgentStatic, zip(positions, directions, targets, [False] * len(positions), speed_datas)))
def to_list(self):
......
......@@ -18,7 +18,7 @@ def empty_rail_generator():
rail_array = grid_map.grid
rail_array.fill(0)
return grid_map, [], [], []
return grid_map, [], [], [], []
return generator
......@@ -139,7 +139,7 @@ def complex_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist=
agents_target = [sg[1] for sg in start_goal[:num_agents]]
agents_direction = start_dir[:num_agents]
return grid_map, agents_position, agents_direction, agents_target
return grid_map, agents_position, agents_direction, agents_target, [1.0]*len(agents_position)
return generator
......@@ -183,7 +183,7 @@ def rail_from_manual_specifications_generator(rail_spec):
rail,
num_agents)
return rail, agents_position, agents_direction, agents_target
return rail, agents_position, agents_direction, agents_target, [1.0]*len(agents_position)
return generator
......@@ -209,7 +209,7 @@ def rail_from_GridTransitionMap_generator(rail_map):
rail_map,
num_agents)
return rail_map, agents_position, agents_direction, agents_target
return rail_map, agents_position, agents_direction, agents_target, [1.0]*len(agents_position)
return generator
......@@ -482,6 +482,6 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11):
return_rail,
num_agents)
return return_rail, agents_position, agents_direction, agents_target
return return_rail, agents_position, agents_direction, agents_target, [1.0]*len(agents_position)
return generator
......@@ -151,7 +151,7 @@ class RailEnv(Environment):
self.rail = tRailAgents[0]
if replace_agents:
self.agents_static = EnvAgentStatic.from_lists(*tRailAgents[1:4])
self.agents_static = EnvAgentStatic.from_lists(*tRailAgents[1:5])
self.restart_agents()
......@@ -191,7 +191,7 @@ class RailEnv(Environment):
# for i in range(len(self.agents_handles)):
for iAgent in range(self.get_num_agents()):
agent = self.agents[iAgent]
agent.speed_data['speed']=0.5
print(agent.speed_data['speed'])
if self.dones[iAgent]: # this agent has already completed...
continue
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment