云建站微网站,网站建设有技术的公司,网站 什么语言开发的,wordpress登录1.自建数据集与划分训练集与测试集 2.模型相关知识 3.model.py——定义AlexNet网络模型 4.train.py——加载数据集并训练#xff0c;训练集计算损失值loss#xff0c;测试集计算accuracy#xff0c;保存训练好的网络参数 5.predict.py——利用训练好的网络参数后#xff0c…1.自建数据集与划分训练集与测试集 2.模型相关知识 3.model.py——定义AlexNet网络模型 4.train.py——加载数据集并训练训练集计算损失值loss测试集计算accuracy保存训练好的网络参数 5.predict.py——利用训练好的网络参数后用自己找的图像进行分类测试
一、自建数据集与划分训练集与测试集
1.自建数据文件夹 首先我们确定这次分类种类采用爬虫、官网数据集和自己拍照的照片获取三类准备个文件夹里面包含三个文件夹文件夹名字随便取最好是所属种类英文每个文件夹照片数量最好一样多五百多张以上。如我选了蒲公英玫瑰郁金香三类如data_set包含flowers_data,它包含flowers_photos,它包含三个文件夹分别是三个类文件夹。
2.划分训练集与测试集
这里需要使用通用的划分数据代码这次是与flowers_data同一目录下运行。
import os
from shutil import copy
import randomdef mkfile(file):if not os.path.exists(file):os.makedirs(file)# 获取 photos 文件夹下除 .txt 文件以外所有文件夹名即3种分类的类名
file_path flower_data/flower_photos
flower_class [cla for cla in os.listdir(file_path) if .txt not in cla]# 创建 训练集train 文件夹并由3种类名在其目录下创建3个子目录
mkfile(flower_data/train)
for cla in flower_class:mkfile(flower_data/train/ cla)# 创建 验证集val 文件夹并由3种类名在其目录下创建3个子目录
mkfile(flower_data/val)
for cla in flower_class:mkfile(flower_data/val/ cla)# 划分比例训练集 : 验证集 9 : 1
split_rate 0.1# 遍历3种花的全部图像并按比例分成训练集和验证集
for cla in flower_class:cla_path file_path / cla / # 某一类别动作的子目录images os.listdir(cla_path) # iamges 列表存储了该目录下所有图像的名称num len(images)eval_index random.sample(images, kint(num * split_rate)) # 从images列表中随机抽取 k 个图像名称for index, image in enumerate(images):# eval_index 中保存验证集val的图像名称if image in eval_index:image_path cla_path imagenew_path flower_data/val/ clacopy(image_path, new_path) # 将选中的图像复制到新路径# 其余的图像保存在训练集train中else:image_path cla_path imagenew_path flower_data/train/ clacopy(image_path, new_path)print(\r[{}] processing [{}/{}].format(cla, index 1, num), end) # processing barprint()print(processing done!)
最后运行在flowers_data会多两个文件是train和val训练集和测试集
二、模型相关知识
之前有文章介绍模型如果不清楚可以点下链接转过去学习
深度学习卷积神经网络CNN之 VGGNet模型主vgg16和vgg19网络模型详解说明理论篇 三、model.py——定义AlexNet网络模型
这里还是直接复制给出原模型不用改参数。
import torch.nn as nn
import torch# official pretrain weights
model_urls {vgg11: https://download.pytorch.org/models/vgg11-bbd30ac9.pth,vgg13: https://download.pytorch.org/models/vgg13-c768596a.pth,vgg16: https://download.pytorch.org/models/vgg16-397923af.pth,vgg19: https://download.pytorch.org/models/vgg19-dcbb9e9d.pth
}class VGG(nn.Module):def __init__(self, features, num_classes1000, init_weightsFalse):super(VGG, self).__init__()self.features featuresself.classifier nn.Sequential(nn.Linear(512*7*7, 4096),nn.ReLU(True),nn.Dropout(p0.5),nn.Linear(4096, 4096),nn.ReLU(True),nn.Dropout(p0.5),nn.Linear(4096, num_classes))if init_weights:self._initialize_weights()def forward(self, x):# N x 3 x 224 x 224x self.features(x)# N x 512 x 7 x 7x torch.flatten(x, start_dim1)# N x 512*7*7x self.classifier(x)return xdef _initialize_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):# nn.init.kaiming_normal_(m.weight, modefan_out, nonlinearityrelu)nn.init.xavier_uniform_(m.weight)if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.xavier_uniform_(m.weight)# nn.init.normal_(m.weight, 0, 0.01)nn.init.constant_(m.bias, 0)def make_features(cfg: list):layers []in_channels 3for v in cfg:if v M:layers [nn.MaxPool2d(kernel_size2, stride2)]else:conv2d nn.Conv2d(in_channels, v, kernel_size3, padding1)layers [conv2d, nn.ReLU(True)]in_channels vreturn nn.Sequential(*layers)cfgs {vgg11: [64, M, 128, M, 256, 256, M, 512, 512, M, 512, 512, M],vgg13: [64, 64, M, 128, 128, M, 256, 256, M, 512, 512, M, 512, 512, M],vgg16: [64, 64, M, 128, 128, M, 256, 256, 256, M, 512, 512, 512, M, 512, 512, 512, M],vgg19: [64, 64, M, 128, 128, M, 256, 256, 256, 256, M, 512, 512, 512, 512, M, 512, 512, 512, 512, M],
}def vgg(model_namevgg16, **kwargs):assert model_name in cfgs, Warning: model number {} not in cfgs dict!.format(model_name)cfg cfgs[model_name]model VGG(make_features(cfg), **kwargs)return model
四、train.py——模型训练加载数据集并训练训练集计算损失值loss测试集计算accuracy保存训练好的网络参数
在63行修改为3因为只有三类
net vgg(model_namemodel_name, num_classes3, init_weightsTrue)
import os
import sys
import jsonimport torch
import torch.nn as nn
from torchvision import transforms, datasets
import torch.optim as optim
from tqdm import tqdmfrom model import vggdef main():device torch.device(cuda:0 if torch.cuda.is_available() else cpu)print(using {} device..format(device))data_transform {train: transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),val: transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}data_root os.path.abspath(os.path.join(os.getcwd(), ../..)) # get data root pathimage_path os.path.join(data_root, data_set1, flower_data1) # flower data set pathassert os.path.exists(image_path), {} path does not exist..format(image_path)train_dataset datasets.ImageFolder(rootos.path.join(image_path, train),transformdata_transform[train])train_num len(train_dataset)# {daisy:0, dandelion:1, roses:2, sunflower:3, tulips:4}flower_list train_dataset.class_to_idxcla_dict dict((val, key) for key, val in flower_list.items())# write dict into json filejson_str json.dumps(cla_dict, indent4)with open(class_indices.json, w) as json_file:json_file.write(json_str)batch_size 64nw min([os.cpu_count(), batch_size if batch_size 1 else 0, 8]) # number of workersprint(Using {} dataloader workers every process.format(nw))train_loader torch.utils.data.DataLoader(train_dataset,batch_sizebatch_size, shuffleTrue,num_workersnw)validate_dataset datasets.ImageFolder(rootos.path.join(image_path, val),transformdata_transform[val])val_num len(validate_dataset)validate_loader torch.utils.data.DataLoader(validate_dataset,batch_sizebatch_size, shuffleFalse,num_workersnw)print(using {} images for training, {} images for validation..format(train_num,val_num))# test_data_iter iter(validate_loader)# test_image, test_label test_data_iter.next()model_name vgg16net vgg(model_namemodel_name, num_classes3, init_weightsTrue)%%%%%%%%这一行net.to(device)loss_function nn.CrossEntropyLoss()optimizer optim.Adam(net.parameters(), lr0.0001)epochs 10best_acc 0.0save_path ./{}Net.pth.format(model_name)train_steps len(train_loader)for epoch in range(epochs):# trainnet.train()running_loss 0.0train_bar tqdm(train_loader, filesys.stdout)for step, data in enumerate(train_bar):images, labels dataoptimizer.zero_grad()outputs net(images.to(device))loss loss_function(outputs, labels.to(device))loss.backward()optimizer.step()# print statisticsrunning_loss loss.item()train_bar.desc train epoch[{}/{}] loss:{:.3f}.format(epoch 1,epochs,loss)# validatenet.eval()acc 0.0 # accumulate accurate number / epochwith torch.no_grad():val_bar tqdm(validate_loader, filesys.stdout)for val_data in val_bar:val_images, val_labels val_dataoutputs net(val_images.to(device))predict_y torch.max(outputs, dim1)[1]acc torch.eq(predict_y, val_labels.to(device)).sum().item()val_accurate acc / val_numprint([epoch %d] train_loss: %.3f val_accuracy: %.3f %(epoch 1, running_loss / train_steps, val_accurate))if val_accurate best_acc:best_acc val_accuratetorch.save(net.state_dict(), save_path)print(Finished Training)if __name__ __main__:main()
训练结果截图如下
五、predict.py——利用训练好的网络参数后用自己找的图像进行分类测试
import os
import jsonimport torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as pltfrom model import vggdef main():device torch.device(cuda:0 if torch.cuda.is_available() else cpu)data_transform transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# load imageimg_path 1.jpgassert os.path.exists(img_path), file: {} dose not exist..format(img_path)img Image.open(img_path)plt.imshow(img)# [N, C, H, W]img data_transform(img)# expand batch dimensionimg torch.unsqueeze(img, dim0)# read class_indictjson_path ./class_indices.jsonassert os.path.exists(json_path), file: {} dose not exist..format(json_path)with open(json_path, r) as f:class_indict json.load(f)# create modelmodel vgg(model_namevgg16, num_classes5).to(device)# load model weightsweights_path ./vgg16Net.pthassert os.path.exists(weights_path), file: {} dose not exist..format(weights_path)model.load_state_dict(torch.load(weights_path))model.eval()with torch.no_grad():# predict classoutput torch.squeeze(model(img.to(device))).cpu()predict torch.softmax(output, dim0)predict_cla torch.argmax(predict).numpy()print_res class: {} prob: {:.3}.format(class_indict[str(predict_cla)],predict[predict_cla].numpy())plt.title(print_res)for i in range(len(predict)):print(class: {:10} prob: {:.3}.format(class_indict[str(i)],predict[i].numpy()))plt.show()if __name__ __main__:main()
在网上下载了一郁金香的图片使用VGG16网络查看是否可以将图片种类正确识别。