Skip to content
Snippets Groups Projects
Commit 20d0f8fb authored by spiglerg's avatar spiglerg
Browse files

tmp mods

parent 6fc2266b
No related branches found
No related tags found
No related merge requests found
...@@ -6,8 +6,8 @@ from flatland.utils.rendertools import RenderTool ...@@ -6,8 +6,8 @@ from flatland.utils.rendertools import RenderTool
from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.observations import TreeObsForRailEnv
import numpy as np import numpy as np
random.seed(100) random.seed(10)
np.random.seed(100) np.random.seed(10)
env = RailEnv(width=7, env = RailEnv(width=7,
height=7, height=7,
...@@ -24,7 +24,7 @@ obs, all_rewards, done, _ = env.step({0: 0}) ...@@ -24,7 +24,7 @@ obs, all_rewards, done, _ = env.step({0: 0})
for i in range(env.get_num_agents()): for i in range(env.get_num_agents()):
env.obs_builder.util_print_obs_subtree(tree=obs[i], num_features_per_node=5) env.obs_builder.util_print_obs_subtree(tree=obs[i], num_features_per_node=5)
env_renderer = RenderTool(env, gl="QT") env_renderer = RenderTool(env, gl="PIL")
env_renderer.renderEnv(show=True) env_renderer.renderEnv(show=True)
print("Manual control: s=perform step, q=quit, [agent id] [1-2-3 action] \ print("Manual control: s=perform step, q=quit, [agent id] [1-2-3 action] \
...@@ -52,4 +52,4 @@ for step in range(100): ...@@ -52,4 +52,4 @@ for step in range(100):
i = i + 1 i = i + 1
i += 1 i += 1
env_renderer.renderEnv(show=True) env_renderer.renderEnv(show=True, frames=True)
...@@ -29,17 +29,19 @@ class EnvAgentStatic(object): ...@@ -29,17 +29,19 @@ class EnvAgentStatic(object):
position = attrib() position = attrib()
direction = attrib() direction = attrib()
target = attrib() target = attrib()
moving = attrib()
def __init__(self, position, direction, target): def __init__(self, position, direction, target):
self.position = position self.position = position
self.direction = direction self.direction = direction
self.target = target self.target = target
self.moving = False
@classmethod @classmethod
def from_lists(cls, positions, directions, targets): def from_lists(cls, positions, directions, targets):
""" Create a list of EnvAgentStatics from lists of positions, directions and targets """ Create a list of EnvAgentStatics from lists of positions, directions and targets
""" """
return list(starmap(EnvAgentStatic, zip(positions, directions, targets))) return list(starmap(EnvAgentStatic, zip(positions, directions, targets, [False]*len(positions))))
def to_list(self): def to_list(self):
...@@ -53,7 +55,7 @@ class EnvAgentStatic(object): ...@@ -53,7 +55,7 @@ class EnvAgentStatic(object):
if type(lTarget) is np.ndarray: if type(lTarget) is np.ndarray:
lTarget = lTarget.tolist() lTarget = lTarget.tolist()
return [lPos, int(self.direction), lTarget] return [lPos, int(self.direction), lTarget, int(self.moving)]
@attrs @attrs
...@@ -77,7 +79,7 @@ class EnvAgent(EnvAgentStatic): ...@@ -77,7 +79,7 @@ class EnvAgent(EnvAgentStatic):
def to_list(self): def to_list(self):
return [ return [
self.position, self.direction, self.target, self.handle, self.position, self.direction, self.target, self.handle,
self.old_direction, self.old_position] self.old_direction, self.old_position, self.moving]
@classmethod @classmethod
def from_static(cls, oStatic): def from_static(cls, oStatic):
......
...@@ -9,6 +9,7 @@ a GridTransitionMap object. ...@@ -9,6 +9,7 @@ a GridTransitionMap object.
import msgpack import msgpack
import numpy as np import numpy as np
from enum import IntEnum
from flatland.core.env import Environment from flatland.core.env import Environment
from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent
...@@ -21,6 +22,14 @@ from flatland.envs.observations import TreeObsForRailEnv ...@@ -21,6 +22,14 @@ from flatland.envs.observations import TreeObsForRailEnv
# from flatland.core.transition_map import GridTransitionMap # from flatland.core.transition_map import GridTransitionMap
class RailEnvActions(IntEnum):
DO_NOTHING = 0
MOVE_LEFT = 1
MOVE_FORWARD = 2
MOVE_RIGHT = 3
STOP_MOVING = 4
class RailEnv(Environment): class RailEnv(Environment):
""" """
RailEnv environment class. RailEnv environment class.
...@@ -32,9 +41,10 @@ class RailEnv(Environment): ...@@ -32,9 +41,10 @@ class RailEnv(Environment):
The valid actions in the environment are: The valid actions in the environment are:
0: do nothing 0: do nothing
1: turn left and move to the next cell 1: turn left and move to the next cell; if the agent was not moving, movement is started
2: move to the next cell in front of the agent 2: move to the next cell in front of the agent; if the agent was not moving, movement is started
3: turn right and move to the next cell 3: turn right and move to the next cell; if the agent was not moving, movement is started
4: stop moving
Moving forward in a dead-end cell makes the agent turn 180 degrees and step Moving forward in a dead-end cell makes the agent turn 180 degrees and step
to the cell it came from. to the cell it came from.
...@@ -176,7 +186,7 @@ class RailEnv(Environment): ...@@ -176,7 +186,7 @@ class RailEnv(Environment):
alpha = 1.0 alpha = 1.0
beta = 1.0 beta = 1.0
invalid_action_penalty = -2 invalid_action_penalty = 0 # -2 GIACOMO: we decided that invalid actions will carry no penalty
step_penalty = -1 * alpha step_penalty = -1 * alpha
global_reward = 1 * beta global_reward = 1 * beta
...@@ -198,7 +208,11 @@ class RailEnv(Environment): ...@@ -198,7 +208,11 @@ class RailEnv(Environment):
agent = self.agents[iAgent] agent = self.agents[iAgent]
if iAgent not in action_dict: # no action has been supplied for this agent if iAgent not in action_dict: # no action has been supplied for this agent
continue if agent.moving:
# Keep moving
action_dict[iAgent] = RailEnvActions.MOVE_FORWARD
else:
action_dict[iAgent] = RailEnvActions.DO_NOTHING
if self.dones[iAgent]: # this agent has already completed... if self.dones[iAgent]: # this agent has already completed...
# print("rail_env.py @", currentframe().f_back.f_lineno, " agent ", iAgent, # print("rail_env.py @", currentframe().f_back.f_lineno, " agent ", iAgent,
...@@ -206,12 +220,27 @@ class RailEnv(Environment): ...@@ -206,12 +220,27 @@ class RailEnv(Environment):
continue continue
action = action_dict[iAgent] action = action_dict[iAgent]
if action < 0 or action > 3: if action < 0 or action > len(RailEnvActions):
print('ERROR: illegal action=', action, print('ERROR: illegal action=', action,
'for agent with index=', iAgent) 'for agent with index=', iAgent)
return return
if action > 0: if action == RailEnvActions.DO_NOTHING and agent.moving:
# Keep moving
action_dict[iAgent] = RailEnvActions.MOVE_FORWARD
action = RailEnvActions.MOVE_FORWARD
if action == RailEnvActions.STOP_MOVING and agent.moving:
action_dict[iAgent] = RailEnvActions.DO_NOTHING
action = RailEnvActions.DO_NOTHING
agent.moving = False
# TODO: possibly, penalty for stopping!
if not agent.moving and (action == RailEnvActions.MOVE_LEFT or action == RailEnvActions.MOVE_FORWARD or action == RailEnvActions.MOVE_RIGHT):
agent.moving = True
# TODO: possibly, may add a penalty for starting, but the best is only for stopping (GIACOMO's opinion)
if action != RailEnvActions.DO_NOTHING and action != RailEnvActions.STOP_MOVING:
# pos = agent.position # self.agents_position[i] # pos = agent.position # self.agents_position[i]
# direction = agent.direction # self.agents_direction[i] # direction = agent.direction # self.agents_direction[i]
...@@ -293,7 +322,7 @@ class RailEnv(Environment): ...@@ -293,7 +322,7 @@ class RailEnv(Environment):
# Reset the step actions (in case some agent doesn't 'register_action' # Reset the step actions (in case some agent doesn't 'register_action'
# on the next step) # on the next step)
self.actions = [0] * self.get_num_agents() self.actions = [RailEnvActions.DO_NOTHING] * self.get_num_agents()
return self._get_observations(), self.rewards_dict, self.dones, {} return self._get_observations(), self.rewards_dict, self.dones, {}
def check_action(self, agent, action): def check_action(self, agent, action):
...@@ -303,19 +332,19 @@ class RailEnv(Environment): ...@@ -303,19 +332,19 @@ class RailEnv(Environment):
new_direction = agent.direction new_direction = agent.direction
# print(nbits,np.sum(possible_transitions)) # print(nbits,np.sum(possible_transitions))
if action == 1: if action == RailEnvActions.MOVE_LEFT:
new_direction = agent.direction - 1 new_direction = agent.direction - 1
if num_transitions <= 1: if num_transitions <= 1:
transition_isValid = False transition_isValid = False
elif action == 3: elif action == RailEnvActions.MOVE_RIGHT:
new_direction = agent.direction + 1 new_direction = agent.direction + 1
if num_transitions <= 1: if num_transitions <= 1:
transition_isValid = False transition_isValid = False
new_direction %= 4 new_direction %= 4
if action == 2: if action == RailEnvActions.MOVE_FORWARD:
if num_transitions == 1: if num_transitions == 1:
# - dead-end, straight line or curved line; # - dead-end, straight line or curved line;
# new_direction will be the only valid transition # new_direction will be the only valid transition
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment