Skip to content
Snippets Groups Projects
Commit e4c677e5 authored by spmohanty's avatar spmohanty
Browse files

Addresses #117 - Implements interface for env_client.step

parent 2dfc3174
No related branches found
No related tags found
No related merge requests found
...@@ -11,7 +11,6 @@ m.patch() ...@@ -11,7 +11,6 @@ m.patch()
import hashlib import hashlib
import random import random
from flatland.evaluators import messages from flatland.evaluators import messages
from flatland.evaluators.utils import get_all_env_pickle_files
from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env import RailEnv
from flatland.envs.generators import rail_from_file from flatland.envs.generators import rail_from_file
...@@ -110,7 +109,7 @@ class FlatlandRemoteClient(object): ...@@ -110,7 +109,7 @@ class FlatlandRemoteClient(object):
if self.verbose: print("Response : ", _response) if self.verbose: print("Response : ", _response)
_response = msgpack.unpackb(_response, object_hook=m.decode, encoding="utf8") _response = msgpack.unpackb(_response, object_hook=m.decode, encoding="utf8")
if _response['type'] == messages.FLATLAND_RL.ERROR: if _response['type'] == messages.FLATLAND_RL.ERROR:
raise Exception(str(_response)) raise Exception(str(_response["payload"]))
else: else:
return _response return _response
...@@ -195,12 +194,27 @@ if __name__ == "__main__": ...@@ -195,12 +194,27 @@ if __name__ == "__main__":
_action[_idx] = np.random.randint(0, 5) _action[_idx] = np.random.randint(0, 5)
return _action return _action
obs = True
episode = 0 episode = 0
obs = True
while obs: while obs:
obs = env_client.env_create() obs = env_client.env_create()
if not obs:
break
print("Episode : {}".format(episode)) print("Episode : {}".format(episode))
print(obs)
print(env_client.env.width)
print(env_client.env.height)
episode += 1 episode += 1
print(env_client.env.dones['__all__'])
while True:
action = my_controller(obs, env_client.env)
observation, all_rewards, done, info = env_client.env_step(action)
if done['__all__']:
print("Current Episode : ", episode)
print("Episode Done")
print("Reward : ", sum(list(all_rewards.values())))
break
print("Evaluation Complete...")
print(env_client.submit())
...@@ -21,6 +21,8 @@ import os ...@@ -21,6 +21,8 @@ import os
import timeout_decorator import timeout_decorator
import time import time
import traceback
######################################################## ########################################################
# CONSTANTS # CONSTANTS
######################################################## ########################################################
...@@ -69,6 +71,8 @@ class FlatlandRemoteEvaluationService: ...@@ -69,6 +71,8 @@ class FlatlandRemoteEvaluationService:
self.reward = 0 self.reward = 0
self.simulation_count = 0 self.simulation_count = 0
self.simualation_rewards = [] self.simualation_rewards = []
self.simulation_percentage_complete = []
self.simulation_steps = []
self.simulation_times = [] self.simulation_times = []
self.begin_simulation = False self.begin_simulation = False
self.current_step = 0 self.current_step = 0
...@@ -205,6 +209,9 @@ class FlatlandRemoteEvaluationService: ...@@ -205,6 +209,9 @@ class FlatlandRemoteEvaluationService:
self.begin_simulation = time.time() self.begin_simulation = time.time()
self.simualation_rewards.append(0) self.simualation_rewards.append(0)
self.simulation_percentage_complete.append(0)
self.simulation_steps.append(0)
self.current_step = 0 self.current_step = 0
_observation = self.env.reset() _observation = self.env.reset()
...@@ -227,7 +234,7 @@ class FlatlandRemoteEvaluationService: ...@@ -227,7 +234,7 @@ class FlatlandRemoteEvaluationService:
All test env evaluations are complete All test env evaluations are complete
""" """
_command_response = {} _command_response = {}
_command_response['type'] = messages.FLATLAND_RL.ENV_RESET_RESPONSE _command_response['type'] = messages.FLATLAND_RL.ENV_CREATE_RESPONSE
_command_response['payload'] = {} _command_response['payload'] = {}
_command_response['payload']['observation'] = False _command_response['payload']['observation'] = False
_command_response['payload']['env_file_path'] = False _command_response['payload']['env_file_path'] = False
...@@ -240,6 +247,90 @@ class FlatlandRemoteEvaluationService: ...@@ -240,6 +247,90 @@ class FlatlandRemoteEvaluationService:
use_bin_type=True) use_bin_type=True)
) )
def handle_env_step(self, command):
_redis = self.get_redis_connection()
command_response_channel = command['response_channel']
_payload = command['payload']
if self.env.dones['__all__']:
raise Exception("Client attempted to perform an action on an Env which has done['__all__']==True")
action = _payload['action']
_observation, all_rewards, done, info = self.env.step(action)
cumulative_reward = np.sum(list(all_rewards.values()))
self.simualation_rewards[-1] += cumulative_reward
self.simulation_steps[-1] += 1
if done["__all__"]:
# Compute percentage complete
complete = 0
for i_agent in range(self.env.get_num_agents()):
agent = self.env.agents[i_agent]
if agent.position == agent.target:
complete += 1
percentage_complete = complete * 1.0 / self.env.get_num_agents()
self.simulation_percentage_complete[-1] = percentage_complete
# Build and send response
_command_response = {}
_command_response['type'] = messages.FLATLAND_RL.ENV_STEP_RESPONSE
_command_response['payload'] = {}
_command_response['payload']['observation'] = _observation
_command_response['payload']['reward'] = all_rewards
_command_response['payload']['done'] = done
_command_response['payload']['info'] = info
if self.verbose:
# print("Responding with : ", _command_response)
print("Current Step : ", self.simulation_steps[-1])
_redis.rpush(
command_response_channel,
msgpack.packb(
_command_response,
default=m.encode,
use_bin_type=True)
)
def handle_env_submit(self, command):
_redis = self.get_redis_connection()
command_response_channel = command['response_channel']
_payload = command['payload']
# Register simulation time of the last episode
self.simulation_times.append(time.time()-self.begin_simulation)
_response = {}
_response['type'] = messages.FLATLAND_RL.ENV_SUBMIT_RESPONSE
_payload = {}
_payload['mean_reward'] = np.mean(self.simualation_rewards)
_payload['mean_percentage_complete'] = \
np.mean(self.simulation_percentage_complete)
_response['payload'] = _payload
if self.verbose:
print("Responding with : ", _response)
print("Registering Env Submit call")
_redis.rpush(
command_response_channel,
msgpack.packb(
_response,
default=m.encode,
use_bin_type=True)
)
def report_error(self, error_message, command_response_channel):
_redis = self.get_redis_connection()
_response = {}
_response['type'] = messages.FLATLAND_RL.ERROR
_response['payload'] = error_message
_redis.rpush(
command_response_channel,
msgpack.packb(
_response,
default=m.encode,
use_bin_type=True)
)
def run(self): def run(self):
print("Listening for commands at : ", self.command_channel) print("Listening for commands at : ", self.command_channel)
...@@ -269,117 +360,36 @@ class FlatlandRemoteEvaluationService: ...@@ -269,117 +360,36 @@ class FlatlandRemoteEvaluationService:
Respond with an internal _env object Respond with an internal _env object
""" """
self.handle_env_create(command) self.handle_env_create(command)
elif command['type'] == messages.FLATLAND_RL.ENV_RESET:
"""
ENV_RESET
Respond with observation from next simulation or
False if no simulations are left
"""
self.simulation_count += 1
if self.begin_simulation:
self.simulation_times.append(time.time()-self.begin_simulation)
self.begin_simulation = time.time()
if self.seed_map and self.simulation_count < len(self.seed_map):
_observation = self.env.reset(seed=self.seed_map[self.simulation_count], project=False)
self.simualation_rewards.append(0)
self.env_available = True
self.current_step = 0
#_observation = list(_observation)
_command_response = {}
_command_response['type'] = messages.FLATLAND_RL.ENV_RESET_RESPONSE
_command_response['payload'] = {}
_command_response['payload']['observation'] = _observation
if self.verbose: print("Responding with : ", _command_response)
_redis.rpush(command_response_channel, msgpack.packb(_command_response, default=m.encode, use_bin_type=True))
else:
_command_response = {}
_command_response['type'] = messages.FLATLAND_RL.ENV_RESET_RESPONSE
_command_response['payload'] = {}
_command_response['payload']['observation'] = False
if self.verbose: print("Responding with : ", _command_response)
_redis.rpush(command_response_channel, msgpack.packb(_command_response, default=m.encode, use_bin_type=True))
elif command['type'] == messages.FLATLAND_RL.ENV_STEP: elif command['type'] == messages.FLATLAND_RL.ENV_STEP:
""" """
ENV_STEP ENV_STEP
Request : Action array Request : Action dict
Respond with updated [observation,reward,done,info] after step Respond with updated [observation,reward,done,info] after step
""" """
args = command['payload'] self.handle_env_step(command)
action = args['action']
if self.env and self.env_available:
[_observation, reward, done, info] = self.env.step(action)
else:
if self.env:
raise Exception("Attempt to call `step` function after max_steps={} in a single simulation. Please reset your environment before calling the `step` function after max_step s".format(self.max_steps))
else:
raise Exception("Attempt to call `step` function on a non existent `env`")
self.reward += reward
self.simualation_rewards[-1] += reward
self.current_step += 1
#_observation = np.array(_observation).tolist()
if self.current_step >= self.max_steps:
_command_response = {}
_command_response['type'] = messages.FLATLAND_RL.ENV_STEP_RESPONSE
_command_response['payload'] = {}
_command_response['payload']['observation'] = _observation
_command_response['payload']['reward'] = reward
_command_response['payload']['done'] = True
_command_response['payload']['info'] = info
"""
Mark env as unavailable until next reset
"""
self.env_available = False
else:
_command_response = {}
_command_response['type'] = messages.FLATLAND_RL.ENV_STEP_RESPONSE
_command_response['payload'] = {}
_command_response['payload']['observation'] = _observation
_command_response['payload']['reward'] = reward
_command_response['payload']['done'] = done
_command_response['payload']['info'] = info
if done:
"""
Mark env as unavailable until next reset
"""
self.env_available = False
if self.verbose: print("Responding with : ", _command_response)
if self.verbose: print("Current Step : ", self.current_step)
_redis.rpush(command_response_channel, msgpack.packb(_command_response, default=m.encode, use_bin_type=True))
elif command['type'] == messages.FLATLAND_RL.ENV_SUBMIT: elif command['type'] == messages.FLATLAND_RL.ENV_SUBMIT:
""" """
ENV_SUBMIT ENV_SUBMIT
Submit the final cumulative reward Submit the final cumulative reward
""" """
_response = {} self.handle_env_submit(command)
_response['type'] = messages.FLATLAND_RL.ENV_SUBMIT_RESPONSE
_payload = {}
_payload['mean_reward'] = np.float(self.reward)/len(self.seed_map) #Mean reward
_payload['simulation_rewards'] = self.simualation_rewards
_payload['simulation_times'] = self.simulation_times
_response['payload'] = _payload
_redis.rpush(command_response_channel, msgpack.packb(_response, default=m.encode, use_bin_type=True))
elif command['type'] == messages.FLATLAND_RL.ENV_SUBMIT:
if self.verbose: print("Responding with : ", _response)
return _response
else: else:
_error = self._error_template( _error = self._error_template(
"UNKNOWN_REQUEST:{}".format( "UNKNOWN_REQUEST:{}".format(
str(command))) str(command)))
if self.verbose:print("Responding with : ", _error) if self.verbose:print("Responding with : ", _error)
_redis.rpush(command_response_channel, msgpack.packb(_error, default=m.encode, use_bin_type=True)) self.report_error(
self._error_template(str(e)),
command['response_channel'])
return _error return _error
except Exception as e: except Exception as e:
print("Error : ", str(e)) print("Error : ", str(e))
_redis.rpush( command_response_channel, print(traceback.format_exc())
msgpack.packb(self._error_template(str(e)), default=m.encode, use_bin_type=True)) self.report_error(
self._error_template(str(e)),
command['response_channel'])
return self._error_template(str(e)) return self._error_template(str(e))
if __name__ == "__main__": if __name__ == "__main__":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment