Skip to content
Snippets Groups Projects
Commit 41856d28 authored by u214892's avatar u214892
Browse files

Merge branch 'master' of gitlab.aicrowd.com:flatland/flatland into cleanup-continuous-integration

parents b2eab862 229ddb8c
No related branches found
No related tags found
No related merge requests found
......@@ -84,6 +84,21 @@ class Environment:
"""
raise NotImplementedError()
def predict(self):
"""
Predictions step.
Returns predictions for the agents.
The returns are dicts mapping from agent_id strings to values.
Returns
-------
predictions : dict
New predictions for each ready agent.
"""
raise NotImplementedError()
def render(self):
"""
Perform rendering of the environment.
......
......@@ -19,7 +19,6 @@ class ObservationBuilder:
def __init__(self):
self.observation_space = ()
pass
def _set_env(self, env):
self.env = env
......
"""
PredictionBuilder objects are objects that can be passed to environments designed for customizability.
The PredictionBuilder-derived custom classes implement 2 functions, reset() and get([handle]).
If predictions are not required in every step or not for all agents, then
+ Reset() is called after each environment reset, to allow for pre-computing relevant data.
+ Get() is called whenever an step has to be computed, potentially for each agent independently in
case of multi-agent environments.
"""
class PredictionBuilder:
"""
PredictionBuilder base class.
"""
def __init__(self, max_depth: int = 20):
self.max_depth = max_depth
def _set_env(self, env):
self.env = env
def reset(self):
"""
Called after each environment reset.
"""
pass
def get(self, handle=0):
"""
Called whenever step_prediction is called on the environment.
Parameters
-------
handle : int (optional)
Handle of the agent for which to compute the observation vector.
Returns
-------
function
A prediction structure, specific to the corresponding environment.
"""
raise NotImplementedError()
"""
Collection of environment-specific PredictionBuilder.
"""
import numpy as np
from flatland.core.env_prediction_builder import PredictionBuilder
class DummyPredictorForRailEnv(PredictionBuilder):
"""
DummyPredictorForRailEnv object.
This object returns predictions for agents in the RailEnv environment.
The prediction acts as if no other agent is in the environment and always takes the forward action.
"""
def get(self, handle=None):
"""
Called whenever step_prediction is called on the environment.
Parameters
-------
handle : int (optional)
Handle of the agent for which to compute the observation vector.
Returns
-------
function
Returns a dictionary index by the agent handle and for each agent a vector of 5 elements:
- time_offset
- position axis 0
- position axis 1
- direction
- action taken to come here
"""
agents = self.env.agents
if handle:
agents = [self.env.agents[handle]]
prediction_dict = {}
for agent in agents:
# 0: do nothing
# 1: turn left and move to the next cell
# 2: move to the next cell in front of the agent
# 3: turn right and move to the next cell
action_priorities = [2, 1, 3]
_agent_initial_position = agent.position
_agent_initial_direction = agent.direction
prediction = np.zeros(shape=(self.max_depth, 5))
prediction[0] = [0, _agent_initial_position[0], _agent_initial_position[1], _agent_initial_direction, 0]
for index in range(1, self.max_depth):
action_done = False
for action in action_priorities:
cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = self.env._check_action_on_agent(action,
agent)
if all([new_cell_isValid, transition_isValid]):
# move and change direction to face the new_direction that was
# performed
agent.position = new_position
agent.direction = new_direction
prediction[index] = [index, new_position[0], new_position[1], new_direction, action]
action_done = True
break
if not action_done:
print("Cannot move further.")
prediction_dict[agent.handle] = prediction
agent.position = _agent_initial_position
agent.direction = _agent_initial_direction
return prediction_dict
......@@ -51,7 +51,9 @@ class RailEnv(Environment):
height,
rail_generator=random_rail_generator(),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2)):
obs_builder_object=TreeObsForRailEnv(max_depth=2),
prediction_builder_object=None
):
"""
Environment init.
......@@ -94,6 +96,11 @@ class RailEnv(Environment):
self.obs_builder = obs_builder_object
self.obs_builder._set_env(self)
self.prediction_builder = prediction_builder_object
if self.prediction_builder:
self.prediction_builder._set_env(self)
self.action_space = [1]
self.observation_space = self.obs_builder.observation_space # updated on resets?
......@@ -212,52 +219,8 @@ class RailEnv(Environment):
return
if action > 0:
# pos = agent.position # self.agents_position[i]
# direction = agent.direction # self.agents_direction[i]
# compute number of possible transitions in the current
# cell used to check for invalid actions
new_direction, transition_isValid = self.check_action(agent, action)
new_position = get_new_position(agent.position, new_direction)
# Is it a legal move?
# 1) transition allows the new_direction in the cell,
# 2) the new cell is not empty (case 0),
# 3) the cell is free, i.e., no agent is currently in that cell
# if (
# new_position[1] >= self.width or
# new_position[0] >= self.height or
# new_position[0] < 0 or new_position[1] < 0):
# new_cell_isValid = False
# if self.rail.get_transitions(new_position) == 0:
# new_cell_isValid = False
new_cell_isValid = (
np.array_equal( # Check the new position is still in the grid
new_position,
np.clip(new_position, [0, 0], [self.height - 1, self.width - 1]))
and # check the new position has some transitions (ie is not an empty cell)
self.rail.get_transitions(new_position) > 0)
# If transition validity hasn't been checked yet.
if transition_isValid is None:
transition_isValid = self.rail.get_transition(
(*agent.position, agent.direction),
new_direction)
# cell_isFree = True
# for j in range(self.number_of_agents):
# if self.agents_position[j] == new_position:
# cell_isFree = False
# break
# Check the new position is not the same as any of the existing agent positions
# (including itself, for simplicity, since it is moving)
cell_isFree = not np.any(
np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1))
cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = self._check_action_on_agent(action,
agent)
if all([new_cell_isValid, transition_isValid, cell_isFree]):
# move and change direction to face the new_direction that was
# performed
......@@ -296,6 +259,52 @@ class RailEnv(Environment):
self.actions = [0] * self.get_num_agents()
return self._get_observations(), self.rewards_dict, self.dones, {}
def _check_action_on_agent(self, action, agent):
# pos = agent.position # self.agents_position[i]
# direction = agent.direction # self.agents_direction[i]
# compute number of possible transitions in the current
# cell used to check for invalid actions
new_direction, transition_isValid = self.check_action(agent, action)
new_position = get_new_position(agent.position, new_direction)
# Is it a legal move?
# 1) transition allows the new_direction in the cell,
# 2) the new cell is not empty (case 0),
# 3) the cell is free, i.e., no agent is currently in that cell
# if (
# new_position[1] >= self.width or
# new_position[0] >= self.height or
# new_position[0] < 0 or new_position[1] < 0):
# new_cell_isValid = False
# if self.rail.get_transitions(new_position) == 0:
# new_cell_isValid = False
new_cell_isValid = (
np.array_equal( # Check the new position is still in the grid
new_position,
np.clip(new_position, [0, 0], [self.height - 1, self.width - 1]))
and # check the new position has some transitions (ie is not an empty cell)
self.rail.get_transitions(new_position) > 0)
# If transition validity hasn't been checked yet.
if transition_isValid is None:
transition_isValid = self.rail.get_transition(
(*agent.position, agent.direction),
new_direction)
# cell_isFree = True
# for j in range(self.number_of_agents):
# if self.agents_position[j] == new_position:
# cell_isFree = False
# break
# Check the new position is not the same as any of the existing agent positions
# (including itself, for simplicity, since it is moving)
cell_isFree = not np.any(
np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1))
return cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid
def predict(self):
if not self.prediction_builder:
return {}
return self.prediction_builder.get()
def check_action(self, agent, action):
transition_isValid = None
possible_transitions = self.rail.get_transitions((*agent.position, agent.direction))
......@@ -332,6 +341,11 @@ class RailEnv(Environment):
self.obs_dict[iAgent] = self.obs_builder.get(iAgent)
return self.obs_dict
def _get_predictions(self):
if not self.prediction_builder:
return {}
return {}
def render(self):
# TODO:
pass
......
......@@ -394,9 +394,9 @@ class PILSVG(PILGL):
print("Illegal target rail:", row, col, format(binTrans, "#018b")[2:])
if isSelected:
svgBG = self.pilFromSvgFile("./svg/Selected_Agent.svg")
self.clear_layer(3, 0)
self.drawImageRC(svgBG, (row, col), layer=3)
svgBG = self.pilFromSvgFile("./svg/Selected_Target.svg")
self.clear_layer(3,0)
self.drawImageRC(svgBG,(row,col),layer=3)
def recolorImage(self, pil, a3BaseColor, ltColors):
rgbaImg = array(pil)
......
%% Cell type:markdown id: tags:
# Rail Editor v0.2
%% Cell type:code id: tags:
``` python
%load_ext autoreload
%autoreload 2
```
%% Output
The autoreload extension is already loaded. To reload it, use:
%reload_ext autoreload
%% Cell type:code id: tags:
``` python
import numpy as np
from numpy import array
import ipywidgets
import IPython
from IPython.core.display import display, HTML
```
%% Cell type:code id: tags:
``` python
display(HTML("<style>.container { width:95% !important; }</style>"))
```
%% Output
%% Cell type:code id: tags:
``` python
from flatland.utils.editor import EditorMVC, EditorModel, View, Controller
```
%% Output
cairo installed: OK
%% Cell type:code id: tags:
``` python
mvc = EditorMVC(sGL="PILSVG" ) # sGL="PIL")
```
%% Output
<flatland.utils.graphics_pil.PILSVG object at 0x0000022C5FB44198> <class 'flatland.utils.graphics_pil.PILSVG'>
<flatland.utils.graphics_pil.PILSVG object at 0x000001FC6FB9E198> <class 'flatland.utils.graphics_pil.PILSVG'>
<super: <class 'PILSVG'>, <PILSVG object>> <class 'super'>
Clear rails
%% Cell type:markdown id: tags:
## Instructions
- Drag to draw (improved dead-ends)
- Shift-Drag to erase rails (erasing dead ends not yet automated - drag right across them)
- ctrl-click to add agent
- direction chosen randomly to fit rail
- ctrl-shift-click to add target for last agent
- target can be moved by repeating
- to Resize the env (cannot preserve work):
- select "Regen" tab, set regen size slider, click regenerate.
- alt-click remove all rails from cell
%% Cell type:code id: tags:
``` python
mvc.view.display()
```
%% Output
<flatland.utils.graphics_pil.PILSVG object at 0x0000022C6066EC50> <class 'flatland.utils.graphics_pil.PILSVG'>
<flatland.utils.graphics_pil.PILSVG object at 0x000001FC6FBB7FD0> <class 'flatland.utils.graphics_pil.PILSVG'>
<super: <class 'PILSVG'>, <PILSVG object>> <class 'super'>
<flatland.utils.graphics_pil.PILSVG object at 0x000001FC6FA8C5C0> <class 'flatland.utils.graphics_pil.PILSVG'>
<super: <class 'PILSVG'>, <PILSVG object>> <class 'super'>
<flatland.utils.graphics_pil.PILSVG object at 0x000001FC73AF2908> <class 'flatland.utils.graphics_pil.PILSVG'>
<super: <class 'PILSVG'>, <PILSVG object>> <class 'super'>
%% Cell type:code id: tags:
``` python
mvc.view.wOutput.clear_output()
mvc.view.wOutput
```
%% Output
%% Cell type:code id: tags:
``` python
len(mvc.editor.env.agents), len(mvc.editor.env.agents_static)
```
%% Output
(0, 0)
......
......@@ -2,45 +2,45 @@
<!-- Generator: Adobe Illustrator 23.0.3, SVG Export Plug-In . SVG Version: 6.00 Build 0) -->
<svg
xmlns:dc="http://purl.org/dc/elements/1.1/"
xmlns:cc="http://creativecommons.org/ns#"
xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"
xmlns:svg="http://www.w3.org/2000/svg"
xmlns="http://www.w3.org/2000/svg"
xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
version="1.1"
id="Ebene_1"
x="0px"
y="0px"
viewBox="0 0 240 240"
style="enable-background:new 0 0 240 240;"
xml:space="preserve"
sodipodi:docname="Selected_Agent.svg"
inkscape:version="0.92.4 (5da689c313, 2019-01-14)"><metadata
xmlns:dc="http://purl.org/dc/elements/1.1/"
xmlns:cc="http://creativecommons.org/ns#"
xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"
xmlns="http://www.w3.org/2000/svg"
xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
version="1.1"
id="Ebene_1"
x="0px"
y="0px"
viewBox="0 0 240 240"
style="enable-background:new 0 0 240 240;"
xml:space="preserve"
sodipodi:docname="Selected_Agent.svg"
inkscape:version="0.92.4 (5da689c313, 2019-01-14)"><metadata
id="metadata11"><rdf:RDF><cc:Work
rdf:about=""><dc:format>image/svg+xml</dc:format><dc:type
rdf:resource="http://purl.org/dc/dcmitype/StillImage" /><dc:title></dc:title></cc:Work></rdf:RDF></metadata><defs
rdf:resource="http://purl.org/dc/dcmitype/StillImage"/><dc:title/></cc:Work></rdf:RDF></metadata>
<defs
id="defs9" /><sodipodi:namedview
pagecolor="#ffffff"
bordercolor="#666666"
borderopacity="1"
objecttolerance="10"
gridtolerance="10"
guidetolerance="10"
inkscape:pageopacity="0"
inkscape:pageshadow="2"
inkscape:window-width="1920"
inkscape:window-height="1137"
id="namedview7"
showgrid="false"
inkscape:zoom="2.7812867"
inkscape:cx="205.50339"
inkscape:cy="161.549"
inkscape:window-x="-8"
inkscape:window-y="-8"
inkscape:window-maximized="1"
inkscape:current-layer="Ebene_1" />
pagecolor="#ffffff"
bordercolor="#666666"
borderopacity="1"
objecttolerance="10"
gridtolerance="10"
guidetolerance="10"
inkscape:pageopacity="0"
inkscape:pageshadow="2"
inkscape:window-width="1920"
inkscape:window-height="1137"
id="namedview7"
showgrid="false"
inkscape:zoom="2.7812867"
inkscape:cx="126.94263"
inkscape:cy="161.549"
inkscape:window-x="-8"
inkscape:window-y="-8"
inkscape:window-maximized="1"
inkscape:current-layer="Ebene_1" />
<style
type="text/css"
id="style2">
......@@ -48,45 +48,60 @@
</style>
<rect
id="rect13"
width="23.389822"
height="23.38983"
x="1.697217e-07"
y="-0.23616901" /><rect
id="rect13-0"
width="23.389822"
height="23.38983"
x="216.82077"
y="0.26172119" /><rect
id="rect13-0-0"
width="23.389822"
height="23.38983"
x="216.75911"
y="216.39955" /><rect
id="rect13-0-0-4"
width="23.389822"
height="23.38983"
x="0.50847793"
y="216.6974" /><rect
id="rect60"
width="2.5168207"
height="198.4693"
x="10.067283"
y="22.474777" /><rect
id="rect60-8"
width="2.5168207"
height="198.4693"
x="228.49136"
y="22.115229" /><rect
id="rect60-8-5"
width="2.5168207"
height="198.4693"
x="-11.868174"
y="19.775019"
transform="rotate(-90)" /><rect
id="rect60-8-5-1"
width="2.5168207"
height="198.4693"
x="-230.11249"
y="20.853657"
transform="rotate(-90)" /></svg>
\ No newline at end of file
id="rect13"
width="23.389822"
height="23.38983"
x="1.697217e-07"
y="-0.23616901"
style="fill:#ff0000"/>
<rect
id="rect13-0"
width="23.389822"
height="23.38983"
x="216.82077"
y="0.26172119"
style="fill:#ff0000"/>
<rect
id="rect13-0-0"
width="23.389822"
height="23.38983"
x="216.75911"
y="216.39955"
style="fill:#ff0000"/>
<rect
id="rect13-0-0-4"
width="23.389822"
height="23.38983"
x="0.50847793"
y="216.6974"
style="fill:#ff0000"/>
<rect
id="rect60"
width="2.5741608"
height="240.53616"
x="1.697217e-07"
y="-0.23616901"
style="stroke-width:1.11335897;fill:#ff0000"/>
<rect
id="rect60-8"
width="2.5741608"
height="239.45752"
x="237.63643"
y="0.26172119"
style="stroke-width:1.11085987;fill:#ff0000"/>
<rect
id="rect60-8-5"
width="2.5168207"
height="237.86693"
x="-2.778542"
y="2.3436601"
transform="rotate(-90)"
style="stroke-width:1.09476364;fill:#ff0000"/>
<rect
id="rect60-8-5-1"
width="2.5168207"
height="240.0242"
x="-240.29999"
y="1.6972172e-07"
transform="rotate(-90)"
style="stroke-width:1.09971678;fill:#ff0000"/></svg>
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<!-- Generator: Adobe Illustrator 23.0.3, SVG Export Plug-In . SVG Version: 6.00 Build 0) -->
<svg
xmlns:dc="http://purl.org/dc/elements/1.1/"
xmlns:cc="http://creativecommons.org/ns#"
xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"
xmlns="http://www.w3.org/2000/svg"
xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
version="1.1"
id="Ebene_1"
x="0px"
y="0px"
viewBox="0 0 240 240"
style="enable-background:new 0 0 240 240;"
xml:space="preserve"
sodipodi:docname="Selected_Target.svg"
inkscape:version="0.92.4 (5da689c313, 2019-01-14)"><metadata
id="metadata11"><rdf:RDF><cc:Work
rdf:about=""><dc:format>image/svg+xml</dc:format><dc:type
rdf:resource="http://purl.org/dc/dcmitype/StillImage"/><dc:title/></cc:Work></rdf:RDF></metadata>
<defs
id="defs9" /><sodipodi:namedview
pagecolor="#ffffff"
bordercolor="#666666"
borderopacity="1"
objecttolerance="10"
gridtolerance="10"
guidetolerance="10"
inkscape:pageopacity="0"
inkscape:pageshadow="2"
inkscape:window-width="1920"
inkscape:window-height="1137"
id="namedview7"
showgrid="false"
inkscape:zoom="2.7812867"
inkscape:cx="48.381869"
inkscape:cy="161.549"
inkscape:window-x="-8"
inkscape:window-y="-8"
inkscape:window-maximized="1"
inkscape:current-layer="Ebene_1" />
<style
type="text/css"
id="style2">
.st0{fill:#9CCB89;}
</style>
<rect
id="rect60"
width="13.303195"
height="220.04205"
x="10.067283"
y="9.8906736"
style="fill:#ff0000;stroke-width:2.42079496"/>
<rect
id="rect60-8"
width="14.741379"
height="220.40161"
x="215.72748"
y="9.531127"
style="fill:#ff0000;stroke-width:2.5503726"/>
<rect
id="rect60-8-5"
width="12.584104"
height="220.04205"
x="-22.115231"
y="10.4268"
transform="rotate(-90)"
style="fill:#ff0000;stroke-width:2.35445929"/>
<rect
id="rect60-8-5-1"
width="13.662741"
height="220.40158"
x="216.26999"
y="10.067283"
transform="matrix(0,1,1,0,0,0)"
style="fill:#ff0000;stroke-width:2.45529389"/></svg>
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import numpy as np
from flatland.core.transition_map import GridTransitionMap, Grid4Transitions
from flatland.envs.generators import rail_from_GridTransitionMap_generator
from flatland.envs.observations import GlobalObsForRailEnv
from flatland.envs.predictions import DummyPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
"""Test predictions for `flatland` package."""
def test_predictions():
# We instantiate a very simple rail network on a 7x10 grid:
# |
# |
# |
# _ _ _ /_\ _ _ _ _ _ _
# \ /
# |
# |
# |
cells = [int('0000000000000000', 2), # empty cell - Case 0
int('1000000000100000', 2), # Case 1 - straight
int('1001001000100000', 2), # Case 2 - simple switch
int('1000010000100001', 2), # Case 3 - diamond drossing
int('1001011000100001', 2), # Case 4 - single slip switch
int('1100110000110011', 2), # Case 5 - double slip switch
int('0101001000000010', 2), # Case 6 - symmetrical switch
int('0010000000000000', 2)] # Case 7 - dead end
transitions = Grid4Transitions([])
empty = cells[0]
dead_end_from_south = cells[7]
dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180)
dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270)
vertical_straight = cells[1]
horizontal_straight = transitions.rotate_transition(vertical_straight, 90)
double_switch_south_horizontal_straight = horizontal_straight + cells[6]
double_switch_north_horizontal_straight = transitions.rotate_transition(
double_switch_south_horizontal_straight, 180)
rail_map = np.array(
[[empty] * 3 + [dead_end_from_south] + [empty] * 6] +
[[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 +
[[dead_end_from_east] + [horizontal_straight] * 2 +
[double_switch_north_horizontal_straight] +
[horizontal_straight] * 2 + [double_switch_south_horizontal_straight] +
[horizontal_straight] * 2 + [dead_end_from_west]] +
[[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 +
[[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16)
rail = GridTransitionMap(width=rail_map.shape[1],
height=rail_map.shape[0], transitions=transitions)
rail.grid = rail_map
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_GridTransitionMap_generator(rail),
number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv(),
prediction_builder_object=DummyPredictorForRailEnv(max_depth=20)
)
env.reset()
# set initial position and direction for testing...
env.agents[0].position = (5, 6)
env.agents[0].direction = 0
predictions = env.predict()
positions = np.array(list(map(lambda prediction: [prediction[1], prediction[2]], predictions[0])))
directions = np.array(list(map(lambda prediction: [prediction[3]], predictions[0])))
time_offsets = np.array(list(map(lambda prediction: [prediction[0]], predictions[0])))
actions = np.array(list(map(lambda prediction: [prediction[4]], predictions[0])))
# compare against expected values
expected_positions = np.array([[5., 6.],
[4., 6.],
[3., 6.],
[3., 5.],
[3., 4.],
[3., 3.],
[3., 2.],
[3., 1.],
[3., 0.],
[3., 1.],
[3., 2.],
[3., 3.],
[3., 4.],
[3., 5.],
[3., 6.],
[3., 7.],
[3., 8.],
[3., 9.],
[3., 8.],
[3., 7.]])
expected_directions = np.array([[0.],
[0.],
[0.],
[3.],
[3.],
[3.],
[3.],
[3.],
[3.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[3.],
[3.]])
expected_time_offsets = np.array([[0.],
[1.],
[2.],
[3.],
[4.],
[5.],
[6.],
[7.],
[8.],
[9.],
[10.],
[11.],
[12.],
[13.],
[14.],
[15.],
[16.],
[17.],
[18.],
[19.]])
expected_actions = np.array([[0.],
[2.],
[2.],
[1.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.]])
assert np.array_equal(positions, expected_positions)
assert np.array_equal(directions, expected_directions)
assert np.array_equal(time_offsets, expected_time_offsets)
assert np.array_equal(actions, expected_actions)
def main():
test_predictions()
if __name__ == "__main__":
main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment