Skip to content

Commit

Permalink
rename
Browse files Browse the repository at this point in the history
  • Loading branch information
Rubiksman78 committed Mar 1, 2023
1 parent 8ed985d commit aa2e474
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 20 deletions.
4 changes: 2 additions & 2 deletions chatbot/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
logger = logging.getLogger(__name__)

with open("chatbot/chatbot_config.yml", "r") as f:
PYG_CONFIG = yaml.safe_load(f)
CHAT_CONFIG = yaml.safe_load(f)

USE_INT_8 = PYG_CONFIG["use_int_8"]
USE_INT_8 = CHAT_CONFIG["use_int_8"]

def build_model_and_tokenizer_for(
model_name: str
Expand Down
36 changes: 18 additions & 18 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,28 +42,28 @@ def __exit__(self, exc_type, exc_val, exc_tb):
import gc

with open("chatbot/chatbot_config.yml", "r") as f:
PYG_CONFIG = yaml.safe_load(f)
CHAT_CONFIG = yaml.safe_load(f)

with open(f"char_json/{CHARACTER_JSON}", "r") as f:
char_settings = json.load(f)
f.close()

model_name = PYG_CONFIG["model_name"]
model_name = CHAT_CONFIG["model_name"]
gc.collect()
torch.cuda.empty_cache()
pyg_model, tokenizer = build_model_and_tokenizer_for(model_name)
chat_model, tokenizer = build_model_and_tokenizer_for(model_name)

generation_settings = {
"max_new_tokens": PYG_CONFIG["max_new_tokens"],
"temperature": PYG_CONFIG["temperature"],
"repetition_penalty": PYG_CONFIG["repetition_penalty"],
"top_p": PYG_CONFIG["top_p"],
"top_k": PYG_CONFIG["top_k"],
"do_sample": PYG_CONFIG["do_sample"],
"typical_p":PYG_CONFIG["typical_p"],
"max_new_tokens": CHAT_CONFIG["max_new_tokens"],
"temperature": CHAT_CONFIG["temperature"],
"repetition_penalty": CHAT_CONFIG["repetition_penalty"],
"top_p": CHAT_CONFIG["top_p"],
"top_k": CHAT_CONFIG["top_k"],
"do_sample": CHAT_CONFIG["do_sample"],
"typical_p":CHAT_CONFIG["typical_p"],
}

context_size = PYG_CONFIG["context_size"]
context_size = CHAT_CONFIG["context_size"]

with open("chat_history.txt", "a") as chat_history:
chat_history.write("Conversation started at: " + time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + "\n")
Expand Down Expand Up @@ -221,7 +221,7 @@ def listenToClient(client):
name = "User"
clients[client] = name
launched = False
pyg_count = 0
chat_count = 0
play_obj = None
if os.path.exists("char_history.txt"):
history = open("char_history.txt","r").read()
Expand Down Expand Up @@ -262,23 +262,23 @@ def listenToClient(client):
print("User: "+received_msg)

while True:
if pyg_count == 0:
if chat_count == 0:
sendMessage("server_ok".encode("utf-8"))
ok_ready = client.recv(BUFSIZE).decode("utf-8")
bot_message = inference_fn(pyg_model,tokenizer,history, "",generation_settings,char_settings,history_length=context_size,count=pyg_count)
bot_message = inference_fn(chat_model,tokenizer,history, "",generation_settings,char_settings,history_length=context_size,count=chat_count)
else:
bot_message = inference_fn(pyg_model,tokenizer,history, received_msg,generation_settings,char_settings,history_length=context_size,count=pyg_count)
bot_message = inference_fn(chat_model,tokenizer,history, received_msg,generation_settings,char_settings,history_length=context_size,count=chat_count)
history = history + "\n" + f"You: {received_msg}" + "\n" + f"{bot_message}"
if received_msg != "QUIT":
if received_msg == "REGEN":
history.replace("\n" + f"You: {received_msg}" + "\n" + f"{bot_message}","")
bot_message = inference_fn(pyg_model,tokenizer,history, received_msg,generation_settings,char_settings,history_length=context_size,count=pyg_count)
bot_message = inference_fn(chat_model,tokenizer,history, received_msg,generation_settings,char_settings,history_length=context_size,count=chat_count)
bot_message = bot_message.replace("<USER>","Player")
play_obj = play_TTS(step,bot_message,play_obj)
print("Sent: "+ bot_message)
send_answer(received_msg,bot_message)
pyg_count += 1
if pyg_count > 1:
chat_count += 1
if chat_count > 1:
with open("chat_history.txt", "a",encoding="utf-8") as f:
f.write(f"You: {received_msg}" + "\n" + f'{char_settings["char_name"]}: {bot_message}' + "\n")
break
Expand Down

0 comments on commit aa2e474

Please sign in to comment.