mysql python开发网站开发,wordpress兼容html,为某公司或企业做的门户网站,适合做装饰公司的名字知识点回顾#xff1a; resnet结构解析CBAM放置位置的思考针对预训练模型的训练策略 差异化学习率三阶段微调 ps#xff1a;今日的代码训练时长较长#xff0c;3080ti大概需要40min的训练时长 作业#xff1a; 好好理解下resnet18的模型结构尝试对vgg16cbam进行微调策略 i… 知识点回顾 resnet结构解析CBAM放置位置的思考针对预训练模型的训练策略 差异化学习率三阶段微调 ps今日的代码训练时长较长3080ti大概需要40min的训练时长 作业 好好理解下resnet18的模型结构尝试对vgg16cbam进行微调策略 import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import time# 通道注意力机制
class ChannelAttentionModule(nn.Module):def __init__(self, channels, reduction16):super().__init__()self.avg_pool nn.AdaptiveAvgPool2d(1)self.max_pool nn.AdaptiveMaxPool2d(1)self.mlp nn.Sequential(nn.Linear(channels, channels // reduction, biasFalse),nn.ReLU(),nn.Linear(channels // reduction, channels, biasFalse))self.sigmoid nn.Sigmoid()def forward(self, x):b, c, _, _ x.size()avg_features self.mlp(self.avg_pool(x).view(b, c))max_features self.mlp(self.max_pool(x).view(b, c))weights self.sigmoid(avg_features max_features).view(b, c, 1, 1)return x * weights# 空间注意力机制
class SpatialAttentionModule(nn.Module):def __init__(self, kernel7):super().__init__()self.conv nn.Conv2d(2, 1, kernel, paddingkernel//2, biasFalse)self.sigmoid nn.Sigmoid()def forward(self, x):avg_features torch.mean(x, dim1, keepdimTrue)max_features, _ torch.max(x, dim1, keepdimTrue)combined torch.cat([avg_features, max_features], dim1)spatial_weights self.sigmoid(self.conv(combined))return x * spatial_weights# 结合通道和空间注意力
class CBAMBlock(nn.Module):def __init__(self, channels, reduction16, kernel7):super().__init__()self.channel_attention ChannelAttentionModule(channels, reduction)self.spatial_attention SpatialAttentionModule(kernel)def forward(self, x):x self.channel_attention(x)x self.spatial_attention(x)return x# 配置绘图环境
plt.rcParams[font.family] [SimHei]
plt.rcParams[axes.unicode_minus] False# 设置计算设备
device torch.device(cuda if torch.cuda.is_available() else cpu)
print(f使用设备: {device})# 数据预处理
train_augmentation transforms.Compose([transforms.RandomCrop(32, padding4),transforms.RandomHorizontalFlip(),transforms.ColorJitter(brightness0.2, contrast0.2, saturation0.2, hue0.1),transforms.RandomRotation(15),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])test_transform transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])# 加载数据集
cifar_train datasets.CIFAR10(./data, trainTrue, downloadTrue, transformtrain_augmentation)
cifar_test datasets.CIFAR10(./data, trainFalse, transformtest_transform)
train_loader DataLoader(cifar_train, batch_size64, shuffleTrue)
test_loader DataLoader(cifar_test, batch_size64, shuffleFalse)# 增强型ResNet模型
class EnhancedResNet(nn.Module):def __init__(self, num_classes10, pretrainedTrue, reduction16, kernel7):super().__init__()# 加载预训练模型base_model models.resnet18(pretrainedpretrained)# 调整输入层适应小尺寸图像base_model.conv1 nn.Conv2d(3, 64, kernel_size3, stride1, padding1, biasFalse)base_model.maxpool nn.Identity()# 添加注意力模块self.attention1 CBAMBlock(64, reduction, kernel)self.attention2 CBAMBlock(128, reduction, kernel)self.attention3 CBAMBlock(256, reduction, kernel)self.attention4 CBAMBlock(512, reduction, kernel)# 替换分类层base_model.fc nn.Linear(512, num_classes)self.base base_modeldef forward(self, x):x self.base.conv1(x)x self.base.bn1(x)x self.base.relu(x)# 残差块与注意力模块交替x self.base.layer1(x)x self.attention1(x)x self.base.layer2(x)x self.attention2(x)x self.base.layer3(x)x self.attention3(x)x self.base.layer4(x)x self.attention4(x)# 分类输出x self.base.avgpool(x)x torch.flatten(x, 1)return self.base.fc(x)# 配置模型训练参数
def configure_optimizer(model, stage):if stage 1:for param in model.parameters():param.requires_grad Falsefor name, param in model.named_parameters():if attention in name or fc in name:param.requires_grad Truereturn optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr1e-3)elif stage 2:for name, param in model.named_parameters():if layer3 in name or layer4 in name or attention in name or fc in name:param.requires_grad Truereturn optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr1e-4)else: # stage 3for param in model.parameters():param.requires_grad Truereturn optim.Adam(model.parameters(), lr1e-5)# 训练和验证过程
def run_training(model, criterion, train_loader, test_loader, device, total_epochs):batch_losses []epoch_losses []train_acc_history []test_acc_history []optimizer Nonefor epoch in range(1, total_epochs 1):start_time time.time()# 分阶段配置优化器if epoch 1:print(\n *50 \n阶段1训练注意力模块和分类层\n *50)optimizer, lr configure_optimizer(model, 1)elif epoch 6:print(\n *50 \n阶段2解冻高层卷积层\n *50)optimizer, lr configure_optimizer(model, 2)elif epoch 21:print(\n *50 \n阶段3全局微调\n *50)optimizer, lr configure_optimizer(model, 3)# 训练阶段model.train()running_loss 0.0correct 0total_samples 0for batch_idx, (inputs, targets) in enumerate(train_loader):inputs, targets inputs.to(device), targets.to(device)optimizer.zero_grad()outputs model(inputs)loss criterion(outputs, targets)loss.backward()optimizer.step()# 记录损失current_loss loss.item()batch_losses.append(current_loss)running_loss current_loss# 计算准确率_, predicted outputs.max(1)total_samples targets.size(0)correct predicted.eq(targets).sum().item()# 定期打印进度if (batch_idx 1) % 100 0:avg_loss running_loss / (batch_idx 1)print(f周期: {epoch}/{total_epochs} | 批次: {batch_idx1}/{len(train_loader)} f| 当前损失: {current_loss:.4f} | 平均损失: {avg_loss:.4f})# 计算训练统计train_loss running_loss / len(train_loader)train_acc 100. * correct / total_samplesepoch_losses.append(train_loss)train_acc_history.append(train_acc)# 验证阶段model.eval()test_loss 0.0test_correct 0test_total 0with torch.no_grad():for inputs, targets in test_loader:inputs, targets inputs.to(device), targets.to(device)outputs model(inputs)test_loss criterion(outputs, targets).item()_, predicted outputs.max(1)test_total targets.size(0)test_correct predicted.eq(targets).sum().item()test_loss / len(test_loader)test_acc 100. * test_correct / test_totaltest_acc_history.append(test_acc)# 打印周期结果epoch_time time.time() - start_timeprint(f周期 {epoch}/{total_epochs} 完成 | 用时: {epoch_time:.2f}s | f训练准确率: {train_acc:.2f}% | 测试准确率: {test_acc:.2f}%)# 可视化结果visualize_results(batch_losses, epoch_losses, train_acc_history, test_acc_history)return test_acc_history[-1]# 结果可视化
def visualize_results(batch_losses, epoch_losses, train_acc, test_acc):plt.figure(figsize(15, 5))# 批次损失plt.subplot(1, 3, 1)plt.plot(batch_losses, b-, alpha0.7)plt.xlabel(训练批次)plt.ylabel(损失值)plt.title(批次训练损失)plt.grid(True)# 周期损失plt.subplot(1, 3, 2)plt.plot(epoch_losses, r-)plt.xlabel(训练周期)plt.ylabel(平均损失)plt.title(周期训练损失)plt.grid(True)# 准确率曲线plt.subplot(1, 3, 3)plt.plot(train_acc, g-, label训练准确率)plt.plot(test_acc, b-, label测试准确率)plt.xlabel(训练周期)plt.ylabel(准确率 (%))plt.title(训练和测试准确率)plt.legend()plt.grid(True)plt.tight_layout()plt.show()# 主执行流程
if __name__ __main__:# 初始化模型net EnhancedResNet().to(device)loss_fn nn.CrossEntropyLoss()print(开始训练增强型ResNet模型...)final_acc run_training(net, loss_fn, train_loader, test_loader, device, 50)print(f训练完成! 最终测试准确率: {final_acc:.2f}%)# 保存模型torch.save(net.state_dict(), enhanced_resnet_cifar10.pth)print(模型已保存至: enhanced_resnet_cifar10.pth) 浙大疏锦行