Skip to content
Snippets Groups Projects
Commit 576715cb authored by Erik Nygren's avatar Erik Nygren
Browse files

updated rendering method for debugging

parent c528e652
No related branches found
No related tags found
No related merge requests found
...@@ -8,16 +8,7 @@ import torch,random ...@@ -8,16 +8,7 @@ import torch,random
random.seed(1) random.seed(1)
np.random.seed(1) np.random.seed(1)
"""
transition_probability = [1.0, # empty cell - Case 0
3.0, # Case 1 - straight
1.0, # Case 2 - simple switch
3.0, # Case 3 - diamond drossing
2.0, # Case 4 - single slip
1.0, # Case 5 - double slip
1.0, # Case 6 - symmetrical
1.0] # Case 7 - dead end
"""
# Example generate a rail given a manual specification, # Example generate a rail given a manual specification,
# a map of tuples (cell_type, rotation) # a map of tuples (cell_type, rotation)
transition_probability = [1.0, # empty cell - Case 0 transition_probability = [1.0, # empty cell - Case 0
...@@ -63,6 +54,7 @@ for trials in range(1, n_trials + 1): ...@@ -63,6 +54,7 @@ for trials in range(1, n_trials + 1):
# Run episode # Run episode
for step in range(100): for step in range(100):
#env_renderer.renderEnv(show=True)
# Action # Action
for a in range(env.number_of_agents): for a in range(env.number_of_agents):
...@@ -73,7 +65,6 @@ for trials in range(1, n_trials + 1): ...@@ -73,7 +65,6 @@ for trials in range(1, n_trials + 1):
next_obs, all_rewards, done, _ = env.step(action_dict) next_obs, all_rewards, done, _ = env.step(action_dict)
# Update replay buffer and train agent # Update replay buffer and train agent
for a in range(env.number_of_agents): for a in range(env.number_of_agents):
agent.step(obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a]) agent.step(obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a])
...@@ -81,7 +72,7 @@ for trials in range(1, n_trials + 1): ...@@ -81,7 +72,7 @@ for trials in range(1, n_trials + 1):
obs = next_obs.copy() obs = next_obs.copy()
if all(done): if done['__all__']:
env_done = 1 env_done = 1
break break
# Epsioln decay # Epsioln decay
......
...@@ -403,10 +403,9 @@ class RenderTool(object): ...@@ -403,10 +403,9 @@ class RenderTool(object):
# cell_size is a bit pointless with matplotlib - it does not relate to pixels, # cell_size is a bit pointless with matplotlib - it does not relate to pixels,
# so for now I've changed it to 1 (from 10) # so for now I've changed it to 1 (from 10)
cell_size = 1 cell_size = 1
plt.clf()
# if oFigure is None: # if oFigure is None:
# oFigure = plt.figure() # oFigure = plt.figure()
def drawTrans(oFrom, oTo, sColor="gray"): def drawTrans(oFrom, oTo, sColor="gray"):
plt.plot( plt.plot(
[oFrom[0], oTo[0]], # x [oFrom[0], oTo[0]], # x
...@@ -551,7 +550,11 @@ class RenderTool(object): ...@@ -551,7 +550,11 @@ class RenderTool(object):
plt.xlim([0, env.width * cell_size]) plt.xlim([0, env.width * cell_size])
plt.ylim([-env.height * cell_size, 0]) plt.ylim([-env.height * cell_size, 0])
if show: if show:
plt.show() plt.show(block=False)
plt.pause(0.00001)
return
def _draw_square(self, center, size, color): def _draw_square(self, center, size, color):
x0 = center[0]-size/2 x0 = center[0]-size/2
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment