Skip to content
Snippets Groups Projects
Commit 9d28be8f authored by hagrid67's avatar hagrid67
Browse files

manually merging Adrian's changes (made by Erik) from master

parent be3c58b0
No related branches found
No related tags found
No related merge requests found
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
import networkx as nx import networkx as nx
import numpy as np import numpy as np
import matplotlib.pyplot as plt
from typing import List, Tuple from typing import List, Tuple
import graphviz as gv import graphviz as gv
...@@ -372,18 +371,11 @@ def test_agent_following(): ...@@ -372,18 +371,11 @@ def test_agent_following():
for v in lvCells ] for v in lvCells ]
dPos = dict(zip(lvCells, lvCells)) dPos = dict(zip(lvCells, lvCells))
#plt.ion()
nx.draw(omc.G, nx.draw(omc.G,
with_labels=True, arrowsize=20, with_labels=True, arrowsize=20,
pos=dPos, pos=dPos,
node_color = lColours) node_color = lColours)
#plt.pause(20)
#plt.show()
def main(): def main():
test_agent_following() test_agent_following()
......
...@@ -4,13 +4,10 @@ Definition of the RailEnv environment. ...@@ -4,13 +4,10 @@ Definition of the RailEnv environment.
import random import random
# TODO: _ this is a global method --> utils or remove later # TODO: _ this is a global method --> utils or remove later
from enum import IntEnum from enum import IntEnum
from typing import List, NamedTuple, Optional, Dict from typing import List, NamedTuple, Optional, Dict, Tuple
import msgpack
import msgpack_numpy as m
import numpy as np import numpy as np
from gym.utils import seeding
from msgpack import Packer
from flatland.core.env import Environment from flatland.core.env import Environment
from flatland.core.env_observation_builder import ObservationBuilder from flatland.core.env_observation_builder import ObservationBuilder
...@@ -28,21 +25,50 @@ from flatland.envs import schedule_generators as sched_gen ...@@ -28,21 +25,50 @@ from flatland.envs import schedule_generators as sched_gen
from flatland.envs import persistence from flatland.envs import persistence
from flatland.envs import agent_chains as ac from flatland.envs import agent_chains as ac
from flatland.envs.observations import GlobalObsForRailEnv
from gym.utils import seeding
# Direct import of objects / classes does not work with circular imports. # Direct import of objects / classes does not work with circular imports.
# from flatland.envs.malfunction_generators import no_malfunction_generator, Malfunction, MalfunctionProcessData # from flatland.envs.malfunction_generators import no_malfunction_generator, Malfunction, MalfunctionProcessData
# from flatland.envs.observations import GlobalObsForRailEnv # from flatland.envs.observations import GlobalObsForRailEnv
# from flatland.envs.rail_generators import random_rail_generator, RailGenerator # from flatland.envs.rail_generators import random_rail_generator, RailGenerator
# from flatland.envs.schedule_generators import random_schedule_generator, ScheduleGenerator # from flatland.envs.schedule_generators import random_schedule_generator, ScheduleGenerator
from flatland.envs.observations import GlobalObsForRailEnv
# import debugpy
import pickle
m.patch() m.patch()
# Adrian Egli performance fix (the fast methods brings more than 50%)
def fast_isclose(a, b, rtol):
return (a < (b + rtol)) or (a < (b - rtol))
def fast_clip(position: (int, int), min_value: (int, int), max_value: (int, int)) -> bool:
return (
max(min_value[0], min(position[0], max_value[0])),
max(min_value[1], min(position[1], max_value[1]))
)
def fast_argmax(possible_transitions: (int, int, int, int)) -> bool:
if possible_transitions[0] == 1:
return 0
if possible_transitions[1] == 1:
return 1
if possible_transitions[2] == 1:
return 2
return 3
def fast_position_equal(pos_1: (int, int), pos_2: (int, int)) -> bool:
return pos_1[0] == pos_2[0] and pos_1[1] == pos_2[1]
def fast_count_nonzero(possible_transitions: (int, int, int, int)):
return possible_transitions[0] + possible_transitions[1] + possible_transitions[2] + possible_transitions[3]
class RailEnvActions(IntEnum): class RailEnvActions(IntEnum):
DO_NOTHING = 0 # implies change of direction in a dead-end! DO_NOTHING = 0 # implies change of direction in a dead-end!
MOVE_LEFT = 1 MOVE_LEFT = 1
...@@ -298,11 +324,11 @@ class RailEnv(Environment): ...@@ -298,11 +324,11 @@ class RailEnv(Environment):
False: Agent cannot provide an action False: Agent cannot provide an action
""" """
return (agent.status == RailAgentStatus.READY_TO_DEPART or ( return (agent.status == RailAgentStatus.READY_TO_DEPART or (
agent.status == RailAgentStatus.ACTIVE and np.isclose(agent.speed_data['position_fraction'], 0.0, agent.status == RailAgentStatus.ACTIVE and fast_isclose(agent.speed_data['position_fraction'], 0.0,
rtol=1e-03))) rtol=1e-03)))
def reset(self, regenerate_rail: bool = True, regenerate_schedule: bool = True, activate_agents: bool = False, def reset(self, regenerate_rail: bool = True, regenerate_schedule: bool = True, activate_agents: bool = False,
random_seed: bool = None) -> (Dict, Dict): random_seed: bool = None) -> Tuple[Dict, Dict]:
""" """
reset(regenerate_rail, regenerate_schedule, activate_agents, random_seed) reset(regenerate_rail, regenerate_schedule, activate_agents, random_seed)
...@@ -604,7 +630,7 @@ class RailEnv(Environment): ...@@ -604,7 +630,7 @@ class RailEnv(Environment):
RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT, RailEnvActions.MOVE_FORWARD] RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT, RailEnvActions.MOVE_FORWARD]
if action in [RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT, if action in [RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT,
RailEnvActions.MOVE_FORWARD] and self.cell_free(agent.initial_position): RailEnvActions.MOVE_FORWARD] and self.cell_free(agent.initial_position):
agent.status = RailAgentStatus.ACTIVE agent.status = RailAgentStatus.ACTIVE
self._set_agent_to_initial_position(agent, agent.initial_position) self._set_agent_to_initial_position(agent, agent.initial_position)
self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
...@@ -626,7 +652,7 @@ class RailEnv(Environment): ...@@ -626,7 +652,7 @@ class RailEnv(Environment):
# Is the agent at the beginning of the cell? Then, it can take an action. # Is the agent at the beginning of the cell? Then, it can take an action.
# As long as the agent is malfunctioning or stopped at the beginning of the cell, # As long as the agent is malfunctioning or stopped at the beginning of the cell,
# different actions may be taken! # different actions may be taken!
if np.isclose(agent.speed_data['position_fraction'], 0.0, rtol=1e-03): if fast_isclose(agent.speed_data['position_fraction'], 0.0, rtol=1e-03):
# No action has been supplied for this agent -> set DO_NOTHING as default # No action has been supplied for this agent -> set DO_NOTHING as default
if action is None: if action is None:
action = RailEnvActions.DO_NOTHING action = RailEnvActions.DO_NOTHING
...@@ -686,8 +712,8 @@ class RailEnv(Environment): ...@@ -686,8 +712,8 @@ class RailEnv(Environment):
# transition_action_on_cellexit if the cell is free. # transition_action_on_cellexit if the cell is free.
if agent.moving: if agent.moving:
agent.speed_data['position_fraction'] += agent.speed_data['speed'] agent.speed_data['position_fraction'] += agent.speed_data['speed']
if agent.speed_data['position_fraction'] > 1.0 or np.isclose(agent.speed_data['position_fraction'], 1.0, if agent.speed_data['position_fraction'] > 1.0 or fast_isclose(agent.speed_data['position_fraction'], 1.0,
rtol=1e-03): rtol=1e-03):
# Perform stored action to transition to the next cell as soon as cell is free # Perform stored action to transition to the next cell as soon as cell is free
# Notice that we've already checked new_cell_valid and transition valid when we stored the action, # Notice that we've already checked new_cell_valid and transition valid when we stored the action,
# so we only have to check cell_free now! # so we only have to check cell_free now!
...@@ -695,7 +721,7 @@ class RailEnv(Environment): ...@@ -695,7 +721,7 @@ class RailEnv(Environment):
# Traditional check that next cell is free # Traditional check that next cell is free
# cell and transition validity was checked when we stored transition_action_on_cellexit! # cell and transition validity was checked when we stored transition_action_on_cellexit!
cell_free, new_cell_valid, new_direction, new_position, transition_valid = self._check_action_on_agent( cell_free, new_cell_valid, new_direction, new_position, transition_valid = self._check_action_on_agent(
agent.speed_data['transition_action_on_cellexit'], agent) agent.speed_data['transition_action_on_cellexit'], agent)
# N.B. validity of new_cell and transition should have been verified before the action was stored! # N.B. validity of new_cell and transition should have been verified before the action was stored!
assert new_cell_valid assert new_cell_valid
...@@ -845,7 +871,6 @@ class RailEnv(Environment): ...@@ -845,7 +871,6 @@ class RailEnv(Environment):
trans_block = sbTrans[agent.direction*4 : agent.direction * 4 + 4] trans_block = sbTrans[agent.direction*4 : agent.direction * 4 + 4]
if (trans_block == "0000"): if (trans_block == "0000"):
print (i_agent, agent.position, agent.direction, sbTrans, trans_block) print (i_agent, agent.position, agent.direction, sbTrans, trans_block)
# debugpy.breakpoint()
# if agent cannot enter env, then we should have move=False # if agent cannot enter env, then we should have move=False
...@@ -862,20 +887,16 @@ class RailEnv(Environment): ...@@ -862,20 +887,16 @@ class RailEnv(Environment):
if not all([transition_valid, new_cell_valid]): if not all([transition_valid, new_cell_valid]):
print(f"ERRROR: step_agent2 invalid transition ag {i_agent} dir {new_direction} pos {agent.position} next {rc_next}") print(f"ERRROR: step_agent2 invalid transition ag {i_agent} dir {new_direction} pos {agent.position} next {rc_next}")
# debugpy.breakpoint()
if new_position != rc_next: if new_position != rc_next:
print(f"ERROR: agent {i_agent} new_pos {new_position} != rc_next {rc_next} " + print(f"ERROR: agent {i_agent} new_pos {new_position} != rc_next {rc_next} " +
f"pos {agent.position} dir {agent.direction} new_dir {new_direction}" + f"pos {agent.position} dir {agent.direction} new_dir {new_direction}" +
f"stored action: {agent.speed_data['transition_action_on_cellexit']}") f"stored action: {agent.speed_data['transition_action_on_cellexit']}")
# debugpy.breakpoint()
sbTrans = format(self.rail.grid[agent.position], "016b") sbTrans = format(self.rail.grid[agent.position], "016b")
trans_block = sbTrans[agent.direction*4 : agent.direction * 4 + 4] trans_block = sbTrans[agent.direction*4 : agent.direction * 4 + 4]
if (trans_block == "0000"): if (trans_block == "0000"):
print ("ERROR: ", i_agent, agent.position, agent.direction, sbTrans, trans_block) print ("ERROR: ", i_agent, agent.position, agent.direction, sbTrans, trans_block)
# debugpy.breakpoint()
agent.position = rc_next agent.position = rc_next
agent.direction = new_direction agent.direction = new_direction
...@@ -937,6 +958,7 @@ class RailEnv(Environment): ...@@ -937,6 +958,7 @@ class RailEnv(Environment):
self.agent_positions[agent.position] = -1 self.agent_positions[agent.position] = -1
if self.remove_agents_at_target: if self.remove_agents_at_target:
agent.position = None agent.position = None
# setting old_position to None here stops the DONE agents from appearing in the rendered image
agent.old_position = None agent.old_position = None
agent.status = RailAgentStatus.DONE_REMOVED agent.status = RailAgentStatus.DONE_REMOVED
...@@ -964,9 +986,9 @@ class RailEnv(Environment): ...@@ -964,9 +986,9 @@ class RailEnv(Environment):
new_position = get_new_position(agent.position, new_direction) new_position = get_new_position(agent.position, new_direction)
new_cell_valid = ( new_cell_valid = (
np.array_equal( # Check the new position is still in the grid fast_position_equal( # Check the new position is still in the grid
new_position, new_position,
np.clip(new_position, [0, 0], [self.height - 1, self.width - 1])) fast_clip(new_position, [0, 0], [self.height - 1, self.width - 1]))
and # check the new position has some transitions (ie is not an empty cell) and # check the new position has some transitions (ie is not an empty cell)
self.rail.get_full_transitions(*new_position) > 0) self.rail.get_full_transitions(*new_position) > 0)
...@@ -1038,7 +1060,7 @@ class RailEnv(Environment): ...@@ -1038,7 +1060,7 @@ class RailEnv(Environment):
""" """
transition_valid = None transition_valid = None
possible_transitions = self.rail.get_transitions(*agent.position, agent.direction) possible_transitions = self.rail.get_transitions(*agent.position, agent.direction)
num_transitions = np.count_nonzero(possible_transitions) num_transitions = fast_count_nonzero(possible_transitions)
new_direction = agent.direction new_direction = agent.direction
if action == RailEnvActions.MOVE_LEFT: if action == RailEnvActions.MOVE_LEFT:
...@@ -1057,7 +1079,7 @@ class RailEnv(Environment): ...@@ -1057,7 +1079,7 @@ class RailEnv(Environment):
# - dead-end, straight line or curved line; # - dead-end, straight line or curved line;
# new_direction will be the only valid transition # new_direction will be the only valid transition
# - take only available transition # - take only available transition
new_direction = np.argmax(possible_transitions) new_direction = fast_argmax(possible_transitions)
transition_valid = True transition_valid = True
return new_direction, transition_valid return new_direction, transition_valid
......
...@@ -122,5 +122,10 @@ def makeTestEnv(sName="single_alternative", nAg=2, bUCF=True): ...@@ -122,5 +122,10 @@ def makeTestEnv(sName="single_alternative", nAg=2, bUCF=True):
dSpec = ddEnvSpecs[sName] dSpec = ddEnvSpecs[sName]
return makeEnv2(nAg=nAg, bUCF=bUCF, **dSpec) return makeEnv2(nAg=nAg, bUCF=bUCF, **dSpec)
def getAgentState(env):
dAgState={}
for iAg, ag in enumerate(env.agents):
dAgState[iAg] = (*ag.position, ag.direction)
return dAgState
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment