m4pl1mp/plugins/maple_plugin.py
[d] 1f3ecfdfc8 Update 'plugins/maple_plugin.py'
renamed to replace the deprecation
2022-02-02 10:06:38 +00:00

357 lines
18 KiB
Python

# -*- coding: utf-8 -*-
import os
import io
import irc3
import requests
from tqdm import tqdm
from glob import glob
import torch
import torch.nn.functional as F
import numpy as np
import logging
import signal
import configparser
import logging
import random
from transformers import GPT2Config,GPT2LMHeadModel,GPT2Tokenizer
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor
#import ipdb
###########################################################################################################
###########################################################################################################
@irc3.plugin
class Plugin:
#######################################################################################################
#######################################################################################################
PoolExecutor=ThreadPoolExecutor
#######################################################################################################
#######################################################################################################
terminate=False
logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.INFO)
logger=logging.getLogger(__name__)
#######################################################################################
#######################################################################################
CONFIG_FILE={
'small':'https://convaisharables.blob.core.windows.net/lsp/117M/config.json',
'medium':'https://convaisharables.blob.core.windows.net/lsp/345M/config.json'
}
VOCAB_FILE={
'small':'https://convaisharables.blob.core.windows.net/lsp/117M/vocab.json',
'medium':'https://convaisharables.blob.core.windows.net/lsp/345M/vocab.json'
}
MERGE_FILE={
'small':'https://convaisharables.blob.core.windows.net/lsp/117M/merges.txt',
'medium':'https://convaisharables.blob.core.windows.net/lsp/345M/merges.txt'
}
LSP_MODEL_URL={
'multiref':{
'medium_fs':'https://convaisharables.blob.core.windows.net/lsp/multiref/medium_fs.pkl',
'medium_ft':'https://convaisharables.blob.core.windows.net/lsp/multiref/medium_ft.pkl',
'small_fs':'https://convaisharables.blob.core.windows.net/lsp/multiref/small_fs.pkl',
'small_ft':'https://convaisharables.blob.core.windows.net/lsp/multiref/small_ft.pkl'
},
'dstc':{
'small_ft':'https://convaisharables.blob.core.windows.net/lsp/DSTC/medium_ft.pkl'
}
}
#######################################################################################
#######################################################################################
REVERSE_MODEL_URL='https://convaisharables.blob.core.windows.net/lsp/multiref/small_reverse.pkl'
#######################################################################################
#######################################################################################
PERSONALITY="""
[model]
data_folder=models
model_size=medium
dataset=multiref
from_scratch=False
no_cuda=False
use_mmi=False
[decoder]
seed=0
temperature=0.6474
top_k=40
top_p=0
max_length=128
num_samples=1
max_turns_history=1
[personality]
telegram_token=YOUR_TOKEN_HERE
giphy_token=YOUR_TOKEN_HERE
giphy_weirdness=5
"""
#######################################################################################
#######################################################################################
def __init__(self,bot):
self.maple_message=''
self.PERSONALITY=self.PERSONALITY.format(RND=datetime.now().microsecond)
self.bot=bot
self.delay=0.025
self.cycle=0
CONFIG=io.StringIO(self.PERSONALITY)
self.config=configparser.ConfigParser()
self.config.read_file(CONFIG)
self.target_folder_name=self.download_model_folder(self.config)
self.model,self.tokenizer=self.load_model(self.target_folder_name,self.config)
self.use_mmi=self.config.getboolean('model','use_mmi')
if self.use_mmi:
self.mmi_target_folder_name=self.download_reverse_model_folder(self.config)
self.mmi_model,mmi_tokenizer=self.load_model(self.mmi_target_folder_name,self.config)
else:
self.mmi_model=None
self.mmi_tokenizer=None
loop=self.bot.loop
loop.call_later(self.delay,self.main)
#######################################################################################
#######################################################################################
@irc3.event(irc3.rfc.PRIVMSG)
def on_privmsg_search_for_maple(self, mask=None, target=None, data=None, **kw):
##############################################
if mask.nick == 'maple':
data=data.lower().replace('maple','').strip()
# ##############################################
# if not data.lower().find('maple') > -1:
# return
#data=data.lower().replace('maple','').strip()
##############################################
self.maple_message=data
#######################################################################################
#######################################################################################
def signal_handling(self,signum,frame):
self.terminate=True
#######################################################################################
#######################################################################################
def http_get(self,url,temp_file):
req=requests.get(url,stream=True)
content_length=req.headers.get('Content-Length')
total=int(content_length) if content_length is not None else None
progress=tqdm(unit="B",total=total)
for chunk in req.iter_content(chunk_size=1024):
if chunk:
progress.update(len(chunk))
temp_file.write(chunk)
progress.close()
#######################################################################################
#######################################################################################
def download_file(self,url,folder):
if not os.path.exists(folder):
os.makedirs(folder,exist_ok=True)
file_name=os.path.basename(url)
if 'pytorch_model.bin' in file_name:
file_name='pytorch_model.bin'
if os.path.isfile(os.path.join(folder,file_name)):
return
with open(os.path.join(folder,file_name),'wb') as f:
self.http_get(url,f)
#######################################################################################
#######################################################################################
def download_model_folder(self,config):
data_folder=config.get('model','data_folder')
model_size=config.get('model','model_size')
dataset=config.get('model','dataset')
from_scratch=config.getboolean('model','from_scratch')
if not os.path.exists(data_folder):
os.makedirs(data_folder, exist_ok=True)
target_folder_name=model_size+"_"+dataset+("_fs" if from_scratch else "_ft")
target_folder=os.path.join(data_folder,target_folder_name)
self.logger.info(f"Downloading model files to {target_folder_name}...")
self.download_file(self.CONFIG_FILE[model_size],target_folder)
self.download_file(self.VOCAB_FILE[model_size],target_folder)
self.download_file(self.MERGE_FILE[model_size],target_folder)
model_train_type=model_size+('_fs' if from_scratch else '_ft')
if model_train_type not in self.LSP_MODEL_URL[dataset]:
k=','.join(list(self.LSP_MODEL_URL[dataset].keys()))
raise ValueError(f"'{model_train_type}' not exist for dataset '{dataset}', please choose from [{k}]")
self.download_file(self.LSP_MODEL_URL[dataset][model_train_type],target_folder)
return target_folder_name
#######################################################################################
#######################################################################################
def download_reverse_model_folder(self,config):
data_folder=config.get('model','data_folder')
model_size='medium'
if not os.path.exists(data_folder):
os.makedirs(data_folder,exist_ok=True)
target_folder_name=model_size+'_reverse'
target_folder=os.path.join(data_folder,target_folder_name)
self.logger.info(f"Downloading model files to {target_folder_name}...")
self.download_file(self.CONFIG_FILE[model_size],target_folder)
self.download_file(self.VOCAB_FILE[model_size],target_folder)
self.download_file(self.MERGE_FILE[model_size],target_folder)
self.download_file(self.REVERSE_MODEL_URL,target_folder)
return target_folder_name
#######################################################################################
#######################################################################################
def load_model(self,target_folder_name,config):
data_folder=config.get('model','data_folder')
model_size=config.get('model','model_size')
no_cuda=config.getboolean('model', 'no_cuda')
self.logger.info(f"Loading model from {target_folder_name}...")
device=torch.device("cuda" if torch.cuda.is_available() and not no_cuda else "cpu")
target_folder=os.path.join(data_folder,target_folder_name)
tokenizer=GPT2Tokenizer(os.path.join(target_folder, 'vocab.json'), os.path.join(target_folder, 'merges.txt'))
config=GPT2Config.from_json_file(os.path.join(target_folder, 'config.json'))
state_dict_path=glob(os.path.join(target_folder,f'*.pkl'))[0]
state_dict=torch.load(state_dict_path,map_location=device)
if model_size=='small':
for key in list(state_dict.keys()):
state_dict[key.replace('module.','')]=state_dict.pop(key)
state_dict['lm_head.weight']=state_dict['lm_head.decoder.weight']
state_dict.pop("lm_head.decoder.weight",None)
model=GPT2LMHeadModel(config)
model.load_state_dict(state_dict)
model.to(device)
model.eval()
return model,tokenizer
#######################################################################################
#######################################################################################
def set_seed(self,seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
#######################################################################################
#######################################################################################
def top_k_top_p_filtering(self,logits,top_k=0,top_p=0.0,filter_value=-float('Inf')):
top_k=min(top_k,logits.size(-1))
if top_k>0:
indices_to_remove=logits<torch.topk(logits,top_k)[0][...,-1,None]
logits[indices_to_remove]=filter_value
if top_p>0.0:
sorted_logits,sorted_indices=torch.sort(logits,descending=True)
cumulative_probs=torch.cumsum(F.softmax(sorted_logits,dim=-1),dim=-1)
sorted_indices_to_remove=cumulative_probs>top_p
sorted_indices_to_remove[...,1:]=sorted_indices_to_remove[...,:-1].clone()
sorted_indices_to_remove[...,0]=0
indices_to_remove=sorted_indices_to_remove.scatter(dim=1,index=sorted_indices,src=sorted_indices_to_remove)
logits[indices_to_remove]=filter_value
return logits
#######################################################################################
#######################################################################################
def sample_sequence(self,model,tokenizer,context_ids,config):
no_cuda=config.getboolean('model','no_cuda')
num_samples=config.getint('decoder','num_samples')
max_length=config.getint('decoder','max_length')
temperature=config.getfloat('decoder','temperature')
top_k=config.getint('decoder','top_k')
top_p=config.getfloat('decoder','top_p')
device=torch.device("cuda" if torch.cuda.is_available() and not no_cuda else "cpu")
context_tensor=torch.tensor(context_ids,dtype=torch.long,device=device)
context_tensor=context_tensor.unsqueeze(0).repeat(num_samples,1)
generated=context_tensor
with torch.no_grad():
while True:
inputs={'input_ids':generated}
outputs=model(**inputs)
next_token_logits=outputs[0][:,-1,:]/(temperature if temperature>0 else 1.)
filtered_logits=self.top_k_top_p_filtering(next_token_logits,top_k=top_k,top_p=top_p)
if temperature==0.0:
next_token=torch.argmax(filtered_logits,dim=-1).unsqueeze(-1)
else:
next_token=torch.multinomial(F.softmax(filtered_logits,dim=-1),num_samples=1)
generated=torch.cat((generated,next_token),dim=1)
if (generated[:,len(context_ids):]==tokenizer.eos_token_id).any(dim=1).all():
break
if generated.shape[1]-len(context_ids)>=max_length:
break
return generated
#######################################################################################
#######################################################################################
def select_using_mmi(self,mmi_model,mmi_tokenizer,candidates,config):
no_cuda=config.getboolean('model','no_cuda')
device=torch.device("cuda" if torch.cuda.is_available() and not no_cuda else "cpu")
scores=[]
for i,candidate in enumerate(candidates):
context=[]
for response in reversed(candidate):
context.extend(response)
context.append(mmi_tokenizer.eos_token_id)
context_ids=mmi_tokenizer.encode(context)
context_tensor=torch.tensor(context_ids,dtype=torch.long,device=device)
loss,_,_=mmi_model(input_ids=context_tensor,labels=context_tensor)
scores.append(-loss.float())
scores=torch.stack(scores, dim=0)
winner=torch.multinomial(F.softmax(scores,dim=0),num_samples=1).item()
return winner
#######################################################################################
#######################################################################################
def generate_response(self,model,tokenizer,context,config,mmi_model=None,mmi_tokenizer=None):
use_mmi=config.getboolean('model','use_mmi')
num_samples=config.getint('decoder','num_samples')
max_length=config.getint('decoder','max_length')
seed=config.get('decoder','seed')
seed=int(seed) if seed is not None else None
if seed is not None:
self.set_seed(seed)
context_ids=tokenizer.encode(context)
samples=self.sample_sequence(model, tokenizer, context_ids, config)
samples=samples[:, len(context_ids):].tolist()
texts=[]
for sample in samples:
text=tokenizer.decode(sample,clean_up_tokenization_spaces=True)
text=text[: text.find(tokenizer.eos_token)]
texts.append(text)
if use_mmi:
assert(num_samples > 1)
candidates=[context+text for text in texts]
best_i=self.select_using_mmi(mmi_model,mmi_tokenizer,candidates,config)
return [texts[best_i]]
return texts
#######################################################################################
#######################################################################################
def run_chat(self,model,tokenizer,config,mmi_model=None,mmi_tokenizer=None):
num_samples=config.getint('decoder','num_samples')
max_turns_history=config.getint('decoder','max_turns_history')
turns=[]
signal.signal(signal.SIGINT,self.signal_handling)
# print("*** RUNNING ***")
# while True:
# if self.terminate:
# break
config.set('decoder','seed',f'{datetime.now().microsecond}')
s=self.maple_message
if len(s)>1:
prompt=s.strip();del(s)
del(self.maple_message)
print(f'human > {prompt}')
if max_turns_history==0:
turns=[]
turn={
'human_messages':[],
'maple_messages':[]
}
turns.append(turn)
turn['human_messages'].append(prompt)
history=""
from_index=max(len(turns)-max_turns_history-1,0) if max_turns_history>=0 else 0
for turn in turns[from_index:]:
for message in turn['human_messages']:
history+=message+tokenizer.eos_token
for message in turn['maple_messages']:
history+=message+tokenizer.eos_token
maple_messages=self.generate_response(
model,
tokenizer,
history,
config,
mmi_model=mmi_model,
mmi_tokenizer=mmi_tokenizer
)
if num_samples==1:
maple_message=maple_messages[0]
else:
maple_message=random.choice(maple_messages)
turn['maple_messages'].append(maple_message)
print(f'maple > {maple_message}')
print(maple_message)
self.bot.privmsg('#staff',self.bot.emo(maple_message))
print('cya')
#######################################################################################
#######################################################################################
def main(self):
loop=self.bot.loop
loop.call_later(self.delay,self.main)
tasks=[]
task=loop.run_in_executor(None,self.run_chat(self.model,self.tokenizer,self.config,mmi_model=self.mmi_model,mmi_tokenizer=self.mmi_tokenizer))
tasks.append(task)
print('l8r')
###########################################################################################################
###########################################################################################################