tf.nn.dynamic_rnn的输出outputs和state含义

1850阅读 0评论2020-04-18 wwm
分类:Python/Ruby

参考转自https://blog.csdn.net/u010960155/article/details/81707498
实验如下

import tensorflow as tf
import numpy as np
 
def dynamic_rnn(rnn_type='lstm'):
    # 创建输入数据,3代表batch size,6代表输入序列的最大步长(max time),8代表每个序列的维度
    X = np.random.randn(3, 6, 4)
 
    # 第二个输入的实际长度为4。在此处也就是time_step 设定为4了,不再是6.注意看返回结果state,不再是步长内的第6个而是第4个。
    X[1, 4:] = 0
 
    #记录三个输入的实际步长
    X_lengths = [6, 4, 6]
 
    rnn_hidden_size = 5
    if rnn_type == 'lstm':
        cell = tf.contrib.rnn.BasicLSTMCell(num_units=rnn_hidden_size, state_is_tuple=True)
    else:
        cell = tf.contrib.rnn.GRUCell(num_units=rnn_hidden_size)
 
    outputs, last_states = tf.nn.dynamic_rnn(
        cell=cell,
        dtype=tf.float64,
        sequence_length=X_lengths,
        inputs=X)
 
    with tf.Session() as session:
        session.run(tf.global_variables_initializer())
        o1, s1 = session.run([outputs, last_states])
        print(np.shape(o1))
        print(o1)
        print(np.shape(s1))
        print(s1)
 
 
if __name__ == '__main__':
    dynamic_rnn(rnn_type='lstm')


lstm模式下
outputs 为 [batch_size , time_step, rnn_unit]
last_states   (c,h)  结构为[ 2,batch_size ,rnn_unit]
last_states.h  [batch_size ,rnn_unit]
last_states.c  [batch_size ,rnn_unit]

last_states.h 和outputs  最后一个相同

outputs:包含了所有时刻的输出 H ,
states :包含了 "每个time_step内"最后一个时刻的输出 H 和 C


上一篇:tensorflow中四种不同交叉熵函数tf.nn.softmax_cross_entropy_with_logits()
下一篇:pandas DataFrame 修改单个数据