-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathllm.py
124 lines (98 loc) · 2.61 KB
/
llm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os
import requests
import json_repair
import time
from dotenv import load_dotenv
load_dotenv(".env")
MISTRAL_API_CHAT_URL = "https://api.mistral.ai/v1/chat/completions"
MISTRAL_API_EMBED_URL = "https://api.mistral.ai/v1/embeddings"
def mistral_request(messages, model, **kwargs):
api_key = os.getenv("MISTRAL_API_KEY")
headers = {
"Accept": "application/json",
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
data = {
"model": model,
"messages": messages,
**kwargs
}
max_delay = 20
for tries in range(7):
response = requests.post(MISTRAL_API_CHAT_URL, json=data, headers=headers)
if response.ok:
break
elif response.status_code == 429:
wait_time = min(max_delay, 2 ** tries)
print(f"Waiting {wait_time} second(s)...")
time.sleep(wait_time)
else:
print(response.text)
response.raise_for_status()
else:
print(response.text)
response.raise_for_status()
return response.json()
def mistral_embed_texts(inputs):
api_key = os.getenv("MISTRAL_API_KEY")
headers = {
"Accept": "application/json",
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
data = {
"model": "mistral-embed",
"input": inputs
}
for tries in range(4):
response = requests.post(MISTRAL_API_EMBED_URL, json=data, headers=headers)
if response.ok:
break
elif response.status_code == 429:
wait_time = 2 ** tries
#print(f"Waiting {wait_time} second(s)...")
time.sleep(wait_time)
else:
print(response.text)
response.raise_for_status()
else:
print(response.text)
response.raise_for_status()
embed_res = response.json()
if isinstance(inputs, str):
return embed_res["data"][0]["embedding"]
return [obj["embedding"] for obj in embed_res["data"]]
def _convert_system_to_user(messages):
new_messages = []
for msg in messages:
role = msg["role"]
content = msg["content"]
if role == "system":
role = "user"
new_messages.append({"role":role, "content":content})
return new_messages
class MistralLLM:
def __init__(self, model="mistral-large-latest"):
self.model = model
def generate(
self,
prompt,
return_json=False,
**kwargs
):
if isinstance(prompt, str):
prompt = [{"role":"user", "content":prompt}]
if self.model != "mistral-large-latest":
prompt = _convert_system_to_user(prompt)
format = "json_object" if return_json else "text"
response = mistral_request(
prompt,
**kwargs,
model=self.model,
response_format={"type":format}
)
response = response["choices"][0]["message"]["content"]
if return_json:
return json_repair.loads(response)
return response