Skip to content
Snippets Groups Projects
Commit 9ae54ce7 authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

Merge branch 'master' of gitlab.aicrowd.com:flatland/flatland into 264_fixing_examples

# Conflicts:
#	flatland/envs/rail_env.py
parents 5e89c7d2 4ed208c4
No related branches found
No related tags found
No related merge requests found
......@@ -78,6 +78,30 @@ class TreeObsForRailEnv(ObservationBuilder):
self.predicted_pos.update({t: coordinate_to_position(self.env.width, pos_list)})
self.predicted_dir.update({t: dir_list})
self.max_prediction_depth = len(self.predicted_pos)
# Update local lookup table for all agents' positions
# ignore other agents not in the grid (only status active and done)
# self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents if
# agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE]}
self.location_has_agent = {}
self.location_has_agent_direction = {}
self.location_has_agent_speed = {}
self.location_has_agent_malfunction = {}
self.location_has_agent_ready_to_depart = {}
for _agent in self.env.agents:
if _agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE] and \
_agent.position:
self.location_has_agent[tuple(_agent.position)] = 1
self.location_has_agent_direction[tuple(_agent.position)] = _agent.direction
self.location_has_agent_speed[tuple(_agent.position)] = _agent.speed_data['speed']
self.location_has_agent_malfunction[tuple(_agent.position)] = _agent.malfunction_data[
'malfunction']
if _agent.status in [RailAgentStatus.READY_TO_DEPART] and \
_agent.initial_position:
self.location_has_agent_ready_to_depart[tuple(_agent.initial_position)] = \
self.location_has_agent_ready_to_depart.get(tuple(_agent.initial_position), 0) + 1
observations = super().get_many(handles)
......@@ -162,30 +186,6 @@ class TreeObsForRailEnv(ObservationBuilder):
In case the target node is reached, the values are [0, 0, 0, 0, 0].
"""
# Update local lookup table for all agents' positions
# ignore other agents not in the grid (only status active and done)
# self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents if
# agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE]}
self.location_has_agent = {}
self.location_has_agent_direction = {}
self.location_has_agent_speed = {}
self.location_has_agent_malfunction = {}
self.location_has_agent_ready_to_depart = {}
for _agent in self.env.agents:
if _agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE] and \
_agent.position:
self.location_has_agent[tuple(_agent.position)] = 1
self.location_has_agent_direction[tuple(_agent.position)] = _agent.direction
self.location_has_agent_speed[tuple(_agent.position)] = _agent.speed_data['speed']
self.location_has_agent_malfunction[tuple(_agent.position)] = _agent.malfunction_data['malfunction']
if _agent.status in [RailAgentStatus.READY_TO_DEPART] and \
_agent.initial_position:
self.location_has_agent_ready_to_depart[tuple(_agent.initial_position)] = \
self.location_has_agent_ready_to_depart.get(tuple(_agent.initial_position), 0) + 1
if handle > len(self.env.agents):
print("ERROR: obs _get - handle ", handle, " len(agents)", len(self.env.agents))
agent = self.env.agents[handle] # TODO: handle being treated as index
......
......@@ -14,6 +14,7 @@ from flatland.core.env import Environment
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.grid.grid4 import Grid4TransitionsEnum, Grid4Transitions
from flatland.core.grid.grid4_utils import get_new_position
from flatland.core.grid.grid_utils import IntVector2D
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent, RailAgentStatus
from flatland.envs.distance_map import DistanceMap
......@@ -83,6 +84,7 @@ class RailEnv(Environment):
- invalid_action_penalty = 0
- step_penalty = -alpha
- global_reward = beta
- epsilon = avoid rounding errors
- stop_penalty = 0 # penalty for stopping a moving agent
- start_penalty = 0 # penalty for starting a stopped agent
......@@ -217,6 +219,9 @@ class RailEnv(Environment):
self.valid_positions = None
# global numpy array of agents position, True means that there is an agent at that cell
self.agent_positions: np.ndarray = np.full((height, width), False)
def _seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]
......@@ -242,7 +247,7 @@ class RailEnv(Environment):
agent = self.agents[handle]
if agent.status == RailAgentStatus.READY_TO_DEPART and self.cell_free(agent.initial_position):
agent.status = RailAgentStatus.ACTIVE
agent.position = agent.initial_position
self._set_agent_to_initial_position(agent, agent.initial_position)
def restart_agents(self):
""" Reset the agents to their starting positions defined in agents_static
......@@ -275,6 +280,24 @@ class RailEnv(Environment):
alpha = 2
return int(timedelay_factor * alpha * (width + height + ratio_nr_agents_to_nr_cities))
def action_required(self, agent):
"""
Check if an agent needs to provide an action
Parameters
----------
agent: RailEnvAgent
Agent we want to check
Returns
-------
True: Agent needs to provide an action
False: Agent cannot provide an action
"""
return (agent.status == RailAgentStatus.READY_TO_DEPART or (
agent.status == RailAgentStatus.ACTIVE and np.isclose(agent.speed_data['position_fraction'], 0.0,
rtol=1e-03)))
def reset(self, regenerate_rail: bool = True, regenerate_schedule: bool = True, activate_agents: bool = False,
random_seed: bool = None) -> (Dict, Dict):
"""
......@@ -339,6 +362,8 @@ class RailEnv(Environment):
else:
self._max_episode_steps = self.compute_max_episode_steps(width=self.width, height=self.height)
self.agent_positions = np.full((self.height, self.width), False)
self.restart_agents()
if activate_agents:
......@@ -370,10 +395,7 @@ class RailEnv(Environment):
self.distance_map.reset(self.agents, self.rail)
info_dict: Dict = {
'action_required': {
i: (agent.status == RailAgentStatus.READY_TO_DEPART or (
agent.status == RailAgentStatus.ACTIVE and agent.speed_data['position_fraction'] == 0.0))
for i, agent in enumerate(self.agents)},
'action_required': {i: self.action_required(agent) for i, agent in enumerate(self.agents)},
'malfunction': {
i: self.agents[i].malfunction_data['malfunction'] for i in range(self.get_num_agents())
},
......@@ -390,26 +412,27 @@ class RailEnv(Environment):
"""
agent = self.agents[i_agent]
# Skip agents that cannot break
# TODO: Make a better malfunction model such that not always the same agents break.
if agent.malfunction_data['malfunction_rate'] < 1:
return False
# Decrease counter for next event only if agent is currently not broken and agent has a malfunction rate
if agent.malfunction_data['malfunction_rate'] >= 1 and agent.malfunction_data['next_malfunction'] > 0 and \
agent.malfunction_data['malfunction'] < 1:
agent.malfunction_data['next_malfunction'] -= 1
# If agent is currently working and next malfunction time is reached we set it to malfunctioning
if 1 > agent.malfunction_data['malfunction'] and agent.malfunction_data['next_malfunction'] < 1:
# Only agents that have a positive rate for malfunctions and are not currently broken are considered
# If counter has come to zero --> Agent has malfunction
# set next malfunction time and duration of current malfunction
if agent.malfunction_data['malfunction_rate'] >= 1 and 1 > agent.malfunction_data['malfunction'] and \
agent.malfunction_data['next_malfunction'] < 1:
# Increase number of malfunctions
agent.malfunction_data['nr_malfunctions'] += 1
# Next malfunction in number of steps
# Next malfunction in number of stops
next_breakdown = int(
self._exp_distirbution_synced(rate=agent.malfunction_data['malfunction_rate']))
agent.malfunction_data['next_malfunction'] = max(next_breakdown, 1)
# Duration of current malfunction
num_broken_steps = self.np_random.randint(self.min_number_of_steps_broken,
self.max_number_of_steps_broken + 1)
self.max_number_of_steps_broken + 1) + 1
agent.malfunction_data['malfunction'] = num_broken_steps
# Remember current moving state of the agent
agent.malfunction_data['moving_before_malfunction'] = agent.moving
return True
......@@ -429,16 +452,17 @@ class RailEnv(Environment):
# Nothing left to do with broken agent
return True
# Decrease counter for next event only if agent is currently not broken and agent has a malfunction rate
if agent.malfunction_data['malfunction_rate'] >= 1 and agent.malfunction_data['next_malfunction'] > 0 and \
agent.malfunction_data['malfunction'] < 1:
agent.malfunction_data['next_malfunction'] -= 1
return False
def step(self, action_dict_: Dict[int, RailEnvActions]):
"""
Updates rewards for the agents at a step.
Parameters
----------
action_dict_ : Dict[int,RailEnvActions]
"""
self._elapsed_steps += 1
# If we're done, set reward and info_dict and step() is done.
......@@ -479,10 +503,7 @@ class RailEnv(Environment):
have_all_agents_ended &= (agent.status in [RailAgentStatus.DONE, RailAgentStatus.DONE_REMOVED])
# Build info dict
info_dict["action_required"][i_agent] = \
(agent.status == RailAgentStatus.READY_TO_DEPART or (
agent.status == RailAgentStatus.ACTIVE and np.isclose(agent.speed_data['position_fraction'], 0.0,
rtol=1e-03)))
info_dict["action_required"][i_agent] = self.action_required(agent)
info_dict["malfunction"][i_agent] = agent.malfunction_data['malfunction']
info_dict["speed"][i_agent] = agent.speed_data['speed']
info_dict["status"][i_agent] = agent.status
......@@ -520,7 +541,7 @@ class RailEnv(Environment):
if action in [RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT,
RailEnvActions.MOVE_FORWARD] and self.cell_free(agent.initial_position):
agent.status = RailAgentStatus.ACTIVE
agent.position = agent.initial_position
self._set_agent_to_initial_position(agent, agent.initial_position)
self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
return
else:
......@@ -615,7 +636,7 @@ class RailEnv(Environment):
assert new_cell_valid
assert transition_valid
if cell_free:
agent.position = new_position
self._move_agent_to_new_position(agent, new_position)
agent.direction = new_direction
agent.speed_data['position_fraction'] = 0.0
......@@ -624,16 +645,54 @@ class RailEnv(Environment):
agent.status = RailAgentStatus.DONE
self.dones[i_agent] = True
agent.moving = False
if self.remove_agents_at_target:
agent.position = None
agent.status = RailAgentStatus.DONE_REMOVED
self._remove_agent_from_scene(agent)
else:
self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
else:
# step penalty if not moving (stopped now or before)
self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
def _set_agent_to_initial_position(self, agent: EnvAgent, new_position: IntVector2D):
"""
Sets the agent to its initial position. Updates the agent object and the position
of the agent inside the global agent_position numpy array
Parameters
-------
agent: EnvAgent object
new_position: IntVector2D
"""
agent.position = new_position
self.agent_positions[agent.position] = True
def _move_agent_to_new_position(self, agent: EnvAgent, new_position: IntVector2D):
"""
Move the agent to the a new position. Updates the agent object and the position
of the agent inside the global agent_position numpy array
Parameters
-------
agent: EnvAgent object
new_position: IntVector2D
"""
agent.position = new_position
self.agent_positions[agent.old_position] = False
self.agent_positions[agent.position] = True
def _remove_agent_from_scene(self, agent: EnvAgent):
"""
Remove the agent from the scene. Updates the agent object and the position
of the agent inside the global agent_position numpy array
Parameters
-------
agent: EnvAgent object
"""
self.agent_positions[agent.position] = False
if self.remove_agents_at_target:
agent.position = None
agent.status = RailAgentStatus.DONE_REMOVED
def _check_action_on_agent(self, action: RailEnvActions, agent: EnvAgent):
"""
......@@ -670,16 +729,32 @@ class RailEnv(Environment):
(*agent.position, agent.direction),
new_direction)
# Check the new position is not the same as any of the existing agent positions
# (including itself, for simplicity, since it is moving)
cell_free = self.cell_free(new_position)
# only call cell_free() if new cell is inside the scene
if new_cell_valid:
# Check the new position is not the same as any of the existing agent positions
# (including itself, for simplicity, since it is moving)
cell_free = self.cell_free(new_position)
else:
# if new cell is outside of scene -> cell_free is False
cell_free = False
return cell_free, new_cell_valid, new_direction, new_position, transition_valid
def cell_free(self, position):
def cell_free(self, position: IntVector2D) -> bool:
"""
Utility to check if a cell is free
Parameters:
--------
position : Tuple[int, int]
Returns
-------
bool
is the cell free or not?
agent_positions = [agent.position for agent in self.agents if agent.position is not None]
ret = len(agent_positions) == 0 or not np.any(np.equal(position, agent_positions).all(1))
return ret
"""
return not self.agent_positions[position]
def check_action(self, agent: EnvAgent, action: RailEnvActions):
"""
......@@ -722,13 +797,35 @@ class RailEnv(Environment):
return new_direction, transition_valid
def _get_observations(self):
"""
Utility which returns the observations for an agent with respect to environment
Returns
------
Dict object
"""
self.obs_dict = self.obs_builder.get_many(list(range(self.get_num_agents())))
return self.obs_dict
def get_valid_directions_on_grid(self, row: int, col: int) -> List[int]:
"""
Returns directions in which the agent can move
Parameters:
---------
row : int
col : int
Returns:
-------
List[int]
"""
return Grid4Transitions.get_entry_directions(self.rail.get_full_transitions(row, col))
def get_full_state_msg(self):
"""
Returns state of environment in msgpack object
"""
grid_data = self.rail.grid.tolist()
agent_static_data = [agent.to_list() for agent in self.agents_static]
agent_data = [agent.to_list() for agent in self.agents]
......@@ -742,12 +839,22 @@ class RailEnv(Environment):
return msgpack.packb(msg_data, use_bin_type=True)
def get_agent_state_msg(self):
"""
Returns agents information in msgpack object
"""
agent_data = [agent.to_list() for agent in self.agents]
msg_data = {
"agents": agent_data}
return msgpack.packb(msg_data, use_bin_type=True)
def set_full_state_msg(self, msg_data):
"""
Sets environment state with msgdata object passed as argument
Parameters
-------
msg_data: msgpack object
"""
data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8')
self.rail.grid = np.array(data["grid"])
# agents are always reset as not moving
......@@ -760,6 +867,13 @@ class RailEnv(Environment):
self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)
def set_full_state_dist_msg(self, msg_data):
"""
Sets environment grid state and distance map with msgdata object passed as argument
Parameters
-------
msg_data: msgpack object
"""
data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8')
self.rail.grid = np.array(data["grid"])
# agents are always reset as not moving
......@@ -774,6 +888,9 @@ class RailEnv(Environment):
self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)
def get_full_state_dist_msg(self):
"""
Returns environment information with distance map information as msgpack object
"""
grid_data = self.rail.grid.tolist()
agent_static_data = [agent.to_list() for agent in self.agents_static]
agent_data = [agent.to_list() for agent in self.agents]
......@@ -791,6 +908,14 @@ class RailEnv(Environment):
return msgpack.packb(msg_data, use_bin_type=True)
def save(self, filename, save_distance_maps=False):
"""
Saves environment and distance map information in a file
Parameters:
---------
filename: string
save_distance_maps: bool
"""
if save_distance_maps is True:
if self.distance_map.get() is not None:
if len(self.distance_map.get()) > 0:
......@@ -807,14 +932,31 @@ class RailEnv(Environment):
file_out.write(self.get_full_state_msg())
def load(self, filename):
"""
Load environment with distance map from a file
Parameters:
-------
filename: string
"""
with open(filename, "rb") as file_in:
load_data = file_in.read()
self.set_full_state_dist_msg(load_data)
def load_pkl(self, pkl_data):
"""
Load environment with distance map from a pickle file
Parameters:
-------
pkl_data: pickle file
"""
self.set_full_state_msg(pkl_data)
def load_resource(self, package, resource):
"""
Load environment with distance map from a binary
"""
from importlib_resources import read_binary
load_data = read_binary(package, resource)
self.set_full_state_msg(load_data)
......
......@@ -343,6 +343,17 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11, seed=1) -> R
template = [template[-1]] + template[:-1]
def get_matching_templates(template):
"""
Returns a list of possible transition maps for a given template
Parameters:
------
template:List[int]
Returns:
------
List[int]
"""
ret = []
for i in range(len(transitions_templates_)):
is_match = True
......
%% Cell type:markdown id: tags:
### Example 1 - generate a rail from a manual specification
From a map of tuples (cell_type, rotation)
%% Cell type:code id: tags:
``` python
from flatland.envs.rail_generators import rail_from_manual_specifications_generator
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.utils.rendertools import RenderTool
from PIL import Image
```
%% Output
---------------------------------------------------------------------------
SystemError Traceback (most recent call last)
<ipython-input-2-b6a25a9cfbbb> in <module>
----> 1 from ..flatland.envs.rail_generators import rail_from_manual_specifications_generator
2 from flatland.envs.observations import TreeObsForRailEnv
3 from flatland.envs.rail_env import RailEnv
4 from flatland.utils.rendertools import RenderTool
5 from PIL import Image
SystemError: Parent module '' not loaded, cannot perform relative import
%% Cell type:code id: tags:
``` python
specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0)],
[(0, 0), (0, 0), (0, 0), (0, 0), (7, 0), (0, 0)],
[(7, 270), (1, 90), (1, 90), (1, 90), (2, 90), (7, 90)],
[(0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0)]]
env = RailEnv(width=6,
height=4,
rail_generator=rail_from_manual_specifications_generator(specs),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2))
env.reset()
env_renderer = RenderTool(env, gl="PILSVG")
env_renderer.render_env(show=False)
```
%% Cell type:code id: tags:
``` python
Image.fromarray(env_renderer.gl.get_image())
```
%% Output
<PIL.Image.Image image mode=RGBA size=718x480 at 0x14DD8FD52E8>
......
......@@ -88,7 +88,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
"version": "3.5.2"
},
"latex_envs": {
"LaTeX_envs_menu_present": true,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment