网站记登录账号怎么做,建个企业营销型网站,室内设计效果图怎么做出来的,昌吉建设网站为什么 nn.CrossEntropyLoss LogSoftmax nn.NLLLoss#xff1f;
在使用 PyTorch 时#xff0c;我们经常听说 nn.CrossEntropyLoss 是 LogSoftmax 和 nn.NLLLoss 的组合。这句话听起来简单#xff0c;但背后到底是怎么回事#xff1f;为什么这两个分开的功能加起来就等于…为什么 nn.CrossEntropyLoss LogSoftmax nn.NLLLoss
在使用 PyTorch 时我们经常听说 nn.CrossEntropyLoss 是 LogSoftmax 和 nn.NLLLoss 的组合。这句话听起来简单但背后到底是怎么回事为什么这两个分开的功能加起来就等于一个完整的交叉熵损失今天我们就从数学公式到代码实现彻底搞清楚它们的联系。
1. 先认识三个主角
要理解这个等式先得知道每个部分的定义和作用
nn.CrossEntropyLoss交叉熵损失直接接受未归一化的 logits计算模型预测与真实标签的差距适用于多分类任务。LogSoftmax将 logits 转为对数概率log probabilities输出范围是负值。nn.NLLLoss负对数似然损失接受对数概率计算正确类别的负对数值。
表面上看nn.CrossEntropyLoss 是一个独立的损失函数而 LogSoftmax 和 nn.NLLLoss 是两步操作。为什么说它们本质上是一回事呢答案藏在数学公式和计算逻辑里。
2. 数学上的拆解
让我们从交叉熵的定义开始逐步推导。
(1) 交叉熵的数学形式
交叉熵Cross-Entropy衡量两个概率分布的差异。在多分类任务中
( p p p )真实分布通常是 one-hot 编码比如 [0, 1, 0] 表示第 1 类。( q q q )预测分布是模型输出的概率比如 [0.2, 0.5, 0.3]。
交叉熵公式为 H ( p , q ) − ∑ c 1 C p c log ( q c ) H(p, q) -\sum_{c1}^{C} p_c \log(q_c) H(p,q)−c1∑Cpclog(qc)
对于 one-hot 编码( p c p_c pc ) 在正确类别上为 1其他为 0所以简化为 H ( p , q ) − log ( q correct ) H(p, q) -\log(q_{\text{correct}}) H(p,q)−log(qcorrect)
其中 ( q correct q_{\text{correct}} qcorrect ) 是正确类别对应的预测概率。对 ( N N N ) 个样本取平均损失为 Loss − 1 N ∑ i 1 N log ( q i , y i ) \text{Loss} -\frac{1}{N} \sum_{i1}^{N} \log(q_{i, y_i}) Loss−N1i1∑Nlog(qi,yi)
这正是交叉熵损失的核心。
(2) 从 logits 到概率
神经网络输出的是原始分数logits比如 ( z [ z 1 , z 2 , z 3 ] z [z_1, z_2, z_3] z[z1,z2,z3] )。要得到概率 ( q q q )需要用 Softmax q j e z j ∑ k 1 C e z k q_j \frac{e^{z_j}}{\sum_{k1}^{C} e^{z_k}} qj∑k1Cezkezj
交叉熵损失变成 Loss − 1 N ∑ i 1 N log ( e z i , y i ∑ k 1 C e z i , k ) \text{Loss} -\frac{1}{N} \sum_{i1}^{N} \log\left(\frac{e^{z_{i, y_i}}}{\sum_{k1}^{C} e^{z_{i,k}}}\right) Loss−N1i1∑Nlog(∑k1Cezi,kezi,yi)
这就是 nn.CrossEntropyLoss 的数学形式。
(3) 分解为两步
现在我们把这个公式拆开 第一步LogSoftmax 计算对数概率 log ( q j ) log ( e z j ∑ k 1 C e z k ) z j − log ( ∑ k 1 C e z k ) \log(q_j) \log\left(\frac{e^{z_j}}{\sum_{k1}^{C} e^{z_k}}\right) z_j - \log\left(\sum_{k1}^{C} e^{z_k}\right) log(qj)log(∑k1Cezkezj)zj−log(k1∑Cezk) 这正是 LogSoftmax 的定义。它把 logits ( z z z ) 转为对数概率 ( log ( q ) \log(q) log(q) )。 第二步NLLLoss 有了对数概率 ( log ( q ) \log(q) log(q) )取出正确类别的值取负号并平均 NLL − 1 N ∑ i 1 N log ( q i , y i ) \text{NLL} -\frac{1}{N} \sum_{i1}^{N} \log(q_{i, y_i}) NLL−N1i1∑Nlog(qi,yi) 这就是 nn.NLLLoss 的公式。
组合起来
LogSoftmax 把 logits 转为 ( log ( q ) \log(q) log(q) )。nn.NLLLoss 对 ( log ( q ) \log(q) log(q) ) 取负号计算损失。两步合起来正好是 ( − log ( q correct ) -\log(q_{\text{correct}}) −log(qcorrect) )与交叉熵一致。
3. PyTorch 中的实现验证
从数学上看nn.CrossEntropyLoss 的确可以分解为 LogSoftmax 和 nn.NLLLoss。我们用代码验证一下
import torch
import torch.nn as nn# 输入数据
logits torch.tensor([[1.0, 2.0, 0.5], [0.1, 0.5, 2.0]]) # [batch_size, num_classes]
target torch.tensor([1, 2]) # 真实类别索引# 方法 1直接用 nn.CrossEntropyLoss
ce_loss_fn nn.CrossEntropyLoss()
ce_loss ce_loss_fn(logits, target)
print(CrossEntropyLoss:, ce_loss.item())# 方法 2LogSoftmax nn.NLLLoss
log_softmax nn.LogSoftmax(dim1)
nll_loss_fn nn.NLLLoss()
log_probs log_softmax(logits) # 计算对数概率
nll_loss nll_loss_fn(log_probs, target)
print(LogSoftmax NLLLoss:, nll_loss.item())运行结果两个输出的值完全相同比如 0.75。这证明 nn.CrossEntropyLoss 在内部就是先做 LogSoftmax再做 nn.NLLLoss。
4. 为什么 PyTorch 这么设计
既然 nn.CrossEntropyLoss 等价于 LogSoftmax nn.NLLLoss为什么 PyTorch 提供了两种方式 便利性 nn.CrossEntropyLoss 是一个“一体式”工具直接输入 logits 就能用适合大多数场景省去手动搭配的麻烦。 模块化 LogSoftmax 和 nn.NLLLoss 分开设计给开发者更多灵活性 你可以在模型里加 LogSoftmax只用 nn.NLLLoss 计算损失。可以单独调试对数概率比如打印 log_probs。在某些自定义损失中可能需要用到独立的 LogSoftmax。 数值稳定性 nn.CrossEntropyLoss 内部优化了计算避免了分开操作时可能出现的溢出问题比如 logits 很大时Softmax 的分母溢出。
5. 为什么不直接用 Softmax
你可能好奇为什么不用 Softmax 对数 取负而是用 LogSoftmax 答案是数值稳定性
单独计算 Softmax指数运算可能导致溢出比如 ( e 1000 e^{1000} e1000 )。LogSoftmax 把指数和对数合并为 ( z j − log ( ∑ e z k ) z_j - \log(\sum e^{z_k}) zj−log(∑ezk) )计算更稳定。
6. 使用场景对比 nn.CrossEntropyLoss 输入logits。场景标准多分类任务图像分类、文本分类。优点简单直接。 LogSoftmax nn.NLLLoss 输入logits 需手动转为对数概率。场景需要显式控制 Softmax或者模型已输出对数概率。优点灵活性高。
7. 小结为什么等价
数学上交叉熵 ( − log ( q correct ) -\log(q_{\text{correct}}) −log(qcorrect) ) 可以拆成两步 LogSoftmax从 logits 到 ( log ( q ) \log(q) log(q) )。nn.NLLLoss从 ( log ( q ) \log(q) log(q) ) 到 ( − log ( q correct ) -\log(q_{\text{correct}}) −log(qcorrect) )。 实现上nn.CrossEntropyLoss 把这两步封装成一个函数结果一致。设计上PyTorch 提供两种方式满足不同需求。
所以nn.CrossEntropyLoss LogSoftmax nn.NLLLoss 不是巧合而是交叉熵计算的自然分解。理解这一点能帮助你更灵活地使用 PyTorch 的损失函数。
8. 彩蛋手动推导
想自己验证试试手动计算
logits [1.0, 2.0, 0.5]目标是 1。Softmax[0.23, 0.63, 0.14]。LogSoftmax[-1.47, -0.47, -1.97]。NLL-(-0.47) 0.47。直接用 nn.CrossEntropyLoss结果一样
希望这篇博客解开了你的疑惑
后记
2025年2月28日18点51分于上海在grok3 大模型辅助下完成。