Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support vllm-0.6.6 #214

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 15 additions & 23 deletions chatlearn/models/vllm/hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,26 +19,18 @@
from .. import is_vllm_v2


if is_vllm_v2():
if importlib.util.find_spec("vllm"):
from . import ray_gpu_executor
from chatlearn.utils.constant import CURRENT_VLLM_VERSION, VLLMVersion
if CURRENT_VLLM_VERSION == VLLMVersion.v_0_6_3:
from chatlearn.models.vllm.hooks import input_preprocess
from chatlearn.models.vllm.hooks import async_llm_engine
from chatlearn.models.vllm.hooks import llm
from chatlearn.models.vllm.hooks import loader
from chatlearn.models.vllm.hooks import worker_base
else:
if importlib.util.find_spec("vllm"):
import vllm
from chatlearn.utils.constant import CURRENT_VLLM_VERSION, VLLMVersion # pylint: disable=ungrouped-imports
if CURRENT_VLLM_VERSION == VLLMVersion.v_0_3_0:
from chatlearn.models.vllm.hooks import sampler
elif CURRENT_VLLM_VERSION in [VLLMVersion.v_0_5_1, VLLMVersion.v_0_6_3]:
from chatlearn.models.vllm.hooks import llm_engine, logits_processor
if CURRENT_VLLM_VERSION == VLLMVersion.v_0_5_1:
from chatlearn.models.vllm.hooks import worker
else:
from chatlearn.models.vllm.hooks import input_preprocess
from chatlearn.models.vllm.hooks import format_device_name
if importlib.util.find_spec("vllm"):

from chatlearn.utils.constant import CURRENT_VLLM_VERSION, VLLMVersion

if CURRENT_VLLM_VERSION == VLLMVersion.v_0_3_0:
from chatlearn.models.vllm.hooks.vllm_0_3_0 import *
elif CURRENT_VLLM_VERSION == VLLMVersion.v_0_5_1:
from chatlearn.models.vllm.hooks.vllm_0_5_1 import *
elif CURRENT_VLLM_VERSION == VLLMVersion.v_0_6_3:
from chatlearn.models.vllm.hooks.vllm_0_6_3 import *
elif CURRENT_VLLM_VERSION == VLLMVersion.v_0_6_6:
from .vllm_0_6_6 import *
else:
raise RuntimeError(
f"vLLM version expected in {list(member.value for member in VLLMVersion)}, while {CURRENT_VLLM_VERSION}.")
62 changes: 0 additions & 62 deletions chatlearn/models/vllm/hooks/input_preprocess.py

This file was deleted.

21 changes: 21 additions & 0 deletions chatlearn/models/vllm/hooks/vllm_0_3_0/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Additional hooks of vllm-0.3.0."""

from ... import is_vllm_v2

assert not is_vllm_v2(), "vLLM-0.3.0 only supports vLLM Module v1. Set env `ENABLE_VLLM_V2=False`."

from . import sampler
23 changes: 23 additions & 0 deletions chatlearn/models/vllm/hooks/vllm_0_5_1/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Additional hooks of vllm-0.5.1."""

from ... import is_vllm_v2

assert not is_vllm_v2(), "vLLM-0.5.1 only supports vLLM Module v1. Set env `ENABLE_VLLM_V2=False`."

from . import llm_engine
from . import logits_processor
from . import worker
29 changes: 29 additions & 0 deletions chatlearn/models/vllm/hooks/vllm_0_6_3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Additional hooks of vllm-0.6.3."""

from ... import is_vllm_v2
from . import format_device_name
from . import input_preprocess

if is_vllm_v2():
from . import async_llm_engine
from . import llm
from . import loader
from . import ray_gpu_executor
from . import worker_base
else:
from . import llm_engine
from . import logits_processor
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Hooks of vllm-0.6.3 del init_ray_cluster in AsyncLLMEngine."""
"""del init_ray_cluster in AsyncLLMEngine."""

from typing import Dict, Optional

Expand Down
55 changes: 55 additions & 0 deletions chatlearn/models/vllm/hooks/vllm_0_6_3/input_preprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Hooks of vllm-0.6.3 input preprocess to pass prompt text."""

# pylint: disable=unused-import,unused-argument
from vllm.inputs import preprocess
from vllm.inputs.parse import parse_singleton_prompt

def extract_prompt_components(
self,
prompt,
request_id,
lora_request=None):
'''
Extract the components of any single encoder or decoder input prompt.

Arguments:

* request_id
* prompt: single encoder or decoder input prompt
* lora_request: this is only valid for decoder prompts

Returns:

* prompt
* prompt_token_ids
* multi_modal_data
* mm_processor_kwargs (request-level input processor/mapper overrides)
'''
parsed = parse_singleton_prompt(prompt)

assert parsed["type"] == "tokens", \
f"you must pass prompt_token_ids when add request to scheduler. while prompt {prompt}"

prompt_text = parsed["content"]["prompt"]
prompt_token_ids = parsed["content"]["prompt_token_ids"]
multi_modal_data = parsed["content"].get("multi_modal_data")
mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs")

return (prompt_text, prompt_token_ids, multi_modal_data,
mm_processor_kwargs)

preprocess.InputPreprocessor._extract_prompt_components = extract_prompt_components
30 changes: 30 additions & 0 deletions chatlearn/models/vllm/hooks/vllm_0_6_3/llm_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Hooks of vllm-0.5.1 llm_engine remove __reduce__ function."""

import inspect

# pylint: disable=unused-import,wildcard-import,unused-argument
from vllm.engine import llm_engine


source = inspect.getsource(llm_engine.LLMEngine.__reduce__)
if 'RuntimeError' in source:
def __reduce__(self):
# This is to ensure that the LLMEngine can be referenced in
# the closure used to initialize Ray worker actors
pass

del llm_engine.LLMEngine.__reduce__
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ def init(self, load_config):

loader.DummyModelLoader.__init__ = init


# add ckpt loading of megatron format
def load_model(self, *, model_config,
device_config,
Expand Down
42 changes: 42 additions & 0 deletions chatlearn/models/vllm/hooks/vllm_0_6_3/logits_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Hooks of vllm-0.5.1 logits_processor to allgather logits of all ranks."""

import inspect

# pylint: disable=wildcard-import,ungrouped-imports
from vllm.model_executor.layers import logits_processor


source = inspect.getsource(logits_processor.LogitsProcessor._get_logits)
if 'tensor_model_parallel_gather' in source:
import torch
from typing import Optional
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
def _get_logits(self, hidden_states: torch.Tensor,
lm_head: VocabParallelEmbedding,
embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
# Get the logits for the next tokens.
logits = lm_head.linear_method.apply(lm_head,
hidden_states,
bias=embedding_bias)
from vllm.distributed.communication_op import tensor_model_parallel_all_gather # pylint: disable=import-outside-toplevel
logits = tensor_model_parallel_all_gather(logits)
# Remove paddings in vocab (if any).
if logits is not None:
logits = logits[:, :self.org_vocab_size]
return logits

logits_processor.LogitsProcessor._get_logits = _get_logits
27 changes: 27 additions & 0 deletions chatlearn/models/vllm/hooks/vllm_0_6_6/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Additional hooks of vllm-0.6.6."""

from ... import is_vllm_v2

assert is_vllm_v2(), "vLLM-0.6.6 only supports vLLM Module v2."

from . import async_llm_engine
from . import input_preprocess
from . import llm
from . import llm_engine
from . import loader
from . import ray_gpu_executor
from . import worker_base
Loading
Loading