Transformer翻译模型Decoder详解(Masking)
写这个博客的原因在于:⼤部分解释Transformer的⽂章都只注重讲解Encoder部分,在Encoder中⼜侧重讲解self-attention原理。为了读者更好地理解整个Transformer的训练过程,我决定结合代码写⼀篇在理解了Encoder部分怎么理解Decoder模块的博⽂。
参考⽂章:jalammar.github.io/illustrated-transformer/
参考代码:github/Kyubyong/transformer
pre: Encoder
根据以上参考⽂章及代码理解Encoder的self-attention原理⾮常容易,这⾥不再赘述。需要说明的是以下维度:
德⽂输⼊X.shape:[batch_size, max_len]
英⽂标注Y.shape:[batch_size, max_len]
Encoder输出维度
[batch_size, max_len, hidden_units]
也就是⾥的[N, T_q, C]
Decoder
在训练过程中,Transformer同所有seq2seq模型⼀样,会⽤到source data以及不断⽣成的target data的部分数据(理解就是RNN的因果关系,训练过程中不像BiRNN⼀样使⽤未来数据,因此需要Masking)。
decoder需要说明的是中的key masking和query masking是对于⽂本padding部分的掩盖,⽬的是使Encoder不过多的关注于padding这种⽆效信息。
causality
代码中的causality部分是是对未来信息的掩盖。这部分代码位于modules.py中。
if causality:
diag_vals = tf.ones_like(outputs[0, :, :]) # (T_q, T_k)
tril = tf.linalg.LinearOperatorLowerTriangular(diag_vals).to_dense() # (T_q, T_k)
masks = tf.pand_dims(tril, 0), [tf.shape(outputs)[0], 1, 1]) # (h*N, T_q, T_k)
paddings = tf.ones_like(masks)*(-2**32+1)
outputs = tf.where(tf.equal(masks, 0), paddings, outputs) # (h*N, T_q, T_k)
下⾯我通过对⽐Decoder中的self-attention和Encoder-Decoder attention两个模块说明Decoder在代码中是如何具体同时attention源数据及⽣成数据的。这对理解Decoder如何使⽤数据很关键。
同Encoder⼀样,使⽤多个block叠加:
with tf.variable_scope("num_blocks_{}".format(i)):
block中包含使⽤源数据的self-attention【⽬标数据⾃⾝关注,因此需要掩盖未来数据来模拟逐词⽣成、类似于单向RNN】,
和使⽤⽣成数据的vanilla attention【⽬标数据关注于源数据,也就是en关注于de,由于源数据是存在的,因此没有属于未来的数据,不需要进⾏掩盖未来数据的操作,类似于BiRNN】。
self-attention(⾃⾝关注,需掩盖未来数据)
self.dec = multihead_attention(queries=self.dec,
keys=self.dec,
num_units=hp.hidden_units,
num_heads=hp.num_heads,
dropout_rate=hp.dropout_rate,
is_training=is_training,
causality=True,
scope="self_attention")
vanilla attention(关注源数据,causality=False)
self.dec = multihead_attention(queries=self.dec,
,
num_units=hp.hidden_units,
num_heads=hp.num_heads,
dropout_rate=hp.dropout_rate,
is_training=is_training,
causality=False,
scope="vanilla_attention")
在这个对⽐中,主要的输⼊参数不同是:
keys
causality
keys输⼊⽤来计算关注的权重,在代码中key=value,同时⽤来计算权重以及attention之后的结果。
self-attention:关注self.dec,也就是⾃⾝关注,设置causality=True掩盖训练数据集中的未来数据。
vanilla attention:关注,也就是关注数据集中的源数据,设置causality=False来取消掩盖未来数据(因为训练集的X是已知的)。
causality的不同,具体代码如本⽂的第⼀段代码所⽰,在此复制过来进⾏分析:
if causality:
diag_vals = tf.ones_like(outputs[0, :, :]) # (T_q, T_k)
tril = tf.linalg.LinearOperatorLowerTriangular(diag_vals).to_dense() # (T_q, T_k)
masks = tf.pand_dims(tril, 0), [tf.shape(outputs)[0], 1, 1]) # (h*N, T_q, T_k)
paddings = tf.ones_like(masks)*(-2**32+1)
outputs = tf.where(tf.equal(masks, 0), paddings, outputs) # (h*N, T_q, T_k)
这⾥主要是使⽤了
tf.linalg.LinearOperatorLowerTriangular().to_dense()
这个函数⽣成mask,该函数的作⽤是将:
1111
1111
1111
1111
变成:
1000
1100
1110
1111
⽽后通过:
paddings = tf.ones_like(masks)*(-2**32+1)
outputs = tf.where(tf.equal(masks, 0), paddings, outputs) # (h*N, T_q, T_k)
将未来数据的权重设置为⽆穷⼩,以达到在训练过程中不关注未来数据的作⽤。也就是⽣成第⼀个词时关注第0个token,⽣成第⼆个词时关注第0及第1个token,如上表格所⽰。
⽽在vanilla attention中设置causality=False关注源数据的所有token。
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系QQ:729038198,我们将在24小时内删除。
发表评论