Skip to content

Commit

Permalink
Add integration test
Browse files Browse the repository at this point in the history
  • Loading branch information
xyang16 committed Jan 29, 2025
1 parent c7fa2f8 commit 93dda05
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 3 deletions.
124 changes: 122 additions & 2 deletions tests/integration/llm/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,19 @@ def get_model_name():
"batch_size": [1, 4],
"seq_length": [256],
"tokenizer": "TheBloke/Llama-2-7B-Chat-fp16"
}
},
"llama3-8b-tool": {
"batch_size": [1, 4],
"seq_length": [256],
"tool": True,
"tokenizer": "unsloth/Meta-Llama-3.1-8B-Instruct"
},
"mistral-7b-v03-tool": {
"batch_size": [1, 4],
"seq_length": [256],
"tool": True,
"tokenizer": "unsloth/mistral-7b-instruct-v0.3"
},
}

lmi_dist_aiccl_model_spec = {
Expand Down Expand Up @@ -1277,6 +1289,111 @@ def batch_generation_pair(batch_size):
return data[:batch_size]


def batch_generation_tool(batch_size):
data = [{
"messages": [{
"role": "user",
"content": "Hi! How are you doing today?"
}, {
"role": "assistant",
"content": "I am doing well! How can I help you?"
}, {
"role":
"user",
"content":
"Can you tell me what the temperate will be in Dallas, in fahrenheit?"
}],
"tools": [{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"city": {
"type":
"string",
"description":
"The city to find the weather for, e.g. San Francisco"
},
"state": {
"type":
"string",
"description":
"the two-letter abbreviation for the state that the city is in, e.g. CA which would mean California"
},
"unit": {
"type": "string",
"description":
"The unit to fetch the temperature in",
"enum": ["celsius", "fahrenheit"]
}
},
"required": ["city", "state", "unit"]
}
}
}],
"tool_choice":
"auto"
}, {
"messages": [{
"role": "user",
"content": "Hi! How are you doing today?"
}, {
"role": "assistant",
"content": "I am doing well! How can I help you?"
}, {
"role":
"user",
"content":
"Can you tell me what the temperate will be in Dallas, in fahrenheit?"
}],
"tools": [{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"city": {
"type":
"string",
"description":
"The city to find the weather for, e.g. San Francisco"
},
"state": {
"type":
"string",
"description":
"the two-letter abbreviation for the state that the city is in, e.g. CA which would mean California"
},
"unit": {
"type": "string",
"description":
"The unit to fetch the temperature in",
"enum": ["celsius", "fahrenheit"]
}
},
"required": ["city", "state", "unit"]
}
}
}],
"tool_choice": {
"type": "function",
"function": {
"name": "get_current_weather"
}
},
}]

if batch_size > len(data):
# dynamically extend to support larger bs by repetition
data *= math.ceil(batch_size / len(data))
return data[:batch_size]


def t5_batch_generation(batch_size):
input_sentences = [
"translate English to German: The house is wonderful.",
Expand Down Expand Up @@ -1521,7 +1638,10 @@ def test_handler_rolling_batch_chat(model, model_spec):
check_worker_number(spec["worker"])
stream_values = spec.get("stream", [False, True])
# dryrun phase
req = {"messages": batch_generation_chat(1)[0]}
if spec.get("tool", False):
req = batch_generation_tool(1)[0]
else:
req = {"messages": batch_generation_chat(1)[0]}
seq_length = 100
req["max_tokens"] = seq_length
req["logprobs"] = True
Expand Down
16 changes: 15 additions & 1 deletion tests/integration/llm/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -1056,7 +1056,21 @@
"option.max_model_len": 8192,
"option.max_rolling_batch_size": 16,
"option.enforce_eager": True,
}
},
"llama3-8b-tool": {
"option.model_id": "s3://djl-llm/llama-3.1-8b-hf/",
"option.tensor_parallel_degree": 4,
"option.max_rolling_batch_size": 4,
"option.enable_auto_tool_choice": True,
"option.tool_call_parser": "llama3_json",
},
"mistral-7b-v03-tool": {
"option.model_id": "s3://djl-llm/mistral-7b-instruct-v03/",
"option.tensor_parallel_degree": 4,
"option.max_rolling_batch_size": 4,
"option.enable_auto_tool_choice": True,
"option.tool_call_parser": "mistral",
},
}

vllm_neo_model_list = {
Expand Down
12 changes: 12 additions & 0 deletions tests/integration/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,18 @@ def test_llama_68m_speculative_eagle(self):
r.launch()
client.run("vllm llama-68m-speculative-eagle".split())

def test_llama3_8b_tool(self):
with Runner('lmi', 'llama3-8b-tool') as r:
prepare.build_vllm_model("llama3-8b-tool")
r.launch()
client.run("vllm_chat llama3-8b-tool".split())

def test_mistral_7b_v03_tool(self):
with Runner('lmi', 'mistral-7b-v03-tool') as r:
prepare.build_vllm_model("mistral-7b-v03-tool")
r.launch()
client.run("vllm_chat mistral-7b-v03-tool".split())


@pytest.mark.vllm
@pytest.mark.lora
Expand Down

0 comments on commit 93dda05

Please sign in to comment.