Skip to content

Commit

Permalink
Azure bug (#50)
Browse files Browse the repository at this point in the history
* Add model parameter to AzureChat constructor

---------

Co-authored-by: wangyuxin <wangyuxin@mokahr.com>
  • Loading branch information
wangyuxinwhy and wangyuxin authored Mar 11, 2024
1 parent de517e9 commit 540e25b
Show file tree
Hide file tree
Showing 7 changed files with 17 additions and 11 deletions.
10 changes: 9 additions & 1 deletion generate/chat_completion/models/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from generate.chat_completion.model_output import ChatCompletionOutput, ChatCompletionStreamOutput
from generate.chat_completion.models.openai import OpenAIChatParameters, OpenAIChatParametersDict
from generate.chat_completion.models.openai_like import convert_to_openai_message, process_openai_like_model_reponse
from generate.chat_completion.stream_manager import StreamManager
from generate.http import HttpClient, HttpxPostKwargs
from generate.platforms.azure import AzureSettings

Expand All @@ -21,14 +22,17 @@ class AzureChat(RemoteChatCompletionModel):

def __init__(
self,
model: str,
model: str | None = None,
parameters: OpenAIChatParameters | None = None,
settings: AzureSettings | None = None,
http_client: HttpClient | None = None,
) -> None:
parameters = parameters or OpenAIChatParameters()
settings = settings or AzureSettings() # type: ignore
http_client = http_client or HttpClient()
model = model or settings.chat_api_engine
if model is None:
raise ValueError('model must be provided or set in settings.chat_api_engine')
super().__init__(model, parameters=parameters, settings=settings, http_client=http_client)

@override
Expand Down Expand Up @@ -75,3 +79,7 @@ def _get_request_parameters(self, prompt: Prompt, **kwargs: Unpack[OpenAIChatPar
@override
def _process_reponse(self, response: dict[str, Any]) -> ChatCompletionOutput:
return process_openai_like_model_reponse(response, model_type=self.model_type)

@override
def _process_stream_line(self, line: str, stream_manager: StreamManager) -> ChatCompletionStreamOutput | None:
raise NotImplementedError
1 change: 0 additions & 1 deletion generate/chat_completion/models/minimax.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import json
import uuid
from typing import Any, AsyncIterator, ClassVar, Dict, Iterator, List, Optional

from pydantic import PositiveInt, field_validator
Expand Down
2 changes: 1 addition & 1 deletion generate/chat_completion/models/openai_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

import base64
import json
import uuid
from abc import ABC
from functools import partial
from typing import Any, Callable, Dict, List, Literal, Type, Union, cast

from typing_extensions import NotRequired, TypedDict, override
import uuid

from generate.chat_completion.base import RemoteChatCompletionModel
from generate.chat_completion.cost_caculator import GeneralCostCalculator
Expand Down
3 changes: 3 additions & 0 deletions generate/platforms/azure.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

from pydantic import SecretStr
from pydantic_settings import SettingsConfigDict

Expand All @@ -10,4 +12,5 @@ class AzureSettings(PlatformSettings):
api_key: SecretStr
api_base: str
api_version: str
chat_api_engine: Optional[str] = None
platform_url: str = 'https://learn.microsoft.com/en-us/azure/ai-services/openai/'
4 changes: 2 additions & 2 deletions generate/text_to_speech/models/minimax.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def _get_request_parameters(self, text: str, parameters: MinimaxSpeechParameters
'Content-Type': 'application/json',
}
return {
'url': self.settings.api_base + 'text_to_speech',
'url': self.settings.api_base + '/text_to_speech',
'json': json_data,
'headers': headers,
'params': {'GroupId': self.settings.group_id},
Expand Down Expand Up @@ -154,7 +154,7 @@ def _get_request_parameters(self, text: str, parameters: MinimaxProSpeechParamet
'Content-Type': 'application/json',
}
return {
'url': self.settings.api_base + 't2a_pro',
'url': self.settings.api_base + '/t2a_pro',
'json': json_data,
'headers': headers,
'params': {'GroupId': self.settings.group_id},
Expand Down
4 changes: 0 additions & 4 deletions tests/test_chat_completion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
RemoteChatCompletionModel,
)
from generate.chat_completion.message import Prompt
from generate.chat_completion.models.azure import AzureChat
from generate.test import get_pytest_params


Expand All @@ -33,9 +32,6 @@ def test_model_type_is_unique() -> None:
],
)
def test_http_chat_model(model_cls: Type[ChatCompletionModel], parameters: dict[str, Any]) -> None:
if issubclass(model_cls, AzureChat):
return

model = model_cls()
prompt = '这是测试,只回复你好'
sync_output = model.generate(prompt, **parameters)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_modifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class Country(BaseModel):

def test_session() -> None:
model = OpenAIChat().session()
model.generate('I am bob')
reply = model.generate('who am i?').reply
model.generate('call me BOB')
reply = model.generate('TEST: my name is ?').reply

assert 'bob' in reply.lower()

0 comments on commit 540e25b

Please sign in to comment.