Skip to content
Snippets Groups Projects
Commit 50973768 authored by u214892's avatar u214892
Browse files

Merge branch 'master' of gitlab.aicrowd.com:flatland/flatland into 80-specs

parents dab3d244 0a5dcd7d
No related branches found
No related tags found
No related merge requests found
......@@ -23,7 +23,7 @@ class SimpleObs(ObservationBuilder):
def reset(self):
return
def get(self, handle: int = 0):
def get(self, handle: int = 0) -> np.ndarray:
observation = handle * np.ones((5,))
return observation
......
......@@ -2,6 +2,7 @@ import getopt
import random
import sys
import time
from typing import List
import numpy as np
......@@ -36,7 +37,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv):
# Recompute the distance map, if the environment has changed.
super().reset()
def get(self, handle: int = 0):
def get(self, handle: int = 0) -> List[int]:
agent = self.env.agents[handle]
possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
......
......@@ -2,7 +2,7 @@ import getopt
import random
import sys
import time
from typing import Optional, List
from typing import Optional, List, Dict
import numpy as np
......@@ -41,7 +41,7 @@ class ObservePredictions(TreeObsForRailEnv):
# Recompute the distance map, if the environment has changed.
super().reset()
def get_many(self, handles: Optional[List[int]] = None):
def get_many(self, handles: Optional[List[int]] = None) -> Dict[int, np.ndarray]:
'''
Because we do not want to call the predictor seperately for every agent we implement the get_many function
Here we can call the predictor just ones for all the agents and use the predictions to generate our observations
......@@ -69,7 +69,7 @@ class ObservePredictions(TreeObsForRailEnv):
observations[h] = self.get(h)
return observations
def get(self, handle: int = 0):
def get(self, handle: int = 0) -> np.ndarray:
'''
Lets write a simple observation which just indicates whether or not the own predicted path
overlaps with other predicted paths at any time. This is useless for the task of navigation but might
......
import random
import time
from typing import List
import numpy as np
......@@ -33,7 +34,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv):
# Recompute the distance map, if the environment has changed.
super().reset()
def get(self, handle: int = 0):
def get(self, handle: int = 0) -> List[int]:
agent = self.env.agents[handle]
possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
......
......@@ -97,8 +97,8 @@ class DummyObservationBuilder(ObservationBuilder):
def reset(self):
pass
def get_many(self, handles: Optional[List[int]] = None):
def get_many(self, handles: Optional[List[int]] = None) -> bool:
return True
def get(self, handle: int = 0):
def get(self, handle: int = 0) -> bool:
return True
......@@ -2,7 +2,7 @@
Collection of environment-specific ObservationBuilder.
"""
import pprint
from typing import Optional, List
from typing import Optional, List, Dict, T, Tuple
import numpy as np
......@@ -46,7 +46,7 @@ class TreeObsForRailEnv(ObservationBuilder):
def reset(self):
self.location_has_target = {tuple(agent.target): 1 for agent in self.env.agents}
def get_many(self, handles: Optional[List[int]] = None):
def get_many(self, handles: Optional[List[int]] = None) -> Dict[int, List[int]]:
"""
Called whenever an observation has to be computed for the `env` environment, for each agent with handle
in the `handles` list.
......@@ -75,7 +75,7 @@ class TreeObsForRailEnv(ObservationBuilder):
observations[h] = self.get(h)
return observations
def get(self, handle: int = 0):
def get(self, handle: int = 0) -> List[int]:
"""
Computes the current observation for agent `handle` in env
......@@ -534,7 +534,7 @@ class GlobalObsForRailEnv(ObservationBuilder):
bitlist = [0] * (16 - len(bitlist)) + bitlist
self.rail_obs[i, j] = np.array(bitlist)
def get(self, handle: int = 0):
def get(self, handle: int = 0) -> (np.ndarray, np.ndarray, np.ndarray):
obs_targets = np.zeros((self.env.height, self.env.width, 2))
obs_agents_state = np.zeros((self.env.height, self.env.width, 4))
agents = self.env.agents
......@@ -600,7 +600,7 @@ class LocalObsForRailEnv(ObservationBuilder):
bitlist = [0] * (16 - len(bitlist)) + bitlist
self.rail_obs[i, j] = np.array(bitlist)
def get(self, handle: int = 0):
def get(self, handle: int = 0) -> (np.ndarray, np.ndarray, np.ndarray, np.ndarray):
agents = self.env.agents
agent = agents[handle]
......@@ -640,7 +640,7 @@ class LocalObsForRailEnv(ObservationBuilder):
direction = np.identity(4)[agent.direction]
return local_rail_obs, obs_map_state, obs_other_agents_state, direction
def get_many(self, handles: Optional[List[int]] = None):
def get_many(self, handles: Optional[List[int]] = None) -> Dict[int, Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]]:
"""
Called whenever an observation has to be computed for the `env` environment, for each agent with handle
in the `handles` list.
......
import random
from typing import Dict
from typing import Dict, List
import numpy as np
......@@ -31,7 +31,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv):
# Recompute the distance map, if the environment has changed.
super().reset()
def get(self, handle: int = 0):
def get(self, handle: int = 0) -> List[int]:
agent = self.env.agents[handle]
possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment