diff --git a/local_evaluation.py b/local_evaluation.py index 56ae05b557d586fd44069aa78ff3e92e6a43677f..16da38396920e49049683ce02649e675bfa12e76 100644 --- a/local_evaluation.py +++ b/local_evaluation.py @@ -37,9 +37,8 @@ def get_responses(agent, test_data, BATCH_SIZE): for turn_id in range(7): batch_inputs = [test_data[i][f"turn_{turn_id}"] for i in batch_idx] responses = agent.generate_responses(batch_inputs) - for resp in responses: - for bi in batch_idx: - all_responses[bi][f"turn_{turn_id}"] = resp + for bi, resp in zip(batch_idx, responses): + all_responses[bi][f"turn_{turn_id}"] = resp return all_responses def evaluate(responses, test_data):