WaveNetの学習時と推論時の畳み込み入力xの形状
zzxiang opened this issue · comments
聞く場所を間違えたら申し訳ございません。
第7章のWaveNetのソースコードを拝見するときによくわからないところがあります。
WaveNetの学習においては、畳み込み入力x
の形状は(B, out_channels, T)
です。μ-lawアルゴリズムによって音声波形を8 bit(2^8=256通り)に量子化した場合、out_channels
は256です。
# 量子化された離散値列から One-hot ベクトルに変換
# (B, T) -> (B, T, out_channels) -> (B, out_channels, T)
x = F.one_hot(x, self.out_channels).transpose(1, 2).float()
# 条件付き特徴量のアップサンプリング
c = self.upsample_net(c)
assert c.size(-1) == x.size(-1)
# One-hot ベクトルの次元から隠れ層の次元に変換
x = self.first_conv(x)
なので、自分の認識として、学習時の畳み込みは時間軸方向に行われると思います。本の224ページ目の話通り、
時刻t - 1の教師データ(図7-15a)を時刻tの入力として利用することで、学習の難しさを緩和します。この方法はteacher forcingと呼ばれます[64]。
その一方、
推論時には教師データは得られないため、1サンプルずつ逐次的に音声を生成しなければなりません。
推論においては、畳み込み入力x
の形状は(1, out_channels)
です。
outputs = []
# 自己回帰生成における初期値
current_input = torch.zeros(B, 1, self.out_channels).to(c.device)
current_input[:, :, int(mulaw_quantize(0))] = 1
# ...
# 逐次的に生成
for t in ts:
# 時刻 t における入力は、時刻 t-1 における出力
if t > 0:
current_input = outputs[-1]
# 時刻 t における条件付け特徴量
ct = c[:, t, :].unsqueeze(1)
x = current_input
x = self.first_conv.incremental_forward(x)
# ...
outputs += [x.data]
すなわち、推論時の畳み込みはout_channels
方向、あるいはOne-hotベクトル方向に行われて、学習時と違います。
この認識は正しいですか?
もし正しければ、なぜ時間軸に学習した畳み込みの重みはOne-hotベクトル方向に使えますでしょうか?
ご質問ありがとうございます。質問はgithub issuesでしていただいて問題ありません。
WaveNetにおける一次元畳み込みは、学習時と推論時の両方において、時間方向に対して行われるのが正しいです。
すなわち、推論時の畳み込みはout_channels方向、あるいはOne-hotベクトル方向に行われて、学習時と違います。
この認識は正しいですか?
この認識が誤っています。推論時の畳み込みは時間方向に行われ、その点で言えば学習時と同じです。補足で説明させていただきます。
WaveNetで利用している一次元畳み込みの計算方法は、時間方向に並列で処理する前向き計算 (forward
関数) と、時間方向に逐次的に計算するインクリメンタル前向き計算 (incremental_forward
関数) の二種類があります。
ttslearn/ttslearn/wavenet/conv.py
Lines 9 to 10 in 925d491
前者は学習時に利用し、実装レベルの話で言いますと、(B, out_channels, T)
のサイズを持つテンソルを入力とします。一方、後者は推論時に利用し、(B, 1, out_channels)
のサイズを持つテンソルを入力とします。
後者は、一見して時間方向に畳み込みを行っていないように思われるかもしれませんが、実際には内部的に時間方向の情報がキャッシュされており、時間方向の畳み込みが行われます。内部的な計算の効率化の都合上、入力とするテンソルのサイズが通常の一次元畳み込みの場合 ((B, out_channels, T)
) と異なることにご注意ください。実装の詳細は、以下のとおりです。
ttslearn/ttslearn/wavenet/conv.py
Lines 21 to 53 in 925d491
#18 こちらで、forwardとincremetnal_forwardが同じ時間方向の畳み込みを行っていることを示すテストコードを追加しました。参考になれば幸いです。
ご回答と例のテストコードありがとうございます!理解できたと思います!
incremental_forward
の場合、t - 1 時刻以前の出力はinput_buffer
に(B, T, C)
の形状でキャッシュされます。正しく言えば、(B, kw + (kw - 1) * (dilation - 1), C)
の形状です。kw
はkernel_size
です。下記F.linear
を呼び出すソースコードに、入力input
は(B, (kw + (kw - 1) * (dilation - 1)) * C)
の形状に変形されます。
ttslearn/ttslearn/wavenet/conv.py
Line 52 in 925d491
このソースコードと結果的に等しい演算として、時間方向に幅が kw + (kw - 1) * (dilation - 1)
の全ての入力チャネルにある値と対応の重みに対して、積和演算を行います。これはまさに時間方向の一次元畳み込みです!
ちなみに、もう一つ気づいたことがあります。forward
の場合、入力は (B, C, T)
のサイズである一方、incremental_forward
の場合、入力は (B, 1, C)
のサイズですね。しかしながら、下記推論のソースコードに、同じ (B, 1, C)
のサイズの入力 x
は incremental_forward
と forward
両方に使えます。
ttslearn/ttslearn/wavenet/wavenet.py
Lines 163 to 167 in 925d491
なぜならば、self.last_conv_layers
に入っているのは ReLU と Conv1d1x1
だけです。Conv1d1x1
には incremental_forward
が定義されています。 ReLU
には incremental_foward
がないですが、そもそも畳み込み演算を行わないので、形状にこだわる必要がないです。この認識は合っていますか?
いただいたテストソースコードも試しました。ありがとうございます!
はい、おっしゃるとおりの認識で合っています。ReLUの場合には、形状にこだわる必要はないため、そのような実装になっています。
ありがとうございました!