# -*- coding: utf-8 -*-
"""
@from:www.linerect.com
@author: duimu
@contact: duimu@qq.com
@Created on: 2019/12/2 10:08
"""
import os
import pickle
from six.moves import urllib
import tflearn
from tflearn.data_utils import *
#path="RNN_text/shakespeare_input.txt"
#char_idx_file='RNN_text/char_idx.pickle'
#maxlen=25
path="RNN_text/poems.txt"
char_idx_file='RNN_text/poems_char_idx.pickle'
maxlen=5
#if not os.path.isfile(path):
# urllib.request.urlretrieve("https://raw.githubusercontent.com/tflearn/tflearn.github.io/master/resources/shakespeare_input.txt", path)
#将输入文本转换为向量,并通过 string_to_semi_redundant_sequences()
# 返回解析的序列和目标以及关联的字典(函数输出一个元组:包括输入、目标和字典):
char_idx=None
if os.path.isfile(char_idx_file):
print('loading previos char_idx')
char_idx=pickle.load(open(char_idx_file,'rb'))
X,Y,char_idx=textfile_to_semi_redundant_sequences(path,seq_maxlen=maxlen,redun_step=25,
pre_defined_char_idx=char_idx)
pickle.dump(char_idx,open(char_idx_file,'wb'))
#定义由三个 LSTM 组成的 RNN,每个 LTSM 有 512 个节点,
# 并返回完整序列而不是仅返回最后一个序列。请注意,使用概率为 50% 的 drop-out 模块来连接 LSTM 模块。
# 最后一层是全连接层,softmax 长度等于字典尺寸。损失函数采用 categorical_crossentropy,优化器采用 Adam:
g=tflearn.input_data([None,maxlen,len(char_idx)])
g=tflearn.lstm(g,512,return_seq=True)
g=tflearn.dropout(g,0.5)
g=tflearn.lstm(g,512,return_seq=True)
g=tflearn.dropout(g,0.5)
g=tflearn.lstm(g,512)
g=tflearn.dropout(g,0.5)
g=tflearn.fully_connected(g,len(char_idx),activation='softmax')
g=tflearn.regression(g,optimizer='adam',loss='categorical_crossentropy',learning_rate=0.001)
#可以用库函数 flearn.models.generator.SequenceGenerator(network,dictionary=char_idx,seq_maxlen=maxle,clip_gradients=5.0,checkpoint_path='model_shakespeare') 生成序列:
m=tflearn.SequenceGenerator(g,dictionary=char_idx,seq_maxlen=maxlen,
clip_gradients=5.0,
checkpoint_path='model_shakpeare')
#温度是表征概念变化的量。如果温度高,比如大于1,就代表希望输出结果更稳定。
# 稳定的结果就是可能每次生成的句子都是一样的。如果等于1,
# 那么对结果没有影响。如果小于1,那么就会让每次生成的结果变化比较大
for i in range(50):
seed=random_sequence_from_textfile(path,maxlen)
m.fit(X,Y,validation_set=0.1,batch_size=128,n_epoch=1,run_id='shakespeare')
print("--TESTING...")
print("--Test with temprature of 1.0 --")
print(m.generate(600,temperature=1.0,seq_seed=seed))
print("--Test with temprature of 0.5 --")
print(m.generate(600, temperature=0.5, seq_seed=seed))