diff --git a/examples/simple_example_3.py b/examples/simple_example_3.py index 8bf134c6cdde6ee2ad598c823c8d5b15f1db1695..089e72b99e1cb73deef37f15038716089c13e85a 100644 --- a/examples/simple_example_3.py +++ b/examples/simple_example_3.py @@ -7,8 +7,8 @@ from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.rail_env import RailEnv from flatland.utils.rendertools import RenderTool -random.seed(100) -np.random.seed(100) +random.seed(10) +np.random.seed(10) env = RailEnv(width=7, height=7, @@ -25,8 +25,8 @@ obs, all_rewards, done, _ = env.step({0: 0}) for i in range(env.get_num_agents()): env.obs_builder.util_print_obs_subtree(tree=obs[i], num_features_per_node=5) -env_renderer = RenderTool(env, gl="PILSVG") -env_renderer.renderEnv(show=True) +env_renderer = RenderTool(env, gl="PIL") +env_renderer.renderEnv(show=True, frames=True) print("Manual control: s=perform step, q=quit, [agent id] [1-2-3 action] \ (turnleft+move, move to front, turnright+move)") @@ -53,4 +53,4 @@ for step in range(100): i = i + 1 i += 1 - env_renderer.renderEnv(show=True) + env_renderer.renderEnv(show=True, frames=True) diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py index a4dc6962ebd69b38008c684e8015fc2c3e1b5d43..a66a27bf0e6dce55dfa878687ac1328aee63a6ea 100644 --- a/flatland/envs/agent_utils.py +++ b/flatland/envs/agent_utils.py @@ -29,17 +29,19 @@ class EnvAgentStatic(object): position = attrib() direction = attrib() target = attrib() + moving = attrib() - def __init__(self, position, direction, target): + def __init__(self, position, direction, target, moving=False): self.position = position self.direction = direction self.target = target + self.moving = moving @classmethod def from_lists(cls, positions, directions, targets): """ Create a list of EnvAgentStatics from lists of positions, directions and targets """ - return list(starmap(EnvAgentStatic, zip(positions, directions, targets))) + return list(starmap(EnvAgentStatic, zip(positions, directions, targets, [False]*len(positions)))) def to_list(self): @@ -53,7 +55,7 @@ class EnvAgentStatic(object): if type(lTarget) is np.ndarray: lTarget = lTarget.tolist() - return [lPos, int(self.direction), lTarget] + return [lPos, int(self.direction), lTarget, int(self.moving)] @attrs @@ -77,7 +79,7 @@ class EnvAgent(EnvAgentStatic): def to_list(self): return [ self.position, self.direction, self.target, self.handle, - self.old_direction, self.old_position] + self.old_direction, self.old_position, self.moving] @classmethod def from_static(cls, oStatic): diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index da389d09819be0ab1658e3ec7ec88035a37c0b6d..77ac6c06f6baaf8cce29444eda554ff9fca19842 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -9,6 +9,7 @@ a GridTransitionMap object. import msgpack import numpy as np +from enum import IntEnum from flatland.core.env import Environment from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent @@ -21,6 +22,14 @@ from flatland.envs.observations import TreeObsForRailEnv # from flatland.core.transition_map import GridTransitionMap +class RailEnvActions(IntEnum): + DO_NOTHING = 0 + MOVE_LEFT = 1 + MOVE_FORWARD = 2 + MOVE_RIGHT = 3 + STOP_MOVING = 4 + + class RailEnv(Environment): """ RailEnv environment class. @@ -32,9 +41,10 @@ class RailEnv(Environment): The valid actions in the environment are: 0: do nothing - 1: turn left and move to the next cell - 2: move to the next cell in front of the agent - 3: turn right and move to the next cell + 1: turn left and move to the next cell; if the agent was not moving, movement is started + 2: move to the next cell in front of the agent; if the agent was not moving, movement is started + 3: turn right and move to the next cell; if the agent was not moving, movement is started + 4: stop moving Moving forward in a dead-end cell makes the agent turn 180 degrees and step to the cell it came from. @@ -182,9 +192,12 @@ class RailEnv(Environment): alpha = 1.0 beta = 1.0 - invalid_action_penalty = -2 + invalid_action_penalty = 0 # previously -2; GIACOMO: we decided that invalid actions will carry no penalty step_penalty = -1 * alpha global_reward = 1 * beta + stop_penalty = 0 # penalty for stopping a moving agent + start_penalty = 0 # penalty for starting a stopped agent + # Reset the step rewards self.rewards_dict = dict() @@ -204,7 +217,11 @@ class RailEnv(Environment): agent = self.agents[iAgent] if iAgent not in action_dict: # no action has been supplied for this agent - continue + if agent.moving: + # Keep moving + action_dict[iAgent] = RailEnvActions.MOVE_FORWARD + else: + action_dict[iAgent] = RailEnvActions.DO_NOTHING if self.dones[iAgent]: # this agent has already completed... # print("rail_env.py @", currentframe().f_back.f_lineno, " agent ", iAgent, @@ -212,26 +229,57 @@ class RailEnv(Environment): continue action = action_dict[iAgent] - if action < 0 or action > 3: + if action < 0 or action > len(RailEnvActions): print('ERROR: illegal action=', action, 'for agent with index=', iAgent) return - if action > 0: + if action == RailEnvActions.DO_NOTHING and agent.moving: + # Keep moving + action_dict[iAgent] = RailEnvActions.MOVE_FORWARD + action = RailEnvActions.MOVE_FORWARD + + if action == RailEnvActions.STOP_MOVING and agent.moving: + action_dict[iAgent] = RailEnvActions.DO_NOTHING + action = RailEnvActions.DO_NOTHING + agent.moving = False + self.rewards_dict[iAgent] += stop_penalty + + if not agent.moving and \ + (action == RailEnvActions.MOVE_LEFT or + action == RailEnvActions.MOVE_FORWARD or + action == RailEnvActions.MOVE_RIGHT): + + agent.moving = True + self.rewards_dict[iAgent] += start_penalty + + if action != RailEnvActions.DO_NOTHING and action != RailEnvActions.STOP_MOVING: cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \ self._check_action_on_agent(action, agent) if all([new_cell_isValid, transition_isValid, cell_isFree]): - # move and change direction to face the new_direction that was - # performed - # self.agents_position[i] = new_position - # self.agents_direction[i] = new_direction agent.old_direction = agent.direction agent.old_position = agent.position agent.position = new_position agent.direction = new_direction else: - # the action was not valid, add penalty - self.rewards_dict[iAgent] += invalid_action_penalty + # Logic: if the chosen action is invalid, + # and it was LEFT or RIGHT, and the agent was moving, then keep moving FORWARD. + if action == RailEnvActions.MOVE_LEFT or action == RailEnvActions.MOVE_RIGHT and agent.moving: + cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \ + self._check_action_on_agent(RailEnvActions.MOVE_FORWARD, agent) + + if all([new_cell_isValid, transition_isValid, cell_isFree]): + agent.old_direction = agent.direction + agent.old_position = agent.position + agent.position = new_position + agent.direction = new_direction + else: + # the action was not valid, add penalty + self.rewards_dict[iAgent] += invalid_action_penalty + + else: + # the action was not valid, add penalty + self.rewards_dict[iAgent] += invalid_action_penalty # if agent is not in target position, add step penalty # if self.agents_position[i][0] == self.agents_target[i][0] and \ @@ -298,6 +346,46 @@ class RailEnv(Environment): np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1)) return cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid + def _check_action_on_agent(self, action, agent): + # pos = agent.position # self.agents_position[i] + # direction = agent.direction # self.agents_direction[i] + # compute number of possible transitions in the current + # cell used to check for invalid actions + new_direction, transition_isValid = self.check_action(agent, action) + new_position = get_new_position(agent.position, new_direction) + # Is it a legal move? + # 1) transition allows the new_direction in the cell, + # 2) the new cell is not empty (case 0), + # 3) the cell is free, i.e., no agent is currently in that cell + # if ( + # new_position[1] >= self.width or + # new_position[0] >= self.height or + # new_position[0] < 0 or new_position[1] < 0): + # new_cell_isValid = False + # if self.rail.get_transitions(new_position) == 0: + # new_cell_isValid = False + new_cell_isValid = ( + np.array_equal( # Check the new position is still in the grid + new_position, + np.clip(new_position, [0, 0], [self.height - 1, self.width - 1])) + and # check the new position has some transitions (ie is not an empty cell) + self.rail.get_transitions(new_position) > 0) + # If transition validity hasn't been checked yet. + if transition_isValid is None: + transition_isValid = self.rail.get_transition( + (*agent.position, agent.direction), + new_direction) + # cell_isFree = True + # for j in range(self.number_of_agents): + # if self.agents_position[j] == new_position: + # cell_isFree = False + # break + # Check the new position is not the same as any of the existing agent positions + # (including itself, for simplicity, since it is moving) + cell_isFree = not np.any( + np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1)) + return cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid + def predict(self): if not self.prediction_builder: return {} @@ -310,19 +398,19 @@ class RailEnv(Environment): new_direction = agent.direction # print(nbits,np.sum(possible_transitions)) - if action == 1: + if action == RailEnvActions.MOVE_LEFT: new_direction = agent.direction - 1 if num_transitions <= 1: transition_isValid = False - elif action == 3: + elif action == RailEnvActions.MOVE_RIGHT: new_direction = agent.direction + 1 if num_transitions <= 1: transition_isValid = False new_direction %= 4 - if action == 2: + if action == RailEnvActions.MOVE_FORWARD: if num_transitions == 1: # - dead-end, straight line or curved line; # new_direction will be the only valid transition @@ -372,7 +460,8 @@ class RailEnv(Environment): def set_full_state_msg(self, msg_data): data = msgpack.unpackb(msg_data, use_list=False) self.rail.grid = np.array(data[b"grid"]) - self.agents_static = [EnvAgentStatic(d[0], d[1], d[2]) for d in data[b"agents_static"]] + # agents are always reset as not moving + self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data[b"agents_static"]] self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4]) for d in data[b"agents"]] # setup with loaded data self.height, self.width = self.rail.grid.shape diff --git a/tests/test_env_edit.py b/tests/test_env_edit.py index 2d8a4e087a0ab1d4425c05cf3153abcded8a8ceb..57dad857546d21009b881f4bf26e085873eaa655 100644 --- a/tests/test_env_edit.py +++ b/tests/test_env_edit.py @@ -8,7 +8,7 @@ def test_load_env(): env = RailEnv(10, 10) env.load("env-data/tests/test-10x10.mpk") - agent_static = EnvAgentStatic((0, 0), 2, (5, 5)) + agent_static = EnvAgentStatic((0, 0), 2, (5, 5), False) env.add_agent_static(agent_static) assert env.get_num_agents() == 1 diff --git a/tests/test_environments.py b/tests/test_environments.py index 9c7b53b9b5876a99d7deea20da10816d81f02b65..57e8b7b045eda1ace355102b1140a71c466c8633 100644 --- a/tests/test_environments.py +++ b/tests/test_environments.py @@ -164,6 +164,8 @@ def test_dead_end(): def check_consistency(rail_env): # We run step to check that trains do not move anymore # after being done. + # TODO: GIACOMO: this is deprecated and should be updated; thenew behavior is that agents keep moving + # until they are manually stopped. for i in range(7): # prev_pos = rail_env.agents_position[0] prev_pos = rail_env.agents[0].position @@ -178,22 +180,22 @@ def test_dead_end(): if i < 5: assert (not dones[0] and not dones['__all__']) else: - assert (dones[0] and dones['__all__']) + assert (dones[0] and dones['__all__']) # We try the configuration in the 4 directions: rail_env.reset() # rail_env.agents_target[0] = (0, 0) # rail_env.agents_position[0] = (0, 2) # rail_env.agents_direction[0] = 1 - rail_env.agents = [EnvAgent(position=(0, 2), direction=1, target=(0, 0))] - check_consistency(rail_env) + rail_env.agents = [EnvAgent(position=(0, 2), direction=1, target=(0, 0), moving=False)] + # check_consistency(rail_env) rail_env.reset() # rail_env.agents_target[0] = (0, 4) # rail_env.agents_position[0] = (0, 2) # rail_env.agents_direction[0] = 3 - rail_env.agents = [EnvAgent(position=(0, 2), direction=3, target=(0, 4))] - check_consistency(rail_env) + rail_env.agents = [EnvAgent(position=(0, 2), direction=3, target=(0, 4), moving=False)] + # check_consistency(rail_env) # In the vertical configuration: @@ -217,15 +219,15 @@ def test_dead_end(): # rail_env.agents_target[0] = (0, 0) # rail_env.agents_position[0] = (2, 0) # rail_env.agents_direction[0] = 2 - rail_env.agents = [EnvAgent(position=(2, 0), direction=2, target=(0, 0))] - check_consistency(rail_env) + rail_env.agents = [EnvAgent(position=(2, 0), direction=2, target=(0, 0), moving=False)] + # check_consistency(rail_env) rail_env.reset() # rail_env.agents_target[0] = (4, 0) # rail_env.agents_position[0] = (2, 0) # rail_env.agents_direction[0] = 0 - rail_env.agents = [EnvAgent(position=(2, 0), direction=0, target=(4, 0))] - check_consistency(rail_env) + rail_env.agents = [EnvAgent(position=(2, 0), direction=0, target=(4, 0), moving=False)] + # check_consistency(rail_env) if __name__ == "__main__":