机器学习RRN训练聊天机器人

前言

上篇写过一个机器学习写唐诗的实验,这次我们搞个稍微复杂些的,实现一个聊天机器人,也是基于腾讯云实验室的一篇教程,有些部分做了改动,大部分时间都用在了环境的适配上面。开始本地是在Mac环境,单独依靠CPU训练,比较慢。后来找了个配置比较好的机器, 6核心12线程,效果好一些。总结来说,机器学习相关有两个重点,一个是基础的训练资源,包括对原始数据的清洗处理和规范化,训练中其实模型是没有很大区别的。其次,是好的机器配置,资源有限,没有上GPU。这次实验,本地训练大概半天到4000步的时候,还只是个复读机,换了高配机器1天左右就可以到30万左右,两天到70万,基本达到损失率稳定(30万就可以)。
以下是本地机器的配置,奈何效果不行

MacBook Pro (13-inch, 2017, Four Thunderbolt 3 Ports)
10.13.6 (17G65)16 GB 2133 MHz LPDDR3
3.1 GHz Intel Core i5

注意事项
强烈建议使用virtualenv配置python,简单而且不会对本地运行环境造成影响。
同时需要安装好TensorFlow环境

过程步骤

实验内容

  1. 首先进行数据的清洗,处理。提取ask和answer数据,并得到字典,以及做向量化处理。训练数据可以使用本次实验链接里的,也可以使用网上的小黄鸡等等语料。注意这里的字典之前查的资料是满足3000左右的常用汉字就可以,是在语料中找到常用字。

  2. 模型学习部分。
    这里引用了seq2seq的部分,单独有一些修改。之前下载实验中提供的训练了30万次左右的模型直接进行对话,但是本地一直提示错误。最终选择了自己训练,保存了完整的checkpoint文件,可以启动程序。如图最终训练在71万次左右,其实30万左右损失率基本就已经不变了,如果能提供更优化的语料应该效果会更好。后续有链接提供所有资料,可以直接下载。


    训练完毕的模型
  3. 模拟对话,这部分是最终的成果,启动本地依赖,加载训练模型之后就可以对话了,效果看图,可以看到有些句子还是可以对上的,一问一答,有些幼稚。


    模拟对话

代码部分

  1. 数据整理和向量化 generate.py
# -*- coding:utf-8 -*-
from io import open
import random
import tensorflow as tf

# version tf 1.12 2018-12-08 22:22:08
PAD = "PAD"
GO = "GO"
EOS = "EOS"
UNK = "UNK"
START_VOCAB = [PAD, GO, EOS, UNK]

PAD_ID = 0  # 填充
GO_ID = 1  # 开始标志
EOS_ID = 2  # 结束标志
UNK_ID = 3  # 未知字符
_buckets = [(10, 15), (20, 25), (40, 50), (80, 100)]
units_num = 256
num_layers = 3
max_gradient_norm = 5.0
batch_size = 50
learning_rate = 0.5
learning_rate_decay_factor = 0.97

train_encode_file = "data/train_encode"
train_decode_file = "data/train_decode"
test_encode_file = "data/test_encode"
test_decode_file = "data/test_decode"
vocab_encode_file = "data/vocab_encode"
vocab_decode_file = "data/vocab_decode"
train_encode_vec_file = "data/train_encode_vec"
train_decode_vec_file = "data/train_decode_vec"
test_encode_vec_file = "data/test_encode_vec"
test_decode_vec_file = "data/test_decode_vec"


def is_chinese(sentence):
    flag = True
    if len(sentence) < 2:
        flag = False
        return flag
    for uchar in sentence:
        if (uchar == ',' or uchar == '。' or
                uchar == '~' or uchar == '?' or
                uchar == '!'):
            flag = True
        elif '一' <= uchar <= '鿿':
            flag = True
        else:
            flag = False
            break
    return flag


