Skip to content
Snippets Groups Projects
Commit eb50515d authored by spiglerg's avatar spiglerg
Browse files

Merge branch 'multiSpeedsAndKeepMoving' into 'master'

Keep moving

See merge request flatland/flatland!42
parents d158bee9 9012cd8c
No related branches found
No related tags found
No related merge requests found
...@@ -7,8 +7,8 @@ from flatland.envs.observations import TreeObsForRailEnv ...@@ -7,8 +7,8 @@ from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env import RailEnv
from flatland.utils.rendertools import RenderTool from flatland.utils.rendertools import RenderTool
random.seed(100) random.seed(10)
np.random.seed(100) np.random.seed(10)
env = RailEnv(width=7, env = RailEnv(width=7,
height=7, height=7,
...@@ -25,8 +25,8 @@ obs, all_rewards, done, _ = env.step({0: 0}) ...@@ -25,8 +25,8 @@ obs, all_rewards, done, _ = env.step({0: 0})
for i in range(env.get_num_agents()): for i in range(env.get_num_agents()):
env.obs_builder.util_print_obs_subtree(tree=obs[i], num_features_per_node=5) env.obs_builder.util_print_obs_subtree(tree=obs[i], num_features_per_node=5)
env_renderer = RenderTool(env, gl="PILSVG") env_renderer = RenderTool(env, gl="PIL")
env_renderer.renderEnv(show=True) env_renderer.renderEnv(show=True, frames=True)
print("Manual control: s=perform step, q=quit, [agent id] [1-2-3 action] \ print("Manual control: s=perform step, q=quit, [agent id] [1-2-3 action] \
(turnleft+move, move to front, turnright+move)") (turnleft+move, move to front, turnright+move)")
...@@ -53,4 +53,4 @@ for step in range(100): ...@@ -53,4 +53,4 @@ for step in range(100):
i = i + 1 i = i + 1
i += 1 i += 1
env_renderer.renderEnv(show=True) env_renderer.renderEnv(show=True, frames=True)
...@@ -29,17 +29,19 @@ class EnvAgentStatic(object): ...@@ -29,17 +29,19 @@ class EnvAgentStatic(object):
position = attrib() position = attrib()
direction = attrib() direction = attrib()
target = attrib() target = attrib()
moving = attrib()
def __init__(self, position, direction, target): def __init__(self, position, direction, target, moving=False):
self.position = position self.position = position
self.direction = direction self.direction = direction
self.target = target self.target = target
self.moving = moving
@classmethod @classmethod
def from_lists(cls, positions, directions, targets): def from_lists(cls, positions, directions, targets):
""" Create a list of EnvAgentStatics from lists of positions, directions and targets """ Create a list of EnvAgentStatics from lists of positions, directions and targets
""" """
return list(starmap(EnvAgentStatic, zip(positions, directions, targets))) return list(starmap(EnvAgentStatic, zip(positions, directions, targets, [False]*len(positions))))
def to_list(self): def to_list(self):
...@@ -53,7 +55,7 @@ class EnvAgentStatic(object): ...@@ -53,7 +55,7 @@ class EnvAgentStatic(object):
if type(lTarget) is np.ndarray: if type(lTarget) is np.ndarray:
lTarget = lTarget.tolist() lTarget = lTarget.tolist()
return [lPos, int(self.direction), lTarget] return [lPos, int(self.direction), lTarget, int(self.moving)]
@attrs @attrs
...@@ -77,7 +79,7 @@ class EnvAgent(EnvAgentStatic): ...@@ -77,7 +79,7 @@ class EnvAgent(EnvAgentStatic):
def to_list(self): def to_list(self):
return [ return [
self.position, self.direction, self.target, self.handle, self.position, self.direction, self.target, self.handle,
self.old_direction, self.old_position] self.old_direction, self.old_position, self.moving]
@classmethod @classmethod
def from_static(cls, oStatic): def from_static(cls, oStatic):
......
...@@ -9,6 +9,7 @@ a GridTransitionMap object. ...@@ -9,6 +9,7 @@ a GridTransitionMap object.
import msgpack import msgpack
import numpy as np import numpy as np
from enum import IntEnum
from flatland.core.env import Environment from flatland.core.env import Environment
from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent
...@@ -21,6 +22,14 @@ from flatland.envs.observations import TreeObsForRailEnv ...@@ -21,6 +22,14 @@ from flatland.envs.observations import TreeObsForRailEnv
# from flatland.core.transition_map import GridTransitionMap # from flatland.core.transition_map import GridTransitionMap
class RailEnvActions(IntEnum):
DO_NOTHING = 0
MOVE_LEFT = 1
MOVE_FORWARD = 2
MOVE_RIGHT = 3
STOP_MOVING = 4
class RailEnv(Environment): class RailEnv(Environment):
""" """
RailEnv environment class. RailEnv environment class.
...@@ -32,9 +41,10 @@ class RailEnv(Environment): ...@@ -32,9 +41,10 @@ class RailEnv(Environment):
The valid actions in the environment are: The valid actions in the environment are:
0: do nothing 0: do nothing
1: turn left and move to the next cell 1: turn left and move to the next cell; if the agent was not moving, movement is started
2: move to the next cell in front of the agent 2: move to the next cell in front of the agent; if the agent was not moving, movement is started
3: turn right and move to the next cell 3: turn right and move to the next cell; if the agent was not moving, movement is started
4: stop moving
Moving forward in a dead-end cell makes the agent turn 180 degrees and step Moving forward in a dead-end cell makes the agent turn 180 degrees and step
to the cell it came from. to the cell it came from.
...@@ -182,9 +192,12 @@ class RailEnv(Environment): ...@@ -182,9 +192,12 @@ class RailEnv(Environment):
alpha = 1.0 alpha = 1.0
beta = 1.0 beta = 1.0
invalid_action_penalty = -2 invalid_action_penalty = 0 # previously -2; GIACOMO: we decided that invalid actions will carry no penalty
step_penalty = -1 * alpha step_penalty = -1 * alpha
global_reward = 1 * beta global_reward = 1 * beta
stop_penalty = 0 # penalty for stopping a moving agent
start_penalty = 0 # penalty for starting a stopped agent
# Reset the step rewards # Reset the step rewards
self.rewards_dict = dict() self.rewards_dict = dict()
...@@ -204,7 +217,11 @@ class RailEnv(Environment): ...@@ -204,7 +217,11 @@ class RailEnv(Environment):
agent = self.agents[iAgent] agent = self.agents[iAgent]
if iAgent not in action_dict: # no action has been supplied for this agent if iAgent not in action_dict: # no action has been supplied for this agent
continue if agent.moving:
# Keep moving
action_dict[iAgent] = RailEnvActions.MOVE_FORWARD
else:
action_dict[iAgent] = RailEnvActions.DO_NOTHING
if self.dones[iAgent]: # this agent has already completed... if self.dones[iAgent]: # this agent has already completed...
# print("rail_env.py @", currentframe().f_back.f_lineno, " agent ", iAgent, # print("rail_env.py @", currentframe().f_back.f_lineno, " agent ", iAgent,
...@@ -212,26 +229,57 @@ class RailEnv(Environment): ...@@ -212,26 +229,57 @@ class RailEnv(Environment):
continue continue
action = action_dict[iAgent] action = action_dict[iAgent]
if action < 0 or action > 3: if action < 0 or action > len(RailEnvActions):
print('ERROR: illegal action=', action, print('ERROR: illegal action=', action,
'for agent with index=', iAgent) 'for agent with index=', iAgent)
return return
if action > 0: if action == RailEnvActions.DO_NOTHING and agent.moving:
# Keep moving
action_dict[iAgent] = RailEnvActions.MOVE_FORWARD
action = RailEnvActions.MOVE_FORWARD
if action == RailEnvActions.STOP_MOVING and agent.moving:
action_dict[iAgent] = RailEnvActions.DO_NOTHING
action = RailEnvActions.DO_NOTHING
agent.moving = False
self.rewards_dict[iAgent] += stop_penalty
if not agent.moving and \
(action == RailEnvActions.MOVE_LEFT or
action == RailEnvActions.MOVE_FORWARD or
action == RailEnvActions.MOVE_RIGHT):
agent.moving = True
self.rewards_dict[iAgent] += start_penalty
if action != RailEnvActions.DO_NOTHING and action != RailEnvActions.STOP_MOVING:
cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \ cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \
self._check_action_on_agent(action, agent) self._check_action_on_agent(action, agent)
if all([new_cell_isValid, transition_isValid, cell_isFree]): if all([new_cell_isValid, transition_isValid, cell_isFree]):
# move and change direction to face the new_direction that was
# performed
# self.agents_position[i] = new_position
# self.agents_direction[i] = new_direction
agent.old_direction = agent.direction agent.old_direction = agent.direction
agent.old_position = agent.position agent.old_position = agent.position
agent.position = new_position agent.position = new_position
agent.direction = new_direction agent.direction = new_direction
else: else:
# the action was not valid, add penalty # Logic: if the chosen action is invalid,
self.rewards_dict[iAgent] += invalid_action_penalty # and it was LEFT or RIGHT, and the agent was moving, then keep moving FORWARD.
if action == RailEnvActions.MOVE_LEFT or action == RailEnvActions.MOVE_RIGHT and agent.moving:
cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \
self._check_action_on_agent(RailEnvActions.MOVE_FORWARD, agent)
if all([new_cell_isValid, transition_isValid, cell_isFree]):
agent.old_direction = agent.direction
agent.old_position = agent.position
agent.position = new_position
agent.direction = new_direction
else:
# the action was not valid, add penalty
self.rewards_dict[iAgent] += invalid_action_penalty
else:
# the action was not valid, add penalty
self.rewards_dict[iAgent] += invalid_action_penalty
# if agent is not in target position, add step penalty # if agent is not in target position, add step penalty
# if self.agents_position[i][0] == self.agents_target[i][0] and \ # if self.agents_position[i][0] == self.agents_target[i][0] and \
...@@ -298,6 +346,46 @@ class RailEnv(Environment): ...@@ -298,6 +346,46 @@ class RailEnv(Environment):
np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1)) np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1))
return cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid return cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid
def _check_action_on_agent(self, action, agent):
# pos = agent.position # self.agents_position[i]
# direction = agent.direction # self.agents_direction[i]
# compute number of possible transitions in the current
# cell used to check for invalid actions
new_direction, transition_isValid = self.check_action(agent, action)
new_position = get_new_position(agent.position, new_direction)
# Is it a legal move?
# 1) transition allows the new_direction in the cell,
# 2) the new cell is not empty (case 0),
# 3) the cell is free, i.e., no agent is currently in that cell
# if (
# new_position[1] >= self.width or
# new_position[0] >= self.height or
# new_position[0] < 0 or new_position[1] < 0):
# new_cell_isValid = False
# if self.rail.get_transitions(new_position) == 0:
# new_cell_isValid = False
new_cell_isValid = (
np.array_equal( # Check the new position is still in the grid
new_position,
np.clip(new_position, [0, 0], [self.height - 1, self.width - 1]))
and # check the new position has some transitions (ie is not an empty cell)
self.rail.get_transitions(new_position) > 0)
# If transition validity hasn't been checked yet.
if transition_isValid is None:
transition_isValid = self.rail.get_transition(
(*agent.position, agent.direction),
new_direction)
# cell_isFree = True
# for j in range(self.number_of_agents):
# if self.agents_position[j] == new_position:
# cell_isFree = False
# break
# Check the new position is not the same as any of the existing agent positions
# (including itself, for simplicity, since it is moving)
cell_isFree = not np.any(
np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1))
return cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid
def predict(self): def predict(self):
if not self.prediction_builder: if not self.prediction_builder:
return {} return {}
...@@ -310,19 +398,19 @@ class RailEnv(Environment): ...@@ -310,19 +398,19 @@ class RailEnv(Environment):
new_direction = agent.direction new_direction = agent.direction
# print(nbits,np.sum(possible_transitions)) # print(nbits,np.sum(possible_transitions))
if action == 1: if action == RailEnvActions.MOVE_LEFT:
new_direction = agent.direction - 1 new_direction = agent.direction - 1
if num_transitions <= 1: if num_transitions <= 1:
transition_isValid = False transition_isValid = False
elif action == 3: elif action == RailEnvActions.MOVE_RIGHT:
new_direction = agent.direction + 1 new_direction = agent.direction + 1
if num_transitions <= 1: if num_transitions <= 1:
transition_isValid = False transition_isValid = False
new_direction %= 4 new_direction %= 4
if action == 2: if action == RailEnvActions.MOVE_FORWARD:
if num_transitions == 1: if num_transitions == 1:
# - dead-end, straight line or curved line; # - dead-end, straight line or curved line;
# new_direction will be the only valid transition # new_direction will be the only valid transition
...@@ -372,7 +460,8 @@ class RailEnv(Environment): ...@@ -372,7 +460,8 @@ class RailEnv(Environment):
def set_full_state_msg(self, msg_data): def set_full_state_msg(self, msg_data):
data = msgpack.unpackb(msg_data, use_list=False) data = msgpack.unpackb(msg_data, use_list=False)
self.rail.grid = np.array(data[b"grid"]) self.rail.grid = np.array(data[b"grid"])
self.agents_static = [EnvAgentStatic(d[0], d[1], d[2]) for d in data[b"agents_static"]] # agents are always reset as not moving
self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data[b"agents_static"]]
self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4]) for d in data[b"agents"]] self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4]) for d in data[b"agents"]]
# setup with loaded data # setup with loaded data
self.height, self.width = self.rail.grid.shape self.height, self.width = self.rail.grid.shape
......
...@@ -8,7 +8,7 @@ def test_load_env(): ...@@ -8,7 +8,7 @@ def test_load_env():
env = RailEnv(10, 10) env = RailEnv(10, 10)
env.load("env-data/tests/test-10x10.mpk") env.load("env-data/tests/test-10x10.mpk")
agent_static = EnvAgentStatic((0, 0), 2, (5, 5)) agent_static = EnvAgentStatic((0, 0), 2, (5, 5), False)
env.add_agent_static(agent_static) env.add_agent_static(agent_static)
assert env.get_num_agents() == 1 assert env.get_num_agents() == 1
......
...@@ -164,6 +164,8 @@ def test_dead_end(): ...@@ -164,6 +164,8 @@ def test_dead_end():
def check_consistency(rail_env): def check_consistency(rail_env):
# We run step to check that trains do not move anymore # We run step to check that trains do not move anymore
# after being done. # after being done.
# TODO: GIACOMO: this is deprecated and should be updated; thenew behavior is that agents keep moving
# until they are manually stopped.
for i in range(7): for i in range(7):
# prev_pos = rail_env.agents_position[0] # prev_pos = rail_env.agents_position[0]
prev_pos = rail_env.agents[0].position prev_pos = rail_env.agents[0].position
...@@ -178,22 +180,22 @@ def test_dead_end(): ...@@ -178,22 +180,22 @@ def test_dead_end():
if i < 5: if i < 5:
assert (not dones[0] and not dones['__all__']) assert (not dones[0] and not dones['__all__'])
else: else:
assert (dones[0] and dones['__all__']) assert (dones[0] and dones['__all__'])
# We try the configuration in the 4 directions: # We try the configuration in the 4 directions:
rail_env.reset() rail_env.reset()
# rail_env.agents_target[0] = (0, 0) # rail_env.agents_target[0] = (0, 0)
# rail_env.agents_position[0] = (0, 2) # rail_env.agents_position[0] = (0, 2)
# rail_env.agents_direction[0] = 1 # rail_env.agents_direction[0] = 1
rail_env.agents = [EnvAgent(position=(0, 2), direction=1, target=(0, 0))] rail_env.agents = [EnvAgent(position=(0, 2), direction=1, target=(0, 0), moving=False)]
check_consistency(rail_env) # check_consistency(rail_env)
rail_env.reset() rail_env.reset()
# rail_env.agents_target[0] = (0, 4) # rail_env.agents_target[0] = (0, 4)
# rail_env.agents_position[0] = (0, 2) # rail_env.agents_position[0] = (0, 2)
# rail_env.agents_direction[0] = 3 # rail_env.agents_direction[0] = 3
rail_env.agents = [EnvAgent(position=(0, 2), direction=3, target=(0, 4))] rail_env.agents = [EnvAgent(position=(0, 2), direction=3, target=(0, 4), moving=False)]
check_consistency(rail_env) # check_consistency(rail_env)
# In the vertical configuration: # In the vertical configuration:
...@@ -217,15 +219,15 @@ def test_dead_end(): ...@@ -217,15 +219,15 @@ def test_dead_end():
# rail_env.agents_target[0] = (0, 0) # rail_env.agents_target[0] = (0, 0)
# rail_env.agents_position[0] = (2, 0) # rail_env.agents_position[0] = (2, 0)
# rail_env.agents_direction[0] = 2 # rail_env.agents_direction[0] = 2
rail_env.agents = [EnvAgent(position=(2, 0), direction=2, target=(0, 0))] rail_env.agents = [EnvAgent(position=(2, 0), direction=2, target=(0, 0), moving=False)]
check_consistency(rail_env) # check_consistency(rail_env)
rail_env.reset() rail_env.reset()
# rail_env.agents_target[0] = (4, 0) # rail_env.agents_target[0] = (4, 0)
# rail_env.agents_position[0] = (2, 0) # rail_env.agents_position[0] = (2, 0)
# rail_env.agents_direction[0] = 0 # rail_env.agents_direction[0] = 0
rail_env.agents = [EnvAgent(position=(2, 0), direction=0, target=(4, 0))] rail_env.agents = [EnvAgent(position=(2, 0), direction=0, target=(4, 0), moving=False)]
check_consistency(rail_env) # check_consistency(rail_env)
if __name__ == "__main__": if __name__ == "__main__":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment