Skip to content
Snippets Groups Projects
Commit 311b9814 authored by u214892's avatar u214892
Browse files

25 skeleton

parent 6fc2266b
No related branches found
No related tags found
No related merge requests found
...@@ -84,6 +84,21 @@ class Environment: ...@@ -84,6 +84,21 @@ class Environment:
""" """
raise NotImplementedError() raise NotImplementedError()
def predict(self):
"""
Predictions step.
Returns predictions for the agents.
The returns are dicts mapping from agent_id strings to values.
Returns
-------
predictions : dict
New predictions for each ready agent.
"""
raise NotImplementedError()
def render(self): def render(self):
""" """
Perform rendering of the environment. Perform rendering of the environment.
......
...@@ -19,7 +19,6 @@ class ObservationBuilder: ...@@ -19,7 +19,6 @@ class ObservationBuilder:
def __init__(self): def __init__(self):
self.observation_space = () self.observation_space = ()
pass
def _set_env(self, env): def _set_env(self, env):
self.env = env self.env = env
......
"""
PredictionBuilder objects are objects that can be passed to environments designed for customizability.
The PredictionBuilder-derived custom classes implement 2 functions, reset() and get([handle]).
If predictions are not required in every step or not for all agents, then
+ Reset() is called after each environment reset, to allow for pre-computing relevant data.
+ Get() is called whenever an step has to be computed, potentially for each agent independently in
case of multi-agent environments.
"""
class PredictionBuilder:
"""
PredictionBuilder base class.
"""
def __init__(self):
pass
def _set_env(self, env):
self.env = env
def reset(self):
"""
Called after each environment reset.
"""
raise NotImplementedError()
def get(self, handle=0):
"""
Called whenever an observation has to be computed for the `env' environment, possibly
for each agent independently (agent id `handle').
Parameters
-------
handle : int (optional)
Handle of the agent for which to compute the observation vector.
Returns
-------
function
An prediction structure, specific to the corresponding environment.
"""
raise NotImplementedError()
"""
Collection of environment-specific PredictionBuilder.
"""
from flatland.core.env_prediction_builder import PredictionBuilder
class DummyPredictorForRailEnv(PredictionBuilder):
"""
DummyPredictorForRailEnv object.
This object returns predictions for agents in the RailEnv environment.
The prediction acts as if no other agent is in the environment and always takes the forward action.
"""
def __init__(self):
pass
def reset(self):
pass
def get(self, handle=0):
return {}
...@@ -51,7 +51,9 @@ class RailEnv(Environment): ...@@ -51,7 +51,9 @@ class RailEnv(Environment):
height, height,
rail_generator=random_rail_generator(), rail_generator=random_rail_generator(),
number_of_agents=1, number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2)): obs_builder_object=TreeObsForRailEnv(max_depth=2),
prediction_builder_object=None
):
""" """
Environment init. Environment init.
...@@ -94,6 +96,11 @@ class RailEnv(Environment): ...@@ -94,6 +96,11 @@ class RailEnv(Environment):
self.obs_builder = obs_builder_object self.obs_builder = obs_builder_object
self.obs_builder._set_env(self) self.obs_builder._set_env(self)
self.prediction_builder = prediction_builder_object
if self.prediction_builder:
self.prediction_builder._set_env(self)
self.action_space = [1] self.action_space = [1]
self.observation_space = self.obs_builder.observation_space # updated on resets? self.observation_space = self.obs_builder.observation_space # updated on resets?
...@@ -296,6 +303,12 @@ class RailEnv(Environment): ...@@ -296,6 +303,12 @@ class RailEnv(Environment):
self.actions = [0] * self.get_num_agents() self.actions = [0] * self.get_num_agents()
return self._get_observations(), self.rewards_dict, self.dones, {} return self._get_observations(), self.rewards_dict, self.dones, {}
def predict(self):
if not self.prediction_builder:
return {}
return self.prediction_builder.get()
def check_action(self, agent, action): def check_action(self, agent, action):
transition_isValid = None transition_isValid = None
possible_transitions = self.rail.get_transitions((*agent.position, agent.direction)) possible_transitions = self.rail.get_transitions((*agent.position, agent.direction))
...@@ -332,6 +345,11 @@ class RailEnv(Environment): ...@@ -332,6 +345,11 @@ class RailEnv(Environment):
self.obs_dict[iAgent] = self.obs_builder.get(iAgent) self.obs_dict[iAgent] = self.obs_builder.get(iAgent)
return self.obs_dict return self.obs_dict
def _get_predictions(self):
if not self.prediction_builder:
return {}
return {}
def render(self): def render(self):
# TODO: # TODO:
pass pass
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import numpy as np
from flatland.envs.observations import GlobalObsForRailEnv
from flatland.core.transition_map import GridTransitionMap, Grid4Transitions
from flatland.envs.predictions import DummyPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.generators import rail_from_GridTransitionMap_generator
"""Tests for `flatland` package."""
def test_predictions():
# We instantiate a very simple rail network on a 7x10 grid:
# |
# |
# |
# _ _ _ /_\ _ _ _ _ _ _
# \ /
# |
# |
# |
cells = [int('0000000000000000', 2), # empty cell - Case 0
int('1000000000100000', 2), # Case 1 - straight
int('1001001000100000', 2), # Case 2 - simple switch
int('1000010000100001', 2), # Case 3 - diamond drossing
int('1001011000100001', 2), # Case 4 - single slip switch
int('1100110000110011', 2), # Case 5 - double slip switch
int('0101001000000010', 2), # Case 6 - symmetrical switch
int('0010000000000000', 2)] # Case 7 - dead end
transitions = Grid4Transitions([])
empty = cells[0]
dead_end_from_south = cells[7]
dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180)
dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270)
vertical_straight = cells[1]
horizontal_straight = transitions.rotate_transition(vertical_straight, 90)
double_switch_south_horizontal_straight = horizontal_straight + cells[6]
double_switch_north_horizontal_straight = transitions.rotate_transition(
double_switch_south_horizontal_straight, 180)
rail_map = np.array(
[[empty] * 3 + [dead_end_from_south] + [empty] * 6] +
[[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 +
[[dead_end_from_east] + [horizontal_straight] * 2 +
[double_switch_north_horizontal_straight] +
[horizontal_straight] * 2 + [double_switch_south_horizontal_straight] +
[horizontal_straight] * 2 + [dead_end_from_west]] +
[[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 +
[[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16)
rail = GridTransitionMap(width=rail_map.shape[1],
height=rail_map.shape[0], transitions=transitions)
rail.grid = rail_map
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_GridTransitionMap_generator(rail),
number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv(),
prediction_builder_object=DummyPredictorForRailEnv()
)
env.reset()
predictions = env.predict()
def main():
test_predictions()
if __name__ == "__main__":
main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment