Commit ff79c0a2 authored by dzorlu's avatar dzorlu

change tf dataset generator to accomodate different episode lengths

parent 7e9d488f
......@@ -117,10 +117,15 @@ class DemonstrationRecorder:
@property
def episode_reward(self):
return self._episode_reward
def _change_shape(self, shape):
shape = list(shape)
shape[0] = None
return tuple(shape)
def make_tf_dataset(self):
types = tree.map_structure(lambda x: x.dtype, self._demos[0])
shapes = tree.map_structure(lambda x: x.shape, self._demos[0])
shapes = tree.map_structure(lambda x: self._change_shape(x.shape), self._demos[0])
ds = tf.data.Dataset.from_generator(lambda: self._demos, types, shapes)
return ds.repeat().shuffle(len(self._demos))
......@@ -165,7 +170,7 @@ def main():
env_loop_logger = loggers.TerminalLogger(label='env_loop', time_delta=10.)
# Build demonstrations
recorder = build_demonstrations(data, sequence_length)
demonstration_dataset = build_demonstrations(data, sequence_length)
# Construct the network.
network = create_network()
......@@ -176,14 +181,14 @@ def main():
environment_spec=spec,
network=network,
target_network=target_network,
demonstration_dataset=recorder,
demonstration_dataset=demonstration_dataset,
demonstration_ratio=0.5,
batch_size=10,
samples_per_insert=2,
min_replay_size=10,
burn_in_length=burn_in_length,
trace_length=trace_length,
replay_period=4,
replay_period=40, # per R2D3 paper.
checkpoint=False,
logger=agent_logger
)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment