81 lines
3.1 KiB
Python
81 lines
3.1 KiB
Python
import os
|
|
import sys
|
|
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
|
|
|
from cakechat.utils.env import init_cuda_env
|
|
|
|
init_cuda_env()
|
|
|
|
from collections import defaultdict
|
|
|
|
from cakechat.config import INPUT_SEQUENCE_LENGTH, INPUT_CONTEXT_SIZE, OUTPUT_SEQUENCE_LENGTH, TEST_CORPUS_NAME, \
|
|
TEST_DATA_DIR
|
|
from cakechat.dialog_model.inference.utils import get_sequence_score
|
|
from cakechat.dialog_model.quality import compute_retrieval_metric_mean, compute_average_precision, compute_recall_k
|
|
from cakechat.dialog_model.model_utils import transform_lines_to_token_ids, transform_contexts_to_token_ids
|
|
from cakechat.dialog_model.factory import get_trained_model
|
|
from cakechat.utils.text_processing import get_tokens_sequence
|
|
from cakechat.utils.files_utils import load_file
|
|
from cakechat.utils.data_structures import flatten
|
|
|
|
|
|
def _read_testset():
|
|
corpus_path = os.path.join(TEST_DATA_DIR, '{}.txt'.format(TEST_CORPUS_NAME))
|
|
test_lines = load_file(corpus_path)
|
|
|
|
testset = defaultdict(set)
|
|
for i in range(0, len(test_lines) - 1, 2):
|
|
context = test_lines[i].strip()
|
|
response = test_lines[i + 1].strip()
|
|
testset[context].add(response)
|
|
|
|
return testset
|
|
|
|
|
|
def _get_context_to_weighted_responses(nn_model, testset, all_utterances):
|
|
token_to_index = nn_model.token_to_index
|
|
|
|
all_utterances_ids = transform_lines_to_token_ids(
|
|
list(map(get_tokens_sequence, all_utterances)), token_to_index, OUTPUT_SEQUENCE_LENGTH, add_start_end=True)
|
|
|
|
context_to_weighted_responses = {}
|
|
|
|
for context in testset:
|
|
context_tokenized = get_tokens_sequence(context)
|
|
repeated_context_ids = transform_contexts_to_token_ids(
|
|
[[context_tokenized]] * len(all_utterances), token_to_index, INPUT_SEQUENCE_LENGTH, INPUT_CONTEXT_SIZE)
|
|
|
|
scores = get_sequence_score(nn_model, repeated_context_ids, all_utterances_ids)
|
|
|
|
context_to_weighted_responses[context] = dict(zip(all_utterances, scores))
|
|
|
|
return context_to_weighted_responses
|
|
|
|
|
|
def _compute_metrics(model, testset):
|
|
all_utterances = list(flatten(testset.values(), set)) # Get all unique responses
|
|
context_to_weighted_responses = _get_context_to_weighted_responses(model, testset, all_utterances)
|
|
|
|
test_set_size = len(all_utterances)
|
|
metrics = {
|
|
'mean_ap':
|
|
compute_retrieval_metric_mean(
|
|
compute_average_precision, testset, context_to_weighted_responses, top_count=test_set_size),
|
|
'mean_recall@10':
|
|
compute_retrieval_metric_mean(compute_recall_k, testset, context_to_weighted_responses, top_count=10),
|
|
'mean_recall@25%':
|
|
compute_retrieval_metric_mean(
|
|
compute_recall_k, testset, context_to_weighted_responses, top_count=test_set_size // 4)
|
|
}
|
|
|
|
print('Test set size = {}'.format(test_set_size))
|
|
for metric_name, metric_value in metrics.items():
|
|
print('{} = {}'.format(metric_name, metric_value))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
nn_model = get_trained_model()
|
|
testset = _read_testset()
|
|
_compute_metrics(nn_model, testset)
|