Skip to content
Snippets Groups Projects
Commit eb50515d authored by spiglerg's avatar spiglerg
Browse files

Merge branch 'multiSpeedsAndKeepMoving' into 'master'

Keep moving

See merge request flatland/flatland!42
parents d158bee9 9012cd8c
No related branches found
No related tags found
No related merge requests found
......@@ -7,8 +7,8 @@ from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.utils.rendertools import RenderTool
random.seed(100)
np.random.seed(100)
random.seed(10)
np.random.seed(10)
env = RailEnv(width=7,
height=7,
......@@ -25,8 +25,8 @@ obs, all_rewards, done, _ = env.step({0: 0})
for i in range(env.get_num_agents()):
env.obs_builder.util_print_obs_subtree(tree=obs[i], num_features_per_node=5)
env_renderer = RenderTool(env, gl="PILSVG")
env_renderer.renderEnv(show=True)
env_renderer = RenderTool(env, gl="PIL")
env_renderer.renderEnv(show=True, frames=True)
print("Manual control: s=perform step, q=quit, [agent id] [1-2-3 action] \
(turnleft+move, move to front, turnright+move)")
......@@ -53,4 +53,4 @@ for step in range(100):
i = i + 1
i += 1
env_renderer.renderEnv(show=True)
env_renderer.renderEnv(show=True, frames=True)
......@@ -29,17 +29,19 @@ class EnvAgentStatic(object):
position = attrib()
direction = attrib()
target = attrib()
moving = attrib()
def __init__(self, position, direction, target):
def __init__(self, position, direction, target, moving=False):
self.position = position
self.direction = direction
self.target = target
self.moving = moving
@classmethod
def from_lists(cls, positions, directions, targets):
""" Create a list of EnvAgentStatics from lists of positions, directions and targets
"""
return list(starmap(EnvAgentStatic, zip(positions, directions, targets)))
return list(starmap(EnvAgentStatic, zip(positions, directions, targets, [False]*len(positions))))
def to_list(self):
......@@ -53,7 +55,7 @@ class EnvAgentStatic(object):
if type(lTarget) is np.ndarray:
lTarget = lTarget.tolist()
return [lPos, int(self.direction), lTarget]
return [lPos, int(self.direction), lTarget, int(self.moving)]
@attrs
......@@ -77,7 +79,7 @@ class EnvAgent(EnvAgentStatic):
def to_list(self):
return [
self.position, self.direction, self.target, self.handle,
self.old_direction, self.old_position]
self.old_direction, self.old_position, self.moving]
@classmethod
def from_static(cls, oStatic):
......
......@@ -9,6 +9,7 @@ a GridTransitionMap object.
import msgpack
import numpy as np
from enum import IntEnum
from flatland.core.env import Environment
from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent
......@@ -21,6 +22,14 @@ from flatland.envs.observations import TreeObsForRailEnv
# from flatland.core.transition_map import GridTransitionMap
class RailEnvActions(IntEnum):
DO_NOTHING = 0
MOVE_LEFT = 1
MOVE_FORWARD = 2
MOVE_RIGHT = 3
STOP_MOVING = 4
class RailEnv(Environment):
"""
RailEnv environment class.
......@@ -32,9 +41,10 @@ class RailEnv(Environment):
The valid actions in the environment are:
0: do nothing
1: turn left and move to the next cell
2: move to the next cell in front of the agent
3: turn right and move to the next cell
1: turn left and move to the next cell; if the agent was not moving, movement is started
2: move to the next cell in front of the agent; if the agent was not moving, movement is started
3: turn right and move to the next cell; if the agent was not moving, movement is started
4: stop moving
Moving forward in a dead-end cell makes the agent turn 180 degrees and step
to the cell it came from.
......@@ -182,9 +192,12 @@ class RailEnv(Environment):
alpha = 1.0
beta = 1.0
invalid_action_penalty = -2
invalid_action_penalty = 0 # previously -2; GIACOMO: we decided that invalid actions will carry no penalty
step_penalty = -1 * alpha
global_reward = 1 * beta
stop_penalty = 0 # penalty for stopping a moving agent
start_penalty = 0 # penalty for starting a stopped agent
# Reset the step rewards
self.rewards_dict = dict()
......@@ -204,7 +217,11 @@ class RailEnv(Environment):
agent = self.agents[iAgent]
if iAgent not in action_dict: # no action has been supplied for this agent
continue
if agent.moving:
# Keep moving
action_dict[iAgent] = RailEnvActions.MOVE_FORWARD
else:
action_dict[iAgent] = RailEnvActions.DO_NOTHING
if self.dones[iAgent]: # this agent has already completed...
# print("rail_env.py @", currentframe().f_back.f_lineno, " agent ", iAgent,
......@@ -212,26 +229,57 @@ class RailEnv(Environment):
continue
action = action_dict[iAgent]
if action < 0 or action > 3:
if action < 0 or action > len(RailEnvActions):
print('ERROR: illegal action=', action,
'for agent with index=', iAgent)
return
if action > 0:
if action == RailEnvActions.DO_NOTHING and agent.moving:
# Keep moving
action_dict[iAgent] = RailEnvActions.MOVE_FORWARD
action = RailEnvActions.MOVE_FORWARD
if action == RailEnvActions.STOP_MOVING and agent.moving:
action_dict[iAgent] = RailEnvActions.DO_NOTHING
action = RailEnvActions.DO_NOTHING
agent.moving = False
self.rewards_dict[iAgent] += stop_penalty
if not agent.moving and \
(action == RailEnvActions.MOVE_LEFT or
action == RailEnvActions.MOVE_FORWARD or
action == RailEnvActions.MOVE_RIGHT):
agent.moving = True
self.rewards_dict[iAgent] += start_penalty
if action != RailEnvActions.DO_NOTHING and action != RailEnvActions.STOP_MOVING:
cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \
self._check_action_on_agent(action, agent)
if all([new_cell_isValid, transition_isValid, cell_isFree]):
# move and change direction to face the new_direction that was
# performed
# self.agents_position[i] = new_position
# self.agents_direction[i] = new_direction
agent.old_direction = agent.direction
agent.old_position = agent.position
agent.position = new_position
agent.direction = new_direction
else:
# the action was not valid, add penalty
self.rewards_dict[iAgent] += invalid_action_penalty
# Logic: if the chosen action is invalid,
# and it was LEFT or RIGHT, and the agent was moving, then keep moving FORWARD.
if action == RailEnvActions.MOVE_LEFT or action == RailEnvActions.MOVE_RIGHT and agent.moving:
cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \
self._check_action_on_agent(RailEnvActions.MOVE_FORWARD, agent)
if all([new_cell_isValid, transition_isValid, cell_isFree]):
agent.old_direction = agent.direction
agent.old_position = agent.position
agent.position = new_position
agent.direction = new_direction
else:
# the action was not valid, add penalty
self.rewards_dict[iAgent] += invalid_action_penalty
else:
# the action was not valid, add penalty
self.rewards_dict[iAgent] += invalid_action_penalty
# if agent is not in target position, add step penalty
# if self.agents_position[i][0] == self.agents_target[i][0] and \
......@@ -298,6 +346,46 @@ class RailEnv(Environment):
np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1))
return cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid
def _check_action_on_agent(self, action, agent):
# pos = agent.position # self.agents_position[i]
# direction = agent.direction # self.agents_direction[i]
# compute number of possible transitions in the current
# cell used to check for invalid actions
new_direction, transition_isValid = self.check_action(agent, action)
new_position = get_new_position(agent.position, new_direction)
# Is it a legal move?
# 1) transition allows the new_direction in the cell,
# 2) the new cell is not empty (case 0),
# 3) the cell is free, i.e., no agent is currently in that cell
# if (
# new_position[1] >= self.width or
# new_position[0] >= self.height or
# new_position[0] < 0 or new_position[1] < 0):
# new_cell_isValid = False
# if self.rail.get_transitions(new_position) == 0:
# new_cell_isValid = False
new_cell_isValid = (
np.array_equal( # Check the new position is still in the grid
new_position,
np.clip(new_position, [0, 0], [self.height - 1, self.width - 1]))
and # check the new position has some transitions (ie is not an empty cell)
self.rail.get_transitions(new_position) > 0)
# If transition validity hasn't been checked yet.
if transition_isValid is None:
transition_isValid = self.rail.get_transition(
(*agent.position, agent.direction),
new_direction)
# cell_isFree = True
# for j in range(self.number_of_agents):
# if self.agents_position[j] == new_position:
# cell_isFree = False
# break
# Check the new position is not the same as any of the existing agent positions
# (including itself, for simplicity, since it is moving)
cell_isFree = not np.any(
np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1))
return cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid
def predict(self):
if not self.prediction_builder:
return {}
......@@ -310,19 +398,19 @@ class RailEnv(Environment):
new_direction = agent.direction
# print(nbits,np.sum(possible_transitions))
if action == 1:
if action == RailEnvActions.MOVE_LEFT:
new_direction = agent.direction - 1
if num_transitions <= 1:
transition_isValid = False
elif action == 3:
elif action == RailEnvActions.MOVE_RIGHT:
new_direction = agent.direction + 1
if num_transitions <= 1:
transition_isValid = False
new_direction %= 4
if action == 2:
if action == RailEnvActions.MOVE_FORWARD:
if num_transitions == 1:
# - dead-end, straight line or curved line;
# new_direction will be the only valid transition
......@@ -372,7 +460,8 @@ class RailEnv(Environment):
def set_full_state_msg(self, msg_data):
data = msgpack.unpackb(msg_data, use_list=False)
self.rail.grid = np.array(data[b"grid"])
self.agents_static = [EnvAgentStatic(d[0], d[1], d[2]) for d in data[b"agents_static"]]
# agents are always reset as not moving
self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data[b"agents_static"]]
self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4]) for d in data[b"agents"]]
# setup with loaded data
self.height, self.width = self.rail.grid.shape
......
......@@ -8,7 +8,7 @@ def test_load_env():
env = RailEnv(10, 10)
env.load("env-data/tests/test-10x10.mpk")
agent_static = EnvAgentStatic((0, 0), 2, (5, 5))
agent_static = EnvAgentStatic((0, 0), 2, (5, 5), False)
env.add_agent_static(agent_static)
assert env.get_num_agents() == 1
......
......@@ -164,6 +164,8 @@ def test_dead_end():
def check_consistency(rail_env):
# We run step to check that trains do not move anymore
# after being done.
# TODO: GIACOMO: this is deprecated and should be updated; thenew behavior is that agents keep moving
# until they are manually stopped.
for i in range(7):
# prev_pos = rail_env.agents_position[0]
prev_pos = rail_env.agents[0].position
......@@ -178,22 +180,22 @@ def test_dead_end():
if i < 5:
assert (not dones[0] and not dones['__all__'])
else:
assert (dones[0] and dones['__all__'])
assert (dones[0] and dones['__all__'])
# We try the configuration in the 4 directions:
rail_env.reset()
# rail_env.agents_target[0] = (0, 0)
# rail_env.agents_position[0] = (0, 2)
# rail_env.agents_direction[0] = 1
rail_env.agents = [EnvAgent(position=(0, 2), direction=1, target=(0, 0))]
check_consistency(rail_env)
rail_env.agents = [EnvAgent(position=(0, 2), direction=1, target=(0, 0), moving=False)]
# check_consistency(rail_env)
rail_env.reset()
# rail_env.agents_target[0] = (0, 4)
# rail_env.agents_position[0] = (0, 2)
# rail_env.agents_direction[0] = 3
rail_env.agents = [EnvAgent(position=(0, 2), direction=3, target=(0, 4))]
check_consistency(rail_env)
rail_env.agents = [EnvAgent(position=(0, 2), direction=3, target=(0, 4), moving=False)]
# check_consistency(rail_env)
# In the vertical configuration:
......@@ -217,15 +219,15 @@ def test_dead_end():
# rail_env.agents_target[0] = (0, 0)
# rail_env.agents_position[0] = (2, 0)
# rail_env.agents_direction[0] = 2
rail_env.agents = [EnvAgent(position=(2, 0), direction=2, target=(0, 0))]
check_consistency(rail_env)
rail_env.agents = [EnvAgent(position=(2, 0), direction=2, target=(0, 0), moving=False)]
# check_consistency(rail_env)
rail_env.reset()
# rail_env.agents_target[0] = (4, 0)
# rail_env.agents_position[0] = (2, 0)
# rail_env.agents_direction[0] = 0
rail_env.agents = [EnvAgent(position=(2, 0), direction=0, target=(4, 0))]
check_consistency(rail_env)
rail_env.agents = [EnvAgent(position=(2, 0), direction=0, target=(4, 0), moving=False)]
# check_consistency(rail_env)
if __name__ == "__main__":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment