From 67d2c70706a9ad7aea6638948f135f1e8628f7da Mon Sep 17 00:00:00 2001 From: Dipam Chakraborty <dipamc77@gmail.com> Date: Fri, 10 Nov 2023 10:53:51 +0530 Subject: [PATCH] fix batching bug --- local_evaluation.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/local_evaluation.py b/local_evaluation.py index 56ae05b..16da383 100644 --- a/local_evaluation.py +++ b/local_evaluation.py @@ -37,9 +37,8 @@ def get_responses(agent, test_data, BATCH_SIZE): for turn_id in range(7): batch_inputs = [test_data[i][f"turn_{turn_id}"] for i in batch_idx] responses = agent.generate_responses(batch_inputs) - for resp in responses: - for bi in batch_idx: - all_responses[bi][f"turn_{turn_id}"] = resp + for bi, resp in zip(batch_idx, responses): + all_responses[bi][f"turn_{turn_id}"] = resp return all_responses def evaluate(responses, test_data): -- GitLab