Skip to content

Commit

Permalink
[python]remove keys to check param dict (#2677)
Browse files Browse the repository at this point in the history
  • Loading branch information
sindhuvahinis authored Jan 24, 2025
1 parent 69bbe22 commit 6b29181
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,20 +63,20 @@ def translate_triton_params(self, parameters: dict) -> dict:
:return: The same parameters dict, but with TensorRT-LLM style parameter names.
"""
if "request_output_len" not in parameters.keys():
if "request_output_len" not in parameters:
parameters["request_output_len"] = parameters.pop(
"max_new_tokens", 30)
if "top_k" in parameters.keys():
if "top_k" in parameters:
parameters["runtime_top_k"] = parameters.pop("top_k")
if "top_p" in parameters.keys():
if "top_p" in parameters:
parameters["runtime_top_p"] = parameters.pop("top_p")
if "seed" in parameters.keys():
if "seed" in parameters:
parameters["random_seed"] = int(parameters.pop("seed"))
if parameters.pop("do_sample", False):
parameters["runtime_top_k"] = parameters.get("runtime_top_k", 5)
parameters["runtime_top_p"] = parameters.get("runtime_top_p", 0.85)
parameters["temperature"] = parameters.get("temperature", 0.8)
if "length_penalty" in parameters.keys():
if "length_penalty" in parameters:
parameters['len_penalty'] = parameters.pop('length_penalty')
parameters["streaming"] = parameters.pop(
"stream", parameters.get("streaming", True))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@

def translate_neuronx_params(parameters: dict) -> dict:
# TODO: Remove this once presence_penalty is supported
if "presence_penalty" in parameters.keys(
) and "repetition_penalty" not in parameters.keys():
if "presence_penalty" in parameters and "repetition_penalty" not in parameters:
parameters["repetition_penalty"] = float(
parameters.pop("presence_penalty")) + 2.0
return parameters
Expand Down
8 changes: 4 additions & 4 deletions engines/python/setup/djl_python/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def is_best_of(parameters: dict) -> bool:
:param parameters: parameters dictionary
:return: boolean
"""
return "best_of" in parameters.keys() and parameters.get("best_of") > 1
return "best_of" in parameters and parameters.get("best_of") > 1


def is_beam_search(parameters: dict) -> bool:
Expand All @@ -78,7 +78,7 @@ def is_beam_search(parameters: dict) -> bool:
:param parameters: parameters dictionary
:return: boolean
"""
return "num_beams" in parameters.keys() and parameters.get("num_beams") > 1
return "num_beams" in parameters and parameters.get("num_beams") > 1


def is_multiple_sequences(parameters: dict) -> bool:
Expand All @@ -88,7 +88,7 @@ def is_multiple_sequences(parameters: dict) -> bool:
:param parameters: parameters dictionary
:return: boolean
"""
return "n" in parameters.keys() and parameters.get("n") > 1
return "n" in parameters and parameters.get("n") > 1


def is_streaming(parameters: dict) -> bool:
Expand All @@ -97,7 +97,7 @@ def is_streaming(parameters: dict) -> bool:
:param parameters: parameters dictionary
:return: boolean
"""
return "stream" in parameters.keys() and parameters.get("stream")
return "stream" in parameters and parameters.get("stream")


def wait_till_generation_finished(parameters):
Expand Down

0 comments on commit 6b29181

Please sign in to comment.