Dod-o / NLP-practice-program

力求囊括主流NLP模型练手项目,不断更新中

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

关于harry_potter_lstm.py 152行 tf.concat有个疑问

lu161513 opened this issue · comments

按照示例中我们有的是三维的:
[n_seqs, n_sequencd_length, lstm_num_units]
现在要变成二维的:
[n_seqs * n_sequencd_length, lstm_num_units]
是不是应该在第0维度上进行拼接?axis=0而不是axis=1?

比如我有下面的数据:
t1 = [ [[0,1],[2,3],[3,4],[4,5]], [[5,6],[6,7],[7,8],[8,9]], [[9,10],[10,11],[11,12],[12,13]] ] t2=tf.concat(t1,axis=0) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) print(sess.run(t2)) # t2是[[ 0 1] [ 2 3] [ 3 4] [ 4 5] [ 5 6] [ 6 7] [ 7 8] [ 8 9] [ 9 10] [10 11] [11 12] [12 13]]
如果按照axis=1拼接的话,是在第二个维度拼接,会变成:
[[ 0 1 5 6 9 10] [ 2 3 6 7 10 11] [ 3 4 7 8 11 12] [ 4 5 8 9 12 13]]

是我理解错了还是这里有问题,感觉输出的最终维度(列)应该就是lstm的units

commented

按照示例中我们有的是三维的:
[n_seqs, n_sequencd_length, lstm_num_units]
现在要变成二维的:
[n_seqs * n_sequencd_length, lstm_num_units]
是不是应该在第0维度上进行拼接?axis=0而不是axis=1?

比如我有下面的数据:
t1 = [ [[0,1],[2,3],[3,4],[4,5]], [[5,6],[6,7],[7,8],[8,9]], [[9,10],[10,11],[11,12],[12,13]] ] t2=tf.concat(t1,axis=0) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) print(sess.run(t2)) # t2是[[ 0 1] [ 2 3] [ 3 4] [ 4 5] [ 5 6] [ 6 7] [ 7 8] [ 8 9] [ 9 10] [10 11] [11 12] [12 13]]
如果按照axis=1拼接的话,是在第二个维度拼接,会变成:
[[ 0 1 5 6 9 10] [ 2 3 6 7 10 11] [ 3 4 7 8 11 12] [ 4 5 8 9 12 13]]

是我理解错了还是这里有问题,感觉输出的最终维度(列)应该就是lstm的units

我做了一些测试,之前对tf.concat的理解可能存在一些误解。

t1 = [[[0, 1], [2, 3], [3, 4], [4, 5]], [[5, 6], [6, 7], [7, 8], [8, 9]], [[9, 10], [10, 11], [11, 12], [12, 13]]]
t1 = tf.Variable(t1)
print(t1.shape)		#(3, 4, 2)

t2 = tf.concat([t1], axis=1)
print(t2.shape)		#(3, 4, 2)
with tf.Session() as sess:
	sess.run(tf.global_variables_initializer())
	print(sess.run(t1))
	print('----------------')
	print(sess.run(t2))

运行结果中t1和t2仍然是两个完全相同的矩阵,包括形状和内部元素。我认为主要原因在于tf.concat被用于两个矩阵之间的连接,对于单个张量的传入并不能起作用,对tf.concat传入两个及以上参数后才能正常工作:

t1 = [[[0, 1], [2, 3], [3, 4], [4, 5]], [[5, 6], [6, 7], [7, 8], [8, 9]], [[9, 10], [10, 11], [11, 12], [12, 13]]]
t1 = tf.Variable(t1)
print(t1.shape)     #(3, 4, 2)

t2 = tf.concat([t1], axis=1)
print(t2.shape)     #(3, 4, 2)

t3 = tf.concat([t1, t1], axis=1)
print(t3.shape)     #(3, 8, 2)

因此下方第一行的concat没有产生作用,实际上应该去掉,程序能正常运行主要依靠第二行的reshape。

lstm_output = tf.concat(lstm_output, axis=1)
lstm_output = tf.reshape(lstm_output, shape=(-1, in_size))

至于t1未转换成张量前,以list形式单独传入可以正常运行,应该是list被迫降了1维,例如上面的t1是(3, 4, 2),最外侧括号被认为是多个矩阵拼接用的括号,参数被转换为3个(4, 2)的矩阵,所以对于原始的t1,虽然t1是三维的,但axis只能为0或1,不能为2。

最后,非常感谢issues的提出,我已经将代码修改重新运行,如果最后结果无误,会更新github上的代码。

这样就没问题了,实际起作用的是后面的reshape,输出的维度就是in_size也就是lstm的unit,行数根据两个相乘的来
谢谢

ezoic increase your site revenue