122 lines
5.6 KiB
Python
122 lines
5.6 KiB
Python
# Copyright (c) polakowo
|
|
# Licensed under the MIT license.
|
|
|
|
import random
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
def set_seed(seed):
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
|
|
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
|
|
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
|
Args:
|
|
logits: logits distribution shape (batch size x vocabulary size)
|
|
top_k > 0: keep only top k tokens with highest probability (top-k filtering).
|
|
top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
|
|
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
|
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
|
"""
|
|
top_k = min(top_k, logits.size(-1)) # Safety check
|
|
if top_k > 0:
|
|
# Remove all tokens with a probability less than the last token of the top-k
|
|
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
|
logits[indices_to_remove] = filter_value
|
|
if top_p > 0.0:
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
|
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
|
# Remove tokens with cumulative probability above the threshold
|
|
sorted_indices_to_remove = cumulative_probs > top_p
|
|
# Shift the indices to the right to keep also the first token above the threshold
|
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
|
sorted_indices_to_remove[..., 0] = 0
|
|
# scatter sorted tensors to original indexing
|
|
indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
|
|
logits[indices_to_remove] = filter_value
|
|
return logits
|
|
|
|
|
|
def sample_sequence(model, tokenizer, context_ids, config):
|
|
# Parse parameters
|
|
no_cuda = config.getboolean('model', 'no_cuda')
|
|
num_samples = config.getint('decoder', 'num_samples')
|
|
max_length = config.getint('decoder', 'max_length')
|
|
temperature = config.getfloat('decoder', 'temperature')
|
|
top_k = config.getint('decoder', 'top_k')
|
|
top_p = config.getfloat('decoder', 'top_p')
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() and not no_cuda else "cpu")
|
|
context_tensor = torch.tensor(context_ids, dtype=torch.long, device=device)
|
|
context_tensor = context_tensor.unsqueeze(0).repeat(num_samples, 1)
|
|
generated = context_tensor
|
|
with torch.no_grad():
|
|
while True:
|
|
inputs = {'input_ids': generated}
|
|
outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet/CTRL (cached hidden-states)
|
|
next_token_logits = outputs[0][:, -1, :] / (temperature if temperature > 0 else 1.)
|
|
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
|
|
if temperature == 0.0: # greedy sampling:
|
|
next_token = torch.argmax(filtered_logits, dim=-1).unsqueeze(-1)
|
|
else:
|
|
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
|
|
generated = torch.cat((generated, next_token), dim=1)
|
|
if (generated[:, len(context_ids):] == tokenizer.eos_token_id).any(dim=1).all():
|
|
# EOS token id found in each sample
|
|
break
|
|
if generated.shape[1] - len(context_ids) >= max_length:
|
|
# Maximum length reached
|
|
break
|
|
return generated
|
|
|
|
def select_using_mmi(mmi_model, mmi_tokenizer, candidates, config):
|
|
# Parse parameters
|
|
no_cuda = config.getboolean('model', 'no_cuda')
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() and not no_cuda else "cpu")
|
|
scores = []
|
|
for i, candidate in enumerate(candidates):
|
|
context = []
|
|
for response in reversed(candidate):
|
|
context.extend(response)
|
|
context.append(mmi_tokenizer.eos_token_id)
|
|
context_ids = mmi_tokenizer.encode(context)
|
|
context_tensor = torch.tensor(context_ids, dtype=torch.long, device=device)
|
|
loss, _, _ = mmi_model(input_ids=context_tensor, labels=context_tensor)
|
|
scores.append(-loss.float())
|
|
|
|
scores = torch.stack(scores, dim=0)
|
|
# The smaller the loss, the higher must be the probability of selecting it
|
|
winner = torch.multinomial(F.softmax(scores, dim=0), num_samples=1).item()
|
|
return winner
|
|
|
|
def generate_response(model, tokenizer, context, config, mmi_model=None, mmi_tokenizer=None):
|
|
# Parse parameters
|
|
use_mmi = config.getboolean('model', 'use_mmi')
|
|
num_samples = config.getint('decoder', 'num_samples')
|
|
max_length = config.getint('decoder', 'max_length')
|
|
seed = config.get('decoder', 'seed')
|
|
seed = int(seed) if seed is not None else None
|
|
|
|
# Make answers reproducible only if wanted
|
|
if seed is not None:
|
|
set_seed(seed)
|
|
|
|
# Generate response
|
|
context_ids = tokenizer.encode(context)
|
|
samples = sample_sequence(model, tokenizer, context_ids, config)
|
|
samples = samples[:, len(context_ids):].tolist()
|
|
texts = []
|
|
for sample in samples:
|
|
text = tokenizer.decode(sample, clean_up_tokenization_spaces=True)
|
|
text = text[: text.find(tokenizer.eos_token)]
|
|
texts.append(text)
|
|
|
|
if use_mmi:
|
|
assert(num_samples > 1, "MMI requires num_samples > 1")
|
|
candidates = [context + text for text in texts]
|
|
best_i = select_using_mmi(mmi_model, mmi_tokenizer, candidates, config)
|
|
return [texts[best_i]]
|
|
return texts |