新闻发布网站模板,企业信用网查询,php网站平台,wordpress主题模板百度云引言
随着深度学习的发展#xff0c;图像分类已成为一项基础的技术#xff0c;被广泛应用于各种场景之中。本文将介绍如何使用Flask框架和PyTorch库来构建一个简单的图像分类Web服务。通过这个服务#xff0c;用户可以通过HTTP POST请求上传花朵图片#xff0c;然后由后端…引言
随着深度学习的发展图像分类已成为一项基础的技术被广泛应用于各种场景之中。本文将介绍如何使用Flask框架和PyTorch库来构建一个简单的图像分类Web服务。通过这个服务用户可以通过HTTP POST请求上传花朵图片然后由后端的深度学习模型对其进行分类并返回分类结果。
环境搭建
首先确保安装了以下Python库
Flask用于构建Web应用。PyTorch用于加载和运行深度学习模型。torchvision用于图像处理和加载预训练模型。PIL用于图像处理。
1. 初始化Flask应用
import io
import flask
import torch
import torch.nn.functional as F
from PIL import Image
from torch import nn
from torchvision import transforms, models# 初始化Flask app
app flask.Flask(__name__)# 创建一个新的Flask应用程序实例
# __name__参数通常被传递给FasK应用程序来定位应用程序的根路径这样FlasK就可以知道在哪里找到模板、静态文件等。
# 总体来说app flask.Flask(__name_)是FLaSK应用程序的起点。它初始化了一个新的FLaSK应用程序实例。为后续添加路由、配置等莫定
2. 加载模型
为了方便我们将预训练好的ResNet18模型保存在一个名为best.pth的检查点文件中。我们将加载这个模型并准备好用于推理。
def load_model():Load the pre-trained model, you can use your model just as easily.global model# 加载resnet18网络。ResNet残差网络是一种深度学习架构设计用于解决深层神经网络中的梯度消失问题。model models.resnet18()# num_ftrs 被赋值为模型全连接层fc的输入特征数量。num_ftrs model.fc.in_featuresmodel.fc nn.Sequential(nn.Linear(num_ftrs, 102)) # 类别数自己根据自己任务来# print(model)#导入最优模型#这行代码实际上是加载了一个预先训练好的模型的权重。# torch.load(best.pth) 会加载保存在 best.pth 文件中的模型检查点# 通常这个检查点包含模型的状态字典state dict即模型所有层的权重和偏置。# model.load_state_dict(checkpoint[state_dict]) 会将加载的状态字典应用到我们的模型上使模型具有之前训练时学到的参数。checkpoint torch.load(best.pth)model.load_state_dict(checkpoint[state_dict])# 将模型指定为测试格式model.eval()# 是否使用gpuif use_gpu:model.cuda()
3. 预处理图像
为了使图像符合模型的要求我们需要对其进行预处理包括调整大小、转换为张量以及标准化。
def prepare_image(image, target_size):# 检查输入图像的颜色模式是否为 RGB。如果不是则将其转换为 RGB 模式。if image.mode ! RGB:image image.convert(RGB)# Resize the input image and preprocess it.(按照所使用的模型将输入图片的尺寸修改并转为tensor)# 使用 transforms.Resize 对象将图像调整为目标尺寸 target_size。image transforms.Resize(target_size)(image)# 使用 transforms.ToTensor() 将图像转换为 PyTorch 的 Tensor 类型。image transforms.ToTensor()(image)# Convert to Torch, Tensor and normalize. mean与std# 对图像张量进行标准化处理。# 标准化的参数 [0.485, 0.456, 0.406] 是均值代表每个颜色通道红、绿、蓝的平均值# [0.229, 0.224, 0.225] 是标准差代表每个颜色通道的标准差。image transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(image)# Add batch_size axis 增加一个维度用于按batch测试本次这里一次测试一张image image[None]if use_gpu:image image.cuda() # return torch.tensor(imagereturn torch.tensor(image)
4. 设置路由和处理请求
使用Flask设置路由并处理POST请求中的图像数据。
# 定义了一个名为 predict 的视图函数并通过装饰器 app.route 绑定了路由 /predict允许该路由接收 HTTP POST 请求。
app.route(/predict, methods[POST])
def predict():# 做一个标志刚开始无图像传入时为false,传入图像时为truedata {success: False}if flask.request.method POST: # 检查请求的方法是否为 POSTif flask.request.files.get(image): # 判断是否为图像image flask.request.files[image].read() # 将收到的图像进行读取内容为二进制image Image.open(io.BytesIO(image)) # 将这个二进制字符串转换为一个 PIL 图像对象。# 利用上面的预处理函数将读入的图像进行预处理image prepare_image(image, target_size(224, 224))# 将预处理后的图像输入到模型中并得到一个未归一化的输出向量。# 使用 F.softmax 函数将这个输出向量转换为概率分布这表示模型对于每个类别的预测概率。preds F.softmax(model(image), dim1) # 得到各个类别的概率# cpu().data 确保结果在 CPU 上并且不包含梯度信息。dim1 表示沿着列方向查找最大值。results torch.topk(preds.cpu().data, k3, dim1) # 概率最大的前3个结果# torch.topk用于返回输入张量中每行最大的k个元素及其对应的索引# 将结果从 PyTorch 张量转换为 NumPy 数组以便更容易地处理。results[0] 包含了概率值而 results[1] 包含了类别索引。results (results[0].cpu().numpy(), results[1].cpu().numpy())# 将data字典增加一个keyvalue,其中value为ist格式data[predictions] list()for probability, label in zip(results[0][0], results[1][0]):# Label name idx2labellstr(label)]r {label: str(label), probability: float(probability)}# 将预测结果添加至data字典data[predictions].append(r)# Indicate that the reguest was a success.data[success] Truereturn flask.jsonify(data) # 将最后结果以json格式文件传出,并返回给客户端。
5. 启动服务
最后在主入口处启动Flask服务并加载模型。
if __name__ __main__:print(Loading PyTorch model and Flask starting server ...)print(Please wait until server has fully started)load_model() #加载模型app.run(host192.168.24.45, port5012) #启动服务器IP地址端口
我们点击运行即可启动服务器保持程序运行客户端即可通过ip地址和端口访问 接口客户端实现
在上一部分中我们完成了基于Flask和PyTorch的图像分类Web服务的搭建。接下来我们将继续探讨如何编写客户端代码来与该服务进行交互。通过编写一个简单的Python脚本来发送HTTP请求我们可以测试我们的Web服务是否正常工作。
客户端代码实现
为了测试我们的图像分类服务我们需要编写一段代码来模拟客户端的行为。这段代码将负责向服务端发送包含图像的POST请求并接收返回的分类结果。
import requestsflask_url http://192.168.24.45:5012/predict# 定义一个名为 predict_result 的函数该函数接受一个参数 image_path表示要发送给 Flask 应用的图像文件的路径。
def predict_result(image_path):# 使用 open 函数以二进制模式 (rb) 打开图像文件并读取其内容。image open(image_path, rb).read()# 将图像内容包装到一个字典 payload 中键为 image值为图像的二进制内容。payload {image: image}# 使用 requests.post 方法发送一个 POST 请求到 Flask 应用其中 files 参数用于上传文件。# filespayload 表示将 payload 字典中的内容作为文件上传。r requests.post(flask_url, filespayload).json() # .json() 方法将响应内容解析为 Python 字典形式方便后续处理。if r[success]: # 检查响应中的 success 键是否为 True。如果为 True则意味着请求成功并且会打印出预测结果。for (i, result) in enumerate(r[predictions]): print({}.预测类别为{}:的概率:{}.format(i 1, result[label], result[probability]))print(OK) # 预测结果存储在 r[predictions] 列表中每个预测结果都是一个字典包含类别标签 (label) 和概率 (probability)。else: # 失败打印print(Request failed)
if __name__ __main__:predict_result(../data/6/image_07162.jpg)
预测图像
本次实验随机采用一张花的图片上传到到服务端 预测结果 客户端访问记录
当我们通过客户端访问服务端时可通过后台查看访问记录
总结
通过以上步骤我们构建了一个简单的图像分类Web服务。用户可以通过发送POST请求并将图像作为附件上传然后服务端会对图像进行分类并返回最有可能的三个类别及其概率。这种服务可以用于各种场合如在线图像识别、产品分类等。
希望这篇文章能帮助你了解如何使用Flask和PyTorch快速搭建一个图像分类的服务并激发你在实际项目中的应用。