Skip to content
Snippets Groups Projects
Commit 97a27559 authored by spiglerg's avatar spiglerg
Browse files

part3 of getting started, new custom_obs and custom_rail examples, fixes to GlobalObs

parent 441f7493
No related branches found
No related tags found
No related merge requests found
import random
from flatland.envs.generators import random_rail_generator, random_rail_generator
from flatland.envs.rail_env import RailEnv
from flatland.utils.rendertools import RenderTool
from flatland.core.env_observation_builder import ObservationBuilder
import numpy as np
random.seed(100)
np.random.seed(100)
class CustomObs(ObservationBuilder):
def __init__(self):
self.observation_space = [5]
def reset(self):
return
def get(self, handle):
observation = handle*np.ones((5,))
return observation
env = RailEnv(width=7,
height=7,
rail_generator=random_rail_generator(),
number_of_agents=3,
obs_builder_object=CustomObs())
# Print the observation vector for each agents
obs, all_rewards, done, _ = env.step({0: 0})
for i in range(env.get_num_agents()):
print("Agent ", i,"'s observation: ", obs[i])
import random
from flatland.envs.generators import random_rail_generator, random_rail_generator
from flatland.envs.rail_env import RailEnv
from flatland.core.transitions import RailEnvTransitions
from flatland.core.transition_map import GridTransitionMap
from flatland.utils.rendertools import RenderTool
import numpy as np
random.seed(100)
np.random.seed(100)
def custom_rail_generator():
def generator(width, height, num_agents=0, num_resets=0):
rail_trans = RailEnvTransitions()
grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans)
rail_array = grid_map.grid
rail_array.fill(0)
agents_positions = []
agents_direction = []
agents_target = []
return grid_map, agents_positions, agents_direction, agents_target
return generator
env = RailEnv(width=6,
height=4,
rail_generator=custom_rail_generator(),
number_of_agents=1)
env.reset()
env_renderer = RenderTool(env, gl="QT")
env_renderer.renderEnv(show=True)
input("Press Enter to continue...")
...@@ -3,6 +3,7 @@ import random ...@@ -3,6 +3,7 @@ import random
from flatland.envs.generators import random_rail_generator, random_rail_generator from flatland.envs.generators import random_rail_generator, random_rail_generator
from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env import RailEnv
from flatland.utils.rendertools import RenderTool from flatland.utils.rendertools import RenderTool
from flatland.core.env_observation_builder import ObservationBuilder
import numpy as np import numpy as np
random.seed(100) random.seed(100)
...@@ -11,7 +12,8 @@ np.random.seed(100) ...@@ -11,7 +12,8 @@ np.random.seed(100)
env = RailEnv(width=7, env = RailEnv(width=7,
height=7, height=7,
rail_generator=random_rail_generator(), rail_generator=random_rail_generator(),
number_of_agents=2) number_of_agents=2,
obs_builder_object=TreeObsForRailEnv(max_depth=2))
# Print the distance map of each cell to the target of the first agent # Print the distance map of each cell to the target of the first agent
# for i in range(4): # for i in range(4):
......
...@@ -491,8 +491,14 @@ class GlobalObsForRailEnv(ObservationBuilder): ...@@ -491,8 +491,14 @@ class GlobalObsForRailEnv(ObservationBuilder):
""" """
def __init__(self): def __init__(self):
self.observation_space = ()
super(GlobalObsForRailEnv, self).__init__() super(GlobalObsForRailEnv, self).__init__()
def _set_env(self, env):
super()._set_env(env)
self.observation_space = [4, self.env.height, self.env.width]
def reset(self): def reset(self):
self.rail_obs = np.zeros((self.env.height, self.env.width, 16)) self.rail_obs = np.zeros((self.env.height, self.env.width, 16))
for i in range(self.rail_obs.shape[0]): for i in range(self.rail_obs.shape[0]):
......
...@@ -90,6 +90,9 @@ class RailEnv(Environment): ...@@ -90,6 +90,9 @@ class RailEnv(Environment):
self.obs_builder = obs_builder_object self.obs_builder = obs_builder_object
self.obs_builder._set_env(self) self.obs_builder._set_env(self)
self.action_space = [1]
self.observation_space = self.obs_builder.observation_space # updated on resets?
self.actions = [0] * number_of_agents self.actions = [0] * number_of_agents
self.rewards = [0] * number_of_agents self.rewards = [0] * number_of_agents
self.done = False self.done = False
...@@ -112,10 +115,6 @@ class RailEnv(Environment): ...@@ -112,10 +115,6 @@ class RailEnv(Environment):
self.valid_positions = None self.valid_positions = None
self.action_space = [1]
self.observation_space = self.obs_builder.observation_space # updated on resets?
# no more agent_handles # no more agent_handles
def get_agent_handles(self): def get_agent_handles(self):
return range(self.get_num_agents()) return range(self.get_num_agents())
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment