全球建站,wordpress html5 音乐,食品网站建设方案书,百度关键词下拉有什么软件引言
在深度学习中的分类问题中#xff0c;类别不平衡问题是常见的挑战之一。尤其在面部表情分类任务中#xff0c;不同表情类别的样本数量可能差异较大#xff0c;比如“开心”表情的样本远远多于“生气”表情。面对这种情况#xff0c;普通的交叉熵损失函数容易导致模型…引言
在深度学习中的分类问题中类别不平衡问题是常见的挑战之一。尤其在面部表情分类任务中不同表情类别的样本数量可能差异较大比如“开心”表情的样本远远多于“生气”表情。面对这种情况普通的交叉熵损失函数容易导致模型过拟合到大类样本忽略少数类样本。为了有效解决类别不平衡问题Class-balanced Exponential Focal Loss (CEFL) 和 Class-balanced Exponential Focal Loss 2 (CEFL2) 损失函数应运而生。
本文将详细介绍CEFL和CEFL2损失函数阐述它们在面部表情分类任务中的应用并提供PyTorch实现代码带有详细注释适合开发者在实际项目中使用。 目录 引言一、CEFL 和 CEFL2 损失函数概述1.1 Focal Loss 的背景1.2 CEFL 的定义1.3 CEFL2 的扩展与改进1.4 对比 CEFL 和 CEFL2 二、面部表情分类中的类别不平衡问题2.1 类别不平衡对模型训练的影响2.2 解决策略 三、如何使用 CEFL 和 CEFL2 损失函数3.1 CEFL 和 CEFL2 损失函数的核心公式3.2 类别频率的计算与应用3.3 计算类别频率的代码示例 四、PyTorch 实现4.1 CEFL 实现4.2 CEFL2 实现4.3 训练过程 总结参考文献 一、CEFL 和 CEFL2 损失函数概述
1.1 Focal Loss 的背景
在传统的分类任务中交叉熵损失Cross-Entropy Loss常常用作优化目标。然而交叉熵损失函数并没有很好地解决类别不平衡问题特别是在少数类样本较少时。Focal Loss焦点损失由 Lin et al. (2017) 提出主要用于解决 类别不平衡 问题旨在通过减小容易分类样本的损失权重增强模型对困难样本的关注。
Focal Loss 引入了一个调节因子 ( 1 − p t ) γ (1 - p_t)^\gamma (1−pt)γ通过减小容易分类样本的损失聚焦模型训练中的难分类样本从而引导模型更加关注难以分类的样本尤其在类别不平衡的情形下避免多数类样本主导训练。其公式如下 F L ( p t ) − α t ( 1 − p t ) γ log ( p t ) FL(p_t) -\alpha_t (1 - p_t)^\gamma \log(p_t) FL(pt)−αt(1−pt)γlog(pt)
其中 p t p_t pt 是预测类别的概率。 γ \gamma γ 调节因子是一个常量超参数通常设为大于 0 的值例如 2用于控制易分类样本的惩罚程度。较大的 γ \gamma γ 会增加对难分类样本的关注。当 γ \gamma γ0时 焦点损失在形式上等价于交叉熵损失。 α t \alpha_t αt 是一个用于平衡类别不平衡的权重因子可以根据每个类别的频率进行调整。
1.2 CEFL 的定义
Class-balanced Exponential Focal Loss (CEFL) 是在 Focal Loss 基础上的进一步改进。它通过在焦点损失中引入类别平衡策略赋予每个类别不同的权重从而有效地应对类别不平衡问题。 通过引入类别平衡策略来处理类别不平衡问题。与 Focal Loss 相比CEFL 会根据每个类别的频率赋予不同的权重从而调整损失函数特别是在类别不平衡的情况下更加有效。
CEFL 的公式如下 CEFL ( p t ) − ( 1 − p t ) log ( p t ) − p t ( 1 − p t ) γ log ( p t ) \text{CEFL}(p_t) -(1 - p_t) \log(p_t) - p_t (1 - p_t)^\gamma \log(p_t) CEFL(pt)−(1−pt)log(pt)−pt(1−pt)γlog(pt)
其中 p t p_t pt表示样本属于正确类别的预测概率。 γ \gamma γ( γ \gamma γ0焦点损失的调节因子通常设置为 2用于放大难以分类的样本的损失使得模型更加关注困难的样本。注意当 γ \gamma γ0时 CEFL损失在形式上是交叉熵损失。
公式的第一项是传统的交叉熵损失第二项则是引入焦点损失后的部分用来减小易分类样本的影响权重使得困难样本对总损失的贡献更大从而模型更加专注于难分类的样本。特别地第二项通过 $( (1 - p_t)^\gamma ) $调节了模型对不同难度样本的关注程度。
1.3 CEFL2 的扩展与改进
CEFL2 是对 CEFL 损失函数的扩展它进一步考虑了类别的频率信息通过精细的调整每个类别的损失权重使得模型在极度不平衡的数据集上表现更好。CEFL2 引入了类别频率class frequency作为权重使用每个类别在数据集中出现的频率来调整每个类别的影响。
CEFL2 的公式为 CEFL2 ( p t ) − ( 1 − p t ) 2 ( 1 − p t ) 2 p t 2 log ( p t ) − p t 2 ( 1 − p t ) 2 p t 2 ( 1 − p t ) γ log ( p t ) \text{CEFL2}(p_t) -\frac{(1 - p_t)^2}{(1 - p_t)^2 p_t^2} \log(p_t) - \frac{p_t^2}{(1 - p_t)^2 p_t^2} (1 - p_t)^\gamma \log(p_t) CEFL2(pt)−(1−pt)2pt2(1−pt)2log(pt)−(1−pt)2pt2pt2(1−pt)γlog(pt)
其中
第一个项和第二个项分别对应于不同类别的损失权重和焦点损失的加权贡献。 ( 1 − p t ) 2 ( 1 − p t ) 2 p t 2 \frac{(1 - p_t)^2}{(1 - p_t)^2 p_t^2} (1−pt)2pt2(1−pt)2 和 p t 2 ( 1 − p t ) 2 p t 2 \frac{p_t^2}{(1 - p_t)^2 p_t^2} (1−pt)2pt2pt2是根据类别的频率对损失进行调整的权重项。具体来说它们的比例反映了每个类别相对于整个数据集的频率。
该损失函数通过动态调整类别的权重使得模型对少数类样本的损失更加敏感从而提升对少数类的识别能力。
1.4 对比 CEFL 和 CEFL2
特性CEFLCEFL2核心思想结合焦点损失和类别平衡引入类别频率进一步优化类别平衡类别权重通过 α t \alpha_t αt 设置权重通过类别频率动态调整权重适用场景通用的类别不平衡问题极度不平衡的类别问题主要优点简单有效适合一般类别不平衡问题更适用于处理极端类别不平衡的数据
二、面部表情分类中的类别不平衡问题
2.1 类别不平衡对模型训练的影响
在面部表情分类任务中可能会出现不同表情类别样本不平衡的情况。例如常见表情如“开心”或“惊讶”在数据集中占有大量样本而“生气”或“害怕”等情绪类别可能样本较少。这种类别不平衡将导致模型偏向于大类表情忽视少数类表情从而影响分类性能尤其是对少数类样本的识别。
影响
模型可能会对大类表情有较高的分类准确率而忽视少数类表情。少数类表情样本的训练效果较差难以学到有效的特征表示。
2.2 解决策略
使用 CEFL 或 CEFL2 损失函数可以有效缓解类别不平衡问题在训练过程中让模型更多关注少数类样本从而提升少数类样本的分类效果。
三、如何使用 CEFL 和 CEFL2 损失函数
3.1 CEFL 和 CEFL2 损失函数的核心公式
损失函数公式说明CEFL − ( 1 − p t ) log ( p t ) − p t ( 1 − p t ) γ log ( p t ) -(1 - p_t) \log(p_t) - p_t (1 - p_t)^\gamma \log(p_t) −(1−pt)log(pt)−pt(1−pt)γlog(pt)基于 Focal Loss加入类别权重调整CEFL2 − ( 1 − p t ) 2 ( 1 − p t ) 2 p t 2 log ( p t ) − p t 2 ( 1 − p t ) 2 p t 2 ( 1 − p t ) γ log ( p t ) -\frac{(1 - p_t)^2}{(1 - p_t)^2 p_t^2} \log(p_t) - \frac{p_t^2}{(1 - p_t)^2 p_t^2} (1 - p_t)^\gamma \log(p_t) −(1−pt)2pt2(1−pt)2log(pt)−(1−pt)2pt2pt2(1−pt)γlog(pt)引入类别频率进一步调整损失权重
3.2 类别频率的计算与应用
在 CEFL2 中需要根据训练集中的类别分布计算每个类别的频率。这些频率作为权重在损失函数中进行调整。类别频率的计算公式如下 class_freq t 1 num_samples_in_class t \text{class\_freq}_t \frac{1}{\text{num\_samples\_in\_class}_t} class_freqtnum_samples_in_classt1
随后将类别频率归一化使其和为 1 normalized_class_freq t class_freq t ∑ class_freq \text{normalized\_class\_freq}_t \frac{\text{class\_freq}_t}{\sum \text{class\_freq}} normalized_class_freqt∑class_freqclass_freqt
3.3 计算类别频率的代码示例
import numpy as npdef compute_class_frequencies(targets, num_classes):# 计算每个类别的样本数量class_counts np.bincount(targets.numpy(), minlengthnum_classes)# 防止除零错误计算每个类别的频率class_freq 1.0 / (class_counts 1e-6)# 归一化类别频率class_freq class_freq / np.sum(class_freq)return torch.tensor(class_freq, dtypetorch.float32)四、PyTorch 实现
4.1 CEFL 实现
import torch
import torch.nn as nn
import torch.nn.functional as Fclass CEFL(nn.Module):def __init__(self, alpha, gamma2.0):super(CEFL, self).__init__()self.alpha alpha # 类别的权重self.gamma gamma # 焦点损失的调节参数def forward(self, inputs, targets):# 使用softmax计算类别概率p F.softmax(inputs, dim1)# 选择正确类别的预测概率p_t p.gather(1, targets.view(-1, 1))# 计算损失loss -self.alpha * (1 - p_t) ** self.gamma * torch.log(p_t)return loss.mean()代码解释 类的构造函数 (__init__): alpha: 这是一个超参数用于对各类别的损失加权。它在训练过程中控制类别的重要性。一般来说alpha 用来增加或减少某些类别的损失权重通常在类别不平衡时使用。gamma: 这是焦点损失的调节参数。焦点损失Focal Loss是一种为了解决类别不平衡问题而提出的损失函数gamma 控制模型对易分类样本和难分类样本的关注程度。较大的 gamma 会增加对难分类样本的关注。 forward 方法: inputs: 网络的输出通常是 logits大小为 (batch_size, num_classes)表示每个样本对于每个类别的预测得分。targets: 真实标签大小为 (batch_size,)是样本的正确类别标签。 F.softmax(inputs, dim1): 对模型的输出 logits 进行 softmax 计算将其转化为概率分布。softmax 的作用是将每个样本的所有类别得分转化为一个概率分布概率值的总和为 1。dim1 表示在类别维度上进行归一化即每个样本的类别概率和为 1。 p.gather(1, targets.view(-1, 1)): p 是通过 softmax 得到的类别概率矩阵p.gather(1, targets.view(-1, 1)) 选择每个样本的正确类别的概率。gather(1, targets.view(-1, 1)) 会根据 targets 中给出的标签索引从 p 中提取每个样本对应类别的概率。view(-1, 1) 将 targets 转换为列向量确保正确地索引每个样本的类别。 焦点损失部分: loss -self.alpha * (1 - p_t) ** self.gamma * torch.log(p_t): p_t: 每个样本在正确类别上的预测概率。(1 - p_t) ** self.gamma: 这是焦点损失的核心部分。它会放大模型对难分类样本的关注。对于那些预测较为确定的样本即 p_t 接近 1(1 - p_t) 会较小损失减少对于难分类样本即 p_t 接近 0(1 - p_t) 会较大损失增加。self.alpha: 用于控制类别的重要性。如果某些类别较为不平衡alpha 可以增加这些类别的损失权重。torch.log(p_t): 计算类别概率的对数值通常是交叉熵的一部分。 返回平均损失: loss.mean(): 返回所有样本的平均损失。
4.2 CEFL2 实现
import torch
import torch.nn as nn
import torch.nn.functional as Fclass CEFL2(nn.Module):def __init__(self, class_frequencies, gamma2.0):super(CEFL2, self).__init__()self.class_frequencies class_frequencies # 类别频率self.gamma gamma # 焦点损失的调节参数def forward(self, inputs, targets):# 使用softmax计算类别概率p F.softmax(inputs, dim1)# 选择正确类别的预测概率p_t p.gather(1, targets.view(-1, 1))# 计算每个类别的加权损失loss_term_1 (1 - p_t)**2 / ((1 - p_t)**2 p_t**2) * torch.log(p_t)loss_term_2 p_t**2 / ((1 - p_t)**2 p_t**2) * (1 - p_t)**self.gamma * torch.log(p_t)# 将每个类别的频率作为加权项loss -self.class_frequencies[targets] * (loss_term_1 loss_term_2)return loss.mean()代码解释 类的构造函数 (__init__): class_frequencies: 这是每个类别的频率。通常频率是类别样本的出现概率或样本的加权值。该参数在处理类别不平衡时尤其重要。较少出现的类别会赋予较高的权重以便模型对这些类别更敏感。gamma: 和 CEFL 中的 gamma 相同用于调节焦点损失的程度控制对难分类样本的关注。 forward 方法: inputs: 与 CEFL 相同是模型的输出即 logits。targets: 与 CEFL 相同是真实标签。 F.softmax(inputs, dim1): 对模型的输出 inputs 进行 softmax 计算得到每个样本在各类别上的概率。 p.gather(1, targets.view(-1, 1)): gather 方法用来根据 targets 中的标签提取每个样本的正确类别的预测概率。 加权损失部分 loss_term_1 (1 - p_t)**2 / ((1 - p_t)**2 p_t**2) * torch.log(p_t): 这是针对正确类别概率 p_t 的一个加权项。这个项的目的是将模型的关注点放在难分类的样本上。计算时考虑了正确类和错误类之间的比例进而调整损失值。 loss_term_2 p_t**2 / ((1 - p_t)**2 p_t**2) * (1 - p_t)**self.gamma * torch.log(p_t): 另一个加权项考虑了模型对难分类样本的关注即当 p_t 小样本难分类时通过增加 gamma 使得模型对难分类样本的权重更加突出。 这两个损失项的组合有助于在类别不平衡问题中进行加权增强对少数类的学习。 加权频率项 loss -self.class_frequencies[targets] * (loss_term_1 loss_term_2): 将每个类别的频率class_frequencies引入损失计算中。这使得类别频率较低的类别通常是少数类在计算损失时有更高的权重从而让模型更加关注少数类。 返回平均损失 loss.mean(): 返回所有样本的加权平均损失。
4.3 训练过程
为了更好地展示如何在训练过程中使用 CEFL 和 CEFL2 损失函数并将其应用于一个简单的神经网络模型。以下是更新后的代码示例
import torch
from torch.utils.data import DataLoader, Dataset
import numpy as np# 计算类别频率的函数
def compute_class_frequencies(targets, num_classes):# 计算每个类别的样本数量class_counts torch.bincount(targets, minlengthnum_classes)# 防止除零错误计算每个类别的频率class_freq 1.0 / (class_counts.float() 1e-6)# 归一化类别频率class_freq class_freq / class_freq.sum()return class_freq# 自定义数据集类
class CustomDataset(Dataset):def __init__(self, data, targets):self.data dataself.targets targetsdef __len__(self):return len(self.data)def __getitem__(self, idx):return self.data[idx], self.targets[idx]# 创建一个虚拟数据集
num_samples 1000
num_classes 7
input_dim 128
data torch.randn(num_samples, input_dim)
targets torch.randint(0, num_classes, (num_samples,))# 计算每个类别的频率
class_frequencies compute_class_frequencies(targets, num_classes)# 创建数据加载器
dataset CustomDataset(data, targets)
dataloader DataLoader(dataset, batch_size32, shuffleTrue)# 定义模型例如一个简单的全连接网络
class SimpleModel(nn.Module):def __init__(self, input_dim, num_classes):super(SimpleModel, self).__init__()self.fc nn.Linear(input_dim, num_classes)def forward(self, x):return self.fc(x)# 初始化模型和损失函数
model SimpleModel(input_dim, num_classes)
criterion CEFL2(class_frequencies) # 使用 CEFL2 损失函数
optimizer torch.optim.Adam(model.parameters(), lr0.001)# 训练模型
epochs 10
for epoch in range(epochs):model.train()running_loss 0.0for inputs, targets in dataloader:optimizer.zero_grad()outputs model(inputs)loss criterion(outputs, targets)loss.backward()optimizer.step()running_loss loss.item()avg_loss running_loss / len(dataloader)print(fEpoch [{epoch1}/{epochs}], Loss: {avg_loss:.4f})
代码解释 类别频率计算 (compute_class_frequencies): 这个函数计算了每个类别在数据中的出现频率。我们通过计算每个类别的出现次数并进行归一化得到类别频率。torch.bincount(targets) 用于计算每个类别出现的次数随后通过频率逆转的方式进行归一化。 自定义数据集 (CustomDataset): 这个自定义数据集类返回每个样本的输入和标签对适用于 PyTorch 的 DataLoader。 模型定义 (SimpleModel): 定义了一个简单的全连接层的神经网络用于演示如何应用损失函数。模型输入为 input_dim输出为 num_classes。 训练循环: 在每个 epoch 中模型通过前向传播获得预测结果并计算损失。使用 CEFL2 损失函数基于每个类别的频率进行加权损失计算。optimizer.zero_grad() 清空之前的梯度loss.backward() 计算梯度optimizer.step() 更新模型权重。
总结
在本文中我们深入探讨了 Class-balanced Exponential Focal Loss (CEFL) 和 Class-balanced Exponential Focal Loss 2 (CEFL2) 损失函数的定义、原理及其应用重点介绍了它们如何有效解决类别不平衡问题。通过引入类别权重和类别频率这些损失函数能够帮助模型在训练过程中更好地关注少数类样本避免对多数类样本的过拟合从而提升少数类的分类性能。
本文还提供了 PyTorch 实现的详细代码包括如何计算类别频率、定义损失函数并在训练过程中应用它们。
为帮助理解类别频率的影响以下图示展示了不同类别在训练过程中损失调整的效果 #mermaid-svg-ZuaibbDKoWZtvWvj {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-ZuaibbDKoWZtvWvj .error-icon{fill:#552222;}#mermaid-svg-ZuaibbDKoWZtvWvj .error-text{fill:#552222;stroke:#552222;}#mermaid-svg-ZuaibbDKoWZtvWvj .edge-thickness-normal{stroke-width:2px;}#mermaid-svg-ZuaibbDKoWZtvWvj .edge-thickness-thick{stroke-width:3.5px;}#mermaid-svg-ZuaibbDKoWZtvWvj .edge-pattern-solid{stroke-dasharray:0;}#mermaid-svg-ZuaibbDKoWZtvWvj .edge-pattern-dashed{stroke-dasharray:3;}#mermaid-svg-ZuaibbDKoWZtvWvj .edge-pattern-dotted{stroke-dasharray:2;}#mermaid-svg-ZuaibbDKoWZtvWvj .marker{fill:#333333;stroke:#333333;}#mermaid-svg-ZuaibbDKoWZtvWvj .marker.cross{stroke:#333333;}#mermaid-svg-ZuaibbDKoWZtvWvj svg{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;}#mermaid-svg-ZuaibbDKoWZtvWvj .label{font-family:"trebuchet ms",verdana,arial,sans-serif;color:#333;}#mermaid-svg-ZuaibbDKoWZtvWvj .cluster-label text{fill:#333;}#mermaid-svg-ZuaibbDKoWZtvWvj .cluster-label span{color:#333;}#mermaid-svg-ZuaibbDKoWZtvWvj .label text,#mermaid-svg-ZuaibbDKoWZtvWvj span{fill:#333;color:#333;}#mermaid-svg-ZuaibbDKoWZtvWvj .node rect,#mermaid-svg-ZuaibbDKoWZtvWvj .node circle,#mermaid-svg-ZuaibbDKoWZtvWvj .node ellipse,#mermaid-svg-ZuaibbDKoWZtvWvj .node polygon,#mermaid-svg-ZuaibbDKoWZtvWvj .node path{fill:#ECECFF;stroke:#9370DB;stroke-width:1px;}#mermaid-svg-ZuaibbDKoWZtvWvj .node .label{text-align:center;}#mermaid-svg-ZuaibbDKoWZtvWvj .node.clickable{cursor:pointer;}#mermaid-svg-ZuaibbDKoWZtvWvj .arrowheadPath{fill:#333333;}#mermaid-svg-ZuaibbDKoWZtvWvj .edgePath .path{stroke:#333333;stroke-width:2.0px;}#mermaid-svg-ZuaibbDKoWZtvWvj .flowchart-link{stroke:#333333;fill:none;}#mermaid-svg-ZuaibbDKoWZtvWvj .edgeLabel{background-color:#e8e8e8;text-align:center;}#mermaid-svg-ZuaibbDKoWZtvWvj .edgeLabel rect{opacity:0.5;background-color:#e8e8e8;fill:#e8e8e8;}#mermaid-svg-ZuaibbDKoWZtvWvj .cluster rect{fill:#ffffde;stroke:#aaaa33;stroke-width:1px;}#mermaid-svg-ZuaibbDKoWZtvWvj .cluster text{fill:#333;}#mermaid-svg-ZuaibbDKoWZtvWvj .cluster span{color:#333;}#mermaid-svg-ZuaibbDKoWZtvWvj div.mermaidTooltip{position:absolute;text-align:center;max-width:200px;padding:2px;font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:12px;background:hsl(80, 100%, 96.2745098039%);border:1px solid #aaaa33;border-radius:2px;pointer-events:none;z-index:100;}#mermaid-svg-ZuaibbDKoWZtvWvj :root{--mermaid-font-family:"trebuchet ms",verdana,arial,sans-serif;}#mermaid-svg-ZuaibbDKoWZtvWvj .watermark*{fill:#fff!important;stroke:none!important;font-size:15px!important;opacity:0.8!important;}#mermaid-svg-ZuaibbDKoWZtvWvj .watermark span{fill:#fff!important;stroke:none!important;font-size:15px!important;opacity:0.8!important;} CSDN 2136 原始训练集 计算类别频率 计算类别频率加权后的损失 优化模型 训练结果 类别频率 损失调整 模型优化 CSDN 2136 图中展示了训练过程中如何计算类别频率并利用这些频率对损失进行加权从而优化模型训练效果。 通过本文的讲解您应该对 CEFL 和 CEFL2 损失函数的定义、实现和应用有了更深刻的理解。如果您正在处理类别不平衡的分类任务不妨尝试使用这些损失函数它们能有效提升模型的性能特别是在少数类样本的分类效果上。
参考文献
T.-Y. Lin, P. Goyal, R. Girshick, K. He, and P. Dollar, Focal loss for dense object detection,’’ in Proc. IEEE Int. Conf. Comput. Vis., Oct. 2017, pp. 2980-2988.doi:10.48550/arXiv.1708.02002.L. Wang, C. Wang, Z. Sun, S. Cheng and L. Guo, “Class Balanced Loss for Image Classification,” in IEEE Access, vol. 8, pp. 81142-81153, 2020, doi: 10.1109/ACCESS.2020.2991237.