Skip to content
Snippets Groups Projects
Commit 159d5a66 authored by Shinya Shiroshita's avatar Shinya Shiroshita
Browse files

fix unexpected change, remove debug code

parent 3281ac8d
No related branches found
Tags submission-v0.1.17
No related merge requests found
......@@ -106,6 +106,8 @@ def _batch_reset_recurrent_states_when_episodes_end(
def load_experiences_from_demonstrations(
expert_dataset, batch_size, reward=1):
if expert_dataset is None:
raise ValueError("Expert dataset must be provided.")
ret = []
for _ in range(batch_size):
ob, act, _, next_ob, done = expert_dataset.sample()
......@@ -162,7 +164,6 @@ class RewardBasedSampler:
n_samples = [0 for _ in range(len(self.reward_boundaries) + 1)]
for frame in experiences:
n_samples[self._policy_index(frame[0]['state'])] += 1
print(n_samples)
ret = []
for rbuf, n_sample in zip(self.replay_buffers, n_samples):
samples = rbuf.sample(n_sample)
......@@ -235,8 +236,6 @@ class SQIL(agent.AttributeSavingMixin, agent.BatchAgent):
recurrent=False,
reward_boundaries=None, # specific to options
):
if expert_dataset is None:
raise ValueError("Expert dataset must be provided.")
self.expert_dataset = expert_dataset
self.model = q_function
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment