哪个网站可以做优惠券,科技公司网站设,文创做的好的网站推荐,福永做网站PyTorch#xff1a;6-可视化
注#xff1a;所有资料来源且归属于thorough-pytorch(https://datawhalechina.github.io/thorough-pytorch/)#xff0c;下文仅为学习记录
6.1#xff1a;可视化网络结构
Keras中可以调用model.summary()的API进行模型参数可视化
torchinfo…PyTorch6-可视化
注所有资料来源且归属于thorough-pytorch(https://datawhalechina.github.io/thorough-pytorch/)下文仅为学习记录
6.1可视化网络结构
Keras中可以调用model.summary()的API进行模型参数可视化
torchinfo是由torchsummary和torchsummaryX重构出的库用于可视化网络结构
6.1.1使用print函数打印模型基础信息
【案例resnet18】
模型构建
import torchvision.models as models
model models.resnet18()直接print模型只能得出基础构件的信息
ResNet((conv1): Conv2d(3, 64, kernel_size(7, 7), stride(2, 2), padding(3, 3), biasFalse)(bn1): BatchNorm2d(64, eps1e-05, momentum0.1, affineTrue, track_running_statsTrue)(relu): ReLU(inplaceTrue)(maxpool): MaxPool2d(kernel_size3, stride2, padding1, dilation1, ceil_modeFalse)(layer1): Sequential((0): Bottleneck((conv1): Conv2d(64, 64, kernel_size(1, 1), stride(1, 1), biasFalse)(bn1): BatchNorm2d(64, eps1e-05, momentum0.1, affineTrue, track_running_statsTrue)(conv2): Conv2d(64, 64, kernel_size(3, 3), stride(1, 1), padding(1, 1), biasFalse)(bn2): BatchNorm2d(64, eps1e-05, momentum0.1, affineTrue, track_running_statsTrue)(conv3): Conv2d(64, 256, kernel_size(1, 1), stride(1, 1), biasFalse)(bn3): BatchNorm2d(256, eps1e-05, momentum0.1, affineTrue, track_running_statsTrue)(relu): ReLU(inplaceTrue)(downsample): Sequential((0): Conv2d(64, 256, kernel_size(1, 1), stride(1, 1), biasFalse)(1): BatchNorm2d(256, eps1e-05, momentum0.1, affineTrue, track_running_statsTrue)))... ...)(avgpool): AdaptiveAvgPool2d(output_size(1, 1))(fc): Linear(in_features2048, out_features1000, biasTrue)
)结果既不能显示出每一层的shape也不能显示对应参数量的大小。
6.1.2使用torchinfo可视化网络结构
安装
# 安装方法一
pip install torchinfo
# 安装方法二
conda install -c conda-forge torchinfo使用
使用torchinfo.summary()函数必需的参数分别是modelinput_size[batch_size,channel,h,w]。
import torchvision.models as models
from torchinfo import summary
resnet18 models.resnet18()
# 实例化模型
summary(resnet18, (1, 3, 224, 224))
# 1batch_size 3:图片的通道数 224: 图片的高宽结构化输出 Layer (type:depth-idx) Output Shape Param #ResNet -- --
├─Conv2d: 1-1 [1, 64, 112, 112] 9,408
├─BatchNorm2d: 1-2 [1, 64, 112, 112] 128
├─ReLU: 1-3 [1, 64, 112, 112] --
├─MaxPool2d: 1-4 [1, 64, 56, 56] --
├─Sequential: 1-5 [1, 64, 56, 56] --
│ └─BasicBlock: 2-1 [1, 64, 56, 56] --
│ │ └─Conv2d: 3-1 [1, 64, 56, 56] 36,864
│ │ └─BatchNorm2d: 3-2 [1, 64, 56, 56] 128
│ │ └─ReLU: 3-3 [1, 64, 56, 56] --
│ │ └─Conv2d: 3-4 [1, 64, 56, 56] 36,864
│ │ └─BatchNorm2d: 3-5 [1, 64, 56, 56] 128
│ │ └─ReLU: 3-6 [1, 64, 56, 56] --
│ └─BasicBlock: 2-2 [1, 64, 56, 56] --
│ │ └─Conv2d: 3-7 [1, 64, 56, 56] 36,864
│ │ └─BatchNorm2d: 3-8 [1, 64, 56, 56] 128
│ │ └─ReLU: 3-9 [1, 64, 56, 56] --
│ │ └─Conv2d: 3-10 [1, 64, 56, 56] 36,864
│ │ └─BatchNorm2d: 3-11 [1, 64, 56, 56] 128
│ │ └─ReLU: 3-12 [1, 64, 56, 56] --
├─Sequential: 1-6 [1, 128, 28, 28] --
│ └─BasicBlock: 2-3 [1, 128, 28, 28] --
│ │ └─Conv2d: 3-13 [1, 128, 28, 28] 73,728
│ │ └─BatchNorm2d: 3-14 [1, 128, 28, 28] 256
│ │ └─ReLU: 3-15 [1, 128, 28, 28] --
│ │ └─Conv2d: 3-16 [1, 128, 28, 28] 147,456
│ │ └─BatchNorm2d: 3-17 [1, 128, 28, 28] 256
│ │ └─Sequential: 3-18 [1, 128, 28, 28] 8,448
│ │ └─ReLU: 3-19 [1, 128, 28, 28] --
│ └─BasicBlock: 2-4 [1, 128, 28, 28] --
│ │ └─Conv2d: 3-20 [1, 128, 28, 28] 147,456
│ │ └─BatchNorm2d: 3-21 [1, 128, 28, 28] 256
│ │ └─ReLU: 3-22 [1, 128, 28, 28] --
│ │ └─Conv2d: 3-23 [1, 128, 28, 28] 147,456
│ │ └─BatchNorm2d: 3-24 [1, 128, 28, 28] 256
│ │ └─ReLU: 3-25 [1, 128, 28, 28] --
├─Sequential: 1-7 [1, 256, 14, 14] --
│ └─BasicBlock: 2-5 [1, 256, 14, 14] --
│ │ └─Conv2d: 3-26 [1, 256, 14, 14] 294,912
│ │ └─BatchNorm2d: 3-27 [1, 256, 14, 14] 512
│ │ └─ReLU: 3-28 [1, 256, 14, 14] --
│ │ └─Conv2d: 3-29 [1, 256, 14, 14] 589,824
│ │ └─BatchNorm2d: 3-30 [1, 256, 14, 14] 512
│ │ └─Sequential: 3-31 [1, 256, 14, 14] 33,280
│ │ └─ReLU: 3-32 [1, 256, 14, 14] --
│ └─BasicBlock: 2-6 [1, 256, 14, 14] --
│ │ └─Conv2d: 3-33 [1, 256, 14, 14] 589,824
│ │ └─BatchNorm2d: 3-34 [1, 256, 14, 14] 512
│ │ └─ReLU: 3-35 [1, 256, 14, 14] --
│ │ └─Conv2d: 3-36 [1, 256, 14, 14] 589,824
│ │ └─BatchNorm2d: 3-37 [1, 256, 14, 14] 512
│ │ └─ReLU: 3-38 [1, 256, 14, 14] --
├─Sequential: 1-8 [1, 512, 7, 7] --
│ └─BasicBlock: 2-7 [1, 512, 7, 7] --
│ │ └─Conv2d: 3-39 [1, 512, 7, 7] 1,179,648
│ │ └─BatchNorm2d: 3-40 [1, 512, 7, 7] 1,024
│ │ └─ReLU: 3-41 [1, 512, 7, 7] --
│ │ └─Conv2d: 3-42 [1, 512, 7, 7] 2,359,296
│ │ └─BatchNorm2d: 3-43 [1, 512, 7, 7] 1,024
│ │ └─Sequential: 3-44 [1, 512, 7, 7] 132,096
│ │ └─ReLU: 3-45 [1, 512, 7, 7] --
│ └─BasicBlock: 2-8 [1, 512, 7, 7] --
│ │ └─Conv2d: 3-46 [1, 512, 7, 7] 2,359,296
│ │ └─BatchNorm2d: 3-47 [1, 512, 7, 7] 1,024
│ │ └─ReLU: 3-48 [1, 512, 7, 7] --
│ │ └─Conv2d: 3-49 [1, 512, 7, 7] 2,359,296
│ │ └─BatchNorm2d: 3-50 [1, 512, 7, 7] 1,024
│ │ └─ReLU: 3-51 [1, 512, 7, 7] --
├─AdaptiveAvgPool2d: 1-9 [1, 512, 1, 1] --
├─Linear: 1-10 [1, 1000] 513,000Total params: 11,689,512
Trainable params: 11,689,512
Non-trainable params: 0
Total mult-adds (G): 1.81Input size (MB): 0.60
Forward/backward pass size (MB): 39.75
Params size (MB): 46.76
Estimated Total Size (MB): 87.11注意使用colab或者jupyter notebook时想要实现该方法summary()一定是该单元即notebook中的cell的返回值否则就需要使用print(summary(...))来可视化。
6.2CNN可视化
可视化内容可视化特征是如何提取的、提取到的特征的形式、模型在输入数据上的关注点
6.2.1CNN卷积核可视化
卷积核在CNN中负责提取特征——可视化特征是如何提取的
靠近输入的层提取的特征是相对简单的结构靠近输出的层提取的特征和图中的实体形状相近
kernel可视化的核心特定层的卷积核即特定层的模型权重可视化卷积核即可视化对应的权重矩阵
【案例VGG11】
【1】加载模型确定层信息
import torch
from torchvision.models import vgg11model vgg11(pretrainedTrue)
print(dict(model.features.named_children()))
{0: Conv2d(3, 64, kernel_size(3, 3), stride(1, 1), padding(1, 1)),1: ReLU(inplaceTrue),2: MaxPool2d(kernel_size2, stride2, padding0, dilation1, ceil_modeFalse),3: Conv2d(64, 128, kernel_size(3, 3), stride(1, 1), padding(1, 1)),4: ReLU(inplaceTrue),5: MaxPool2d(kernel_size2, stride2, padding0, dilation1, ceil_modeFalse),6: Conv2d(128, 256, kernel_size(3, 3), stride(1, 1), padding(1, 1)),7: ReLU(inplaceTrue),8: Conv2d(256, 256, kernel_size(3, 3), stride(1, 1), padding(1, 1)),9: ReLU(inplaceTrue),10: MaxPool2d(kernel_size2, stride2, padding0, dilation1, ceil_modeFalse),11: Conv2d(256, 512, kernel_size(3, 3), stride(1, 1), padding(1, 1)),12: ReLU(inplaceTrue),13: Conv2d(512, 512, kernel_size(3, 3), stride(1, 1), padding(1, 1)),14: ReLU(inplaceTrue),15: MaxPool2d(kernel_size2, stride2, padding0, dilation1, ceil_modeFalse),16: Conv2d(512, 512, kernel_size(3, 3), stride(1, 1), padding(1, 1)),17: ReLU(inplaceTrue),18: Conv2d(512, 512, kernel_size(3, 3), stride(1, 1), padding(1, 1)),19: ReLU(inplaceTrue),20: MaxPool2d(kernel_size2, stride2, padding0, dilation1, ceil_modeFalse)}【2】可视化卷积层的对应参数第3层
卷积核对应的应为卷积层Conv2d
conv1 dict(model.features.named_children())[3]
kernel_set conv1.weight.detach()
num len(conv1.weight.detach())
print(kernel_set.shape)torch.Size([128, 64, 3, 3])for i in range(0,num):i_kernel kernel_set[i]plt.figure(figsize(20, 17))if (len(i_kernel)) 1:for idx, filer in enumerate(i_kernel):plt.subplot(9, 9, idx1) plt.axis(off)plt.imshow(filer[ :, :].detach(),cmapbwr)由于第3层的特征图由64维变为128维因此共有128*64个卷积核
6.2.2CNN特征图可视化
特征图输入的原始图像经过每次卷积层得到的数据
可视化卷积核是为了看模型提取哪些特征可视化特征图则是为了看模型提取到的特征是什么样子的。
PyTorch提供了一个专用的接口使得网络在前向传播过程中能够获取到特征图接口的名称叫hook。
实现过程
class Hook(object):def __init__(self):self.module_name []self.features_in_hook []self.features_out_hook []def __call__(self,module, fea_in, fea_out):print(hooker working, self)self.module_name.append(module.__class__)self.features_in_hook.append(fea_in)self.features_out_hook.append(fea_out)return Nonedef plot_feature(model, idx, inputs):hh Hook()model.features[idx].register_forward_hook(hh)# forward_model(model,False)model.eval()_ model(inputs)print(hh.module_name)print((hh.features_in_hook[0][0].shape))print((hh.features_out_hook[0].shape))out1 hh.features_out_hook[0]total_ft out1.shape[1]first_item out1[0].cpu().clone() plt.figure(figsize(20, 17))for ftidx in range(total_ft):if ftidx 99:breakft first_item[ftidx]plt.subplot(10, 10, ftidx1) plt.axis(off)#plt.imshow(ft[ :, :].detach(),cmapgray)plt.imshow(ft[ :, :].detach())首先实现了一个hook类之后在plot_feature函数中将该hook类的对象注册到要进行可视化的网络的某层中。
model在进行前向传播的时候会调用hook的__call__函数Hook类在此处存储了当前层的输入和输出。
Hook类种的hook输入为in输出为out是一个list每次前向传播一次都是调用一次即 hook 长度会增加1。
6.2.3CNN class activation map可视化
class activation map CAM的作用是判断哪些变量对模型来说是重要的。
在CNN可视化的场景下即判断图像中哪些像素点对预测结果是重要的。
CAM系列操作的实现可以通过开源工具包pytorch-grad-cam来实现。
安装
pip install grad-cam案例
加载图片
import torch
from torchvision.models import vgg11,resnet18,resnet101,resnext101_32x8d
import matplotlib.pyplot as plt
from PIL import Image
import numpy as npmodel vgg11(pretrainedTrue)
img_path ./dog.png
# resize操作是为了和传入神经网络训练图片大小一致
img Image.open(img_path).resize((224,224))
# 需要将原始图片转为np.float32格式并且在0-1之间
rgb_img np.float32(img)/255
plt.imshow(img)CAM可视化
from pytorch_grad_cam import GradCAM,ScoreCAM,GradCAMPlusPlus,AblationCAM,XGradCAM,EigenCAM,FullGrad
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image# 将图片转为tensor
img_tensor torch.from_numpy(rgb_img).permute(2,0,1).unsqueeze(0)target_layers [model.features[-1]]
# 选取合适的类激活图但是ScoreCAM和AblationCAM需要batch_size
cam GradCAM(modelmodel,target_layerstarget_layers)
targets [ClassifierOutputTarget(preds)]
# 上方preds需要设定比如ImageNet有1000类这里可以设为200
grayscale_cam cam(input_tensorimg_tensor, targetstargets)
grayscale_cam grayscale_cam[0, :]
cam_img show_cam_on_image(rgb_img, grayscale_cam, use_rgbTrue)
print(type(cam_img))
Image.fromarray(cam_img)6.2.4FlashTorch快速实现CNN可视化
https://github.com/MisaOgura/flashtorch
安装
pip install flashtorch可视化梯度
import matplotlib.pyplot as plt
import torchvision.models as models
from flashtorch.utils import apply_transforms, load_image
from flashtorch.saliency import Backpropmodel models.alexnet(pretrainedTrue)
backprop Backprop(model)image load_image(/content/images/great_grey_owl.jpg)
owl apply_transforms(image)target_class 24
backprop.visualize(owl, target_class, guidedTrue, use_gpuTrue)可视化卷积核
import torchvision.models as models
from flashtorch.activmax import GradientAscentmodel models.vgg16(pretrainedTrue)
g_ascent GradientAscent(model.features)# specify layer and filter info
conv5_1 model.features[24]
conv5_1_filters [45, 271, 363, 489]g_ascent.visualize(conv5_1, conv5_1_filters, titleVGG16: conv5_1)6.3使用TensorBoard可视化训练过程
6.3.1安装
使用pip安装
pip install tensorboardX6.3.2TensorBoard可视化的基本逻辑
可将TensorBoard看做一个记录员记录我们指定的数据包括模型每一层的feature map权重训练loss等。
TensorBoard将记录下来的内容保存在一个用户指定的文件夹里程序不断运行中TensorBoard会不断记录记录下的内容可以通过网页的形式加以可视化。
6.3.3TensorBoard的配置和启动
【1】指定保存记录数据的文件夹调用tensorboard中的SummaryWriter作为记录员
from tensorboardX import SummaryWriter
# from torch.utils.tensorboard import SummaryWriter
# 使用PyTorch自带的tensorboard
writer SummaryWriter(./runs)上面的操作实例化SummaryWritter为变量writer并指定writer的输出目录为当前目录下的runs目录。
【2】启动tensorboard
tensorboard --logdir/path/to/logs/ --portxxxx“path/to/logs/是指定的保存tensorboard记录结果的文件路径
–port是外部访问TensorBoard的端口号可以通过访问ip:port访问tensorboard
6.3.4TensorBoard模型结构可视化
【1】定义模型
import torch.nn as nnclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 nn.Conv2d(in_channels3,out_channels32,kernel_size 3)self.pool nn.MaxPool2d(kernel_size 2,stride 2)self.conv2 nn.Conv2d(in_channels32,out_channels64,kernel_size 5)self.adaptive_pool nn.AdaptiveMaxPool2d((1,1))self.flatten nn.Flatten()self.linear1 nn.Linear(64,32)self.relu nn.ReLU()self.linear2 nn.Linear(32,1)self.sigmoid nn.Sigmoid()def forward(self,x):x self.conv1(x)x self.pool(x)x self.conv2(x)x self.pool(x)x self.adaptive_pool(x)x self.flatten(x)x self.linear1(x)x self.relu(x)x self.linear2(x)y self.sigmoid(x)return ymodel Net()
print(model)
Net((conv1): Conv2d(3, 32, kernel_size(3, 3), stride(1, 1))(pool): MaxPool2d(kernel_size2, stride2, padding0, dilation1, ceil_modeFalse)(conv2): Conv2d(32, 64, kernel_size(5, 5), stride(1, 1))(adaptive_pool): AdaptiveMaxPool2d(output_size(1, 1))(flatten): Flatten(start_dim1, end_dim-1)(linear1): Linear(in_features64, out_features32, biasTrue)(relu): ReLU()(linear2): Linear(in_features32, out_features1, biasTrue)(sigmoid): Sigmoid()
)可视化模型的思路给定一个输入数据前向传播后得到模型的结构再通过TensorBoard进行可视化
【2】使用add_graph
writer.add_graph(model, input_to_model torch.rand(1, 3, 224, 224))
writer.close()6.3.5TensorBoard图像可视化
对于单张图片的显示使用add_image对于多张图片的显示使用add_images有时需要使用torchvision.utils.make_grid将多张图片拼成一张图片后用writer.add_image显示
【案例CIFAR10】
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoadertransform_train transforms.Compose([transforms.ToTensor()])
transform_test transforms.Compose([transforms.ToTensor()])train_data datasets.CIFAR10(., trainTrue, downloadTrue, transformtransform_train)
test_data datasets.CIFAR10(., trainFalse, downloadTrue, transformtransform_test)
train_loader DataLoader(train_data, batch_size64, shuffleTrue)
test_loader DataLoader(test_data, batch_size64)images, labels next(iter(train_loader))# 仅查看一张图片
writer SummaryWriter(./pytorch_tb)
writer.add_image(images[0], images[0])
writer.close()# 将多张图片拼接成一张图片中间用黑色网格分割
# create grid of images
writer SummaryWriter(./pytorch_tb)
img_grid torchvision.utils.make_grid(images)
writer.add_image(image_grid, img_grid)
writer.close()# 将多张图片直接写入
writer SummaryWriter(./pytorch_tb)
writer.add_images(images,images,global_step 0)
writer.close()6.3.6TensorBoard连续变量可视化
可视化连续变量或时序变量的变化过程通过add_scalar实现
writer SummaryWriter(./pytorch_tb)
for i in range(500):x iy x**2writer.add_scalar(x, x, i) #日志中记录x在第step i 的值writer.add_scalar(y, y, i) #日志中记录y在第step i 的值
writer.close()如果想在同一张图中显示多个曲线则需要分别建立存放子路径使用SummaryWriter指定路径即可自动创建但需要在tensorboard运行目录下同时在add_scalar中修改曲线的标签使其一致即可。
writer1 SummaryWriter(./pytorch_tb/x)
writer2 SummaryWriter(./pytorch_tb/y)
for i in range(500):x iy x*2writer1.add_scalar(same, x, i) #日志中记录x在第step i 的值writer2.add_scalar(same, y, i) #日志中记录y在第step i 的值
writer1.close()
writer2.close()6.3.7TensorBoard参数分布可视化
对参数或向量的变化或者对其分布进行研究时可通过add_histogram实现。
import torch
import numpy as np# 创建正态分布的张量模拟参数矩阵
def norm(mean, std):t std * torch.randn((100, 20)) meanreturn twriter SummaryWriter(./pytorch_tb/)
for step, mean in enumerate(range(-10, 10, 1)):w norm(mean, 1)writer.add_histogram(w, w, step)writer.flush()
writer.close()6.3.8服务器端使用TensorBoard
由于服务器端没有浏览器纯命令模式因此需要进行相应的配置才可以在本地浏览器使用tensorboard查看服务器运行的训练过程。
方法【1】【2】都是建立SSH隧道实现远程端口到本机端口的转发。
【1】MobaXterm
在MobaXterm点击Tunneling。选择New SSH tunnel。对新建的SSH通道做以下设置第一栏选择Local port forwarding Remote Server处填写localhost Remote port处填写6006tensorboard默认会在6006端口进行显示。也可以根据 tensorboard --logdir/path/to/logs/ --portxxxx的命令中的port进行修改 SSH server 填写连接服务器的ip地址SSH login填写连接的服务器的用户名SSH port填写端口号通常为22 forwarded port填写本地的一个端口号以便后续进行访问。设定好之后点击Save然后Start。再次启动tensorboard在本地的浏览器输入http://localhost:6006/对其进行访问。
【2】Xshell
连接上服务器后打开当前会话属性选择隧道点击添加。目标主机代表的是服务器源主机代表的是本地端口的选择根据实际情况而定。启动tensorboard在本地127.0.0.1:6006 或者 localhost:6006进行访问。
6.4使用wandb可视化训练过程
wandb是Weights Biases的缩写能自动记录模型训练过程中的超参数和输出指标然后可视化和比较结果并快速与其他人共享结果。
6.4.1安装
【1】使用pip安装
pip install wandb【2】在官网注册账号并复制API keyshttps://wandb.ai/
【3】在本地使用命令登录
wandb login【4】粘贴API keys
6.4.2使用
import wandb
wandb.init(projectmy-project, entitymy-name)Quickstart | Weights Biases Documentation (wandb.ai)
project和entity是在wandb上创建的项目名称和用户名
6.4.3demo演示
【案例CIFAR10的图像分类】
【1】导入库
import random # to set the python random seed
import numpy # to set the numpy random seed
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.models import resnet18
import warnings
warnings.filterwarnings(ignore)【2】初始化wandb
# 初始化wandb
import wandb
wandb.init(projectthorough-pytorch,namewandb_demo,)【3】设置超参数
使用wandb.config来设置超参数这样就可以在wandb的界面上看到超参数的变化。
wandb.config的使用方法和字典类似可以使用config.key的方式来设置超参数。
# 超参数设置
config wandb.config # config的初始化
config.batch_size 64
config.test_batch_size 10
config.epochs 5
config.lr 0.01
config.momentum 0.1
config.use_cuda True
config.seed 2043
config.log_interval 10 # 设置随机数
def set_seed(seed):random.seed(config.seed) torch.manual_seed(config.seed) numpy.random.seed(config.seed) 【4】构建train和test的pipeline
def train(model, device, train_loader, optimizer):model.train()for batch_id, (data, target) in enumerate(train_loader):data, target data.to(device), target.to(device)optimizer.zero_grad()output model(data)criterion nn.CrossEntropyLoss()loss criterion(output, target)loss.backward()optimizer.step()# wandb.log用来记录一些日志(accuracy,loss and epoch), 便于随时查看网路的性能
def test(model, device, test_loader, classes):model.eval()test_loss 0correct 0example_images []with torch.no_grad():for data, target in test_loader:data, target data.to(device), target.to(device)output model(data)criterion nn.CrossEntropyLoss()test_loss criterion(output, target).item()pred output.max(1, keepdimTrue)[1]correct pred.eq(target.view_as(pred)).sum().item()example_images.append(wandb.Image(data[0], captionPred:{} Truth:{}.format(classes[pred[0].item()], classes[target[0]])))# 使用wandb.log 记录你想记录的指标wandb.log({Examples: example_images,Test Accuracy: 100. * correct / len(test_loader.dataset),Test Loss: test_loss})wandb.watch_called False def main():use_cuda config.use_cuda and torch.cuda.is_available()device torch.device(cuda:0 if use_cuda else cpu)kwargs {num_workers: 1, pin_memory: True} if use_cuda else {}# 设置随机数set_seed(config.seed)torch.backends.cudnn.deterministic True# 数据预处理transform transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# 加载数据train_loader DataLoader(datasets.CIFAR10(rootdataset,trainTrue,downloadTrue,transformtransform), batch_sizeconfig.batch_size, shuffleTrue, **kwargs)test_loader DataLoader(datasets.CIFAR10(rootdataset,trainFalse,downloadTrue,transformtransform), batch_sizeconfig.batch_size, shuffleFalse, **kwargs)classes (plane, car, bird, cat, deer, dog, frog, horse, ship, truck)model resnet18(pretrainedTrue).to(device)optimizer optim.SGD(model.parameters(), lrconfig.lr, momentumconfig.momentum)wandb.watch(model, logall)for epoch in range(1, config.epochs 1):train(model, device, train_loader, optimizer)test(model, device, test_loader, classes)# 本地和云端模型保存torch.save(model.state_dict(), model.pth)wandb.save(model.pth)if __name__ __main__:main()其他提供的功能模型的超参数搜索模型的版本控制模型的部署等。