m4pl1mp/plugins/maple_plugin.py

442 lines
19 KiB
Python
Raw Normal View History

2022-02-02 04:44:42 +00:00
# -*- coding: utf-8 -*-
import os
import io
import irc3
import requests
from tqdm import tqdm
from glob import glob
import torch
import torch.nn.functional as F
import numpy as np
import signal
import configparser
import logging
import random
from transformers import GPT2Config,GPT2LMHeadModel,GPT2Tokenizer
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor
2022-02-03 04:30:22 +00:00
from time import time
###########################################################################################
2022-02-02 04:44:42 +00:00
@irc3.plugin
class Plugin:
#######################################################################################
2022-02-02 04:44:42 +00:00
PoolExecutor=ThreadPoolExecutor
#######################################################################################
2022-02-02 04:44:42 +00:00
terminate=False
logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.INFO)
logger=logging.getLogger(__name__)
#######################################################################################
CONFIG_FILE={
'small':'https://convaisharables.blob.core.windows.net/lsp/117M/config.json',
'medium':'https://convaisharables.blob.core.windows.net/lsp/345M/config.json'
}
VOCAB_FILE={
'small':'https://convaisharables.blob.core.windows.net/lsp/117M/vocab.json',
'medium':'https://convaisharables.blob.core.windows.net/lsp/345M/vocab.json'
}
MERGE_FILE={
'small':'https://convaisharables.blob.core.windows.net/lsp/117M/merges.txt',
'medium':'https://convaisharables.blob.core.windows.net/lsp/345M/merges.txt'
}
LSP_MODEL_URL={
'multiref':{
'medium_fs':'https://convaisharables.blob.core.windows.net/lsp/multiref/medium_fs.pkl',
'medium_ft':'https://convaisharables.blob.core.windows.net/lsp/multiref/medium_ft.pkl',
'small_fs':'https://convaisharables.blob.core.windows.net/lsp/multiref/small_fs.pkl',
'small_ft':'https://convaisharables.blob.core.windows.net/lsp/multiref/small_ft.pkl'
},
'dstc':{
'small_ft':'https://convaisharables.blob.core.windows.net/lsp/DSTC/medium_ft.pkl'
}
}
#######################################################################################
REVERSE_MODEL_URL='https://convaisharables.blob.core.windows.net/lsp/multiref/small_reverse.pkl'
#######################################################################################
OPINION="""
love
hate
kiss
kill
hack
crack
fuck
suck
smoke
drink
exploit
sploit
0day
malware
network
nft
fhj
fuckholejones
exceptions
upset
happy
music
food
hungry
rap
song
computer
email
lake
"""
2022-02-02 04:44:42 +00:00
#######################################################################################
WISDOM="""
my name is maple
i am a female
i am a beautiful woman
i am not a male
i am not a guy
down to fuck
any furpiles?
on jones
exceptions thrown
slow pumpers
fuck the ops
cut of your jib
^c's up
watch the way i walk
watch the way i talk
allegedly drug dealing
accidentally done drugs
nfts are a crypto scam
nfts are nice fucking titties
"""
#######################################################################################
2022-02-02 04:44:42 +00:00
PERSONALITY="""
[model]
data_folder=models
model_size=medium
dataset=multiref
from_scratch=False
no_cuda=False
use_mmi=False
[decoder]
seed=0
temperature=0.6474
top_k=40
top_p=0
max_length=128
num_samples=1
max_turns_history=1
"""
#######################################################################################
def __init__(self,bot):
self.mode=0
self.span=0
2022-02-03 04:30:22 +00:00
self.epoch_time_last=0
self.epoch_time_now=0
self.epoch_time_boolean=False
self.maple_io=[]
2022-02-02 04:44:42 +00:00
self.PERSONALITY=self.PERSONALITY.format(RND=datetime.now().microsecond)
self.bot=bot
self.delay=0.05
2022-02-02 04:44:42 +00:00
CONFIG=io.StringIO(self.PERSONALITY)
self.config=configparser.ConfigParser()
self.config.read_file(CONFIG)
self.target_folder_name=self.download_model_folder(self.config)
self.model,self.tokenizer=self.load_model(self.target_folder_name,self.config)
self.use_mmi=self.config.getboolean('model','use_mmi')
if self.use_mmi:
self.mmi_target_folder_name=self.download_reverse_model_folder(self.config)
self.mmi_model,mmi_tokenizer=self.load_model(self.mmi_target_folder_name,self.config)
else:
self.mmi_model=None
self.mmi_tokenizer=None
self.main()
2022-02-02 04:44:42 +00:00
loop=self.bot.loop
loop.call_later(self.delay,self.main)
#######################################################################################
def signal_handling(self,signum,frame):
self.terminate=True
#######################################################################################
def http_get(self,url,temp_file):
req=requests.get(url,stream=True)
content_length=req.headers.get('Content-Length')
total=int(content_length) if content_length is not None else None
progress=tqdm(unit="B",total=total)
for chunk in req.iter_content(chunk_size=1024):
if chunk:
progress.update(len(chunk))
temp_file.write(chunk)
progress.close()
#######################################################################################
def download_file(self,url,folder):
if not os.path.exists(folder):
os.makedirs(folder,exist_ok=True)
file_name=os.path.basename(url)
if 'pytorch_model.bin' in file_name:
file_name='pytorch_model.bin'
if os.path.isfile(os.path.join(folder,file_name)):
return
with open(os.path.join(folder,file_name),'wb') as f:
self.http_get(url,f)
#######################################################################################
def download_model_folder(self,config):
data_folder=config.get('model','data_folder')
model_size=config.get('model','model_size')
dataset=config.get('model','dataset')
from_scratch=config.getboolean('model','from_scratch')
if not os.path.exists(data_folder):
os.makedirs(data_folder, exist_ok=True)
target_folder_name=model_size+"_"+dataset+("_fs" if from_scratch else "_ft")
target_folder=os.path.join(data_folder,target_folder_name)
self.logger.info(f"Downloading model files to {target_folder_name}...")
self.download_file(self.CONFIG_FILE[model_size],target_folder)
self.download_file(self.VOCAB_FILE[model_size],target_folder)
self.download_file(self.MERGE_FILE[model_size],target_folder)
model_train_type=model_size+('_fs' if from_scratch else '_ft')
if model_train_type not in self.LSP_MODEL_URL[dataset]:
k=','.join(list(self.LSP_MODEL_URL[dataset].keys()))
raise ValueError(f"'{model_train_type}' not exist for dataset '{dataset}', please choose from [{k}]")
self.download_file(self.LSP_MODEL_URL[dataset][model_train_type],target_folder)
return target_folder_name
#######################################################################################
def download_reverse_model_folder(self,config):
data_folder=config.get('model','data_folder')
model_size='medium'
if not os.path.exists(data_folder):
os.makedirs(data_folder,exist_ok=True)
target_folder_name=model_size+'_reverse'
target_folder=os.path.join(data_folder,target_folder_name)
self.logger.info(f"Downloading model files to {target_folder_name}...")
self.download_file(self.CONFIG_FILE[model_size],target_folder)
self.download_file(self.VOCAB_FILE[model_size],target_folder)
self.download_file(self.MERGE_FILE[model_size],target_folder)
self.download_file(self.REVERSE_MODEL_URL,target_folder)
return target_folder_name
#######################################################################################
def load_model(self,target_folder_name,config):
data_folder=config.get('model','data_folder')
model_size=config.get('model','model_size')
no_cuda=config.getboolean('model', 'no_cuda')
self.logger.info(f"Loading model from {target_folder_name}...")
device=torch.device("cuda" if torch.cuda.is_available() and not no_cuda else "cpu")
target_folder=os.path.join(data_folder,target_folder_name)
tokenizer=GPT2Tokenizer(os.path.join(target_folder, 'vocab.json'), os.path.join(target_folder,'merges.txt'))
config=GPT2Config.from_json_file(os.path.join(target_folder,'config.json'))
2022-02-02 04:44:42 +00:00
state_dict_path=glob(os.path.join(target_folder,f'*.pkl'))[0]
state_dict=torch.load(state_dict_path,map_location=device)
if model_size=='small':
for key in list(state_dict.keys()):
state_dict[key.replace('module.','')]=state_dict.pop(key)
state_dict['lm_head.weight']=state_dict['lm_head.decoder.weight']
state_dict.pop("lm_head.decoder.weight",None)
model=GPT2LMHeadModel(config)
model.load_state_dict(state_dict)
model.to(device)
model.eval()
return model,tokenizer
#######################################################################################
def set_seed(self,seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
#######################################################################################
def top_k_top_p_filtering(self,logits,top_k=0,top_p=0.0,filter_value=-float('Inf')):
top_k=min(top_k,logits.size(-1))
if top_k>0:
indices_to_remove=logits<torch.topk(logits,top_k)[0][...,-1,None]
logits[indices_to_remove]=filter_value
if top_p>0.0:
sorted_logits,sorted_indices=torch.sort(logits,descending=True)
cumulative_probs=torch.cumsum(F.softmax(sorted_logits,dim=-1),dim=-1)
sorted_indices_to_remove=cumulative_probs>top_p
sorted_indices_to_remove[...,1:]=sorted_indices_to_remove[...,:-1].clone()
sorted_indices_to_remove[...,0]=0
indices_to_remove=sorted_indices_to_remove.scatter(dim=1,index=sorted_indices,src=sorted_indices_to_remove)
logits[indices_to_remove]=filter_value
return logits
#######################################################################################
def sample_sequence(self,model,tokenizer,context_ids,config):
no_cuda=config.getboolean('model','no_cuda')
num_samples=config.getint('decoder','num_samples')
max_length=config.getint('decoder','max_length')
temperature=config.getfloat('decoder','temperature')
top_k=config.getint('decoder','top_k')
top_p=config.getfloat('decoder','top_p')
device=torch.device("cuda" if torch.cuda.is_available() and not no_cuda else "cpu")
context_tensor=torch.tensor(context_ids,dtype=torch.long,device=device)
context_tensor=context_tensor.unsqueeze(0).repeat(num_samples,1)
generated=context_tensor
with torch.no_grad():
while True:
inputs={'input_ids':generated}
outputs=model(**inputs)
next_token_logits=outputs[0][:,-1,:]/(temperature if temperature>0 else 1.)
filtered_logits=self.top_k_top_p_filtering(next_token_logits,top_k=top_k,top_p=top_p)
if temperature==0.0:
next_token=torch.argmax(filtered_logits,dim=-1).unsqueeze(-1)
else:
next_token=torch.multinomial(F.softmax(filtered_logits,dim=-1),num_samples=1)
generated=torch.cat((generated,next_token),dim=1)
if (generated[:,len(context_ids):]==tokenizer.eos_token_id).any(dim=1).all():
break
if generated.shape[1]-len(context_ids)>=max_length:
break
return generated
#######################################################################################
def select_using_mmi(self,mmi_model,mmi_tokenizer,candidates,config):
no_cuda=config.getboolean('model','no_cuda')
device=torch.device("cuda" if torch.cuda.is_available() and not no_cuda else "cpu")
scores=[]
for i,candidate in enumerate(candidates):
context=[]
for response in reversed(candidate):
context.extend(response)
context.append(mmi_tokenizer.eos_token_id)
context_ids=mmi_tokenizer.encode(context)
context_tensor=torch.tensor(context_ids,dtype=torch.long,device=device)
loss,_,_=mmi_model(input_ids=context_tensor,labels=context_tensor)
scores.append(-loss.float())
scores=torch.stack(scores, dim=0)
winner=torch.multinomial(F.softmax(scores,dim=0),num_samples=1).item()
return winner
#######################################################################################
def generate_response(self,model,tokenizer,context,config,mmi_model=None,mmi_tokenizer=None):
use_mmi=config.getboolean('model','use_mmi')
num_samples=config.getint('decoder','num_samples')
max_length=config.getint('decoder','max_length')
seed=config.get('decoder','seed')
seed=int(seed) if seed is not None else None
if seed is not None:
self.set_seed(seed)
context_ids=tokenizer.encode(context)
samples=self.sample_sequence(model, tokenizer, context_ids, config)
samples=samples[:, len(context_ids):].tolist()
texts=[]
for sample in samples:
text=tokenizer.decode(sample,clean_up_tokenization_spaces=True)
text=text[: text.find(tokenizer.eos_token)]
texts.append(text)
if use_mmi:
assert(num_samples > 1)
candidates=[context+text for text in texts]
best_i=self.select_using_mmi(mmi_model,mmi_tokenizer,candidates,config)
return [texts[best_i]]
return texts
#######################################################################################
@irc3.event(irc3.rfc.PRIVMSG)
def on_privmsg_search_for_maple(self, mask=None, target=None, data=None, **kw):
##############################################
if mask.nick == self.bot.config["nick"]:
print('returning, message data from bot not user')
return
##############################################
if self.epoch_time_boolean==True:
epoch_time_now=int(str(time()).split('.')[0])
if epoch_time_now-self.epoch_time_last>=30:
self.epoch_time_boolean=False
print('[ turned off flood protection ]')
else:
return
##############################################
data=data.strip().lower()
##############################################
self.span=0
for _ in data.split():
if any(x for x in self.OPINION.split() if _ in x):
self.span=1
break
##############################################
if mask.nick == "d":
if data.find(self.bot.config["nick"])>-1:
print('heard d say my name')
if data.find('hang with us')>-1:
msg="ok, i'll hang out for a bit"
self.bot.privmsg(target,self.bot.emo(msg))
self.mode=1
if data.find('leave us alone')>-1:
msg="ok, gotta go"
self.bot.privmsg(target,self.bot.emo(msg))
self.mode=0
##############################################
if self.mode==0:
if not data.find(self.bot.config["nick"])>-1:
if self.span==0:
return
##############################################
self.span=0
##############################################
data=data.replace(self.bot.config["nick"],'')
self.maple_io.append({'user':mask.nick,'message':data,'target':target})
if len(self.maple_io) > 5:
self.maple_io=[]
self.epoch_time_now=int(str(time()).split('.')[0])
self.epoch_time_last=self.epoch_time_now
self.epoch_time_boolean=True
msg=f"kind of busy at the moment {mask.nick}, i'll be right back"
print('[ turned on flood protection ]')
self.bot.privmsg(target,msg)
2022-02-02 04:44:42 +00:00
#######################################################################################
def run_chat(self,model,tokenizer,config,mmi_model=None,mmi_tokenizer=None):
num_samples=config.getint('decoder','num_samples')
max_turns_history=config.getint('decoder','max_turns_history')
turns=[]
signal.signal(signal.SIGINT,self.signal_handling)
config.set('decoder','seed',f'{datetime.now().microsecond}')
2022-02-03 04:30:22 +00:00
try:
self.maple_io.reverse()
maple_io=self.maple_io.pop()
self.maple_io.reverse()
MESSAGE=maple_io['message'].strip()
TARGET=maple_io['target']
except:
return self.exit_strategy
print(f'human > {MESSAGE}')
if max_turns_history==0:
turns=[]
turn={
'human_messages':[],
'maple_messages':[]
}
turns.append(turn)
turn['human_messages'].append(MESSAGE)
history=""
from_index=max(len(turns)-max_turns_history-1,0) if max_turns_history>=0 else 0
WISDOM=self.WISDOM.splitlines()
try:
WISDOM.remove('')
except:
pass
for i,_ in enumerate(WISDOM):
WISDOM[i]=_.strip()
static_history=WISDOM
for message in static_history:
history += message + tokenizer.eos_token
for turn in turns[from_index:]:
for message in turn['human_messages']:
history+=message+tokenizer.eos_token
for message in turn['maple_messages']:
history+=message+tokenizer.eos_token
maple_messages=self.generate_response(
model,
tokenizer,
history,
config,
mmi_model=mmi_model,
mmi_tokenizer=mmi_tokenizer
)
if num_samples==1:
maple_message=maple_messages[0]
else:
maple_message=random.choice(maple_messages)
turn['maple_messages'].append(maple_message)
print(f'maple > {maple_message}')
self.bot.privmsg(TARGET,maple_message)
return self.exit_strategy
2022-02-02 04:44:42 +00:00
#######################################################################################
def main(self):
loop=self.bot.loop
loop.call_later(self.delay,self.main)
tasks=[]
2022-02-03 04:30:22 +00:00
task=loop.run_in_executor(None,\
self.run_chat(self.model,self.tokenizer,self.config,mmi_model=self.mmi_model,mmi_tokenizer=self.mmi_tokenizer))
2022-02-02 04:44:42 +00:00
tasks.append(task)
#######################################################################################
def exit_strategy(self):
pass
###########################################################################################