diff --git a/engines/python/setup/djl_python/rolling_batch/trtllm_rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/trtllm_rolling_batch.py index 02e0f185b..e1ba4651c 100644 --- a/engines/python/setup/djl_python/rolling_batch/trtllm_rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/trtllm_rolling_batch.py @@ -63,20 +63,20 @@ def translate_triton_params(self, parameters: dict) -> dict: :return: The same parameters dict, but with TensorRT-LLM style parameter names. """ - if "request_output_len" not in parameters.keys(): + if "request_output_len" not in parameters: parameters["request_output_len"] = parameters.pop( "max_new_tokens", 30) - if "top_k" in parameters.keys(): + if "top_k" in parameters: parameters["runtime_top_k"] = parameters.pop("top_k") - if "top_p" in parameters.keys(): + if "top_p" in parameters: parameters["runtime_top_p"] = parameters.pop("top_p") - if "seed" in parameters.keys(): + if "seed" in parameters: parameters["random_seed"] = int(parameters.pop("seed")) if parameters.pop("do_sample", False): parameters["runtime_top_k"] = parameters.get("runtime_top_k", 5) parameters["runtime_top_p"] = parameters.get("runtime_top_p", 0.85) parameters["temperature"] = parameters.get("temperature", 0.8) - if "length_penalty" in parameters.keys(): + if "length_penalty" in parameters: parameters['len_penalty'] = parameters.pop('length_penalty') parameters["streaming"] = parameters.pop( "stream", parameters.get("streaming", True)) diff --git a/engines/python/setup/djl_python/transformers_neuronx_scheduler/slot.py b/engines/python/setup/djl_python/transformers_neuronx_scheduler/slot.py index c4bcbbec6..3501acdf4 100644 --- a/engines/python/setup/djl_python/transformers_neuronx_scheduler/slot.py +++ b/engines/python/setup/djl_python/transformers_neuronx_scheduler/slot.py @@ -29,8 +29,7 @@ def translate_neuronx_params(parameters: dict) -> dict: # TODO: Remove this once presence_penalty is supported - if "presence_penalty" in parameters.keys( - ) and "repetition_penalty" not in parameters.keys(): + if "presence_penalty" in parameters and "repetition_penalty" not in parameters: parameters["repetition_penalty"] = float( parameters.pop("presence_penalty")) + 2.0 return parameters diff --git a/engines/python/setup/djl_python/utils.py b/engines/python/setup/djl_python/utils.py index a8c79183a..a5c8cc48a 100644 --- a/engines/python/setup/djl_python/utils.py +++ b/engines/python/setup/djl_python/utils.py @@ -69,7 +69,7 @@ def is_best_of(parameters: dict) -> bool: :param parameters: parameters dictionary :return: boolean """ - return "best_of" in parameters.keys() and parameters.get("best_of") > 1 + return "best_of" in parameters and parameters.get("best_of") > 1 def is_beam_search(parameters: dict) -> bool: @@ -78,7 +78,7 @@ def is_beam_search(parameters: dict) -> bool: :param parameters: parameters dictionary :return: boolean """ - return "num_beams" in parameters.keys() and parameters.get("num_beams") > 1 + return "num_beams" in parameters and parameters.get("num_beams") > 1 def is_multiple_sequences(parameters: dict) -> bool: @@ -88,7 +88,7 @@ def is_multiple_sequences(parameters: dict) -> bool: :param parameters: parameters dictionary :return: boolean """ - return "n" in parameters.keys() and parameters.get("n") > 1 + return "n" in parameters and parameters.get("n") > 1 def is_streaming(parameters: dict) -> bool: @@ -97,7 +97,7 @@ def is_streaming(parameters: dict) -> bool: :param parameters: parameters dictionary :return: boolean """ - return "stream" in parameters.keys() and parameters.get("stream") + return "stream" in parameters and parameters.get("stream") def wait_till_generation_finished(parameters):