Maple_old/gpt2bot/decoder.py
2021-10-05 09:09:45 -05:00

122 lines
5.6 KiB
Python

# Copyright (c) polakowo
# Licensed under the MIT license.
import random
import numpy as np
import torch
import torch.nn.functional as F
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (batch size x vocabulary size)
top_k > 0: keep only top k tokens with highest probability (top-k filtering).
top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
"""
top_k = min(top_k, logits.size(-1)) # Safety check
if top_k > 0:
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p > 0.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
logits[indices_to_remove] = filter_value
return logits
def sample_sequence(model, tokenizer, context_ids, config):
# Parse parameters
no_cuda = config.getboolean('model', 'no_cuda')
num_samples = config.getint('decoder', 'num_samples')
max_length = config.getint('decoder', 'max_length')
temperature = config.getfloat('decoder', 'temperature')
top_k = config.getint('decoder', 'top_k')
top_p = config.getfloat('decoder', 'top_p')
device = torch.device("cuda" if torch.cuda.is_available() and not no_cuda else "cpu")
context_tensor = torch.tensor(context_ids, dtype=torch.long, device=device)
context_tensor = context_tensor.unsqueeze(0).repeat(num_samples, 1)
generated = context_tensor
with torch.no_grad():
while True:
inputs = {'input_ids': generated}
outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet/CTRL (cached hidden-states)
next_token_logits = outputs[0][:, -1, :] / (temperature if temperature > 0 else 1.)
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
if temperature == 0.0: # greedy sampling:
next_token = torch.argmax(filtered_logits, dim=-1).unsqueeze(-1)
else:
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
generated = torch.cat((generated, next_token), dim=1)
if (generated[:, len(context_ids):] == tokenizer.eos_token_id).any(dim=1).all():
# EOS token id found in each sample
break
if generated.shape[1] - len(context_ids) >= max_length:
# Maximum length reached
break
return generated
def select_using_mmi(mmi_model, mmi_tokenizer, candidates, config):
# Parse parameters
no_cuda = config.getboolean('model', 'no_cuda')
device = torch.device("cuda" if torch.cuda.is_available() and not no_cuda else "cpu")
scores = []
for i, candidate in enumerate(candidates):
context = []
for response in reversed(candidate):
context.extend(response)
context.append(mmi_tokenizer.eos_token_id)
context_ids = mmi_tokenizer.encode(context)
context_tensor = torch.tensor(context_ids, dtype=torch.long, device=device)
loss, _, _ = mmi_model(input_ids=context_tensor, labels=context_tensor)
scores.append(-loss.float())
scores = torch.stack(scores, dim=0)
# The smaller the loss, the higher must be the probability of selecting it
winner = torch.multinomial(F.softmax(scores, dim=0), num_samples=1).item()
return winner
def generate_response(model, tokenizer, context, config, mmi_model=None, mmi_tokenizer=None):
# Parse parameters
use_mmi = config.getboolean('model', 'use_mmi')
num_samples = config.getint('decoder', 'num_samples')
max_length = config.getint('decoder', 'max_length')
seed = config.get('decoder', 'seed')
seed = int(seed) if seed is not None else None
# Make answers reproducible only if wanted
if seed is not None:
set_seed(seed)
# Generate response
context_ids = tokenizer.encode(context)
samples = sample_sequence(model, tokenizer, context_ids, config)
samples = samples[:, len(context_ids):].tolist()
texts = []
for sample in samples:
text = tokenizer.decode(sample, clean_up_tokenization_spaces=True)
text = text[: text.find(tokenizer.eos_token)]
texts.append(text)
if use_mmi:
assert(num_samples > 1, "MMI requires num_samples > 1")
candidates = [context + text for text in texts]
best_i = select_using_mmi(mmi_model, mmi_tokenizer, candidates, config)
return [texts[best_i]]
return texts