Skip to content
Snippets Groups Projects
Commit ca6b0fe8 authored by Erik Nygren's avatar Erik Nygren
Browse files

implemented performance enhancements by adrian egli

parent d2a49fc5
No related branches found
No related tags found
No related merge requests found
......@@ -25,17 +25,9 @@ def get_direction(pos1: IntVector2D, pos2: IntVector2D) -> Grid4TransitionsEnum:
def mirror(dir):
return (dir + 2) % 4
MOVEMENT_ARRAY = [(-1, 0), (0, 1), (1, 0), (0, -1)]
def get_new_position(position, movement):
""" Utility function that converts a compass movement over a 2D grid to new positions (r, c). """
if movement == Grid4TransitionsEnum.NORTH:
return (position[0] - 1, position[1])
elif movement == Grid4TransitionsEnum.EAST:
return (position[0], position[1] + 1)
elif movement == Grid4TransitionsEnum.SOUTH:
return (position[0] + 1, position[1])
elif movement == Grid4TransitionsEnum.WEST:
return (position[0], position[1] - 1)
return (position[0] + MOVEMENT_ARRAY[movement][0], position[1] + MOVEMENT_ARRAY[movement][1])
def direction_to_point(pos1: IntVector2D, pos2: IntVector2D) -> Grid4TransitionsEnum:
......
......@@ -35,8 +35,6 @@ from flatland.envs import persistence
from flatland.envs.observations import GlobalObsForRailEnv
import pickle
m.patch()
......@@ -130,11 +128,11 @@ class RailEnv(Environment):
def __init__(self,
width,
height,
rail_generator = None,
schedule_generator = None, # : sched_gen.ScheduleGenerator = sched_gen.random_schedule_generator(),
rail_generator=None,
schedule_generator=None, # : sched_gen.ScheduleGenerator = sched_gen.random_schedule_generator(),
number_of_agents=1,
obs_builder_object: ObservationBuilder = GlobalObsForRailEnv(),
malfunction_generator_and_process_data=None, #mal_gen.no_malfunction_generator(),
malfunction_generator_and_process_data=None, # mal_gen.no_malfunction_generator(),
remove_agents_at_target=True,
random_seed=1,
record_steps=False
......@@ -179,11 +177,11 @@ class RailEnv(Environment):
if malfunction_generator_and_process_data is None:
malfunction_generator_and_process_data = mal_gen.no_malfunction_generator()
self.malfunction_generator, self.malfunction_process_data = malfunction_generator_and_process_data
#self.rail_generator: RailGenerator = rail_generator
# self.rail_generator: RailGenerator = rail_generator
if rail_generator is None:
rail_generator = rail_gen.random_rail_generator()
self.rail_generator = rail_generator
#self.schedule_generator: ScheduleGenerator = schedule_generator
# self.schedule_generator: ScheduleGenerator = schedule_generator
if schedule_generator is None:
schedule_generator = sched_gen.random_schedule_generator()
self.schedule_generator = schedule_generator
......@@ -230,8 +228,8 @@ class RailEnv(Environment):
# save episode timesteps ie agent positions, orientations. (not yet actions / observations)
self.record_steps = record_steps # whether to save timesteps
# save timesteps in here: [[[row, col, dir, malfunction],...nAgents], ...nSteps]
self.cur_episode = []
self.list_actions = [] # save actions in here
self.cur_episode = []
self.list_actions = [] # save actions in here
def _seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
......@@ -264,8 +262,6 @@ class RailEnv(Environment):
agent.reset()
self.active_agents = [i for i in range(len(self.agents))]
def action_required(self, agent):
"""
Check if an agent needs to provide an action
......@@ -281,8 +277,8 @@ class RailEnv(Environment):
False: Agent cannot provide an action
"""
return (agent.status == RailAgentStatus.READY_TO_DEPART or (
agent.status == RailAgentStatus.ACTIVE and np.isclose(agent.speed_data['position_fraction'], 0.0,
rtol=1e-03)))
agent.status == RailAgentStatus.ACTIVE and self.my_isclose(agent.speed_data['position_fraction'], 0.0,
rtol=1e-03)))
def reset(self, regenerate_rail: bool = True, regenerate_schedule: bool = True, activate_agents: bool = False,
random_seed: bool = None) -> (Dict, Dict):
......@@ -329,8 +325,6 @@ class RailEnv(Environment):
if optionals and 'distance_map' in optionals:
self.distance_map.set(optionals['distance_map'])
if regenerate_schedule or regenerate_rail or self.get_num_agents() == 0:
agents_hints = None
if optionals and 'agents_hints' in optionals:
......@@ -548,7 +542,7 @@ class RailEnv(Environment):
# Is the agent at the beginning of the cell? Then, it can take an action.
# As long as the agent is malfunctioning or stopped at the beginning of the cell,
# different actions may be taken!
if np.isclose(agent.speed_data['position_fraction'], 0.0, rtol=1e-03):
if self.my_isclose(agent.speed_data['position_fraction'], 0.0, rtol=1e-03):
# No action has been supplied for this agent -> set DO_NOTHING as default
if action is None:
action = RailEnvActions.DO_NOTHING
......@@ -608,8 +602,8 @@ class RailEnv(Environment):
# transition_action_on_cellexit if the cell is free.
if agent.moving:
agent.speed_data['position_fraction'] += agent.speed_data['speed']
if agent.speed_data['position_fraction'] > 1.0 or np.isclose(agent.speed_data['position_fraction'], 1.0,
rtol=1e-03):
if agent.speed_data['position_fraction'] > 1.0 or self.my_isclose(agent.speed_data['position_fraction'], 1.0,
rtol=1e-03):
# Perform stored action to transition to the next cell as soon as cell is free
# Notice that we've already checked new_cell_valid and transition valid when we stored the action,
# so we only have to check cell_free now!
......@@ -740,7 +734,7 @@ class RailEnv(Environment):
pos = (int(agent.position[0]), int(agent.position[1]))
# print("pos:", pos, type(pos[0]))
list_agents_state.append(
[*pos, int(agent.direction), agent.malfunction_data["malfunction"] ])
[*pos, int(agent.direction), agent.malfunction_data["malfunction"]])
self.cur_episode.append(list_agents_state)
self.list_actions.append(dActions)
......@@ -809,7 +803,7 @@ class RailEnv(Environment):
------
Dict object
"""
#print(f"_get_obs - num agents: {self.get_num_agents()} {list(range(self.get_num_agents()))}")
# print(f"_get_obs - num agents: {self.get_num_agents()} {list(range(self.get_num_agents()))}")
self.obs_dict = self.obs_builder.get_many(list(range(self.get_num_agents())))
return self.obs_dict
......@@ -828,8 +822,6 @@ class RailEnv(Environment):
"""
return Grid4Transitions.get_entry_directions(self.rail.get_full_transitions(row, col))
def _exp_distirbution_synced(self, rate: float) -> float:
"""
Generates sample from exponential distribution
......@@ -859,4 +851,5 @@ class RailEnv(Environment):
print("deprecated call to env.save() - pls call RailEnvPersister.save()")
persistence.RailEnvPersister.save(self, filename)
def my_isclose(self, x, y, rtol=1.e-5, atol=1.e-8):
return abs(x - y) <= atol + rtol * abs(y)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment