4关于word2vec的skip-gram模型使⽤负例采样nce_loss损失函
数的源码剖析
<_loss是word2vec的skip-gram模型的负例采样⽅式的函数,下⾯分析其源代码。
1 上下⽂代码
loss = tf.reduce_mean(
<_loss(weights=nce_weights,
biases=nce_biases,
labels=train_labels,
inputs=embed,
num_sampled=num_sampled,
num_classes=vocabulary_size))
其中,
train_inputs = tf.placeholder(tf.int32, shape=[batch_size])
train_labels = tf.placeholder(tf.int32, shape=[batch_size, 1])
embeddings = tf.Variable(
tf.random_uniform([vocabulary_size, embedding_size], -1.0, 1.0))
embed = bedding_lookup(embeddings, train_inputs)
train_inputs中的就是中⼼词,train_label中的就是语料库中该中⼼词在滑动窗⼝内的上下⽂词。
所以,train_inputs中会有连续n-1(n为滑动窗⼝⼤⼩)个元素是相同的。即同⼀中⼼词。
embddings是词嵌⼊,就是要学习的词向量的存储矩阵。共有词汇表⼤⼩的⾏数,每⼀⾏对应⼀个词的向量。
# Construct the variables for the NCE loss
nce_weights = tf.Variable(
stddev=1.0 / math.sqrt(embedding_size)))
nce_biases = tf.s([vocabulary_size]))
nce_weights就是⽤来存储如下负例采样公式中的
、
sigmoid函数有⼀个对称特性:
故⽽上⾯的公式中,就没有出现1-XX的形式。⽤1-XX的形式,可能会更好理解。
具体解释如下:
l #train_inputs中是中⼼词的单词编号,就是词汇表中对该单词的⼀个编号,⼀般按词频排列,⽤顺序进⾏编号。
l #train_labels中是中⼼词的上下⽂中的单次编号,这些都算是正样本,注意和机器学习中的正样本的意思不⼀样,这⾥是做正确答案的意思。
l #embedding_lookup就是取出某⼀⾏。下标从0开始。
l #tf.truncated_normal从截断的正态分布中输出随机值。#⽣成的值服从具有指定平均值和标准偏差的正态分布,如果⽣成的值⼤于平均值2个标准偏差的值则丢弃重新选择。#标准差就是标准偏差,是⽅差的算术平均根。⽽上⾯的代码中对标准⽅差进⾏了限制的原因就是为了防⽌神经⽹络的参数过⼤。为什么embeddings中的参数没有进⾏限制呢?是因为最初初始化的时候,所有的词的词向量之间要保证⼀定的距离。然后通过学习,才能拉近某些词的关系,使得某些词的词向量更加接近。
l #因为是单层神经⽹络,所以要限制参数过⼤。如果是深层神经⽹络,就不需要标准差除⼀⼀个embedding_size的平⽅根了。深层神经⽹络虽然也要进⾏参数的正则化限制,防⽌过拟合和梯度爆炸问
题,但是很少看见,有直接对stddev进⾏限制的。
2 nce_loss源码
def nce_loss(weights,
biases,
labels,
inputs,
num_sampled,
num_classes,
num_true=1,
sampled_values=None,
remove_accidental_hits=False,
partition_strategy="mod",
name="nce_loss"):
logits, labels = _compute_sampled_logits(
weights=weights,
biases=biases,
labels=labels,
inputs=inputs,
num_sampled=num_sampled,
num_classes=num_classes,
num_true=num_true,
sampled_values=sampled_values,
subtract_log_q=True,
remove_accidental_hits=remove_accidental_hits,
partition_strategy=partition_strategy,
name=name)
sampled_losses = sigmoid_cross_entropy_with_logits(
labels=labels, logits=logits, name="sampled_losses")
# sampled_losses is batch_size x {true_loss, }
# We sum out true and sampled losses.
return _sum_rows(sampled_losses)
可以看出核⼼就在于传⼊sigmoid_cross_entropy_with_logits的参数。对于任何⼀个输出节点只有⼀个的⼆分类神经⽹络,⽤sigmoid_cross_entropy_with_logits是最好理解的。logits的维度是batch_size,1。labels的维度就是batch_size,元素取值为0或者1, 来看⼀下sigmoid_cross_entropy_with_logits函
数
sigmoid_cross_entropy_with_logits的返回值是:
Returns:
A `Tensor` of the same shape as `logits` with the componentwise
logistic losses.
也就是说:logits的维度是batch_size,1,其返回的维度也是batch_size,1。这个位置的元素就是⽤这个公式计算的loss:
但是在负例采样中,传⼊的logits的维度不是batch_size,1,⽽是[batch_size, num_true + num_sampled]`。主要观察⼀下
_compute_sampled_logits函数的输出。其输出如下:
Returns:
out_logits: `Tensor` object with shape
`[batch_size, num_true + num_sampled]`, for passing to either
`nn.sigmoid_cross_entropy_with_logits` (NCE) or
`nn.softmax_cross_entropy_with_logits` (sampled softmax).
out_labels: A Tensor object with the same shape as `out_logits`.
"""
其传⼊参数的解释是:
labels: A `Tensor` of type `int64` and shape `[batch_size,
num_true]`. The target classes. Note that this format differs from
the `labels` argument of `nn.softmax_cross_entropy_with_logits`.
inputs: A `Tensor` of shape `[batch_size, dim]`. The forward
activations of the input network.
weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor`
objects whose concatenation along dimension 0 has shape
`[num_classes, dim]`. The (possibly-partitioned) class embeddings.
可以看出_compute_sampled_logits完成的是⼀个什么过程呢。就是对于每⼀个样本,计算出⼀个维度为[batch_size, num_true +
num_sampled]的向量,向量的每个元素都同之前logits的每个元素的意义⼀样,是输出值。同时,返回⼀个维度为[batch_size, num_true + num_sampled]的向量labels。这个labels中只有⼀个元素为1。于是再看⼀下如下公式:
其实,此时的out_logits中对应(label位置为0)的元素就是,对应label位置为1)的元素就是。
然后再传给sigmoid_cross_entropy_with_logits,同样是对于每个元素位置的计算使⽤下⾯的公式:
所以,nce_loss中调⽤sigmoid_cross_entropy_with_logits后返回的是:[batch_size, num_true + num_sampled]的向量,其中每个元素都是⼀个⽤上述公式计算出loss。
nce_loss的最后⼀步是_sum_rows:
def _sum_rows(x):
"""Returns a vector summing up each row of the matrix x."""
# _sum_rows(x) is equivalent to duce_sum(x, 1) when x is
# a matrix. The gradient of _sum_rows(x) is more efficient than
# reduce_sum(x, 1)'s gradient in today's implementation. Therefore,
# we use _sum_rows(x) in the nce_loss() computation since the loss
# is mostly used for training.
cols = array_ops.shape(x)[1]
ones_shape = array_ops.stack([cols, 1])
ones = s(ones_shape, x.dtype)
return shape(math_ops.matmul(x, ones), [-1])
最后,再对nce_loss的返回结果⽤reduce_mean即可计算⼀个batch的平均损失。
关于_compute_sampled_logits中如何采样,如何计算的,这⾥就不再阐述,同⽂字理论是⼀样的。
我们将_compute_sampled_logits函数中的
# Construct output logits and labels. The true labels/logits start at col 0.
out_logits = at([true_logits, sampled_logits], 1)
# true_logits is a float tensor, ones_like(true_logits) is a float
# tensor of ones. We then divide by num_true to ensure the per-example
# labels sum to 1.0, i.e. form a proper probability distribution.
out_labels = at([
s_like(true_logits) / num_true,
s_like(sampled_logits)
], 1)
改为
out_logits = at([true_logits, sampled_logits], 1,name="xiaojie_logits")
# true_logits is a float tensor, ones_like(true_logits) is a float
# tensor of ones. We then divide by num_true to ensure the per-example
# labels sum to 1.0, i.e. form a proper probability distribution.
out_labels = at([
truncated模型用什么软件
s_like(true_logits) / num_true,
s_like(sampled_logits)
], 1,name="xiaojie_labels")
然后由于这些代码位于:
with ops.name_scope(name, "compute_sampled_logits",
weights + [biases, inputs, labels]):
ops指定的name下,name为“nce_loss”
我们在word2vec的程序训练迭代的过程中添加如下代码:
for step in range(num_steps):
batch_inputs, batch_labels = generate_batch(
batch_size, num_skips, skip_window)
feed_dict = {train_inputs : batch_inputs, train_labels : batch_labels}
print ("xiaojie Debug:")
xiaojie_logits= _tensor_by_name("nce_loss/xiaojie_logits:0")
xiaojie_labels = _tensor_by_name("nce_loss/xiaojie_labels:0")
xiaojie_logits_value,xiaojie_labels_value=session.run([xiaojie_logits,xiaojie_labels],feed_dict=feed_dict) print (xiaojie_logits_value,xiaojie_labels_value)
可以看出输出结果中传递给sigmoid_cross_entropy_with_logits函数的就是这么个玩意。
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系QQ:729038198,我们将在24小时内删除。
发表评论