Maple Backup
This commit is contained in:
commit
aa4c40d32c
|
@ -0,0 +1,2 @@
|
|||
models
|
||||
__pycache__
|
File diff suppressed because one or more lines are too long
|
@ -0,0 +1,14 @@
|
|||
FROM python:3.7-slim-stretch
|
||||
|
||||
RUN apt-get -y update && apt-get -y install gcc
|
||||
|
||||
COPY requirements.txt .
|
||||
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Clean up APT when done.
|
||||
RUN apt-get clean && rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/*
|
||||
|
||||
COPY gpt2bot .
|
||||
|
||||
CMD ["python", "telegram_bot.py"]
|
|
@ -0,0 +1,21 @@
|
|||
MIT License
|
||||
|
||||
Copyright (c) 2020 Oleg Polakow
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
|
@ -0,0 +1,120 @@
|
|||
# Maxwell - A DialoGPT variant for discord.py
|
||||
|
||||
Maxwell is my experiment with Microsoft's DialoGPT model and OpenAI's GPT-2 language model, for use on any discord server. I am planning on fine-tuning much more in the future but for now DialoGPT's model performs admirably.
|
||||
As of 3-2-2020 I am running this model on my old bot, Maple, which you can invite here:
|
||||
<a href="https://discordbotlist.com/bots/571924469661302814">
|
||||
<img
|
||||
width="380"
|
||||
height="140"
|
||||
src="https://discordbotlist.com/bots/571924469661302814/widget"
|
||||
alt="Lithium stats on Discord Bot List">
|
||||
</a>
|
||||
|
||||
## Setup
|
||||
|
||||
The recommended python version is 3.6.8.
|
||||
|
||||
### Requirements
|
||||
You can install all requirements with "pip install -r requirements.txt".
|
||||
|
||||
python 3.6.8
|
||||
numpy 1.16.4
|
||||
torch 1.2.0
|
||||
transformers 2.3.0
|
||||
python-telegram-bot 12.3.0 (Only if you are using the telegram bot)
|
||||
discord.py 1.2.5
|
||||
goolgetrans 2.4.0 (For automatic translation with non-english users)
|
||||
textblob 0.15.3 (Used in some text processing cases)
|
||||
matplotlib 2.0.2 (Used for modeling statistics)
|
||||
|
||||
In discord_bot.py, at line 117, replace "TOKEN_GOES_HERE" with your discord bot's API token.
|
||||
The model will automatically download and set up upon the first run of the program; you should be good to go!
|
||||
|
||||
# gpt2bot - Original Readme
|
||||
|
||||
```
|
||||
User >>> Can we achieve singularity?
|
||||
Bot >>> What does this mean?
|
||||
User >>> Can computers become smarter than humans?
|
||||
Bot >>> Is there any evidence that this is possible?
|
||||
User >>> It must be doable
|
||||
Bot >>> But... how?
|
||||
User >>> I want to find out, that's why I created you
|
||||
Bot >>> You created me?
|
||||
User >>> You're my bot
|
||||
Bot >>> You monster
|
||||
```
|
||||
|
||||
<img src="https://github.com/polakowo/gpt2bot/blob/master/reddit.png?raw=true" width=80>
|
||||
|
||||
gpt2bot implements
|
||||
- a decoder ([source](https://github.com/polakowo/gpt2bot/blob/master/gpt2bot/decoder.py)) for [DialoGPT](https://github.com/microsoft/DialoGPT),
|
||||
- an interactive multiturn chatbot ([source](https://github.com/polakowo/gpt2bot/blob/master/gpt2bot/interactive_bot.py)), and
|
||||
- a Telegram chatbot ([source](https://github.com/polakowo/gpt2bot/blob/master/gpt2bot/telegram_bot.py)).
|
||||
|
||||
The bot is built around [DialoGPT](https://github.com/microsoft/DialoGPT) - a large-scale pretrained dialogue response generation model trained by Microsoft, which was trained on 147M multi-turn dialogue from Reddit discussion thread. The human evaluation results indicate that its quility is comparable to human response quality under a single-turn conversation Turing test.
|
||||
|
||||
Since even with properly filtered Reddit dataset the model can generate toxic/inappropriate responses, the Microsoft team was unable to provide the decoding script. This repository implements the decoding script inspired by `run_generation.py` released earlier by Hugging Face. Moreover, it implements a Telegram bot that can be deployed locally, remotely, and even on Colab, and just makes testing fun.
|
||||
|
||||
## How to use?
|
||||
|
||||
### 1. Create a Telegram bot
|
||||
|
||||
- Register a new Telegram bot via BotFather (see https://core.telegram.org/bots)
|
||||
|
||||
### 2. Deploy the bot
|
||||
|
||||
#### Google Colab
|
||||
|
||||
[A Colab interactive notebook](https://colab.research.google.com/github/polakowo/gpt2bot/blob/master/Demo.ipynb)
|
||||
|
||||
A good thing about Google Colab is free GPU. So why not running the Telegram bot there, for blazingly fast chat? Run the notebook at daytime and do not forget to stop it at night.
|
||||
|
||||
#### Docker
|
||||
|
||||
- Clone the repository
|
||||
- Set your parameters such as API token in dialog.cfg
|
||||
- To avoid re-downloading model files at each re-deployment, download the model files beforehand with
|
||||
```
|
||||
# cd gpt2bot/gpt2bot
|
||||
python model.py
|
||||
```
|
||||
- Finally, deploy the container from the root folder
|
||||
```
|
||||
docker build -t gpt2bot . && docker run gpt2bot
|
||||
```
|
||||
|
||||
#### Manually
|
||||
|
||||
- Clone the repository
|
||||
- Set your parameters such as API token in dialog.cfg
|
||||
- Install packages listed in requirements.txt
|
||||
- Run the script
|
||||
```
|
||||
# cd gpt2bot/gpt2bot
|
||||
python telegram_bot.py
|
||||
```
|
||||
- To test the things out in the console, run
|
||||
```
|
||||
python interactive_bot.py
|
||||
```
|
||||
|
||||
### 3. Start chatting!
|
||||
|
||||
![](telegram_bot.gif)
|
||||
|
||||
Just start texting. Append @gif for the bot to generate a GIF instead of text. To reset, type "Bye".
|
||||
|
||||
## Updates
|
||||
|
||||
#### 18/01/2020
|
||||
|
||||
- EOS token is being checked during generation -> gpt2bot is now fast enough to be run on CPU.
|
||||
- Add support for maximum mutual information (MMI) -> more quality, but slower.
|
||||
|
||||
## References
|
||||
|
||||
- [Official DialoGPT implementation](https://github.com/microsoft/DialoGPT) and [DialoGPT paper](https://arxiv.org/abs/1911.00536)
|
||||
- [Thread on current decoding scripts](https://github.com/microsoft/DialoGPT/issues/3)
|
||||
|
||||
You can wait for a full DialoGPT release and then replace the decoder.
|
|
@ -0,0 +1 @@
|
|||
from GPT2Bot import model, decoder
|
|
@ -0,0 +1,69 @@
|
|||
[model]
|
||||
# Path to folder where the model files will be stored.
|
||||
# If path is relative, then the model.py must be called from the same directory.
|
||||
data_folder = models
|
||||
|
||||
# Size of the GPT-2 model. Could be one of 'small' (117M) or 'medium' (345M)
|
||||
# Select small for CPU or experimentation, and medium for GPU
|
||||
model_size = medium
|
||||
|
||||
# Dataset name the model was trained on. One of 'multiref' (147M multi-turn dialogue
|
||||
# from Reddit discussion thread) or 'dstc' (DSTC-7 grounded dialogue generation challenge).
|
||||
dataset = multiref
|
||||
|
||||
# True: load model trained from scratch or False: load model trained from fine-tuning the GPT-2.
|
||||
from_scratch = False
|
||||
|
||||
# Avoid using CUDA when available.
|
||||
no_cuda = False
|
||||
|
||||
# Further increases quality by selecting the response that yields lowest backward model loss.
|
||||
# Keep in mind: Uses inference on another medium model and further decreases bot's response time.
|
||||
# You should set num_samples > 1 for this to work
|
||||
use_mmi = False
|
||||
|
||||
[decoder]
|
||||
# Seed for random number generators, fix seed to reproduce results.
|
||||
# By default there is no seed and each turn should be unique.
|
||||
seed
|
||||
|
||||
# Float value controlling randomness in boltzmann
|
||||
# distribution. Lower temperature results in less random completions. As the
|
||||
# temperature approaches zero, the model will become deterministic and
|
||||
# repetitive. Higher temperature results in more random completions.
|
||||
temperature = 0.6474
|
||||
|
||||
# Integer value controlling diversity. 1 means only 1 word is
|
||||
# considered for each step (token), resulting in deterministic completions,
|
||||
# while 40 means 40 words are considered at each step. 0 (default) is a
|
||||
# special setting meaning no restrictions. 40 generally is a good value.
|
||||
top_k = 40
|
||||
|
||||
# Like top_k, top_p is a constraint on the craziness of the output
|
||||
top_p = 0
|
||||
|
||||
# The maximal number of tokens to be returned, inclusive of punctuations etc.
|
||||
# It will automatically stop if the end-of-sequence token was found earlier.
|
||||
# Usually, only in rare cases generation will go beyond 64 tokens.
|
||||
max_length = 128
|
||||
|
||||
# Number of samples to generate.
|
||||
# You will have to implement a strategy to choose one message from generated list.
|
||||
# For example, you can choose the most dissimilar message, or the lengthiest one.
|
||||
# But keep in mind: the higher, the slower the generation.
|
||||
num_samples = 1
|
||||
|
||||
# The number of turns (turn = answer and response) the model should consider.
|
||||
# Set to 0 to focus on the last message. Set to -1 for unlimited context length.
|
||||
max_turns_history = 1
|
||||
|
||||
[chatbot]
|
||||
|
||||
# Your Telegram token. See https://core.telegram.org/bots -- Functionality is disabled
|
||||
telegram_token = YOUR_TOKEN_HERE
|
||||
|
||||
# Your GIPHY API token. See
|
||||
giphy_token = YOUR_TOKEN_HERE
|
||||
|
||||
# Value from 0-10 which makes results weirder as you go up the scale.
|
||||
giphy_weirdness = 5
|
|
@ -0,0 +1 @@
|
|||
My name is Maxwell.
|
|
@ -0,0 +1,122 @@
|
|||
# Copyright (c) polakowo
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import random
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
def set_seed(seed):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
|
||||
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
||||
Args:
|
||||
logits: logits distribution shape (batch size x vocabulary size)
|
||||
top_k > 0: keep only top k tokens with highest probability (top-k filtering).
|
||||
top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
|
||||
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
||||
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
||||
"""
|
||||
top_k = min(top_k, logits.size(-1)) # Safety check
|
||||
if top_k > 0:
|
||||
# Remove all tokens with a probability less than the last token of the top-k
|
||||
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
||||
logits[indices_to_remove] = filter_value
|
||||
if top_p > 0.0:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
# Remove tokens with cumulative probability above the threshold
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
# Shift the indices to the right to keep also the first token above the threshold
|
||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
# scatter sorted tensors to original indexing
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
|
||||
logits[indices_to_remove] = filter_value
|
||||
return logits
|
||||
|
||||
|
||||
def sample_sequence(model, tokenizer, context_ids, config):
|
||||
# Parse parameters
|
||||
no_cuda = config.getboolean('model', 'no_cuda')
|
||||
num_samples = config.getint('decoder', 'num_samples')
|
||||
max_length = config.getint('decoder', 'max_length')
|
||||
temperature = config.getfloat('decoder', 'temperature')
|
||||
top_k = config.getint('decoder', 'top_k')
|
||||
top_p = config.getfloat('decoder', 'top_p')
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() and not no_cuda else "cpu")
|
||||
context_tensor = torch.tensor(context_ids, dtype=torch.long, device=device)
|
||||
context_tensor = context_tensor.unsqueeze(0).repeat(num_samples, 1)
|
||||
generated = context_tensor
|
||||
with torch.no_grad():
|
||||
while True:
|
||||
inputs = {'input_ids': generated}
|
||||
outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet/CTRL (cached hidden-states)
|
||||
next_token_logits = outputs[0][:, -1, :] / (temperature if temperature > 0 else 1.)
|
||||
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
|
||||
if temperature == 0.0: # greedy sampling:
|
||||
next_token = torch.argmax(filtered_logits, dim=-1).unsqueeze(-1)
|
||||
else:
|
||||
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
|
||||
generated = torch.cat((generated, next_token), dim=1)
|
||||
if (generated[:, len(context_ids):] == tokenizer.eos_token_id).any(dim=1).all():
|
||||
# EOS token id found in each sample
|
||||
break
|
||||
if generated.shape[1] - len(context_ids) >= max_length:
|
||||
# Maximum length reached
|
||||
break
|
||||
return generated
|
||||
|
||||
def select_using_mmi(mmi_model, mmi_tokenizer, candidates, config):
|
||||
# Parse parameters
|
||||
no_cuda = config.getboolean('model', 'no_cuda')
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() and not no_cuda else "cpu")
|
||||
scores = []
|
||||
for i, candidate in enumerate(candidates):
|
||||
context = []
|
||||
for response in reversed(candidate):
|
||||
context.extend(response)
|
||||
context.append(mmi_tokenizer.eos_token_id)
|
||||
context_ids = mmi_tokenizer.encode(context)
|
||||
context_tensor = torch.tensor(context_ids, dtype=torch.long, device=device)
|
||||
loss, _, _ = mmi_model(input_ids=context_tensor, labels=context_tensor)
|
||||
scores.append(-loss.float())
|
||||
|
||||
scores = torch.stack(scores, dim=0)
|
||||
# The smaller the loss, the higher must be the probability of selecting it
|
||||
winner = torch.multinomial(F.softmax(scores, dim=0), num_samples=1).item()
|
||||
return winner
|
||||
|
||||
def generate_response(model, tokenizer, context, config, mmi_model=None, mmi_tokenizer=None):
|
||||
# Parse parameters
|
||||
use_mmi = config.getboolean('model', 'use_mmi')
|
||||
num_samples = config.getint('decoder', 'num_samples')
|
||||
max_length = config.getint('decoder', 'max_length')
|
||||
seed = config.get('decoder', 'seed')
|
||||
seed = int(seed) if seed is not None else None
|
||||
|
||||
# Make answers reproducible only if wanted
|
||||
if seed is not None:
|
||||
set_seed(seed)
|
||||
|
||||
# Generate response
|
||||
context_ids = tokenizer.encode(context)
|
||||
samples = sample_sequence(model, tokenizer, context_ids, config)
|
||||
samples = samples[:, len(context_ids):].tolist()
|
||||
texts = []
|
||||
for sample in samples:
|
||||
text = tokenizer.decode(sample, clean_up_tokenization_spaces=True)
|
||||
text = text[: text.find(tokenizer.eos_token)]
|
||||
texts.append(text)
|
||||
|
||||
if use_mmi:
|
||||
assert(num_samples > 1, "MMI requires num_samples > 1")
|
||||
candidates = [context + text for text in texts]
|
||||
best_i = select_using_mmi(mmi_model, mmi_tokenizer, candidates, config)
|
||||
return [texts[best_i]]
|
||||
return texts
|
|
@ -0,0 +1,306 @@
|
|||
# Licensed under the MIT license.
|
||||
|
||||
import configparser
|
||||
import argparse
|
||||
import logging
|
||||
import random
|
||||
import asyncio
|
||||
|
||||
import discord
|
||||
from discord.ext import commands
|
||||
import time
|
||||
import os
|
||||
import sys
|
||||
import re
|
||||
#import matplotlib as mpl
|
||||
#mpl.use('Agg')
|
||||
#import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
from model import download_model_folder, download_reverse_model_folder, load_model
|
||||
from decoder import generate_response
|
||||
|
||||
|
||||
from textblob import TextBlob
|
||||
from googletrans import Translator
|
||||
|
||||
|
||||
|
||||
|
||||
# Enable logging
|
||||
logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
client = commands.Bot(command_prefix="BOT_NAME")
|
||||
|
||||
|
||||
#tenor_gifs = tenorpy.Tenor()
|
||||
|
||||
global translator
|
||||
|
||||
global num_samples
|
||||
global max_turns_history
|
||||
global model
|
||||
global tokenizer
|
||||
global mmi_model
|
||||
global config
|
||||
global mmi_tokenizer
|
||||
global number_of_messages
|
||||
global number_of_sent_messages
|
||||
global number_of_servers
|
||||
global start_time
|
||||
global history_dict
|
||||
import datetime
|
||||
|
||||
@client.event
|
||||
async def on_ready():
|
||||
global translator
|
||||
|
||||
global num_samples
|
||||
global max_turns_history
|
||||
global model
|
||||
global tokenizer
|
||||
global mmi_model
|
||||
global mmi_tokenizer
|
||||
global config
|
||||
global number_of_messages
|
||||
global number_of_sent_messages
|
||||
global number_of_servers
|
||||
global history_dict
|
||||
if(number_of_messages is None):
|
||||
number_of_messages = 0
|
||||
number_of_sent_messages = 0
|
||||
number_of_servers = str(len(client.guilds))
|
||||
|
||||
if(history_dict is None):
|
||||
history_dict = {}
|
||||
translator = Translator()
|
||||
print('Logged in as '+client.user.name+' (ID:'+str(client.user.id)+') | '+str(len(client.guilds))+' servers | ' + getAllUsersCount())
|
||||
await client.change_presence(activity=discord.Game(name='chat with me!'))
|
||||
write_status_report()
|
||||
#schedule.every().day.at("00:00").do(client.loop.call_soon_threadsafe, restart_script())
|
||||
#client.loop.create_task(run_schedule())
|
||||
|
||||
|
||||
#Called when a message is received
|
||||
@client.listen()
|
||||
async def on_message(message):
|
||||
global number_of_messages
|
||||
global number_of_sent_messages
|
||||
global number_of_servers
|
||||
if not (message.author == client.user): #Check to ensure the bot does not respond to its own messages
|
||||
if(message.mention_everyone == False):
|
||||
if(client.user.mentioned_in(message) or isinstance(message.channel, discord.abc.PrivateChannel)): #Check if the bot is mentioned or if the message is in DMs
|
||||
async with message.channel.typing(): #Show that the bot is typing
|
||||
number_of_messages += 1
|
||||
number_of_servers = str(len(client.guilds))
|
||||
#write_status_report()
|
||||
translator = Translator()
|
||||
txtinput = message.content.replace("<@" + str(client.user.id) + ">", "").replace("<@!" + str(client.user.id) + ">", "") #Filter out the mention so the bot does not get confused
|
||||
if(len(txtinput) > 220): #Spam protection
|
||||
txt = "I am sorry, that is too long for me."
|
||||
dicestr = re.search("Roll (\d{1,2})d(\d{1,3})",message.content)
|
||||
if(dicestr != None):
|
||||
dice = [dicestr.group(1), dicestr.group(2)]
|
||||
output = "I rolled "
|
||||
for i in range(int(dice[0])):
|
||||
output += str(random.randrange(1, int(dice[1]))) + ", "
|
||||
txt = output
|
||||
else:
|
||||
blob = TextBlob(txtinput)
|
||||
lang = translator.detect(txtinput).lang
|
||||
#lang = "en"
|
||||
if(lang != "en"):
|
||||
txtinput = str(translator.translate(txtinput, dest="en", src=lang).text)
|
||||
#_context.append(txtinput)
|
||||
if(isinstance(message.channel, discord.abc.PrivateChannel)):
|
||||
txt = get_response(txtinput, message.author.id, False) #Get a response!
|
||||
else:
|
||||
txt = get_response(txtinput, message.guild.id, False) #Get a response!
|
||||
response_blob = TextBlob(txt)
|
||||
number_of_sent_messages += 1
|
||||
bot_message = await message.channel.send(txt) #Fire away!
|
||||
#gifchance = 18
|
||||
#gifresult = random.randrange(1, 100)
|
||||
#if(gifresult <= gifchance):
|
||||
# if(lang != "en"):
|
||||
# gif_message = tenor_gifs.random(en_text)
|
||||
# elif(lang == "en"):
|
||||
# gif_message = tenor_gifs.random(txt)
|
||||
# else:
|
||||
# gif_message = tenor_gifs.random(txt)
|
||||
# await message.channel.send(gif_message)
|
||||
write_status_report()
|
||||
|
||||
|
||||
def getAllUsersCount():
|
||||
guilds = client.guilds
|
||||
user_count = 0
|
||||
for g in guilds:
|
||||
user_count += len(g.members)
|
||||
return("Current user count: " + str(user_count))
|
||||
|
||||
|
||||
def write_status_report():
|
||||
global number_of_messages
|
||||
global number_of_sent_messages
|
||||
global number_of_servers
|
||||
global history_dict
|
||||
with open("./status_report.txt", "w") as f:
|
||||
f.write("```")
|
||||
f.write("Status Report: " + datetime.datetime.now().strftime("%m/%d/%Y %H:%M:%S") + "\n")
|
||||
f.write("Number of guilds: " + str(number_of_servers) + "\n")
|
||||
f.write("Number of messages received since last reboot: " + str(number_of_messages) + "\n")
|
||||
f.write("Number of messages sent since last reboot: " + str(number_of_sent_messages) + "\n")
|
||||
f.write("Number of failed responses since last reboot: " + str(number_of_messages - number_of_sent_messages) + "\n")
|
||||
f.write("Number of guilds in memory: " + str(len(history_dict)) + "\n```")
|
||||
f.close()
|
||||
|
||||
def run_chat():
|
||||
# Parse parameters
|
||||
global translator
|
||||
|
||||
global num_samples
|
||||
global max_turns_history
|
||||
global model
|
||||
global tokenizer
|
||||
global mmi_model
|
||||
global mmi_tokenizer
|
||||
global config
|
||||
global number_of_messages
|
||||
global number_of_sent_messages
|
||||
global number_of_servers
|
||||
global history_dict
|
||||
global token
|
||||
|
||||
num_samples = config.getint('decoder', 'num_samples')
|
||||
max_turns_history = config.getint('decoder', 'max_turns_history')
|
||||
|
||||
logger.info("Running the chatbot...")
|
||||
turns = []
|
||||
loop = asyncio.get_event_loop()
|
||||
task1 = loop.create_task(client.start(token))
|
||||
gathered = asyncio.gather(task1, loop=loop)
|
||||
loop.run_until_complete(gathered)
|
||||
|
||||
|
||||
|
||||
def get_prescripted_lines(filepath):
|
||||
lines = []
|
||||
with open(filepath, "r") as f:
|
||||
for line in f:
|
||||
lines.append(line)
|
||||
return lines
|
||||
global static_history
|
||||
static_history = get_prescripted_lines("./constant_thoughts.txt")
|
||||
def get_response(prompt, channel_id, do_infinite):
|
||||
global translator
|
||||
|
||||
global num_samples
|
||||
global max_turns_history
|
||||
global model
|
||||
global tokenizer
|
||||
global mmi_model
|
||||
global mmi_tokenizer
|
||||
global config
|
||||
global history_dict
|
||||
if max_turns_history == 0:
|
||||
# If you still get different responses then set seed
|
||||
turns = []
|
||||
|
||||
# A single turn is a group of user messages and bot responses right after
|
||||
turn = {
|
||||
'user_messages': [],
|
||||
'bot_messages': []
|
||||
}
|
||||
str_channel_id = str(channel_id)
|
||||
#turns.append(turn)
|
||||
turn['user_messages'].append(prompt)
|
||||
if not channel_id in history_dict:
|
||||
history_dict[channel_id] = []
|
||||
|
||||
|
||||
history_dict[channel_id].append(turn)
|
||||
# Merge turns into a single history (don't forget EOS token)
|
||||
history = ""
|
||||
from_index = max(len(history_dict[channel_id])-max_turns_history-1, 0) if max_turns_history >= 0 else 0
|
||||
for message in static_history:
|
||||
history += message + tokenizer.eos_token
|
||||
for i in range(len(history_dict[channel_id])):
|
||||
if(i >= from_index):
|
||||
turn2 = history_dict[channel_id][i]
|
||||
else:
|
||||
continue
|
||||
# Each turn begings with user messages
|
||||
for message in turn2['user_messages']:
|
||||
history += message + tokenizer.eos_token
|
||||
for message in turn2['bot_messages']:
|
||||
history += message + tokenizer.eos_token
|
||||
|
||||
# Generate bot messages
|
||||
bot_messages = generate_response(
|
||||
model,
|
||||
tokenizer,
|
||||
history,
|
||||
config,
|
||||
mmi_model=mmi_model,
|
||||
mmi_tokenizer=mmi_tokenizer
|
||||
)
|
||||
if num_samples == 1:
|
||||
bot_message = bot_messages[0]
|
||||
else:
|
||||
# TODO: Select a message that is the most appropriate given the context
|
||||
# This way you can avoid loops
|
||||
bot_message = random.choice(bot_messages)
|
||||
turn['bot_messages'].append(bot_message)
|
||||
#print(history_dict)
|
||||
return bot_message
|
||||
|
||||
def main():
|
||||
global translator
|
||||
|
||||
global num_samples
|
||||
global max_turns_history
|
||||
global model
|
||||
global tokenizer
|
||||
global mmi_model
|
||||
global mmi_tokenizer
|
||||
global config
|
||||
global number_of_messages
|
||||
global number_of_sent_messages
|
||||
global number_of_servers
|
||||
global history_dict
|
||||
global token
|
||||
|
||||
token = "TOKEN_GOES_HERE" # Replace TOKEN_GOES_HERE with your discord API bot token!
|
||||
history_dict = {}
|
||||
# Script arguments can include path of the config
|
||||
arg_parser = argparse.ArgumentParser()
|
||||
arg_parser.add_argument('--config', type=str, default="chatbot.cfg")
|
||||
args = arg_parser.parse_args()
|
||||
|
||||
# Read the config
|
||||
config = configparser.ConfigParser(allow_no_value=True)
|
||||
with open(args.config) as f:
|
||||
config.read_file(f)
|
||||
|
||||
# Download and load main model
|
||||
target_folder_name = download_model_folder(config)
|
||||
model, tokenizer = load_model(target_folder_name, config)
|
||||
|
||||
# Download and load reverse model
|
||||
use_mmi = config.getboolean('model', 'use_mmi')
|
||||
if use_mmi:
|
||||
mmi_target_folder_name = download_reverse_model_folder(config)
|
||||
mmi_model, mmi_tokenizer = load_model(mmi_target_folder_name, config)
|
||||
else:
|
||||
mmi_model = None
|
||||
mmi_tokenizer = None
|
||||
|
||||
# Run chatbot with GPT-2
|
||||
run_chat()
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
|
|
@ -0,0 +1,98 @@
|
|||
# Copyright (c) polakowo
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import configparser
|
||||
import argparse
|
||||
import logging
|
||||
import random
|
||||
|
||||
from model import download_model_folder, download_reverse_model_folder, load_model
|
||||
from decoder import generate_response
|
||||
|
||||
# Enable logging
|
||||
logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def run_chat(model, tokenizer, config, mmi_model=None, mmi_tokenizer=None):
|
||||
# Parse parameters
|
||||
num_samples = config.getint('decoder', 'num_samples')
|
||||
max_turns_history = config.getint('decoder', 'max_turns_history')
|
||||
|
||||
logger.info("Running the chatbot...")
|
||||
turns = []
|
||||
print("Bot >>>", "Just start texting me. If I'm getting annoying, type \"Bye\". To quit the chat type \"Quit\".")
|
||||
while True:
|
||||
prompt = input("User >>> ")
|
||||
if max_turns_history == 0:
|
||||
# If you still get different responses then set seed
|
||||
turns = []
|
||||
if prompt.lower() == 'bye':
|
||||
print("Bot >>>", "Bye")
|
||||
turns = []
|
||||
continue
|
||||
if prompt.lower() == 'quit':
|
||||
break
|
||||
# A single turn is a group of user messages and bot responses right after
|
||||
turn = {
|
||||
'user_messages': [],
|
||||
'bot_messages': []
|
||||
}
|
||||
turns.append(turn)
|
||||
turn['user_messages'].append(prompt)
|
||||
# Merge turns into a single history (don't forget EOS token)
|
||||
history = ""
|
||||
from_index = max(len(turns)-max_turns_history-1, 0) if max_turns_history >= 0 else 0
|
||||
for turn in turns[from_index:]:
|
||||
# Each turn begings with user messages
|
||||
for message in turn['user_messages']:
|
||||
history += message + tokenizer.eos_token
|
||||
for message in turn['bot_messages']:
|
||||
history += message + tokenizer.eos_token
|
||||
|
||||
# Generate bot messages
|
||||
bot_messages = generate_response(
|
||||
model,
|
||||
tokenizer,
|
||||
history,
|
||||
config,
|
||||
mmi_model=mmi_model,
|
||||
mmi_tokenizer=mmi_tokenizer
|
||||
)
|
||||
if num_samples == 1:
|
||||
bot_message = bot_messages[0]
|
||||
else:
|
||||
# TODO: Select a message that is the most appropriate given the context
|
||||
# This way you can avoid loops
|
||||
bot_message = random.choice(bot_messages)
|
||||
print("Bot >>>", bot_message)
|
||||
turn['bot_messages'].append(bot_message)
|
||||
|
||||
def main():
|
||||
# Script arguments can include path of the config
|
||||
arg_parser = argparse.ArgumentParser()
|
||||
arg_parser.add_argument('--config', type=str, default="chatbot.cfg")
|
||||
args = arg_parser.parse_args()
|
||||
|
||||
# Read the config
|
||||
config = configparser.ConfigParser(allow_no_value=True)
|
||||
with open(args.config) as f:
|
||||
config.read_file(f)
|
||||
|
||||
# Download and load main model
|
||||
target_folder_name = download_model_folder(config)
|
||||
model, tokenizer = load_model(target_folder_name, config)
|
||||
|
||||
# Download and load reverse model
|
||||
use_mmi = config.getboolean('model', 'use_mmi')
|
||||
if use_mmi:
|
||||
mmi_target_folder_name = download_reverse_model_folder(config)
|
||||
mmi_model, mmi_tokenizer = load_model(mmi_target_folder_name, config)
|
||||
else:
|
||||
mmi_model = None
|
||||
mmi_tokenizer = None
|
||||
|
||||
# Run chatbot with GPT-2
|
||||
run_chat(model, tokenizer, config, mmi_model=mmi_model, mmi_tokenizer=mmi_tokenizer)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,172 @@
|
|||
import os
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
from glob import glob
|
||||
import torch
|
||||
import configparser
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
# !pip install transformers==2.3.0
|
||||
from transformers import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer
|
||||
# If you get tensorflow deprecation warnings, run
|
||||
# pip uninstall numpy
|
||||
# pip install numpy==1.16.4
|
||||
|
||||
# Enable logging
|
||||
logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Model configuration files
|
||||
CONFIG_FILE = {
|
||||
'small': 'https://convaisharables.blob.core.windows.net/lsp/117M/config.json',
|
||||
'medium': 'https://convaisharables.blob.core.windows.net/lsp/345M/config.json'
|
||||
}
|
||||
VOCAB_FILE = {
|
||||
'small': 'https://convaisharables.blob.core.windows.net/lsp/117M/vocab.json',
|
||||
'medium': 'https://convaisharables.blob.core.windows.net/lsp/345M/vocab.json'
|
||||
}
|
||||
MERGE_FILE = {
|
||||
'small': 'https://convaisharables.blob.core.windows.net/lsp/117M/merges.txt',
|
||||
'medium': 'https://convaisharables.blob.core.windows.net/lsp/345M/merges.txt'
|
||||
}
|
||||
|
||||
# Model files
|
||||
# Note that the model size is roughly half of the GPT model because our model is saved by fp16
|
||||
LSP_MODEL_URL = {
|
||||
'multiref': {
|
||||
'medium_fs': 'https://convaisharables.blob.core.windows.net/lsp/multiref/medium_fs.pkl',
|
||||
'medium_ft': 'https://convaisharables.blob.core.windows.net/lsp/multiref/medium_ft.pkl',
|
||||
'small_fs': 'https://convaisharables.blob.core.windows.net/lsp/multiref/small_fs.pkl',
|
||||
'small_ft': 'https://convaisharables.blob.core.windows.net/lsp/multiref/small_ft.pkl'
|
||||
},
|
||||
'dstc': { # medium_ft.pkl is actually a small model
|
||||
'small_ft': 'https://convaisharables.blob.core.windows.net/lsp/DSTC/medium_ft.pkl'
|
||||
}
|
||||
}
|
||||
|
||||
# The reverse model is predicting the source from the target. This model is used for MMI reranking.
|
||||
# small_reverse.pkl is actually a medium model
|
||||
REVERSE_MODEL_URL = 'https://convaisharables.blob.core.windows.net/lsp/multiref/small_reverse.pkl'
|
||||
|
||||
def http_get(url, temp_file):
|
||||
req = requests.get(url, stream=True)
|
||||
content_length = req.headers.get('Content-Length')
|
||||
total = int(content_length) if content_length is not None else None
|
||||
progress = tqdm(unit="B", total=total)
|
||||
for chunk in req.iter_content(chunk_size=1024):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
progress.update(len(chunk))
|
||||
temp_file.write(chunk)
|
||||
progress.close()
|
||||
|
||||
|
||||
def download_file(url, folder):
|
||||
if not os.path.exists(folder):
|
||||
os.makedirs(folder, exist_ok=True)
|
||||
|
||||
file_name = os.path.basename(url)
|
||||
if 'pytorch_model.bin' in file_name:
|
||||
file_name = 'pytorch_model.bin'
|
||||
|
||||
if os.path.isfile(os.path.join(folder, file_name)):
|
||||
return
|
||||
|
||||
with open(os.path.join(folder, file_name), 'wb') as f:
|
||||
http_get(url, f)
|
||||
|
||||
|
||||
def download_model_folder(config):
|
||||
# Parse parameters
|
||||
data_folder = config.get('model', 'data_folder')
|
||||
model_size = config.get('model', 'model_size')
|
||||
dataset = config.get('model', 'dataset')
|
||||
from_scratch = config.getboolean('model', 'from_scratch')
|
||||
|
||||
# Create data folder if needed
|
||||
if not os.path.exists(data_folder):
|
||||
os.makedirs(data_folder, exist_ok=True)
|
||||
# Build target folder name (must be unique across all parameter combinations)
|
||||
target_folder_name = model_size + "_" + dataset + ("_fs" if from_scratch else "_ft")
|
||||
target_folder = os.path.join(data_folder, target_folder_name)
|
||||
# Download files
|
||||
logger.info(f"Downloading model files to {target_folder_name}...")
|
||||
download_file(CONFIG_FILE[model_size], target_folder)
|
||||
download_file(VOCAB_FILE[model_size], target_folder)
|
||||
download_file(MERGE_FILE[model_size], target_folder)
|
||||
model_train_type = model_size + ('_fs' if from_scratch else '_ft')
|
||||
if model_train_type not in LSP_MODEL_URL[dataset]:
|
||||
k = ','.join(list(LSP_MODEL_URL[dataset].keys()))
|
||||
raise ValueError(f"'{model_train_type}' not exist for dataset '{dataset}', please choose from [{k}]")
|
||||
download_file(LSP_MODEL_URL[dataset][model_train_type], target_folder)
|
||||
return target_folder_name
|
||||
|
||||
def download_reverse_model_folder(config):
|
||||
# Parse parameters
|
||||
data_folder = config.get('model', 'data_folder')
|
||||
# Only one size is currently supported
|
||||
model_size = 'medium'
|
||||
|
||||
# Create data folder if needed
|
||||
if not os.path.exists(data_folder):
|
||||
os.makedirs(data_folder, exist_ok=True)
|
||||
# Build target folder name (must be unique across all parameter combinations)
|
||||
target_folder_name = model_size + '_reverse'
|
||||
target_folder = os.path.join(data_folder, target_folder_name)
|
||||
# Download files
|
||||
logger.info(f"Downloading model files to {target_folder_name}...")
|
||||
download_file(CONFIG_FILE[model_size], target_folder)
|
||||
download_file(VOCAB_FILE[model_size], target_folder)
|
||||
download_file(MERGE_FILE[model_size], target_folder)
|
||||
download_file(REVERSE_MODEL_URL, target_folder)
|
||||
return target_folder_name
|
||||
|
||||
def load_model(target_folder_name, config):
|
||||
# Parse parameters
|
||||
data_folder = config.get('model', 'data_folder')
|
||||
model_size = config.get('model', 'model_size')
|
||||
no_cuda = config.getboolean('model', 'no_cuda')
|
||||
|
||||
logger.info(f"Loading model from {target_folder_name}...")
|
||||
device = torch.device("cuda" if torch.cuda.is_available() and not no_cuda else "cpu")
|
||||
# Tokenizer
|
||||
target_folder = os.path.join(data_folder, target_folder_name)
|
||||
tokenizer = GPT2Tokenizer(os.path.join(target_folder, 'vocab.json'), os.path.join(target_folder, 'merges.txt'))
|
||||
# Config
|
||||
config = GPT2Config.from_json_file(os.path.join(target_folder, 'config.json'))
|
||||
# Weights
|
||||
state_dict_path = glob(os.path.join(target_folder, f'*.pkl'))[0]
|
||||
state_dict = torch.load(state_dict_path, map_location=device)
|
||||
if model_size == 'small':
|
||||
for key in list(state_dict.keys()):
|
||||
state_dict[key.replace('module.', '')] = state_dict.pop(key)
|
||||
state_dict['lm_head.weight'] = state_dict['lm_head.decoder.weight']
|
||||
state_dict.pop("lm_head.decoder.weight", None)
|
||||
# Model
|
||||
model = GPT2LMHeadModel(config)
|
||||
model.load_state_dict(state_dict)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
return model, tokenizer
|
||||
|
||||
def main():
|
||||
# Script arguments can include path of the config
|
||||
arg_parser = argparse.ArgumentParser()
|
||||
arg_parser.add_argument('--config', type=str, default="chatbot.cfg")
|
||||
args = arg_parser.parse_args()
|
||||
|
||||
# Read the config
|
||||
config = configparser.ConfigParser(allow_no_value=True)
|
||||
with open(args.config) as f:
|
||||
config.read_file(f)
|
||||
|
||||
# Download main model
|
||||
download_model_folder(config)
|
||||
# Download reverse model
|
||||
use_mmi = config.getboolean('model', 'use_mmi')
|
||||
if use_mmi:
|
||||
download_reverse_model_folder(config)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,220 @@
|
|||
# Copyright (c) polakowo
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# !pip install python-telegram-bot --upgrade
|
||||
from telegram.ext import Updater, CommandHandler, MessageHandler, Filters
|
||||
from telegram import ChatAction, ParseMode
|
||||
from functools import wraps
|
||||
import configparser
|
||||
import argparse
|
||||
import logging
|
||||
import requests
|
||||
from urllib.parse import urlencode
|
||||
import requests
|
||||
from requests.adapters import HTTPAdapter
|
||||
from requests.packages.urllib3.util.retry import Retry
|
||||
import random
|
||||
import re
|
||||
|
||||
from model import download_model_folder, download_reverse_model_folder, load_model
|
||||
from decoder import generate_response
|
||||
|
||||
# Enable logging
|
||||
logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# https://github.com/python-telegram-bot/python-telegram-bot/wiki/Code-snippets
|
||||
|
||||
def start_command(update, context):
|
||||
context.chat_data['turns'] = []
|
||||
update.message.reply_text("Just start texting me. Append \"@gif\" for me to generate a GIF. If I'm getting annoying, type \"Bye\"")
|
||||
|
||||
def requests_retry_session(
|
||||
retries=3,
|
||||
backoff_factor=0.3,
|
||||
status_forcelist=(500, 502, 504),
|
||||
session=None,
|
||||
):
|
||||
session = session or requests.Session()
|
||||
retry = Retry(
|
||||
total=retries,
|
||||
read=retries,
|
||||
connect=retries,
|
||||
backoff_factor=backoff_factor,
|
||||
status_forcelist=status_forcelist,
|
||||
)
|
||||
adapter = HTTPAdapter(max_retries=retry)
|
||||
session.mount('http://', adapter)
|
||||
session.mount('https://', adapter)
|
||||
return session
|
||||
|
||||
def translate_message_to_gif(message, config):
|
||||
# https://engineering.giphy.com/contextually-aware-search-giphy-gets-work-specific/
|
||||
params = {
|
||||
'api_key': config.get('chatbot', 'giphy_token'),
|
||||
's': message,
|
||||
'weirdness': config.getint('chatbot', 'giphy_weirdness')
|
||||
}
|
||||
url = "http://api.giphy.com/v1/gifs/translate?" + urlencode(params)
|
||||
response = requests_retry_session().get(url)
|
||||
return response.json()['data']['images']['fixed_height']['url']
|
||||
|
||||
def self_decorator(self, func):
|
||||
"""Passes bot object to func command."""
|
||||
# TODO: Any other ways to pass variables to handlers?
|
||||
def command_func(update, context, *args, **kwargs):
|
||||
return func(self, update, context, *args, **kwargs)
|
||||
return command_func
|
||||
|
||||
def send_action(action):
|
||||
"""Sends `action` while processing func command."""
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def command_func(self, update, context, *args, **kwargs):
|
||||
context.bot.send_chat_action(chat_id=update.effective_message.chat_id, action=action)
|
||||
return func(self, update, context, *args, **kwargs)
|
||||
return command_func
|
||||
return decorator
|
||||
|
||||
send_typing_action = send_action(ChatAction.TYPING)
|
||||
|
||||
def gpt_normalize(txt):
|
||||
txt = re.sub(r"[^A-Za-z0-9()\[\]:,.!?'“”\"]", " ", txt) # remove illegal chars
|
||||
return ' '.join(txt.strip().split()) # remove unnecessary spaces
|
||||
|
||||
@send_typing_action
|
||||
def message(self, update, context):
|
||||
# Parse parameters
|
||||
num_samples = self.config.getint('decoder', 'num_samples')
|
||||
max_turns_history = self.config.getint('decoder', 'max_turns_history')
|
||||
if 'turns' not in context.chat_data:
|
||||
context.chat_data['turns'] = []
|
||||
turns = context.chat_data['turns']
|
||||
|
||||
user_message = update.message.text
|
||||
if user_message.lower() == 'bye':
|
||||
# Restart chat
|
||||
context.chat_data['turns'] = []
|
||||
update.message.reply_text("Bye")
|
||||
return None
|
||||
return_gif = False
|
||||
if '@gif' in user_message:
|
||||
# Return gif
|
||||
return_gif = True
|
||||
user_message = user_message.replace('@gif', '').strip()
|
||||
if max_turns_history == 0:
|
||||
# If you still get different responses then set seed
|
||||
context.chat_data['turns'] = []
|
||||
# A single turn is a group of user messages and bot responses right after
|
||||
turn = {
|
||||
'user_messages': [],
|
||||
'bot_messages': []
|
||||
}
|
||||
turns.append(turn)
|
||||
turn['user_messages'].append(user_message)
|
||||
logger.info(f"{update.effective_message.chat_id} - User >>> {user_message}")
|
||||
# Merge turns into a single history (don't forget EOS token)
|
||||
history = ""
|
||||
from_index = max(len(turns)-max_turns_history-1, 0) if max_turns_history >= 0 else 0
|
||||
for turn in turns[from_index:]:
|
||||
# Each turn begings with user messages
|
||||
for message in turn['user_messages']:
|
||||
history += gpt_normalize(message) + self.tokenizer.eos_token
|
||||
for message in turn['bot_messages']:
|
||||
history += gpt_normalize(message) + self.tokenizer.eos_token
|
||||
|
||||
# Generate bot messages
|
||||
bot_messages = generate_response(
|
||||
self.model,
|
||||
self.tokenizer,
|
||||
history,
|
||||
self.config,
|
||||
mmi_model=self.mmi_model,
|
||||
mmi_tokenizer=self.mmi_tokenizer
|
||||
)
|
||||
if num_samples == 1:
|
||||
bot_message = bot_messages[0]
|
||||
else:
|
||||
# TODO: Select a message that is the most appropriate given the context
|
||||
# This way you can avoid loops
|
||||
bot_message = random.choice(bot_messages)
|
||||
turn['bot_messages'].append(bot_message)
|
||||
logger.info(f"{update.effective_message.chat_id} - Bot >>> {bot_message}")
|
||||
if return_gif:
|
||||
# Return response as GIF
|
||||
gif_url = translate_message_to_gif(bot_message, self.config)
|
||||
context.bot.send_animation(update.effective_message.chat_id, gif_url)
|
||||
else:
|
||||
# Return response as text
|
||||
update.message.reply_text(bot_message)
|
||||
|
||||
|
||||
def error(update, context):
|
||||
logger.warning(context.error)
|
||||
|
||||
class TelegramBot:
|
||||
def __init__(self, model, tokenizer, config, mmi_model=None, mmi_tokenizer=None):
|
||||
logger.info("Initializing the bot...")
|
||||
|
||||
# Set global variables
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.mmi_model = mmi_model
|
||||
self.mmi_tokenizer = mmi_tokenizer
|
||||
self.config = config
|
||||
|
||||
# Set up Telegram bot
|
||||
self.updater = Updater(config.get('chatbot', 'telegram_token'), use_context=True)
|
||||
dp = self.updater.dispatcher
|
||||
|
||||
# on different commands - answer in Telegram
|
||||
# conversation with bot
|
||||
dp.add_handler(MessageHandler(Filters.text, self_decorator(self, message)))
|
||||
|
||||
# chatbot settings
|
||||
dp.add_handler(CommandHandler('start', start_command))
|
||||
|
||||
# log all errors
|
||||
dp.add_error_handler(error)
|
||||
|
||||
def run_chat(self):
|
||||
logger.info("Running the chatbot...")
|
||||
|
||||
# Start the Bot
|
||||
self.updater.start_polling()
|
||||
|
||||
# Run the bot until you press Ctrl-C or the process receives SIGINT,
|
||||
# SIGTERM or SIGABRT. This should be used most of the time, since
|
||||
# start_polling() is non-blocking and will stop the bot gracefully.
|
||||
self.updater.idle()
|
||||
|
||||
def main():
|
||||
# Script arguments can include path of the config
|
||||
arg_parser = argparse.ArgumentParser()
|
||||
arg_parser.add_argument('--config', type=str, default="chatbot.cfg")
|
||||
args = arg_parser.parse_args()
|
||||
|
||||
# Read the config
|
||||
config = configparser.ConfigParser(allow_no_value=True)
|
||||
with open(args.config) as f:
|
||||
config.read_file(f)
|
||||
|
||||
# Download and load main model
|
||||
target_folder_name = download_model_folder(config)
|
||||
model, tokenizer = load_model(target_folder_name, config)
|
||||
|
||||
# Download and load reverse model
|
||||
use_mmi = config.getboolean('model', 'use_mmi')
|
||||
if use_mmi:
|
||||
mmi_target_folder_name = download_reverse_model_folder(config)
|
||||
mmi_model, mmi_tokenizer = load_model(mmi_target_folder_name, config)
|
||||
else:
|
||||
mmi_model = None
|
||||
mmi_tokenizer = None
|
||||
|
||||
# Run Telegram bot
|
||||
bot = TelegramBot(model, tokenizer, config, mmi_model=mmi_model, mmi_tokenizer=mmi_tokenizer)
|
||||
bot.run_chat()
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Binary file not shown.
After Width: | Height: | Size: 18 KiB |
|
@ -0,0 +1,8 @@
|
|||
numpy==1.16.4
|
||||
torch==1.2.0
|
||||
transformers==2.3.0
|
||||
python-telegram-bot==12.3.0
|
||||
discord.py==1.2.5
|
||||
googletrans==2.4.0
|
||||
textblob==0.15.3
|
||||
matplotlib==2.0.2
|
Binary file not shown.
After Width: | Height: | Size: 5.5 MiB |
Loading…
Reference in New Issue