Newer
Older
"""
The env module defines the base Environment class.
The base Environment class is adapted from rllib.env.MultiAgentEnv
(https://github.com/ray-project/ray).
"""
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
class Environment:
"""Base interface for multi-agent environments in Flatland.
Agents are identified by agent ids (handles).
Examples:
>>> obs = env.reset()
>>> print(obs)
{
"train_0": [2.4, 1.6],
"train_1": [3.4, -3.2],
}
>>> obs, rewards, dones, infos = env.step(
action_dict={
"train_0": 1, "train_1": 0})
>>> print(rewards)
{
"train_0": 3,
"train_1": -1,
}
>>> print(dones)
{
"train_0": False, # train_0 is still running
"train_1": True, # train_1 is done
"__all__": False, # the env is not done
}
>>> print(infos)
{
"train_0": {}, # info for train_0
"train_1": {}, # info for train_1
}
"""
def __init__(self):
pass
def reset(self):
"""
Resets the env and returns observations from agents in the environment.
Returns:
obs : dict
New observations for each agent.
"""
raise NotImplementedError()
def step(self, action_dict):
"""
Performs an environment step with simultaneous execution of actions for
agents in action_dict.
Returns observations from agents in the environment.
The returns are dicts mapping from agent_id strings to values.
Parameters
-------
action_dict : dict
Dictionary of actions to execute, indexed by agent id.
Returns
-------
obs : dict
New observations for each ready agent.
rewards: dict
Reward values for each ready agent.
dones : dict
Done values for each ready agent. The special key "__all__"
(required) is used to indicate env termination.
infos : dict
Optional info values for each agent id.
"""
raise NotImplementedError()
def render(self):
"""
Perform rendering of the environment.
"""
raise NotImplementedError()
def get_agent_handles(self):
"""
Returns a list of agents' handles to be used as keys in the step()
function.
"""
raise NotImplementedError()