diff --git a/examples/simple_example_3.py b/examples/simple_example_3.py index 1978e27c8511154d317ef5c96c3a9bc6816168c5..8c98eea7f1810d29522364703a496a1fb883fb36 100644 --- a/examples/simple_example_3.py +++ b/examples/simple_example_3.py @@ -6,8 +6,8 @@ from flatland.utils.rendertools import RenderTool from flatland.envs.observations import TreeObsForRailEnv import numpy as np -random.seed(100) -np.random.seed(100) +random.seed(10) +np.random.seed(10) env = RailEnv(width=7, height=7, @@ -24,7 +24,7 @@ obs, all_rewards, done, _ = env.step({0: 0}) for i in range(env.get_num_agents()): env.obs_builder.util_print_obs_subtree(tree=obs[i], num_features_per_node=5) -env_renderer = RenderTool(env, gl="QT") +env_renderer = RenderTool(env, gl="PIL") env_renderer.renderEnv(show=True) print("Manual control: s=perform step, q=quit, [agent id] [1-2-3 action] \ @@ -52,4 +52,4 @@ for step in range(100): i = i + 1 i += 1 - env_renderer.renderEnv(show=True) + env_renderer.renderEnv(show=True, frames=True) diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py index a4dc6962ebd69b38008c684e8015fc2c3e1b5d43..5105ec5dfb792895a7361060431d84d8ceb5c39e 100644 --- a/flatland/envs/agent_utils.py +++ b/flatland/envs/agent_utils.py @@ -29,17 +29,19 @@ class EnvAgentStatic(object): position = attrib() direction = attrib() target = attrib() + moving = attrib() def __init__(self, position, direction, target): self.position = position self.direction = direction self.target = target + self.moving = False @classmethod def from_lists(cls, positions, directions, targets): """ Create a list of EnvAgentStatics from lists of positions, directions and targets """ - return list(starmap(EnvAgentStatic, zip(positions, directions, targets))) + return list(starmap(EnvAgentStatic, zip(positions, directions, targets, [False]*len(positions)))) def to_list(self): @@ -53,7 +55,7 @@ class EnvAgentStatic(object): if type(lTarget) is np.ndarray: lTarget = lTarget.tolist() - return [lPos, int(self.direction), lTarget] + return [lPos, int(self.direction), lTarget, int(self.moving)] @attrs @@ -77,7 +79,7 @@ class EnvAgent(EnvAgentStatic): def to_list(self): return [ self.position, self.direction, self.target, self.handle, - self.old_direction, self.old_position] + self.old_direction, self.old_position, self.moving] @classmethod def from_static(cls, oStatic): diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index a23b6ac71651e8e424bb90a23cbcfd472ce89a12..ba4e9e46e8d5be9207e6bad64d0f7f364840f4d4 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -9,6 +9,7 @@ a GridTransitionMap object. import msgpack import numpy as np +from enum import IntEnum from flatland.core.env import Environment from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent @@ -21,6 +22,14 @@ from flatland.envs.observations import TreeObsForRailEnv # from flatland.core.transition_map import GridTransitionMap +class RailEnvActions(IntEnum): + DO_NOTHING = 0 + MOVE_LEFT = 1 + MOVE_FORWARD = 2 + MOVE_RIGHT = 3 + STOP_MOVING = 4 + + class RailEnv(Environment): """ RailEnv environment class. @@ -32,9 +41,10 @@ class RailEnv(Environment): The valid actions in the environment are: 0: do nothing - 1: turn left and move to the next cell - 2: move to the next cell in front of the agent - 3: turn right and move to the next cell + 1: turn left and move to the next cell; if the agent was not moving, movement is started + 2: move to the next cell in front of the agent; if the agent was not moving, movement is started + 3: turn right and move to the next cell; if the agent was not moving, movement is started + 4: stop moving Moving forward in a dead-end cell makes the agent turn 180 degrees and step to the cell it came from. @@ -176,7 +186,7 @@ class RailEnv(Environment): alpha = 1.0 beta = 1.0 - invalid_action_penalty = -2 + invalid_action_penalty = 0 # -2 GIACOMO: we decided that invalid actions will carry no penalty step_penalty = -1 * alpha global_reward = 1 * beta @@ -198,7 +208,11 @@ class RailEnv(Environment): agent = self.agents[iAgent] if iAgent not in action_dict: # no action has been supplied for this agent - continue + if agent.moving: + # Keep moving + action_dict[iAgent] = RailEnvActions.MOVE_FORWARD + else: + action_dict[iAgent] = RailEnvActions.DO_NOTHING if self.dones[iAgent]: # this agent has already completed... # print("rail_env.py @", currentframe().f_back.f_lineno, " agent ", iAgent, @@ -206,12 +220,27 @@ class RailEnv(Environment): continue action = action_dict[iAgent] - if action < 0 or action > 3: + if action < 0 or action > len(RailEnvActions): print('ERROR: illegal action=', action, 'for agent with index=', iAgent) return - if action > 0: + if action == RailEnvActions.DO_NOTHING and agent.moving: + # Keep moving + action_dict[iAgent] = RailEnvActions.MOVE_FORWARD + action = RailEnvActions.MOVE_FORWARD + + if action == RailEnvActions.STOP_MOVING and agent.moving: + action_dict[iAgent] = RailEnvActions.DO_NOTHING + action = RailEnvActions.DO_NOTHING + agent.moving = False + # TODO: possibly, penalty for stopping! + + if not agent.moving and (action == RailEnvActions.MOVE_LEFT or action == RailEnvActions.MOVE_FORWARD or action == RailEnvActions.MOVE_RIGHT): + agent.moving = True + # TODO: possibly, may add a penalty for starting, but the best is only for stopping (GIACOMO's opinion) + + if action != RailEnvActions.DO_NOTHING and action != RailEnvActions.STOP_MOVING: # pos = agent.position # self.agents_position[i] # direction = agent.direction # self.agents_direction[i] @@ -293,7 +322,7 @@ class RailEnv(Environment): # Reset the step actions (in case some agent doesn't 'register_action' # on the next step) - self.actions = [0] * self.get_num_agents() + self.actions = [RailEnvActions.DO_NOTHING] * self.get_num_agents() return self._get_observations(), self.rewards_dict, self.dones, {} def check_action(self, agent, action): @@ -303,19 +332,19 @@ class RailEnv(Environment): new_direction = agent.direction # print(nbits,np.sum(possible_transitions)) - if action == 1: + if action == RailEnvActions.MOVE_LEFT: new_direction = agent.direction - 1 if num_transitions <= 1: transition_isValid = False - elif action == 3: + elif action == RailEnvActions.MOVE_RIGHT: new_direction = agent.direction + 1 if num_transitions <= 1: transition_isValid = False new_direction %= 4 - if action == 2: + if action == RailEnvActions.MOVE_FORWARD: if num_transitions == 1: # - dead-end, straight line or curved line; # new_direction will be the only valid transition