Maple_old/gpt2bot/interactive_bot.py
2021-10-05 09:09:45 -05:00

99 lines
3.4 KiB
Python

# Copyright (c) polakowo
# Licensed under the MIT license.
import configparser
import argparse
import logging
import random
from model import download_model_folder, download_reverse_model_folder, load_model
from decoder import generate_response
# Enable logging
logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.INFO)
logger = logging.getLogger(__name__)
def run_chat(model, tokenizer, config, mmi_model=None, mmi_tokenizer=None):
# Parse parameters
num_samples = config.getint('decoder', 'num_samples')
max_turns_history = config.getint('decoder', 'max_turns_history')
logger.info("Running the chatbot...")
turns = []
print("Bot >>>", "Just start texting me. If I'm getting annoying, type \"Bye\". To quit the chat type \"Quit\".")
while True:
prompt = input("User >>> ")
if max_turns_history == 0:
# If you still get different responses then set seed
turns = []
if prompt.lower() == 'bye':
print("Bot >>>", "Bye")
turns = []
continue
if prompt.lower() == 'quit':
break
# A single turn is a group of user messages and bot responses right after
turn = {
'user_messages': [],
'bot_messages': []
}
turns.append(turn)
turn['user_messages'].append(prompt)
# Merge turns into a single history (don't forget EOS token)
history = ""
from_index = max(len(turns)-max_turns_history-1, 0) if max_turns_history >= 0 else 0
for turn in turns[from_index:]:
# Each turn begings with user messages
for message in turn['user_messages']:
history += message + tokenizer.eos_token
for message in turn['bot_messages']:
history += message + tokenizer.eos_token
# Generate bot messages
bot_messages = generate_response(
model,
tokenizer,
history,
config,
mmi_model=mmi_model,
mmi_tokenizer=mmi_tokenizer
)
if num_samples == 1:
bot_message = bot_messages[0]
else:
# TODO: Select a message that is the most appropriate given the context
# This way you can avoid loops
bot_message = random.choice(bot_messages)
print("Bot >>>", bot_message)
turn['bot_messages'].append(bot_message)
def main():
# Script arguments can include path of the config
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument('--config', type=str, default="chatbot.cfg")
args = arg_parser.parse_args()
# Read the config
config = configparser.ConfigParser(allow_no_value=True)
with open(args.config) as f:
config.read_file(f)
# Download and load main model
target_folder_name = download_model_folder(config)
model, tokenizer = load_model(target_folder_name, config)
# Download and load reverse model
use_mmi = config.getboolean('model', 'use_mmi')
if use_mmi:
mmi_target_folder_name = download_reverse_model_folder(config)
mmi_model, mmi_tokenizer = load_model(mmi_target_folder_name, config)
else:
mmi_model = None
mmi_tokenizer = None
# Run chatbot with GPT-2
run_chat(model, tokenizer, config, mmi_model=mmi_model, mmi_tokenizer=mmi_tokenizer)
if __name__ == '__main__':
main()