def get_chatbot():
    f = open("data/chat.conv", "r", encoding="utf-8")
    train_encode = open(train_encode_file, "w", encoding="utf-8")
    train_decode = open(train_decode_file, "w", encoding="utf-8")
    test_encode = open(test_encode_file, "w", encoding="utf-8")
    test_decode = open(test_decode_file, "w", encoding="utf-8")
    vocab_encode = open(vocab_encode_file, "w", encoding="utf-8")
    vocab_decode = open(vocab_decode_file, "w", encoding="utf-8")
    encode = list()
    decode = list()

    chat = list()
    print("start load source data...")
    step = 0
    for line in f.readlines():
        line = line.strip('\n').strip()
        if not line:
            continue
        if line[0] == "E":
            if step % 1000 == 0:
                print("step:%d" % step)
            step += 1
            if (len(chat) == 2 and is_chinese(chat[0]) and is_chinese(chat[1]) and
                    not chat[0] in encode and not chat[1] in decode):
                encode.append(chat[0])
                decode.append(chat[1])
            chat = list()
        elif line[0] == "M":
            L = line.split(' ')
            if len(L) > 1:
                chat.append(L[1])
    encode_size = len(encode)
    if encode_size != len(decode):
        raise ValueError("encode size not equal to decode size")
    test_index = random.sample([i for i in range(encode_size)], int(encode_size * 0.2))
    print("divide source into two...")
    step = 0
    for i in range(encode_size):
        if step % 1000 == 0:
            print("%d" % step)
        step += 1
        if i in test_index:
            test_encode.write(encode[i] + "\n")
            test_decode.write(decode[i] + "\n")
        else:
            train_encode.write(encode[i] + "\n")
            train_decode.write(decode[i] + "\n")

    vocab_encode_set = set(''.join(encode))
    vocab_decode_set = set(''.join(decode))
    print("get vocab_encode...")
    step = 0
    for word in vocab_encode_set:
        if step % 1000 == 0:
            print("%d" % step)
        step += 1
        vocab_encode.write(word + "\n")
    print("get vocab_decode...")
    step = 0
    for word in vocab_decode_set:
        print("%d" % step)
        step += 1
        vocab_decode.write(word + "\n")


def gen_chatbot_vectors(input_file, vocab_file, output_file):
    vocab_f = open(vocab_file, "r", encoding="utf-8")
    output_f = open(output_file, "w")
    input_f = open(input_file, "r", encoding="utf-8")
    words = list()
    for word in vocab_f.readlines():
        word = word.strip('\n').strip()
        words.append(word)
    word_to_id = {word: i for i, word in enumerate(words)}
    to_id = lambda word: word_to_id.get(word, UNK_ID)
    print("get %s vectors" % input_file)
    step = 0
    for line in input_f.readlines():
        if step % 1000 == 0:
            print("step:%d" % step)
        step += 1
        line = line.strip('\n').strip()
        vec = map(to_id, line)
        output_f.write(' '.join([str(n) for n in vec]) + "\n")


def get_vectors():
    gen_chatbot_vectors(train_encode_file, vocab_encode_file, train_encode_vec_file)
    gen_chatbot_vectors(train_decode_file, vocab_decode_file, train_decode_vec_file)
    gen_chatbot_vectors(test_encode_file, vocab_encode_file, test_encode_vec_file)
    gen_chatbot_vectors(test_decode_file, vocab_decode_file, test_decode_vec_file)


def get_vocabs(vocab_file):
    words = list()
    with open(vocab_file, "r", encoding="utf-8") as vocab_f:
        for word in vocab_f:
            words.append(word.strip('\n').strip())
    id_to_word = {i: word for i, word in enumerate(words)}
    word_to_id = {v: k for k, v in id_to_word.items()}
    vocab_size = len(id_to_word)
    return id_to_word, word_to_id, vocab_size


