Commit 0d88ef17 authored by pfrl_rainbow's avatar pfrl_rainbow

pytorch with gpu

parent c1159605
...@@ -4,5 +4,5 @@ ...@@ -4,5 +4,5 @@
"authors": ["minerl_rainbow_baseline"], "authors": ["minerl_rainbow_baseline"],
"tags": "RL", "tags": "RL",
"description": "Test Model for MineRL Challenge", "description": "Test Model for MineRL Challenge",
"gpu": false "gpu": true
} }
...@@ -57,12 +57,12 @@ class MineRLAgentBase(abc.ABC): ...@@ -57,12 +57,12 @@ class MineRLAgentBase(abc.ABC):
""" """
To compete in the competition, you are required to implement a To compete in the competition, you are required to implement a
SUBCLASS to this class. SUBCLASS to this class.
YOUR SUBMISSION WILL FAIL IF: YOUR SUBMISSION WILL FAIL IF:
* Rename this class * Rename this class
* You do not implement a subclass to this class * You do not implement a subclass to this class
This class enables the evaluator to run your agent in parallel, This class enables the evaluator to run your agent in parallel,
so you should load your model only once in the 'load_agent' method. so you should load your model only once in the 'load_agent' method.
""" """
...@@ -83,9 +83,9 @@ class MineRLAgentBase(abc.ABC): ...@@ -83,9 +83,9 @@ class MineRLAgentBase(abc.ABC):
You should just implement the standard environment interaction loop here: You should just implement the standard environment interaction loop here:
obs = env.reset() obs = env.reset()
while not done: while not done:
env.step(self.agent.act(obs)) env.step(self.agent.act(obs))
... ...
NOTE: This method will be called in PARALLEL during evaluation. NOTE: This method will be called in PARALLEL during evaluation.
So, only store state in LOCAL variables. So, only store state in LOCAL variables.
For example, if using an LSTM, don't store the hidden state in the class For example, if using an LSTM, don't store the hidden state in the class
...@@ -103,7 +103,7 @@ class MineRLAgentBase(abc.ABC): ...@@ -103,7 +103,7 @@ class MineRLAgentBase(abc.ABC):
class MineRLMatrixAgent(MineRLAgentBase): class MineRLMatrixAgent(MineRLAgentBase):
""" """
An example random agent. An example random agent.
Note, you MUST subclass MineRLAgentBase. Note, you MUST subclass MineRLAgentBase.
""" """
...@@ -128,6 +128,11 @@ class MineRLMatrixAgent(MineRLAgentBase): ...@@ -128,6 +128,11 @@ class MineRLMatrixAgent(MineRLAgentBase):
Args: Args:
single_episode_env (Episode): The episode on which to run the agent. single_episode_env (Episode): The episode on which to run the agent.
""" """
import torch
device = torch.device('cuda:0')
x = torch.randn(64, 1000, device=device, dtype=torch.float)
assert torch.cuda.is_available()
obs = single_episode_env.reset() obs = single_episode_env.reset()
done = False done = False
while not done: while not done:
...@@ -145,9 +150,9 @@ class MineRLRandomAgent(MineRLAgentBase): ...@@ -145,9 +150,9 @@ class MineRLRandomAgent(MineRLAgentBase):
while not done: while not done:
random_act = single_episode_env.action_space.sample() random_act = single_episode_env.action_space.sample()
single_episode_env.step(random_act) single_episode_env.step(random_act)
##################################################################### #####################################################################
# IMPORTANT: SET THIS VARIABLE WITH THE AGENT CLASS YOU ARE USING # # IMPORTANT: SET THIS VARIABLE WITH THE AGENT CLASS YOU ARE USING #
###################################################################### ######################################################################
AGENT_TO_TEST = MineRLMatrixAgent # MineRLMatrixAgent, MineRLRandomAgent, YourAgentHere AGENT_TO_TEST = MineRLMatrixAgent # MineRLMatrixAgent, MineRLRandomAgent, YourAgentHere
...@@ -177,7 +182,7 @@ def main(): ...@@ -177,7 +182,7 @@ def main():
except EpisodeDone: except EpisodeDone:
print("[{}] Episode complete".format(i)) print("[{}] Episode complete".format(i))
pass pass
evaluator_threads = [threading.Thread(target=evaluate, args=(i, envs[i])) for i in range(EVALUATION_THREAD_COUNT)] evaluator_threads = [threading.Thread(target=evaluate, args=(i, envs[i])) for i in range(EVALUATION_THREAD_COUNT)]
for thread in evaluator_threads: for thread in evaluator_threads:
thread.start() thread.start()
...@@ -188,5 +193,3 @@ def main(): ...@@ -188,5 +193,3 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment