1990 字
10 分钟
从零搭建AI小程序全栈实战(四):模型训练
NOTE

欢迎来到“从零搭建AI小程序全栈实战”系列文章!
此为本系列的第五篇文章,你可以点击跳转到第一篇文章在系列文章中快速跳转。

一、这篇文章将要做的#

1.1 回顾#

  • 上一篇我们完成了微信小程序前端,拍照上传→识别展示→历史记录全链路已通
  • 当前状态:小程序能跑通,但后端 /predict 返回的还是 fake_prediction,AI 是假的

1.2 目标#

  • 训练一个真正的图像分类模型,替换掉假预测
  • 四个里程碑:
    • 数据集准备与预处
    • 用 PyTorch 训练 ResNet18 分类模型
    • 用训练好的模型替换 Flask API
    • 调整 Spring Boot 后端适配新模型接口

1.3 最终效果预览#

  • 小程序拍照后显示 “organic(置信度: 93.21%)” 而非 “fake_prediction”
  • 这一切跑在你自己的 GPU 上,不依赖任何第三方 AI 服务

二、 数据集准备与预处理#

2.1 选择数据集#

我们使用 Kaggle 垃圾分类数据集

  • 地址:Waste Classification data | Kaggle

  • 包含 2.5 万张图片,分 Organic(厨余)和 Recyclable(可回收)两类

  • 二分类最简单,8G 显存轻松训练

2.2 下载数据集并整理#

在 Windows 上的 G:\DEMO_Project\demo_assist\api-server\ 中创建目录 dataset、目录 models 与文件 train.py
创建后 api-server 总体结构如下:

G:\DEMO_Project\demo_assist\api-server\
├── dataset/
│ ├── train/
│ │ ├── organic/ # 厨余垃圾图片
│ │ └── recyclable/ # 可回收垃圾图片
│ └── test/
│ ├── organic/
│ └── recyclable/
├── models/ # 训练好的模型存这里
├── train.py # 训练脚本
└── .... # 其他原有文件

三、模型训练脚本#

3.1 安装 PyTorch 依赖#

打开 Anaconda Prompt,激活环境并安装新包:

Terminal window
conda activate ai-env
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
pip install matplotlib tqdm pillow

3.2 编写训练脚本#

编写 train.py ,完整代码如下:

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import os
import matplotlib.pyplot as plt
# ========== 1. 配置参数 ==========
DATA_DIR = './dataset'
MODEL_SAVE_PATH = './models/waste_classifier.pth'
BATCH_SIZE = 32
EPOCHS = 10
LEARNING_RATE = 0.001
IMG_SIZE = 224
NUM_CLASSES = 2 # 根据你的类别数修改
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'使用设备: {device}')
# ========== 2. 数据预处理 ==========
# 训练集:随机裁剪、翻转、旋转,增加数据多样性
train_transforms = transforms.Compose([
transforms.RandomResizedCrop(IMG_SIZE),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15),
transforms.ColorJitter(brightness=0.2, contrast=0.2),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# 测试集:只需要缩放和归一化
test_transforms = transforms.Compose([
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# ========== 3. 加载数据 ==========
train_dataset = datasets.ImageFolder(
os.path.join(DATA_DIR, 'train'), transform=train_transforms)
test_dataset = datasets.ImageFolder(
os.path.join(DATA_DIR, 'test'), transform=test_transforms)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
# 自动获取类名映射
class_names = train_dataset.classes
print(f'类别数量: {len(class_names)}, 类名: {class_names}')
# ========== 4. 构建模型 ==========
# 使用预训练的 ResNet18,只改最后的全连接层
model = models.resnet18(weights='DEFAULT')
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, NUM_CLASSES)
model = model.to(device)
# ========== 5. 损失函数和优化器 ==========
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
# ========== 6. 训练循环 ==========
train_losses = []
test_accuracies = []
for epoch in range(EPOCHS):
# --- 训练 ---
model.train()
running_loss = 0.0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
avg_loss = running_loss / len(train_loader)
train_losses.append(avg_loss)
# --- 测试 ---
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = 100 * correct / total
test_accuracies.append(accuracy)
print(f'Epoch [{epoch+1}/{EPOCHS}], Loss: {avg_loss:.4f}, Test Accuracy: {accuracy:.2f}%')
# ========== 7. 保存模型 ==========
os.makedirs('models', exist_ok=True)
torch.save({
'model_state_dict': model.state_dict(),
'class_names': class_names,
'img_size': IMG_SIZE
}, MODEL_SAVE_PATH)
print(f'模型已保存到 {MODEL_SAVE_PATH}')
# ========== 8. 绘制训练曲线 ==========
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(train_losses)
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.subplot(1, 2, 2)
plt.plot(test_accuracies)
plt.title('Test Accuracy')
plt.xlabel('Epoch')
plt.ylabel('%')
plt.tight_layout()
plt.savefig('./models/training_curve.png')
plt.show()

3.3 开始训练#

Terminal window
cd "G:\DEMO_Project\demo_assist\api-server\"
python train.py

正常输出类似:

使用设备: cuda
类别数量: 2, 类名: ['organic', 'recyclable']
Epoch [1/10], Loss: 0.3456, Test Accuracy: 85.23%
...
Epoch [10/10], Loss: 0.0213, Test Accuracy: 95.67%
模型已保存到 ./models/waste_classifier.pth

我们使用 nvidia-smi 确认 GPU 在工作(显存占用约2~4GB)

预计10 个 epoch 大约跑 15-30 分钟,时长取决于数据集大小。

四、使用真实模型替换 Flask API#

现在 Flask API 里 /predict 还在返回 “fake_prediction”,用刚训练好的模型替换掉。

4.1 改造 Flask API#

api-server 文件夹 上找到之前创建的 app.py,完整替换为:

from flask import Flask, request, jsonify
import torch
import torch.nn as nn
from torchvision import transforms, models
from PIL import Image
import io
import os
app = Flask(__name__)
# ========== 1. 加载模型 ==========
MODEL_PATH = 'D:/ai-project/models/waste_classifier.pth'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 加载保存的模型信息
checkpoint = torch.load(MODEL_PATH, map_location=device)
class_names = checkpoint['class_names']
img_size = checkpoint.get('img_size', 224)
# 重建模型结构并加载权重
model = models.resnet18(weights=None)
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, len(class_names))
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(device)
model.eval()
print(f'模型已加载,类别: {class_names}, 设备: {device}')
# ========== 2. 预处理函数 ==========
transform = transforms.Compose([
transforms.Resize((img_size, img_size)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def predict_image(image_bytes):
"""输入图片字节流,返回预测类名和置信度"""
image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
image_tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(image_tensor)
probabilities = torch.nn.functional.softmax(outputs, dim=1)
confidence, predicted_idx = torch.max(probabilities, 1)
class_name = class_names[predicted_idx.item()]
confidence_val = confidence.item()
return class_name, confidence_val
# ========== 3. API 接口 ==========
@app.route('/health')
def health():
return jsonify({
"status": "ok",
"cuda": torch.cuda.is_available(),
"classes": class_names
})
@app.route('/predict', methods=['POST'])
def predict():
# 方案A:接收图片URL(后端下载)
if request.is_json:
data = request.get_json()
image_url = data.get('image_url', '')
if image_url.startswith('http'):
import requests
response = requests.get(image_url, timeout=10)
image_bytes = response.content
else:
# 本地路径
with open('.' + image_url, 'rb') as f: # 路径前加 . 因为Flask工作目录问题
image_bytes = f.read()
class_name, confidence = predict_image(image_bytes)
return jsonify({
"class": class_name,
"confidence": round(confidence, 4)
})
# 方案B:直接上传图片文件
if 'file' in request.files:
file = request.files['file']
image_bytes = file.read()
class_name, confidence = predict_image(image_bytes)
return jsonify({
"class": class_name,
"confidence": round(confidence, 4)
})
return jsonify({"error": "no image provided"}), 400
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000, debug=False)

4.2 本地测试新 API#

在 Windows 端的 Powershell 上启动 Flask:

Terminal window
cd "G:\DEMO_Project\demo_assist\api-server\"
python app.py

用 Windows 端的浏览器访问 http://localhost:5000/health,确认返回类名。

在 服务器端 上用 curl 测试:

Terminal window
curl -X POST http://localhost:5000/predict \
-F "file=@D:/test_organic.jpg"

期望返回:

{
"class": "organic",
"confidence": 0.9578
}

五、Spring Boot 后端适配新模型#

现在 Flask 返回的真结果格式变了,Spring Boot 后端需要做一点适配。

5.1 新增 DTO 类#

原来的预测方法返回 Map<String, Object>,现在需要解析新的字段。
让我们创建专门的 DTO 类来接收。

新建 src/main/java/com/example/demo/dto/AiPredictResult.java

package com.example.demo.dto;
public class AiPredictResult {
private String clazz; // 注意:class 是关键字,用 @JsonProperty
private Double confidence;
// Getter / Setter
public String getClazz() { return clazz; }
public void setClazz(String clazz) { this.clazz = clazz; }
public Double getConfidence() { return confidence; }
public void setConfidence(Double confidence) { this.confidence = confidence; }
}

5.2 修改 AiService#

修改 AiServce.java 中的调用:

public AiPredictResult predictByFile(MultipartFile file) throws IOException {
// 把 MultipartFile 转成字节,发给 Flask
// 注意:Spring Boot 在服务器上,笔记本 API 在 localhost:5000(通过 frp)
String url = AI_BASE_URL + "/predict";
// 用 RestTemplate 上传文件
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.MULTIPART_FORM_DATA);
// 先存为临时文件
File tempFile = File.createTempFile("upload_", ".jpg");
file.transferTo(tempFile);
MultiValueMap<String, Object> body = new LinkedMultiValueMap<>();
body.add("file", new FileSystemResource(tempFile));
HttpEntity<MultiValueMap<String, Object>> requestEntity = new HttpEntity<>(body, headers);
ResponseEntity<AiPredictResult> response = restTemplate.postForEntity(
url, requestEntity, AiPredictResult.class);
// 清理临时文件
tempFile.delete();
return response.getBody();
}

5.3 修改 Controller#

ImageControlleruploadAndRecognize 方法中,调用新的 predictByFile 并保存结果:

@PostMapping("/upload-and-recognize")
public Result<ImageRecord> uploadAndRecognize(@RequestParam("file") MultipartFile file) throws IOException {
// 保存文件(原有逻辑)
Files.createDirectories(Paths.get(uploadDir));
String filename = System.currentTimeMillis() + "_" + file.getOriginalFilename();
File dest = new File(uploadDir + "/" + filename);
file.transferTo(dest);
// 调用真实 AI 预测(新逻辑)
AiPredictResult aiResult = aiService.predictByFile(file);
String resultStr = aiResult.getClazz() + " (置信度: " +
String.format("%.2f", aiResult.getConfidence() * 100) + "%)";
// 存数据库
ImageRecord record = new ImageRecord(file.getOriginalFilename(), "/uploads/" + filename);
record.setRecognitionResult(resultStr);
return Result.success(repository.save(record));
}

5.4 确保网络穿透正常#

确保:

  • 笔记本上 Flask API 在跑(端口 5000)

  • 笔记本上 frpc 在跑,连上了服务器

  • 服务器上 frps 在跑

在服务器上验证:

Terminal window
curl -X POST http://localhost:5000/predict -F "file=@一张图片.jpg"

返回真实分类结果即可。

六、全链路验证#

  • 笔记本:Flask + frpc 运行中

  • 服务器:jar 运行中

  • 小程序:拍照 → 上传 → 等待 2-3 秒 → 显示 “recyclable (置信度: 95.67%)”

  • 历史记录:能看到带有真实识别结果的记录

七、验证清单#

  • 数据集下载或自建完成,目录结构正确

  • train.py 训练完毕,测试准确率 > 85%(分类越少越容易达标)

  • 模型保存为 .pth 文件,class_names 正确保存

  • Flask /predict 接口能返回真实类名和置信度

  • Spring Boot 能正确解析 AI 返回结果并存入数据库

  • 小程序拍照后显示的不是 “fake_prediction”,而是 “organic(置信度: 93.21%)”

从零搭建AI小程序全栈实战(四):模型训练
https://47.113.107.125:80/posts/ai-miniapp/04-模型训练/
作者
TerryC
发布于
2026-01-05
许可协议
CC BY-NC-SA 4.0