From 40c294d901b594819a156117511f07cfe707b354 Mon Sep 17 00:00:00 2001 From: "yuxin.wang" Date: Tue, 5 Mar 2024 11:43:42 +0800 Subject: [PATCH] Add support for multi-part messages in AnthropicChat model (#47) Co-authored-by: wangyuxin --- README.md | 2 +- docs/index.md | 2 +- generate/chat_completion/models/anthropic.py | 56 ++++++++++++++++---- generate/version.py | 2 +- pyproject.toml | 2 +- tests/test_chat_completion_model.py | 9 +++- 6 files changed, 56 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index f74d0c1..a6bef97 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ Generate 允许用户通过统一的 api 访问多平台的生成式模型,当 |----------------|---------|---------|---------|-----------|-----------| | OpenAI | ✅ | ✅ | ✅ | ✅ | ✅ | | Azure | ✅ | ✅ | ❌ | ✅ | ✅ | -| Anthropic | ✅ | ✅ | ✅ | ❌ | ❌ | +| Anthropic | ✅ | ✅ | ✅ | ✅ | ❌ | | 文心 Wenxin | ✅ | ✅ | ✅ | ❌ | ✅ | | 百炼 Bailian | ✅ | ✅ | ✅ | ❌ | ❌ | | 灵积 DashScope | ✅ | ✅ | ✅ | ✅ | ❌ | diff --git a/docs/index.md b/docs/index.md index 15a04c4..6f3f65d 100644 --- a/docs/index.md +++ b/docs/index.md @@ -20,7 +20,7 @@ Generate 允许用户通过统一的 api 访问多平台的生成式模型,当 |----------------|---------|---------|---------|-----------|-----------| | OpenAI | ✅ | ✅ | ✅ | ✅ | ✅ | | Azure | ✅ | ✅ | ❌ | ✅ | ✅ | -| Anthropic | ✅ | ✅ | ✅ | ❌ | ❌ | +| Anthropic | ✅ | ✅ | ✅ | ✅ | ❌ | | 文心 Wenxin | ✅ | ✅ | ✅ | ❌ | ✅ | | 百炼 Bailian | ✅ | ✅ | ✅ | ❌ | ❌ | | 灵积 DashScope | ✅ | ✅ | ✅ | ✅ | ❌ | diff --git a/generate/chat_completion/models/anthropic.py b/generate/chat_completion/models/anthropic.py index 345db4a..832c586 100644 --- a/generate/chat_completion/models/anthropic.py +++ b/generate/chat_completion/models/anthropic.py @@ -1,5 +1,6 @@ from __future__ import annotations +import base64 import json from typing import Any, AsyncIterator, ClassVar, Dict, Iterator, List, Literal, Optional @@ -8,7 +9,16 @@ from generate.chat_completion.base import RemoteChatCompletionModel from generate.chat_completion.message import Prompt -from generate.chat_completion.message.core import AssistantMessage, Message, SystemMessage, UserMessage +from generate.chat_completion.message.core import ( + AssistantMessage, + ImagePart, + ImageUrlPart, + Message, + SystemMessage, + TextPart, + UserMessage, + UserMultiPartMessage, +) from generate.chat_completion.message.exception import MessageTypeError from generate.chat_completion.message.utils import ensure_messages from generate.chat_completion.model_output import ChatCompletionOutput, ChatCompletionStreamOutput @@ -89,7 +99,28 @@ def _convert_message(self, message: Message) -> dict[str, str]: return {'role': 'user', 'content': message.content} if isinstance(message, AssistantMessage): return {'role': 'assistant', 'content': message.content} - raise MessageTypeError(message, (UserMessage, AssistantMessage)) + if isinstance(message, UserMultiPartMessage): + message_dict = {'role': 'user', 'content': []} + for part in message.content: + if isinstance(part, TextPart): + message_dict['content'].append({'type': 'text', 'text': part.text}) + + if isinstance(part, ImagePart): + data = base64.b64encode(part.image).decode() + media_type = part.image_format or 'image/jpeg' + message_dict['content'].append( + {'type': 'image', 'source': {'type': 'base64', 'media_type': media_type, 'data': data}} + ) + + if isinstance(part, ImageUrlPart): + response = self.http_client.get({'url': part.image_url.url}) + data = base64.b64encode(response.content).decode() + media_type = response.headers.get('Content-Type') or 'image/jpeg' + message_dict['content'].append( + {'type': 'image', 'source': {'type': 'base64', 'media_type': media_type, 'data': data}} + ) + return message_dict + raise MessageTypeError(message, (UserMessage, AssistantMessage, UserMultiPartMessage)) @override def _get_request_parameters( @@ -135,15 +166,18 @@ def _process_reponse(self, response: dict[str, Any]) -> ChatCompletionOutput: ) def _calculate_cost(self, input_tokens: int, output_tokens: int) -> float | None: + model_price_mapping = { + 'claude-instant': (0.80, 2.40), + 'claude-2': (8, 24), + 'claude-3-haiku': (0.25, 1.25), + 'claude-3-sonnet': (3, 15), + 'claude-3-opus': (15, 75), + } dollar_to_yuan = 7 - if 'claude-instant' in self.model: - # prompt: $0.80/million tokens, completion: $2.40/million tokens - cost = (input_tokens * 0.8 / 1_000_000) + (output_tokens * 2.4 / 1_000_000) - return cost * dollar_to_yuan - if 'claude-2' in self.model: - # prompt: $8/million tokens, completion: $24/million tokens - cost = (input_tokens * 8 / 1_000_000) + (output_tokens * 24 / 1_000_000) - return cost * dollar_to_yuan + for model_name, (prompt_price, completion_price) in model_price_mapping.items(): + if model_name in self.model: + cost = (input_tokens * prompt_price / 1_000_000) + (output_tokens * completion_price / 1_000_000) + return cost * dollar_to_yuan return None @override @@ -163,7 +197,7 @@ def _process_stream_line(self, line: str, stream_manager: StreamManager) -> Chat delta_dict = data['delta'] stream_manager.delta = '' stream_manager.finish_reason = delta_dict['stop_reason'] - stream_manager.extra['output_tokens'] = data['usage']['output_tokens'] + stream_manager.extra['usage']['output_tokens'] = data['usage']['output_tokens'] stream_manager.cost = self._calculate_cost(**stream_manager.extra['usage']) return stream_manager.build_stream_output() diff --git a/generate/version.py b/generate/version.py index 443436f..f0ede3d 100644 --- a/generate/version.py +++ b/generate/version.py @@ -1 +1 @@ -__version__ = '0.4.0.post1' +__version__ = '0.4.1' diff --git a/pyproject.toml b/pyproject.toml index cc45bd5..ee1d107 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "generate-core" -version = "0.4.0.post1" +version = "0.4.1" description = "文本生成,图像生成,语音生成" authors = ["wangyuxin "] license = "MIT" diff --git a/tests/test_chat_completion_model.py b/tests/test_chat_completion_model.py index f3ece5f..a1a070e 100644 --- a/tests/test_chat_completion_model.py +++ b/tests/test_chat_completion_model.py @@ -70,7 +70,7 @@ def test_http_stream_chat_model(chat_completion_model: ChatCompletionModel) -> N 'test_multimodel_chat_completion', ChatModelRegistry, types='model_cls', - include=['dashscope_multimodal', 'zhipu', 'openai'], + include=['dashscope_multimodal', 'zhipu', 'openai', 'anthropic'], ), ) def test_multimodel_chat_completion(model_cls: Type[ChatCompletionModel]) -> None: @@ -81,7 +81,12 @@ def test_multimodel_chat_completion(model_cls: Type[ChatCompletionModel]) -> Non {'image_url': {'url': 'https://dashscope.oss-cn-beijing.aliyuncs.com/images/dog_and_girl.jpeg'}}, ], } - model = model_cls(model='gpt-4-vision-preview') if model_cls.model_type == 'openai' else model_cls() + if model_cls.model_type == 'openai': + model = model_cls(model='gpt-4-vision-preview') + elif model_cls.model_type == 'anthropic': + model = model_cls(model='claude-3-sonnet-20240229') + else: + model = model_cls() output = model.generate(user_message) assert output.reply != ''