Skip to content

Commit

Permalink
Add support for multi-part messages in AnthropicChat model (#47)
Browse files Browse the repository at this point in the history
Co-authored-by: wangyuxin <wangyuxin@mokahr.com>
  • Loading branch information
wangyuxinwhy and wangyuxin authored Mar 5, 2024
1 parent 4dc47bd commit 40c294d
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 17 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Generate 允许用户通过统一的 api 访问多平台的生成式模型,当
|----------------|---------|---------|---------|-----------|-----------|
| OpenAI ||||||
| Azure ||||||
| Anthropic |||| ||
| Anthropic |||| ||
| 文心 Wenxin ||||||
| 百炼 Bailian ||||||
| 灵积 DashScope ||||||
Expand Down
2 changes: 1 addition & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ Generate 允许用户通过统一的 api 访问多平台的生成式模型,当
|----------------|---------|---------|---------|-----------|-----------|
| OpenAI ||||||
| Azure ||||||
| Anthropic |||| ||
| Anthropic |||| ||
| 文心 Wenxin ||||||
| 百炼 Bailian ||||||
| 灵积 DashScope ||||||
Expand Down
56 changes: 45 additions & 11 deletions generate/chat_completion/models/anthropic.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import base64
import json
from typing import Any, AsyncIterator, ClassVar, Dict, Iterator, List, Literal, Optional

Expand All @@ -8,7 +9,16 @@

from generate.chat_completion.base import RemoteChatCompletionModel
from generate.chat_completion.message import Prompt
from generate.chat_completion.message.core import AssistantMessage, Message, SystemMessage, UserMessage
from generate.chat_completion.message.core import (
AssistantMessage,
ImagePart,
ImageUrlPart,
Message,
SystemMessage,
TextPart,
UserMessage,
UserMultiPartMessage,
)
from generate.chat_completion.message.exception import MessageTypeError
from generate.chat_completion.message.utils import ensure_messages
from generate.chat_completion.model_output import ChatCompletionOutput, ChatCompletionStreamOutput
Expand Down Expand Up @@ -89,7 +99,28 @@ def _convert_message(self, message: Message) -> dict[str, str]:
return {'role': 'user', 'content': message.content}
if isinstance(message, AssistantMessage):
return {'role': 'assistant', 'content': message.content}
raise MessageTypeError(message, (UserMessage, AssistantMessage))
if isinstance(message, UserMultiPartMessage):
message_dict = {'role': 'user', 'content': []}
for part in message.content:
if isinstance(part, TextPart):
message_dict['content'].append({'type': 'text', 'text': part.text})

if isinstance(part, ImagePart):
data = base64.b64encode(part.image).decode()
media_type = part.image_format or 'image/jpeg'
message_dict['content'].append(
{'type': 'image', 'source': {'type': 'base64', 'media_type': media_type, 'data': data}}
)

if isinstance(part, ImageUrlPart):
response = self.http_client.get({'url': part.image_url.url})
data = base64.b64encode(response.content).decode()
media_type = response.headers.get('Content-Type') or 'image/jpeg'
message_dict['content'].append(
{'type': 'image', 'source': {'type': 'base64', 'media_type': media_type, 'data': data}}
)
return message_dict
raise MessageTypeError(message, (UserMessage, AssistantMessage, UserMultiPartMessage))

@override
def _get_request_parameters(
Expand Down Expand Up @@ -135,15 +166,18 @@ def _process_reponse(self, response: dict[str, Any]) -> ChatCompletionOutput:
)

def _calculate_cost(self, input_tokens: int, output_tokens: int) -> float | None:
model_price_mapping = {
'claude-instant': (0.80, 2.40),
'claude-2': (8, 24),
'claude-3-haiku': (0.25, 1.25),
'claude-3-sonnet': (3, 15),
'claude-3-opus': (15, 75),
}
dollar_to_yuan = 7
if 'claude-instant' in self.model:
# prompt: $0.80/million tokens, completion: $2.40/million tokens
cost = (input_tokens * 0.8 / 1_000_000) + (output_tokens * 2.4 / 1_000_000)
return cost * dollar_to_yuan
if 'claude-2' in self.model:
# prompt: $8/million tokens, completion: $24/million tokens
cost = (input_tokens * 8 / 1_000_000) + (output_tokens * 24 / 1_000_000)
return cost * dollar_to_yuan
for model_name, (prompt_price, completion_price) in model_price_mapping.items():
if model_name in self.model:
cost = (input_tokens * prompt_price / 1_000_000) + (output_tokens * completion_price / 1_000_000)
return cost * dollar_to_yuan
return None

@override
Expand All @@ -163,7 +197,7 @@ def _process_stream_line(self, line: str, stream_manager: StreamManager) -> Chat
delta_dict = data['delta']
stream_manager.delta = ''
stream_manager.finish_reason = delta_dict['stop_reason']
stream_manager.extra['output_tokens'] = data['usage']['output_tokens']
stream_manager.extra['usage']['output_tokens'] = data['usage']['output_tokens']
stream_manager.cost = self._calculate_cost(**stream_manager.extra['usage'])
return stream_manager.build_stream_output()

Expand Down
2 changes: 1 addition & 1 deletion generate/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.4.0.post1'
__version__ = '0.4.1'
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "generate-core"
version = "0.4.0.post1"
version = "0.4.1"
description = "文本生成,图像生成,语音生成"
authors = ["wangyuxin <wangyuxin@mokahr.com>"]
license = "MIT"
Expand Down
9 changes: 7 additions & 2 deletions tests/test_chat_completion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def test_http_stream_chat_model(chat_completion_model: ChatCompletionModel) -> N
'test_multimodel_chat_completion',
ChatModelRegistry,
types='model_cls',
include=['dashscope_multimodal', 'zhipu', 'openai'],
include=['dashscope_multimodal', 'zhipu', 'openai', 'anthropic'],
),
)
def test_multimodel_chat_completion(model_cls: Type[ChatCompletionModel]) -> None:
Expand All @@ -81,7 +81,12 @@ def test_multimodel_chat_completion(model_cls: Type[ChatCompletionModel]) -> Non
{'image_url': {'url': 'https://dashscope.oss-cn-beijing.aliyuncs.com/images/dog_and_girl.jpeg'}},
],
}
model = model_cls(model='gpt-4-vision-preview') if model_cls.model_type == 'openai' else model_cls()
if model_cls.model_type == 'openai':
model = model_cls(model='gpt-4-vision-preview')
elif model_cls.model_type == 'anthropic':
model = model_cls(model='claude-3-sonnet-20240229')
else:
model = model_cls()
output = model.generate(user_message)
assert output.reply != ''

Expand Down

0 comments on commit 40c294d

Please sign in to comment.