注意⼒机制:pytorch 实现
注意⼒机制:pytorch 实现
查询(queries ),键(keys )和值(Values )
查询、键和值是注意⼒机制的基本三个关键词,注意⼒评分函数则是注意⼒机制建⽴的主要⽅式,注意⼒机制就是以这三个关键词为基础通过注意⼒评分函数进⾏花式操作:加性注意⼒、乘积注意⼒、软硬注意⼒和多头注意⼒等查询(queries): 是⾃主性提⽰,告诉你应该关注什么键(keys): 为⾮⾃主提⽰,为所需的所有信息值(values): 使⽤queries对keys加权,最后得到的带注意⼒权重的信息
注意⼒评分函数: 注意⼒评分函数是对于查询和键的关系进⾏建模,以得到对于键的权重注意⼒评分函数基本操作步骤:通过⼀定运算建⽴查询(queries)和键(keys)之间的函数关系将上述函数的输出结果输⼊到softmax中进⾏运算,计算得到权重
使⽤上述权重对valuse进⾏加权加性注意⼒模型:
当查询(queries)和键(keys)是不同长度的⽮量时,⼀般使⽤加性注意⼒作为评分函数,若查询和键,则评分函数为:其中权重均为可学习参数,感知机包含⼀个隐藏层,其隐藏层单元数是⼀个超参数h 代码如下:
class AdditiveAttention (nn .Module ):
def __init__(self , keys_size , queries_size , num_hiddens , dropout , **kwargs ):
super (AdditiveAttention , self ).__init__(**kwargs )
self .W_q = nn .Linear (queries_size , num_hiddens , bias =False )
self .W_k = nn .Linear (keys_size , num_hiddens , bias =False )
self .W_v = nn .Linear (num_hiddens , 1, bias =False )
self .dropout = nn .Dropout (dropout )
def forward (self , queries , keys , values ):
queries , keys = self .W_q (queries ), self .W_k (keys )
'''
queries --> [batch_size, queries_length, num_hiddens]
keys --> [batch_size, keys_length, num_hiddens]'''
features = queries .unsqueeze (2) + keys .unsqueeze (1)
'''
queries.unsqueeze(2) --> [batch_size, queries_length, 1, num_hiddens]
keys.unsqueeze(1) --> [batch_size, 1, keys_length, num_hiddens]
features --> [batch_size, queries_length, keys_length, num_hiddens] '''
features = torch .tanh (features )
scores = self .W_v (features ).squeeze (-1)
'''
self.W_v(features) --> [batch_size, queries_length, keys_length, 1]
scores--> [batch_size, queries_length, keys_length]'''
self .attention_weights = F .softmax (scores , dim =1)
'''
self.attention_weights --> [batch_size, queries_length, keys_length]'''
return torch .bmm (self .dropout (self .attention_weights ), values )
'''
output --> [batch_size, queries_length, value_features_num]
'''
>>### q ∈R q k ∈R k a (q ,k )=w tanh (W q +v T
q w k )
k
### 实例测试 ###
>>###
queries, keys = al(0,1,(2,2,20)), s((2,10,2))
# `values` 的⼩批量数据集中,两个值矩阵是相同的
values = torch.arange(40, dtype=torch.float32).reshape(1,10,4).repeat(
2,1,1)
attention = AdditiveAttention(
keys_size=2, queries_size=20, num_hiddens=8, dropout=0.1)
attention.eval()
output = attention(queries, keys, values)
'''
output:
tensor([[[ 91.1298, 96.1926, 101.2553, 106.3181],
[ 88.8702, 93.8074, 98.7447, 103.6819]],
[[ 92.0438, 97.1574, 102.2709, 107.3845],
[ 87.9562, 92.8426, 97.7291, 102.6155]]]
shape : [2,2,4]
'''
算法的维度转化如
在前向传播过程中涉及到了多维数据的⼴播机制,解析如下:
⼴播机制的应⽤:
queries:[batch_size, squence_len, hiddens_num]
--> [batch_size, queries_len, 1, hiddens_num]
keys: [batch_size, squence_len, hiddens_num]
--> [batch_size, 1, keys_len, hiddens_num]
此时 queries + keys 则会根据⼴播机制进⾏相加:
⽣成结果为 [batch_size, queries_len, keys_len, hiddens_num]
相加过程为
featurs[0] = keys[:, 0, :, :] + queries[:, 0, 0, :]
featurs[1] = keys[:, 0, :, :] + queries[:, 1, 0, :]
featurs[2] = keys[:, 0, :, :] + queries[:, 2, 0, :]
features = torch.stack(features)
描述为:沿着第⼆维度,提取queies矩阵,分别与keys矩阵在⾏上相加
代码如下:
aa = torch.arange(12).reshape(1,1,4,3)
'''output:
decodertensor([[[[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]]]])'''
bb = torch.arange(6).reshape(1,2,1,3)
'''output:
tensor([[[[0, 1, 2]],
[[3, 4, 5]]]])'''
aa + bb
'''output:
tensor([[[[ 0, 2, 4],
[ 3, 5, 7],
[ 6, 8, 10],
[ 9, 11, 13]],
[[ 3, 5, 7],
[ 6, 8, 10],
[ 9, 11, 13],
[12, 14, 16]]]])'''
缩放点积注意⼒
通过点积可以获得计算效率更好的评分函数,但是点积操作要求查询的键具有相同长度的d, 就是特征
数是⼀样的。其评分函数的公式如下:若基于n和查询和m和键值对计算注意⼒,其中查询和键的长度为d,值的长度为v。
查询 ,键和值 的缩放点积注意⼒为代码如下:
class DotProductAttention (nn .Module ):
def __init__(self , dropout , **kwargs ):
super (DotProductAttention , self ).__init__(**kwargs )
self .dropout = nn .Dropout (dropout )
def forward (self , queries , keys , values ):
'''
queries --> [batch_size, queries_length, queries_feature_num]
keys --> [batch_size, keys_values_length, keys_features_num]
values --> [barch_size, keys_values_length, values_features_num]
点积模型中: queries_features_num = keys_features_num
'''
d = queries .shap
e [-1]
'''交换keys 的后两个维度,相当于公式中的转置'''
scores = torch .bmm (queries , keys .transpose (1,2)) / math .sqrt (d )
self .attention_weights = F .softmax (scores , dim =1)
return torch .bmm (self .dropout (self .attention_weights ), values )
queries = torch .normal (0, 1, (2, 1, 2))
attention = DotProductAttention (dropout =0.5)
attention .eval ()
dot_output = attention (queries , keys , values )
print (dot_output )
'''
dot_output:
tensor([[[180., 190., 200., 210.]],
[[180., 190., 200., 210.]]])
'''
attention 机制的应⽤
a (q ,k )=d q k
T Q ∈R n ×d K ∈R m ×d V ∈R m ×v softmax ()V ∈d QK T
R n ×v
attention机制常与sequence2sequence相结合使⽤,相应的查询(queries)、键(keys)和值(values)分别为:keys和values: 编码层所有时间步的最终隐藏状态,建⽴键值对
queries: 在解码时间步骤中,解码器上⼀个时间步的最终层隐藏状态将作为关注的查询
sequence2sequence with attention的基本流程如下:
1. 将时间序列矩阵输⼊到编码层中,得到编码层的各个时间步隐含层的最终输出和最后⼀个时间步隐含层输出
2. 将编码层各个时间步的隐含层输出[batch_size, time_step, hiddens_num]作为keys 和 valuse,将编码层最后⼀层隐
含层输出[batch_size, 1, hiddens_num]作为query
3. 基于attention机制,使⽤key、values、query得到上下⽂信息:context–>[batch_size, query_length=1,
num_hiddens]
4. 将context与单⼀步长的x在特征维度合并,输⼊到循环神经⽹络中
5. 基于循环神经⽹络的输出的hidden_state更新循环中的hidden_state,对下⼀个时间步的x进⾏处理
sequence2sequence:包括编码层和解码层两个部分,其中attention机制加⼊到解码层中,先定义编码层,代码如下:
class Encoder(nn.Module):
def __init__(self, inputs_dim, num_hiddens, hiddens_layers):
super(Encoder, self).__init__()
<1 = nn.GRU(
input_size=inputs_dim, hidden_size=num_hiddens,
num_layers=hiddens_layers)
def forward(self, inputs):
'''由于nn.GRU没有设置 batch_first=True
因此输⼊的维度排列:[time_step_num, batch_size, num_features]
输出维度为:
output: [time_step_num, batch_size, hiddens_num]
hidSta: [num_layers, batch_size, hiddens_num]
'''
inputs = inputs.permute(1,0,2)
encOut, hidSta = 1(inputs)
return encOut, hidSta
class AttentionDecoder(nn.Module):
def __init__(
self, inputs_dim, num_hiddens, num_layers, outputs_dim, dropout):
super(AttentionDecoder, self).__init__()
self.attention = AdditiveAttention(
num_hiddens, num_hiddens, num_hiddens, dropout)
< = nn.GRU(
inputs_dim + num_hiddens, num_hiddens, num_layers,
dropout=dropout)
self.dense = nn.Linear(num_hiddens, outputs_dim)
def forward(self, inputs, states):
'''
inputs: [batch_size, time_step_num, features]
states:
enc_ouptut, enc_hidden_state
'''
'''
enc_outputs, hidden_state = states
'''将enc_output的维度变为[batch_size, time_step_num, enc_hidden_num]''' enc_outputs = enc_outputs.permute(1,0,2)
inputs = inputs.permute(1,0,2)
'''将inputs的维度变为[time_step_num, batch_size, features_num]'''
outputs, self._attention_weights =[],[]
'''对每⼀时间步的inputs进⾏计算,并于上下⽂信息进⾏融合'''
for x in inputs:
'''提取enc_hidden最后⼀层的输出作为query,并在第2维添加维度
hidden_state[-1] : [batch_size, enc_hidden_num]
--> [batch_size, 1, enc_hidden_num]'''
query = hidden_state[-1].unsqueeze(dim=1)
import pdb;pdb.set_trace()
'''context: [batch_size, query_length=1, hiddens_num]'''
context = self.attention(query, enc_outputs, enc_outputs)
x = torch.cat((context, x.unsqueeze(dim=1)), dim=-1)
'''更新hidden_state'''
out, hidden_state = (x.permute(1,0,2), hidden_state)
outputs.append(out)
self._attention_weights.append(self.attention.attention_weights)
outputs = self.dense(torch.cat(outputs, dim=0))
return outputs.permute(1,0,2),[enc_outputs, hidden_state]
>>
### 实例 ###
>####
encoder = Encoder(inputs_dim=10, num_hiddens=20, hiddens_layers=2) decoder = AttentionDecoder(
inputs_dim=10, num_hiddens=20, num_layers=2, outputs_dim=8, dropout=0.1) inputs = al(0,1,(4,8,10))
state = encoder(inputs)
dec_inputs = al(0,1,(4,1,10))
dec_output, state = decoder(dec_inputs, state)
print(dec_output.shape)
'''
output:
[4, 1, 8]
'''
结合
可以使⽤函数将encoder和decoder结合起来
class EncoderDecoder(nn.Module):
"""The base class for the encoder-decoder architecture."""
def __init__(self, encoder, decoder,**kwargs):
super(EncoderDecoder, self).__init__(**kwargs)
self.decoder = decoder
def forward(self, enc_X, dec_X,*args):
state = der(enc_X,*args)
dec_state = self.decoder(dec_X, state)
return dec_state
net = EncoderDecoder(encoder, decoder)
output = net(inputs, dec_inputs)
print(output[0].shape)# -->[4,1,8]
参考资料
跟着李沐学AI
动⼿深度学习
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系QQ:729038198,我们将在24小时内删除。
informer时间序列预测代码讲解
« 上一篇
EDA课程设计--乒乓球游戏机
下一篇 »
推荐文章
热门文章
-
一种任意人头与任意人体的3D结合方法
2025-01-07 -
正则匹配c语言中8进制
2025-01-07 -
fortran数据格式
2025-01-07 -
python中文本转数字用的公式
2025-01-07 -
gh 文本变数值
2025-01-07 -
js判断输入是否为正整数、浮点数等数字的函数代码
2025-01-07 -
qt浮点数正则表达式
2025-01-07 -
QT正则表达式限制输入值
2025-01-07 -
手机号码和电话号码的正则表达式
2025-01-07 -
str转浮点-概述说明以及解释
2025-01-07 -
英豪结尾的诗句
2025-01-07 -
Java正则表达式:符合以特定字符串开头,以特定字符串结尾的所有结果
2025-01-07 -
machinebuilder使用手册
2025-01-07 -
ASP.NET网站建设基本常用代码
2025-01-07 -
LCD显示实时时钟
2025-01-07 -
经纬度正则表达式解析
2025-01-07 -
前端科学计数法转数字
2025-01-07 -
python正则表达式re之compile函数解析
2025-01-07 -
pythonunittest之断言及示例
2025-01-07 -
[lua]lua中匹配字符串小数
2025-01-07
最新文章
-
nginx map用法 正则
2025-01-07 -
Prometheus监控学习笔记之初识PromQL
2025-01-07 -
关于PHP中的webshell
2025-01-07 -
python中re.findall函数实例用法
2025-01-07 -
nginx url表达式
2025-01-07 -
nginx 正则匹配参数
2025-01-07
发表评论