Skip to content
Snippets Groups Projects
Commit 1ff4d33b authored by Erik Nygren's avatar Erik Nygren
Browse files

included more of adrians performance updates

parent c741a0f1
No related branches found
No related tags found
No related merge requests found
......@@ -8,7 +8,6 @@ from typing import List, NamedTuple, Optional, Dict, Tuple
import msgpack_numpy as m
import numpy as np
from gym.utils import seeding
from flatland.core.env import Environment
from flatland.core.env_observation_builder import ObservationBuilder
......@@ -28,14 +27,43 @@ from flatland.envs import agent_chains as ac
from flatland.envs.observations import GlobalObsForRailEnv
from gym.utils import seeding
# Direct import of objects / classes does not work with circular imports.
# from flatland.envs.malfunction_generators import no_malfunction_generator, Malfunction, MalfunctionProcessData
# from flatland.envs.observations import GlobalObsForRailEnv
# from flatland.envs.rail_generators import random_rail_generator, RailGenerator
# from flatland.envs.schedule_generators import random_schedule_generator, ScheduleGenerator
import pickle
m.patch()
# Adrian Egli performance fix (the fast methods brings more than 50%)
def fast_isclose(a, b, rtol):
return (a < (b + rtol)) or (a < (b - rtol))
def fast_clip(position: (int, int), min_value: (int, int), max_value: (int, int)) -> bool:
return (
max(min_value[0], min(position[0], max_value[0])),
max(min_value[1], min(position[1], max_value[1]))
)
def fast_argmax(possible_transitions: (int, int, int, int)) -> bool:
if possible_transitions[0] == 1:
return 0
if possible_transitions[1] == 1:
return 1
if possible_transitions[2] == 1:
return 2
return 3
def fast_position_equal(pos_1: (int, int), pos_2: (int, int)) -> bool:
return pos_1[0] == pos_2[0] and pos_1[1] == pos_2[1]
def fast_count_nonzero(possible_transitions: (int, int, int, int)):
return possible_transitions[0]+possible_transitions[1]+possible_transitions[2]+possible_transitions[3]
class RailEnvActions(IntEnum):
DO_NOTHING = 0 # implies change of direction in a dead-end!
MOVE_LEFT = 1
......@@ -289,8 +317,8 @@ class RailEnv(Environment):
False: Agent cannot provide an action
"""
return (agent.status == RailAgentStatus.READY_TO_DEPART or (
agent.status == RailAgentStatus.ACTIVE and self.my_isclose(agent.speed_data['position_fraction'], 0.0,
rtol=1e-03)))
agent.status == RailAgentStatus.ACTIVE and fast_isclose(agent.speed_data['position_fraction'], 0.0,
rtol=1e-03)))
def reset(self, regenerate_rail: bool = True, regenerate_schedule: bool = True, activate_agents: bool = False,
random_seed: bool = None) -> Tuple[Dict, Dict]:
......@@ -441,7 +469,6 @@ class RailEnv(Environment):
"""
#malfunction: Malfunction = self.malfunction_generator(agent, self.np_random)
if "generate" in dir(self.malfunction_generator):
malfunction: mal_gen.Malfunction = self.malfunction_generator.generate(agent, self.np_random)
else:
......@@ -567,8 +594,6 @@ class RailEnv(Environment):
return self._get_observations(), self.rewards_dict, self.dones, info_dict
def _step_agent(self, i_agent, action: Optional[RailEnvActions] = None):
"""
Performs a step and step, start and stop penalty on a single agent in the following sub steps:
......@@ -615,7 +640,7 @@ class RailEnv(Environment):
# Is the agent at the beginning of the cell? Then, it can take an action.
# As long as the agent is malfunctioning or stopped at the beginning of the cell,
# different actions may be taken!
if self.my_isclose(agent.speed_data['position_fraction'], 0.0, rtol=1e-03):
if fast_isclose(agent.speed_data['position_fraction'], 0.0, rtol=1e-03):
# No action has been supplied for this agent -> set DO_NOTHING as default
if action is None:
action = RailEnvActions.DO_NOTHING
......@@ -675,8 +700,8 @@ class RailEnv(Environment):
# transition_action_on_cellexit if the cell is free.
if agent.moving:
agent.speed_data['position_fraction'] += agent.speed_data['speed']
if agent.speed_data['position_fraction'] > 1.0 or self.my_isclose(agent.speed_data['position_fraction'], 1.0,
rtol=1e-03):
if agent.speed_data['position_fraction'] > 1.0 or fast_isclose(agent.speed_data['position_fraction'], 1.0,
rtol=1e-03):
# Perform stored action to transition to the next cell as soon as cell is free
# Notice that we've already checked new_cell_valid and transition valid when we stored the action,
# so we only have to check cell_free now!
......@@ -694,7 +719,6 @@ class RailEnv(Environment):
agent.direction = new_direction
agent.speed_data['position_fraction'] = 0.0
# has the agent reached its target?
if np.equal(agent.position, agent.target).all():
agent.status = RailAgentStatus.DONE
......@@ -926,7 +950,6 @@ class RailEnv(Environment):
self.agent_positions[agent.position] = -1
if self.remove_agents_at_target:
agent.position = None
agent.old_position = None
agent.status = RailAgentStatus.DONE_REMOVED
def _check_action_on_agent(self, action: RailEnvActions, agent: EnvAgent):
......@@ -953,11 +976,11 @@ class RailEnv(Environment):
new_position = get_new_position(agent.position, new_direction)
new_cell_valid = (
np.array_equal( # Check the new position is still in the grid
new_position,
np.clip(new_position, [0, 0], [self.height - 1, self.width - 1]))
and # check the new position has some transitions (ie is not an empty cell)
self.rail.get_full_transitions(*new_position) > 0)
fast_position_equal( # Check the new position is still in the grid
new_position,
fast_clip(new_position, [0, 0], [self.height - 1, self.width - 1]))
and # check the new position has some transitions (ie is not an empty cell)
self.rail.get_full_transitions(*new_position) > 0)
# If transition validity hasn't been checked yet.
if transition_valid is None:
......@@ -1027,7 +1050,7 @@ class RailEnv(Environment):
"""
transition_valid = None
possible_transitions = self.rail.get_transitions(*agent.position, agent.direction)
num_transitions = np.count_nonzero(possible_transitions)
num_transitions = fast_count_nonzero(possible_transitions)
new_direction = agent.direction
if action == RailEnvActions.MOVE_LEFT:
......@@ -1046,7 +1069,7 @@ class RailEnv(Environment):
# - dead-end, straight line or curved line;
# new_direction will be the only valid transition
# - take only available transition
new_direction = np.argmax(possible_transitions)
new_direction = fast_argmax(possible_transitions)
transition_valid = True
return new_direction, transition_valid
......@@ -1105,6 +1128,3 @@ class RailEnv(Environment):
def save(self, filename):
print("deprecated call to env.save() - pls call RailEnvPersister.save()")
persistence.RailEnvPersister.save(self, filename)
def my_isclose(self, x, y, rtol=1.e-5, atol=1.e-8):
return abs(x - y) <= atol + rtol * abs(y)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment