Maple_old/gpt2bot/discord_bot.py
2021-10-05 09:09:45 -05:00

307 lines
10 KiB
Python

# Licensed under the MIT license.
import configparser
import argparse
import logging
import random
import asyncio
import discord
from discord.ext import commands
import time
import os
import sys
import re
#import matplotlib as mpl
#mpl.use('Agg')
#import matplotlib.pyplot as plt
from model import download_model_folder, download_reverse_model_folder, load_model
from decoder import generate_response
from textblob import TextBlob
from googletrans import Translator
# Enable logging
logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.INFO)
logger = logging.getLogger(__name__)
client = commands.Bot(command_prefix="BOT_NAME")
#tenor_gifs = tenorpy.Tenor()
global translator
global num_samples
global max_turns_history
global model
global tokenizer
global mmi_model
global config
global mmi_tokenizer
global number_of_messages
global number_of_sent_messages
global number_of_servers
global start_time
global history_dict
import datetime
@client.event
async def on_ready():
global translator
global num_samples
global max_turns_history
global model
global tokenizer
global mmi_model
global mmi_tokenizer
global config
global number_of_messages
global number_of_sent_messages
global number_of_servers
global history_dict
if(number_of_messages is None):
number_of_messages = 0
number_of_sent_messages = 0
number_of_servers = str(len(client.guilds))
if(history_dict is None):
history_dict = {}
translator = Translator()
print('Logged in as '+client.user.name+' (ID:'+str(client.user.id)+') | '+str(len(client.guilds))+' servers | ' + getAllUsersCount())
await client.change_presence(activity=discord.Game(name='chat with me!'))
write_status_report()
#schedule.every().day.at("00:00").do(client.loop.call_soon_threadsafe, restart_script())
#client.loop.create_task(run_schedule())
#Called when a message is received
@client.listen()
async def on_message(message):
global number_of_messages
global number_of_sent_messages
global number_of_servers
if not (message.author == client.user): #Check to ensure the bot does not respond to its own messages
if(message.mention_everyone == False):
if(client.user.mentioned_in(message) or isinstance(message.channel, discord.abc.PrivateChannel)): #Check if the bot is mentioned or if the message is in DMs
async with message.channel.typing(): #Show that the bot is typing
number_of_messages += 1
number_of_servers = str(len(client.guilds))
#write_status_report()
translator = Translator()
txtinput = message.content.replace("<@" + str(client.user.id) + ">", "").replace("<@!" + str(client.user.id) + ">", "") #Filter out the mention so the bot does not get confused
if(len(txtinput) > 220): #Spam protection
txt = "I am sorry, that is too long for me."
dicestr = re.search("Roll (\d{1,2})d(\d{1,3})",message.content)
if(dicestr != None):
dice = [dicestr.group(1), dicestr.group(2)]
output = "I rolled "
for i in range(int(dice[0])):
output += str(random.randrange(1, int(dice[1]))) + ", "
txt = output
else:
blob = TextBlob(txtinput)
lang = translator.detect(txtinput).lang
#lang = "en"
if(lang != "en"):
txtinput = str(translator.translate(txtinput, dest="en", src=lang).text)
#_context.append(txtinput)
if(isinstance(message.channel, discord.abc.PrivateChannel)):
txt = get_response(txtinput, message.author.id, False) #Get a response!
else:
txt = get_response(txtinput, message.guild.id, False) #Get a response!
response_blob = TextBlob(txt)
number_of_sent_messages += 1
bot_message = await message.channel.send(txt) #Fire away!
#gifchance = 18
#gifresult = random.randrange(1, 100)
#if(gifresult <= gifchance):
# if(lang != "en"):
# gif_message = tenor_gifs.random(en_text)
# elif(lang == "en"):
# gif_message = tenor_gifs.random(txt)
# else:
# gif_message = tenor_gifs.random(txt)
# await message.channel.send(gif_message)
write_status_report()
def getAllUsersCount():
guilds = client.guilds
user_count = 0
for g in guilds:
user_count += len(g.members)
return("Current user count: " + str(user_count))
def write_status_report():
global number_of_messages
global number_of_sent_messages
global number_of_servers
global history_dict
with open("./status_report.txt", "w") as f:
f.write("```")
f.write("Status Report: " + datetime.datetime.now().strftime("%m/%d/%Y %H:%M:%S") + "\n")
f.write("Number of guilds: " + str(number_of_servers) + "\n")
f.write("Number of messages received since last reboot: " + str(number_of_messages) + "\n")
f.write("Number of messages sent since last reboot: " + str(number_of_sent_messages) + "\n")
f.write("Number of failed responses since last reboot: " + str(number_of_messages - number_of_sent_messages) + "\n")
f.write("Number of guilds in memory: " + str(len(history_dict)) + "\n```")
f.close()
def run_chat():
# Parse parameters
global translator
global num_samples
global max_turns_history
global model
global tokenizer
global mmi_model
global mmi_tokenizer
global config
global number_of_messages
global number_of_sent_messages
global number_of_servers
global history_dict
global token
num_samples = config.getint('decoder', 'num_samples')
max_turns_history = config.getint('decoder', 'max_turns_history')
logger.info("Running the chatbot...")
turns = []
loop = asyncio.get_event_loop()
task1 = loop.create_task(client.start(token))
gathered = asyncio.gather(task1, loop=loop)
loop.run_until_complete(gathered)
def get_prescripted_lines(filepath):
lines = []
with open(filepath, "r") as f:
for line in f:
lines.append(line)
return lines
global static_history
static_history = get_prescripted_lines("./constant_thoughts.txt")
def get_response(prompt, channel_id, do_infinite):
global translator
global num_samples
global max_turns_history
global model
global tokenizer
global mmi_model
global mmi_tokenizer
global config
global history_dict
if max_turns_history == 0:
# If you still get different responses then set seed
turns = []
# A single turn is a group of user messages and bot responses right after
turn = {
'user_messages': [],
'bot_messages': []
}
str_channel_id = str(channel_id)
#turns.append(turn)
turn['user_messages'].append(prompt)
if not channel_id in history_dict:
history_dict[channel_id] = []
history_dict[channel_id].append(turn)
# Merge turns into a single history (don't forget EOS token)
history = ""
from_index = max(len(history_dict[channel_id])-max_turns_history-1, 0) if max_turns_history >= 0 else 0
for message in static_history:
history += message + tokenizer.eos_token
for i in range(len(history_dict[channel_id])):
if(i >= from_index):
turn2 = history_dict[channel_id][i]
else:
continue
# Each turn begings with user messages
for message in turn2['user_messages']:
history += message + tokenizer.eos_token
for message in turn2['bot_messages']:
history += message + tokenizer.eos_token
# Generate bot messages
bot_messages = generate_response(
model,
tokenizer,
history,
config,
mmi_model=mmi_model,
mmi_tokenizer=mmi_tokenizer
)
if num_samples == 1:
bot_message = bot_messages[0]
else:
# TODO: Select a message that is the most appropriate given the context
# This way you can avoid loops
bot_message = random.choice(bot_messages)
turn['bot_messages'].append(bot_message)
#print(history_dict)
return bot_message
def main():
global translator
global num_samples
global max_turns_history
global model
global tokenizer
global mmi_model
global mmi_tokenizer
global config
global number_of_messages
global number_of_sent_messages
global number_of_servers
global history_dict
global token
token = "TOKEN_GOES_HERE" # Replace TOKEN_GOES_HERE with your discord API bot token!
history_dict = {}
# Script arguments can include path of the config
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument('--config', type=str, default="chatbot.cfg")
args = arg_parser.parse_args()
# Read the config
config = configparser.ConfigParser(allow_no_value=True)
with open(args.config) as f:
config.read_file(f)
# Download and load main model
target_folder_name = download_model_folder(config)
model, tokenizer = load_model(target_folder_name, config)
# Download and load reverse model
use_mmi = config.getboolean('model', 'use_mmi')
if use_mmi:
mmi_target_folder_name = download_reverse_model_folder(config)
mmi_model, mmi_tokenizer = load_model(mmi_target_folder_name, config)
else:
mmi_model = None
mmi_tokenizer = None
# Run chatbot with GPT-2
run_chat()
if __name__ == '__main__':
main()