def read_data(source_path, target_path, max_size=None):
    data_set = [[] for _ in _buckets]
    with tf.gfile.GFile(source_path, mode="r") as source_file:
        with tf.gfile.GFile(target_path, mode="r") as target_file:
            source, target = source_file.readline(), target_file.readline()
            counter = 0
            while source and target and (not max_size or counter < max_size):
                counter += 1
                source_ids = [int(x) for x in source.split()]
                target_ids = [int(x) for x in target.split()]
                target_ids.append(EOS_ID)
                for bucket_id, (source_size, target_size) in enumerate(_buckets):
                    if len(source_ids) < source_size and len(target_ids) < target_size:
                        data_set[bucket_id].append([source_ids, target_ids])
                        break
                source, target = source_file.readline(), target_file.readline()
    return data_set


# run
#获取 ask、answer 数据并生成字典
# get_chatbot()
#训练数据转化为数字表示
# get_vectors()
  1. 学习模型

简书限制太长无法发布,只能在最后的链接获取了
seq2seq.py
seq2seq_model.py

  1. 训练模块

可以改小配置中的step部分,简单验证下效果。这里有些改动,加了间隔一定步骤之后,保存checkpoint到本地的功能,防止中间如果有异常,比如断电或者不小心关闭程序或者其他原因造成程序崩溃,导致前功尽弃。

train_chat.py

# -*- coding:utf-8 -*-
import generate as generate_chat
import seq2seq_model as seq2seq_model
import tensorflow as tf
import numpy as np
import logging
import logging.handlers

if __name__ == '__main__':

    _, _, source_vocab_size = generate_chat.get_vocabs(generate_chat.vocab_encode_file)
    _, _, target_vocab_size = generate_chat.get_vocabs(generate_chat.vocab_decode_file)
    train_set = generate_chat.read_data(generate_chat.train_encode_vec_file, generate_chat.train_decode_vec_file)
    test_set = generate_chat.read_data(generate_chat.test_encode_vec_file, generate_chat.test_decode_vec_file)
    train_bucket_sizes = [len(train_set[i]) for i in range(len(generate_chat._buckets))]
    train_total_size = float(sum(train_bucket_sizes))
    train_buckets_scale = [sum(train_bucket_sizes[:i + 1]) / train_total_size for i in range(len(train_bucket_sizes))]
    cpu_config = tf.ConfigProto(intra_op_parallelism_threads=6,inter_op_parallelism_threads=6,device_count={'CPU':6})
    with tf.Session(config=cpu_config) as sess:
        model = seq2seq_model.Seq2SeqModel(source_vocab_size,
                                           target_vocab_size,
                                           generate_chat._buckets,
                                           generate_chat.units_num,
                                           generate_chat.num_layers,
                                           generate_chat.max_gradient_norm,
                                           generate_chat.batch_size,
                                           generate_chat.learning_rate,
                                           generate_chat.learning_rate_decay_factor,
                                           use_lstm=True)

        ckpt = tf.train.get_checkpoint_state('./mytrain')

        if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
            print("Reading model parameters from %s" % ckpt.model_checkpoint_path)
            model.saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            print("Created model with fresh parameters.")
            sess.run(tf.global_variables_initializer())
        loss = 0.0
        step = 0
        previous_losses = []
        while True:
            random_number_01 = np.random.random_sample()
            bucket_id = min([i for i in range(len(train_buckets_scale)) if train_buckets_scale[i] > random_number_01])
            encoder_inputs, decoder_inputs, target_weights = model.get_batch(train_set, bucket_id)
            _, step_loss, _ = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, False)
            print("step:%d,loss:%f" % (step, step_loss))
            loss += step_loss / 2000
            step += 1
            if step % 1000 == 0:
                print("step:%d,per_loss:%f" % (step, loss))
                if len(previous_losses) > 2 and loss > max(previous_losses[-3:]):
                    sess.run(model.learning_rate_decay_op)
                previous_losses.append(loss)
                model.saver.save(sess, "mytrain/chatbot.ckpt", global_step=model.global_step)
                loss = 0.0
            if step % 5000 == 0:
                for bucket_id in range(len(generate_chat._buckets)):
                    if len(test_set[bucket_id]) == 0:
                        continue
                        encoder_inputs, decoder_inputs, target_weights = model.get_batch(test_set, bucket_id)
                        _, eval_loss, _ = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id,
                                                     True)
                        print("bucket_id:%d,eval_loss:%f" % (bucket_id, eval_loss))

  1. 对话模块
    chat.py
# -*- coding:utf-8 -*-
import generate as generate_chat
import seq2seq_model as seq2seq_model
import tensorflow as tf
import numpy as np
import sys

if __name__ == '__main__':
    source_id_to_word, source_word_to_id, source_vocab_size = generate_chat.get_vocabs(generate_chat.vocab_encode_file)
    target_id_to_word, target_word_to_id, target_vocab_size = generate_chat.get_vocabs(generate_chat.vocab_decode_file)
    to_id = lambda word: source_word_to_id.get(word, generate_chat.UNK_ID)
    cpu_config = tf.ConfigProto(intra_op_parallelism_threads=6,inter_op_parallelism_threads=6,device_count={'CPU':6})
    with tf.Session(config=cpu_config) as sess:
        model = seq2seq_model.Seq2SeqModel(source_vocab_size,
                                           target_vocab_size,
                                           generate_chat._buckets,
                                           generate_chat.units_num,
                                           generate_chat.num_layers,
                                           generate_chat.max_gradient_norm,
                                           1,
                                           generate_chat.learning_rate,
                                           generate_chat.learning_rate_decay_factor,
                                           forward_only=True,
                                           use_lstm=True)
        #model.saver.restore(sess, "model/chatbot.ckpt-317000")
        model.saver.restore(sess, "mytrain/chatbot.ckpt-717000")
        while True:
            sys.stdout.write("ask > ")
            sys.stdout.flush()
            sentence = sys.stdin.readline().strip('\n')
            flag = generate_chat.is_chinese(sentence)
            if not sentence or not flag:
                print("请输入纯中文")
                continue
            sentence_vec = list(map(to_id, sentence))
            bucket_id = len(generate_chat._buckets) - 1
            if len(sentence_vec) > generate_chat._buckets[bucket_id][0]:
                print("sentence too long max:%d" % generate_chat._buckets[bucket_id][0])
                exit(0)
            for i, bucket in enumerate(generate_chat._buckets):
                if bucket[0] >= len(sentence_vec):
                    bucket_id = i
                    break
            encoder_inputs, decoder_inputs, target_weights = model.get_batch({bucket_id: [(sentence_vec, [])]},
                                                                             bucket_id)
            _, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, True)
            outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
            if generate_chat.EOS_ID in outputs:
                outputs = outputs[:outputs.index(generate_chat.EOS_ID)]
            answer = "".join([tf.compat.as_str(target_id_to_word[output]) for output in outputs])
            print("answer > " + answer)

注意
这里在train_chat.py 和 chat.py中,tf.session有个配置改动,限制了使用的CPU数,在Ubuntu下如果没有限制,会造成TF占用所有的CPU资源,导致系统卡死,具体数值根据CPU核心数设置。
代码如下:

cpu_config = tf.ConfigProto(intra_op_parallelism_threads=6,inter_op_parallelism_threads=6,device_count={'CPU':6})
    with tf.Session(config=cpu_config) as sess:

结语

感谢阅读,最后放上实验的实际地址和我自己训练的所有资源,本地实验在mac tf 1.12.0 和 python3.6.7,以及Ubuntu tf.1.12.0 和 python3.5环境下都正常,再次建议在virtualenv环境下。
实验链接(时间过久可能失效):https://cloud.tencent.com/developer/labs/lab/10406
本地实验资源:https://iss.igosh.com/share/201903/tencent-me.tar.gz

推荐阅读更多精彩内容