diff --git a/models/vanilla_llama3_baseline.py b/models/vanilla_llama3_baseline.py index 88724f5568414989a8c25fb5945a2592e07c1aba..eafb3107339b33e438113fb5c579e3c359e2ce0a 100644 --- a/models/vanilla_llama3_baseline.py +++ b/models/vanilla_llama3_baseline.py @@ -50,6 +50,7 @@ class Llama3_8B_ZeroShotModel(ShopBenchBaseModel): # initialize the model with vllm self.llm = vllm.LLM( self.model_name, + worker_use_ray=True, tensor_parallel_size=VLLM_TENSOR_PARALLEL_SIZE, gpu_memory_utilization=VLLM_GPU_MEMORY_UTILIZATION, trust_remote_code=True,