Skip to content
Snippets Groups Projects
Commit d1c6a975 authored by Egli Adrian (IT-SCI-API-PFI)'s avatar Egli Adrian (IT-SCI-API-PFI)
Browse files

editor support agent rotate

parent e47d3f5f
No related branches found
No related tags found
No related merge requests found
No preview for this file type
No preview for this file type
No preview for this file type
...@@ -6,7 +6,6 @@ import time ...@@ -6,7 +6,6 @@ import time
import numpy as np import numpy as np
import torch import torch
from flatland.baselines.dueling_double_dqn import Agent
from flatland.envs.generators import complex_rail_generator from flatland.envs.generators import complex_rail_generator
# from flatland.envs.generators import rail_from_list_of_saved_GridTransitionMap_generator # from flatland.envs.generators import rail_from_list_of_saved_GridTransitionMap_generator
from flatland.envs.generators import random_rail_generator from flatland.envs.generators import random_rail_generator
...@@ -126,13 +125,7 @@ class Demo: ...@@ -126,13 +125,7 @@ class Demo:
def __init__(self, env): def __init__(self, env):
self.env = env self.env = env
self.create_renderer() self.create_renderer()
self.load_agent()
def load_agent(self):
self.state_size = 105 * 2
self.action_size = 4 self.action_size = 4
self.agent = Agent(self.state_size, self.action_size, "FC", 0)
self.agent.qnetwork_local.load_state_dict(torch.load('./flatland/baselines/Nets/avoid_checkpoint15000.pth'))
def create_renderer(self): def create_renderer(self):
self.renderer = RenderTool(self.env, gl="QTSVG") self.renderer = RenderTool(self.env, gl="QTSVG")
...@@ -170,10 +163,12 @@ class Demo: ...@@ -170,10 +163,12 @@ class Demo:
# print(step) # print(step)
# Action # Action
for a in range(self.env.get_num_agents()): for a in range(self.env.get_num_agents()):
action = self.agent.act(agent_obs[a]) action = np.random.choice(self.action_size) #self.agent.act(agent_obs[a])
action_prob[action] += 1 action_prob[action] += 1
action_dict.update({a: action}) action_dict.update({a: action})
print(action_dict)
self.renderer.renderEnv(show=True,action_dict=action_dict) self.renderer.renderEnv(show=True,action_dict=action_dict)
# Environment step # Environment step
...@@ -196,7 +191,7 @@ class Demo: ...@@ -196,7 +191,7 @@ class Demo:
break break
if True: if False:
demo_000 = Demo(Scenario_Generator.generate_random_scenario()) demo_000 = Demo(Scenario_Generator.generate_random_scenario())
demo_000.run_demo() demo_000.run_demo()
demo_000 = None demo_000 = None
......
...@@ -132,6 +132,7 @@ class View(object): ...@@ -132,6 +132,7 @@ class View(object):
dict(name="Clear", method=self.controller.clear, tip="Clear rails and agents"), dict(name="Clear", method=self.controller.clear, tip="Clear rails and agents"),
dict(name="Reset", method=self.controller.reset, dict(name="Reset", method=self.controller.reset,
tip="Standard env reset, including regen rail + agents"), tip="Standard env reset, including regen rail + agents"),
dict(name="Rotate Agent", method=self.controller.rotate_agent, tip="Rotate selected agent"),
dict(name="Restart Agents", method=self.controller.restartAgents, dict(name="Restart Agents", method=self.controller.restartAgents,
tip="Move agents back to start positions"), tip="Move agents back to start positions"),
dict(name="Regenerate", method=self.controller.regenerate, dict(name="Regenerate", method=self.controller.regenerate,
...@@ -325,6 +326,15 @@ class Controller(object): ...@@ -325,6 +326,15 @@ class Controller(object):
self.model.reset(replace_agents=self.view.wReplaceAgents.value, self.model.reset(replace_agents=self.view.wReplaceAgents.value,
nAgents=self.view.wRegenNAgents.value) nAgents=self.view.wRegenNAgents.value)
def rotate_agent(self,event):
self.log("Rotate Agent:", self.model.iSelectedAgent)
if self.model.iSelectedAgent is not None:
for iAgent, agent in enumerate(self.model.env.agents_static):
if agent is None:
continue
agent.direction = (agent.direction + 1) % 4
self.model.redraw()
def restartAgents(self, event): def restartAgents(self, event):
self.log("Restart Agents - nAgents:", self.view.wRegenNAgents.value) self.log("Restart Agents - nAgents:", self.view.wRegenNAgents.value)
self.model.restartAgents() self.model.restartAgents()
...@@ -695,8 +705,7 @@ class EditorModel(object): ...@@ -695,8 +705,7 @@ class EditorModel(object):
# No # No
if self.iSelectedAgent is None: if self.iSelectedAgent is None:
# Create a new agent and select it. # Create a new agent and select it.
## ADRIAN agent_static = EnvAgentStatic(rcCell,0, rcCell)
agent_static = EnvAgentStatic(rcCell, np.random.choice(4), rcCell)
self.iSelectedAgent = self.env.add_agent_static(agent_static) self.iSelectedAgent = self.env.add_agent_static(agent_static)
self.player = None # will need to start a new player self.player = None # will need to start a new player
else: else:
......
...@@ -105,7 +105,7 @@ ...@@ -105,7 +105,7 @@
{ {
"data": { "data": {
"application/vnd.jupyter.widget-view+json": { "application/vnd.jupyter.widget-view+json": {
"model_id": "83640a5c6059421d92e0d69049ad232f", "model_id": "7b66ea9348c9477f881ff27456987363",
"version_major": 2, "version_major": 2,
"version_minor": 0 "version_minor": 0
}, },
...@@ -131,7 +131,7 @@ ...@@ -131,7 +131,7 @@
{ {
"data": { "data": {
"application/vnd.jupyter.widget-view+json": { "application/vnd.jupyter.widget-view+json": {
"model_id": "dc7691bf5f804c5c95604cb551dbb335", "model_id": "ffa0f869fe8a4921a7415384b75c1ded",
"version_major": 2, "version_major": 2,
"version_minor": 0 "version_minor": 0
}, },
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment