更加简单的使用方深度学习

RNN训练

Posted on By duimu

流程:

# -*- coding: utf-8 -*-

"""
@from:www.linerect.com

@author:  duimu

@contact: duimu@qq.com

@Created on: 2019/12/2 10:08
"""

import  os
import pickle
from six.moves import urllib
import tflearn
from tflearn.data_utils import *

#path="RNN_text/shakespeare_input.txt"
#char_idx_file='RNN_text/char_idx.pickle'
#maxlen=25

path="RNN_text/poems.txt"
char_idx_file='RNN_text/poems_char_idx.pickle'
maxlen=5

#if not os.path.isfile(path):
#    urllib.request.urlretrieve("https://raw.githubusercontent.com/tflearn/tflearn.github.io/master/resources/shakespeare_input.txt", path)

#将输入文本转换为向量,并通过 string_to_semi_redundant_sequences()
# 返回解析的序列和目标以及关联的字典(函数输出一个元组:包括输入、目标和字典):

char_idx=None
if os.path.isfile(char_idx_file):
    print('loading previos char_idx')
    char_idx=pickle.load(open(char_idx_file,'rb'))

X,Y,char_idx=textfile_to_semi_redundant_sequences(path,seq_maxlen=maxlen,redun_step=25,
                                                  pre_defined_char_idx=char_idx)
pickle.dump(char_idx,open(char_idx_file,'wb'))

#定义由三个 LSTM 组成的 RNN,每个 LTSM 有 512 个节点,
# 并返回完整序列而不是仅返回最后一个序列。请注意,使用概率为 50% 的 drop-out 模块来连接 LSTM 模块。
# 最后一层是全连接层,softmax 长度等于字典尺寸。损失函数采用 categorical_crossentropy,优化器采用 Adam:
g=tflearn.input_data([None,maxlen,len(char_idx)])
g=tflearn.lstm(g,512,return_seq=True)
g=tflearn.dropout(g,0.5)
g=tflearn.lstm(g,512,return_seq=True)
g=tflearn.dropout(g,0.5)
g=tflearn.lstm(g,512)
g=tflearn.dropout(g,0.5)
g=tflearn.fully_connected(g,len(char_idx),activation='softmax')
g=tflearn.regression(g,optimizer='adam',loss='categorical_crossentropy',learning_rate=0.001)

#可以用库函数 flearn.models.generator.SequenceGenerator(network,dictionary=char_idx,seq_maxlen=maxle,clip_gradients=5.0,checkpoint_path='model_shakespeare') 生成序列:
m=tflearn.SequenceGenerator(g,dictionary=char_idx,seq_maxlen=maxlen,
                            clip_gradients=5.0,
                            checkpoint_path='model_shakpeare')

#温度是表征概念变化的量。如果温度高,比如大于1,就代表希望输出结果更稳定。
# 稳定的结果就是可能每次生成的句子都是一样的。如果等于1,
# 那么对结果没有影响。如果小于1,那么就会让每次生成的结果变化比较大
for i in range(50):
    seed=random_sequence_from_textfile(path,maxlen)
    m.fit(X,Y,validation_set=0.1,batch_size=128,n_epoch=1,run_id='shakespeare')
    print("--TESTING...")
    print("--Test with temprature of 1.0 --")
    print(m.generate(600,temperature=1.0,seq_seed=seed))
    print("--Test with temprature of 0.5 --")
    print(m.generate(600, temperature=0.5, seq_seed=seed))