470 lines
23 KiB
Python
470 lines
23 KiB
Python
# -*- coding: utf-8 -*-
|
|
import os
|
|
import io
|
|
import irc3
|
|
from difflib import SequenceMatcher
|
|
import requests
|
|
from tqdm import tqdm
|
|
from glob import glob
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import numpy as np
|
|
import signal
|
|
import configparser
|
|
import logging
|
|
import random
|
|
from transformers import GPT2Config,GPT2LMHeadModel,GPT2Tokenizer
|
|
from datetime import datetime
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from time import time
|
|
###########################################################################################
|
|
class MESSAGE_HISTORY():
|
|
#######################################################################################
|
|
maple_messages = []
|
|
user_messages = []
|
|
user_users = []
|
|
#######################################################################################
|
|
def __init__(self):
|
|
self.processing=0
|
|
self.bounce=False
|
|
#######################################################################################
|
|
def push_maple_messages(self,data):
|
|
self.maple_messages = self.maple_messages[-1:] + self.maple_messages[:-1]
|
|
self.maple_messages[0] = data
|
|
#######################################################################################
|
|
def push_user_messages(self,user,data):
|
|
self.user_users.append(user)
|
|
self.user_messages.append(data)
|
|
#######################################################################################
|
|
def similar(self,a,b):
|
|
return SequenceMatcher(None,a,b).ratio()
|
|
###########################################################################################
|
|
@irc3.plugin
|
|
class Plugin:
|
|
#######################################################################################
|
|
PoolExecutor=ThreadPoolExecutor
|
|
#######################################################################################
|
|
terminate=False
|
|
logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.INFO)
|
|
logger=logging.getLogger(__name__)
|
|
#######################################################################################
|
|
CONFIG_FILE={
|
|
'small':'https://convaisharables.blob.core.windows.net/lsp/117M/config.json',
|
|
'medium':'https://convaisharables.blob.core.windows.net/lsp/345M/config.json'
|
|
}
|
|
#######################################################################################
|
|
VOCAB_FILE={
|
|
'small':'https://convaisharables.blob.core.windows.net/lsp/117M/vocab.json',
|
|
'medium':'https://convaisharables.blob.core.windows.net/lsp/345M/vocab.json'
|
|
}
|
|
#######################################################################################
|
|
MERGE_FILE={
|
|
'small':'https://convaisharables.blob.core.windows.net/lsp/117M/merges.txt',
|
|
'medium':'https://convaisharables.blob.core.windows.net/lsp/345M/merges.txt'
|
|
}
|
|
#######################################################################################
|
|
LSP_MODEL_URL={
|
|
'multiref':{
|
|
'medium_fs':'https://convaisharables.blob.core.windows.net/lsp/multiref/medium_fs.pkl',
|
|
'medium_ft':'https://convaisharables.blob.core.windows.net/lsp/multiref/medium_ft.pkl',
|
|
'small_fs':'https://convaisharables.blob.core.windows.net/lsp/multiref/small_fs.pkl',
|
|
'small_ft':'https://convaisharables.blob.core.windows.net/lsp/multiref/small_ft.pkl'
|
|
},
|
|
'dstc':{
|
|
'small_ft':'https://convaisharables.blob.core.windows.net/lsp/DSTC/medium_ft.pkl'
|
|
}
|
|
}
|
|
#######################################################################################
|
|
REVERSE_MODEL_URL='https://convaisharables.blob.core.windows.net/lsp/multiref/small_reverse.pkl'
|
|
OPINION="""
|
|
.drinkin
|
|
.smokin
|
|
"""
|
|
#######################################################################################
|
|
WISDOM="""
|
|
"""
|
|
#######################################################################################
|
|
PERSONALITY="""
|
|
[model]
|
|
data_folder=models
|
|
model_size=medium
|
|
dataset=multiref
|
|
from_scratch=True
|
|
no_cuda=False
|
|
use_mmi=False
|
|
[decoder]
|
|
seed=0
|
|
temperature=0.6474
|
|
top_k=40
|
|
top_p=0
|
|
max_length=128
|
|
num_samples=1
|
|
max_turns_history=-1
|
|
"""
|
|
#######################################################################################
|
|
def __init__(self,bot):
|
|
self.bot=bot
|
|
self.bot.history=MESSAGE_HISTORY()
|
|
#############################################
|
|
for _ in range(5):
|
|
self.bot.history.maple_messages.append("")
|
|
#############################################
|
|
self.mode=0
|
|
self.span=0
|
|
self.epoch_time_last=0
|
|
self.epoch_time_now=0
|
|
self.epoch_time_boolean=False
|
|
self.maple_io=[]
|
|
self.PERSONALITY=self.PERSONALITY.format(RND=datetime.now().microsecond)
|
|
self.delay=0.05
|
|
CONFIG=io.StringIO(self.PERSONALITY)
|
|
self.config=configparser.ConfigParser()
|
|
self.config.read_file(CONFIG)
|
|
self.target_folder_name=self.download_model_folder(self.config)
|
|
self.model,self.tokenizer=self.load_model(self.target_folder_name,self.config)
|
|
self.use_mmi=self.config.getboolean('model','use_mmi')
|
|
if self.use_mmi:
|
|
self.mmi_target_folder_name=self.download_reverse_model_folder(self.config)
|
|
self.mmi_model,mmi_tokenizer=self.load_model(self.mmi_target_folder_name,self.config)
|
|
else:
|
|
self.mmi_model=None
|
|
self.mmi_tokenizer=None
|
|
self.main()
|
|
loop=self.bot.loop
|
|
loop.call_later(self.delay,self.main)
|
|
#######################################################################################
|
|
def signal_handling(self,signum,frame):
|
|
self.terminate=True
|
|
#######################################################################################
|
|
def http_get(self,url,temp_file):
|
|
req=requests.get(url,stream=True)
|
|
content_length=req.headers.get('Content-Length')
|
|
total=int(content_length) if content_length is not None else None
|
|
progress=tqdm(unit="B",total=total)
|
|
for chunk in req.iter_content(chunk_size=1024):
|
|
if chunk:
|
|
progress.update(len(chunk))
|
|
temp_file.write(chunk)
|
|
progress.close()
|
|
#######################################################################################
|
|
def download_file(self,url,folder):
|
|
if not os.path.exists(folder):
|
|
os.makedirs(folder,exist_ok=True)
|
|
file_name=os.path.basename(url)
|
|
if 'pytorch_model.bin' in file_name:
|
|
file_name='pytorch_model.bin'
|
|
if os.path.isfile(os.path.join(folder,file_name)):
|
|
return
|
|
with open(os.path.join(folder,file_name),'wb') as f:
|
|
self.http_get(url,f)
|
|
#######################################################################################
|
|
def download_model_folder(self,config):
|
|
data_folder=config.get('model','data_folder')
|
|
model_size=config.get('model','model_size')
|
|
dataset=config.get('model','dataset')
|
|
from_scratch=config.getboolean('model','from_scratch')
|
|
if not os.path.exists(data_folder):
|
|
os.makedirs(data_folder, exist_ok=True)
|
|
target_folder_name=model_size+"_"+dataset+("_fs" if from_scratch else "_ft")
|
|
target_folder=os.path.join(data_folder,target_folder_name)
|
|
self.logger.info(f"Downloading model files to {target_folder_name}...")
|
|
self.download_file(self.CONFIG_FILE[model_size],target_folder)
|
|
self.download_file(self.VOCAB_FILE[model_size],target_folder)
|
|
self.download_file(self.MERGE_FILE[model_size],target_folder)
|
|
model_train_type=model_size+('_fs' if from_scratch else '_ft')
|
|
if model_train_type not in self.LSP_MODEL_URL[dataset]:
|
|
k=','.join(list(self.LSP_MODEL_URL[dataset].keys()))
|
|
raise ValueError(f"'{model_train_type}' not exist for dataset '{dataset}', please choose from [{k}]")
|
|
self.download_file(self.LSP_MODEL_URL[dataset][model_train_type],target_folder)
|
|
return target_folder_name
|
|
#######################################################################################
|
|
def download_reverse_model_folder(self,config):
|
|
data_folder=config.get('model','data_folder')
|
|
model_size='medium'
|
|
if not os.path.exists(data_folder):
|
|
os.makedirs(data_folder,exist_ok=True)
|
|
target_folder_name=model_size+'_reverse'
|
|
target_folder=os.path.join(data_folder,target_folder_name)
|
|
self.logger.info(f"Downloading model files to {target_folder_name}...")
|
|
self.download_file(self.CONFIG_FILE[model_size],target_folder)
|
|
self.download_file(self.VOCAB_FILE[model_size],target_folder)
|
|
self.download_file(self.MERGE_FILE[model_size],target_folder)
|
|
self.download_file(self.REVERSE_MODEL_URL,target_folder)
|
|
return target_folder_name
|
|
#######################################################################################
|
|
def load_model(self,target_folder_name,config):
|
|
data_folder=config.get('model','data_folder')
|
|
model_size=config.get('model','model_size')
|
|
no_cuda=config.getboolean('model', 'no_cuda')
|
|
self.logger.info(f"Loading model from {target_folder_name}...")
|
|
device=torch.device("cuda" if torch.cuda.is_available() and not no_cuda else "cpu")
|
|
target_folder=os.path.join(data_folder,target_folder_name)
|
|
tokenizer=GPT2Tokenizer(os.path.join(target_folder, 'vocab.json'), os.path.join(target_folder,'merges.txt'))
|
|
config=GPT2Config.from_json_file(os.path.join(target_folder,'config.json'))
|
|
state_dict_path=glob(os.path.join(target_folder,f'*.pkl'))[0]
|
|
state_dict=torch.load(state_dict_path,map_location=device)
|
|
if model_size=='small':
|
|
for key in list(state_dict.keys()):
|
|
state_dict[key.replace('module.','')]=state_dict.pop(key)
|
|
state_dict['lm_head.weight']=state_dict['lm_head.decoder.weight']
|
|
state_dict.pop("lm_head.decoder.weight",None)
|
|
model=GPT2LMHeadModel(config)
|
|
model.load_state_dict(state_dict)
|
|
model.to(device)
|
|
model.eval()
|
|
return model,tokenizer
|
|
#######################################################################################
|
|
def set_seed(self,seed):
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
#######################################################################################
|
|
def top_k_top_p_filtering(self,logits,top_k=0,top_p=0.0,filter_value=-float('Inf')):
|
|
top_k=min(top_k,logits.size(-1))
|
|
if top_k>0:
|
|
indices_to_remove=logits<torch.topk(logits,top_k)[0][...,-1,None]
|
|
logits[indices_to_remove]=filter_value
|
|
if top_p>0.0:
|
|
sorted_logits,sorted_indices=torch.sort(logits,descending=True)
|
|
cumulative_probs=torch.cumsum(F.softmax(sorted_logits,dim=-1),dim=-1)
|
|
sorted_indices_to_remove=cumulative_probs>top_p
|
|
sorted_indices_to_remove[...,1:]=sorted_indices_to_remove[...,:-1].clone()
|
|
sorted_indices_to_remove[...,0]=0
|
|
indices_to_remove=sorted_indices_to_remove.scatter(dim=1,index=sorted_indices,src=sorted_indices_to_remove)
|
|
logits[indices_to_remove]=filter_value
|
|
return logits
|
|
#######################################################################################
|
|
def sample_sequence(self,model,tokenizer,context_ids,config):
|
|
no_cuda=config.getboolean('model','no_cuda')
|
|
num_samples=config.getint('decoder','num_samples')
|
|
max_length=config.getint('decoder','max_length')
|
|
temperature=config.getfloat('decoder','temperature')
|
|
top_k=config.getint('decoder','top_k')
|
|
top_p=config.getfloat('decoder','top_p')
|
|
device=torch.device("cuda" if torch.cuda.is_available() and not no_cuda else "cpu")
|
|
context_tensor=torch.tensor(context_ids,dtype=torch.long,device=device)
|
|
context_tensor=context_tensor.unsqueeze(0).repeat(num_samples,1)
|
|
generated=context_tensor
|
|
with torch.no_grad():
|
|
while True:
|
|
inputs={'input_ids':generated}
|
|
outputs=model(**inputs)
|
|
next_token_logits=outputs[0][:,-1,:]/(temperature if temperature>0 else 1.)
|
|
filtered_logits=self.top_k_top_p_filtering(next_token_logits,top_k=top_k,top_p=top_p)
|
|
if temperature==0.0:
|
|
next_token=torch.argmax(filtered_logits,dim=-1).unsqueeze(-1)
|
|
else:
|
|
next_token=torch.multinomial(F.softmax(filtered_logits,dim=-1),num_samples=1)
|
|
generated=torch.cat((generated,next_token),dim=1)
|
|
if (generated[:,len(context_ids):]==tokenizer.eos_token_id).any(dim=1).all():
|
|
break
|
|
if generated.shape[1]-len(context_ids)>=max_length:
|
|
break
|
|
return generated
|
|
#######################################################################################
|
|
def select_using_mmi(self,mmi_model,mmi_tokenizer,candidates,config):
|
|
no_cuda=config.getboolean('model','no_cuda')
|
|
device=torch.device("cuda" if torch.cuda.is_available() and not no_cuda else "cpu")
|
|
scores=[]
|
|
for i,candidate in enumerate(candidates):
|
|
context=[]
|
|
for response in reversed(candidate):
|
|
context.extend(response)
|
|
context.append(mmi_tokenizer.eos_token_id)
|
|
context_ids=mmi_tokenizer.encode(context)
|
|
context_tensor=torch.tensor(context_ids,dtype=torch.long,device=device)
|
|
loss,_,_=mmi_model(input_ids=context_tensor,labels=context_tensor)
|
|
scores.append(-loss.float())
|
|
scores=torch.stack(scores, dim=0)
|
|
winner=torch.multinomial(F.softmax(scores,dim=0),num_samples=1).item()
|
|
return winner
|
|
#######################################################################################
|
|
def generate_response(self,model,tokenizer,context,config,mmi_model=None,mmi_tokenizer=None):
|
|
use_mmi=config.getboolean('model','use_mmi')
|
|
num_samples=config.getint('decoder','num_samples')
|
|
max_length=config.getint('decoder','max_length')
|
|
seed=config.get('decoder','seed')
|
|
seed=int(seed) if seed is not None else None
|
|
if seed is not None:
|
|
self.set_seed(seed)
|
|
context_ids=tokenizer.encode(context)
|
|
samples=self.sample_sequence(model, tokenizer, context_ids, config)
|
|
samples=samples[:, len(context_ids):].tolist()
|
|
texts=[]
|
|
for sample in samples:
|
|
text=tokenizer.decode(sample,clean_up_tokenization_spaces=True)
|
|
text=text[: text.find(tokenizer.eos_token)]
|
|
texts.append(text)
|
|
if use_mmi:
|
|
assert(num_samples > 1)
|
|
candidates=[context+text for text in texts]
|
|
best_i=self.select_using_mmi(mmi_model,mmi_tokenizer,candidates,config)
|
|
return [texts[best_i]]
|
|
return texts
|
|
#######################################################################################
|
|
@irc3.event(irc3.rfc.PRIVMSG)
|
|
def on_privmsg_search_for_maple(self, mask=None, target=None, data=None, **kw):
|
|
##############################################
|
|
if mask.nick == self.bot.config["nick"] or mask.nick == 'nickserv':
|
|
print('returning, message data from bot not user')
|
|
return
|
|
##############################################
|
|
if self.epoch_time_boolean==True:
|
|
print('[ checking flood protection status ]')
|
|
epoch_time_now=int(str(time()).split('.')[0])
|
|
if epoch_time_now-self.epoch_time_last>=30:
|
|
self.epoch_time_boolean=False
|
|
print('[ turned off flood protection ]')
|
|
else:
|
|
print('[ flood protection still on ]')
|
|
return
|
|
##############################################
|
|
data=data.strip().lower()
|
|
##############################################
|
|
self.span=0
|
|
for _ in data.split():
|
|
for __ in self.OPINION.split():
|
|
if _ == __:
|
|
self.span=1
|
|
break
|
|
##############################################
|
|
if mask.nick == "d":
|
|
if data.find(self.bot.config["nick"])>-1:
|
|
if data.find('hang with us')>-1:
|
|
msg="ok, i'll hang out for a bit"
|
|
self.bot.privmsg(target,self.bot.emo(msg))
|
|
self.mode=1
|
|
if data.find('leave us alone')>-1:
|
|
msg="ok, gotta go"
|
|
self.bot.privmsg(target,self.bot.emo(msg))
|
|
self.mode=0
|
|
##############################################
|
|
if self.mode==0:
|
|
if not data.find(self.bot.config["nick"])>-1:
|
|
if self.span==0:
|
|
return
|
|
##############################################
|
|
self.span=0
|
|
##############################################
|
|
self.maple_io.append({'user':mask.nick,'message':data,'target':target})
|
|
if len(self.maple_io) > 5:
|
|
self.maple_io=[]
|
|
self.epoch_time_now=int(str(time()).split('.')[0])
|
|
self.epoch_time_last=self.epoch_time_now
|
|
self.epoch_time_boolean=True
|
|
msg=f"kind of busy at the moment {mask.nick}, i'll be right back"
|
|
print('[ turned on flood protection ]')
|
|
self.bot.privmsg(target,msg)
|
|
#######################################################################################
|
|
def run_chat(self,model,tokenizer,config,mmi_model=None,mmi_tokenizer=None):
|
|
num_samples=config.getint('decoder','num_samples')
|
|
max_turns_history=config.getint('decoder','max_turns_history')
|
|
turns=[]
|
|
signal.signal(signal.SIGINT,self.signal_handling)
|
|
config.set('decoder','seed',f'{datetime.now().microsecond}')
|
|
try:
|
|
if not type(self.bot.history.bounce)==bool:
|
|
USER=self.bot.history.bounce['user']
|
|
MESSAGE=self.bot.history.bounce['message']
|
|
TARGET=self.bot.history.bounce['target']
|
|
print(f"<<< mapleai: processing {TARGET} {USER} message: {MESSAGE}")
|
|
self.maple_io.append({'user':USER,'message':MESSAGE,'target':TARGET})
|
|
self.bot.history.bounce=False
|
|
except:
|
|
pass
|
|
|
|
try:
|
|
self.maple_io.reverse()
|
|
maple_io=self.maple_io.pop()
|
|
self.maple_io.reverse()
|
|
USER=maple_io['user']
|
|
MESSAGE=maple_io['message'].strip()
|
|
TARGET=maple_io['target']
|
|
except:
|
|
return self.exit_strategy
|
|
print(f'human > {MESSAGE}')
|
|
if max_turns_history==0:
|
|
turns=[]
|
|
turn={
|
|
'human_messages':[],
|
|
'maple_messages':[]
|
|
}
|
|
turns.append(turn)
|
|
turn['human_messages'].append(f'{USER} {MESSAGE}')
|
|
history=""
|
|
from_index=max(len(turns)-max_turns_history-1,0) if max_turns_history>=0 else 0
|
|
WISDOM=self.WISDOM.splitlines()
|
|
try:
|
|
WISDOM.remove('')
|
|
except:
|
|
pass
|
|
for i,_ in enumerate(WISDOM):
|
|
WISDOM[i]=_.strip()
|
|
static_history=WISDOM
|
|
for message in static_history:
|
|
history += message + tokenizer.eos_token
|
|
for turn in turns[from_index:]:
|
|
for message in turn['human_messages']:
|
|
history+=message+tokenizer.eos_token
|
|
for message in turn['maple_messages']:
|
|
history+=message+tokenizer.eos_token
|
|
maple_messages=self.generate_response(
|
|
model,
|
|
tokenizer,
|
|
history,
|
|
config,
|
|
mmi_model=mmi_model,
|
|
mmi_tokenizer=mmi_tokenizer
|
|
)
|
|
if num_samples==1:
|
|
maple_message=maple_messages[0]
|
|
else:
|
|
maple_message=random.choice(maple_messages)
|
|
turn['maple_messages'].append(maple_message)
|
|
################################################################################### REPROCESSOR SOF
|
|
# SIMILARITY
|
|
for i in range(len(self.bot.history.maple_messages)):
|
|
if self.bot.history.similar(maple_message,str(self.bot.history.maple_messages[i]))>0.9:
|
|
self.maple_io.append({'user':USER,'message':f'{MESSAGE} are you retarded','target':TARGET})
|
|
print(f'maple - logic ! rejected // maple similarity - repeat of previous response')
|
|
return self.exit_strategy
|
|
###################################################################################
|
|
# MOCK / DUPE
|
|
if self.bot.history.similar(maple_message,MESSAGE)>0.9:
|
|
self.maple_io.append({'user':USER,'message':f'{MESSAGE} are you retarded','target':TARGET})
|
|
print(f'maple - logic ! rejected // human mock - maple response same as human')
|
|
return self.exit_strategy
|
|
###################################################################################
|
|
# GPT LOOP GLITCH
|
|
n=len(maple_message.split())
|
|
i=len(set(maple_message.split()))
|
|
if i<int(n/2):
|
|
self.maple_io.append({'user':USER,'message':f'{MESSAGE} are you retarded','target':TARGET})
|
|
print(f'maple - logic ! rejected // gpt loop glitch - reiterating same thing in multiples')
|
|
return self.exit_strategy
|
|
###################################################################################
|
|
# LIMITED RESPONSE
|
|
n=len(maple_message.split())
|
|
if i<3:
|
|
self.maple_io.append({'user':USER,'message':f'{MESSAGE} are you retarded','target':TARGET})
|
|
print(f'maple - logic ! rejected // limited response - skip an unfinished token chain')
|
|
return self.exit_strategy
|
|
###################################################################################
|
|
self.bot.history.push_maple_messages(maple_message)
|
|
################################################################################### REPROCESSOR EOF
|
|
print(f'maple > {maple_message}')
|
|
self.bot.privmsg(TARGET,f'\x02\x0302{USER:}\x0F\x02\x0304 ▶ \x0F{maple_message.lower()}')
|
|
return self.exit_strategy
|
|
#######################################################################################
|
|
def main(self):
|
|
loop=self.bot.loop
|
|
loop.call_later(self.delay,self.main)
|
|
tasks=[]
|
|
task=loop.run_in_executor(None,\
|
|
self.run_chat(self.model,self.tokenizer,self.config,mmi_model=self.mmi_model,mmi_tokenizer=self.mmi_tokenizer))
|
|
tasks.append(task)
|
|
#######################################################################################
|
|
def exit_strategy(self):
|
|
pass
|
|
###########################################################################################
|