NOTE欢迎来到“从零搭建AI小程序全栈实战”系列文章!
此为本系列的第五篇文章,你可以点击跳转到第一篇文章在系列文章中快速跳转。
一、这篇文章将要做的
1.1 回顾
- 上一篇我们完成了微信小程序前端,拍照上传→识别展示→历史记录全链路已通
- 当前状态:小程序能跑通,但后端 /predict 返回的还是 fake_prediction,AI 是假的
1.2 目标
- 训练一个真正的图像分类模型,替换掉假预测
- 四个里程碑:
- 数据集准备与预处
- 用 PyTorch 训练 ResNet18 分类模型
- 用训练好的模型替换 Flask API
- 调整 Spring Boot 后端适配新模型接口
1.3 最终效果预览
- 小程序拍照后显示 “organic(置信度: 93.21%)” 而非 “fake_prediction”
- 这一切跑在你自己的 GPU 上,不依赖任何第三方 AI 服务
二、 数据集准备与预处理
2.1 选择数据集
我们使用 Kaggle 垃圾分类数据集
-
包含 2.5 万张图片,分 Organic(厨余)和 Recyclable(可回收)两类
-
二分类最简单,8G 显存轻松训练
2.2 下载数据集并整理
在 Windows 上的 G:\DEMO_Project\demo_assist\api-server\ 中创建目录 dataset、目录 models 与文件 train.py 。
创建后 api-server 总体结构如下:
G:\DEMO_Project\demo_assist\api-server\├── dataset/│ ├── train/│ │ ├── organic/ # 厨余垃圾图片│ │ └── recyclable/ # 可回收垃圾图片│ └── test/│ ├── organic/│ └── recyclable/├── models/ # 训练好的模型存这里├── train.py # 训练脚本└── .... # 其他原有文件三、模型训练脚本
3.1 安装 PyTorch 依赖
打开 Anaconda Prompt,激活环境并安装新包:
conda activate ai-envpip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118pip install matplotlib tqdm pillow3.2 编写训练脚本
编写 train.py ,完整代码如下:
import torchimport torch.nn as nnimport torch.optim as optimfrom torchvision import datasets, transforms, modelsfrom torch.utils.data import DataLoaderimport osimport matplotlib.pyplot as plt
# ========== 1. 配置参数 ==========DATA_DIR = './dataset'MODEL_SAVE_PATH = './models/waste_classifier.pth'BATCH_SIZE = 32EPOCHS = 10LEARNING_RATE = 0.001IMG_SIZE = 224NUM_CLASSES = 2 # 根据你的类别数修改
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')print(f'使用设备: {device}')
# ========== 2. 数据预处理 ==========# 训练集:随机裁剪、翻转、旋转,增加数据多样性train_transforms = transforms.Compose([ transforms.RandomResizedCrop(IMG_SIZE), transforms.RandomHorizontalFlip(), transforms.RandomRotation(15), transforms.ColorJitter(brightness=0.2, contrast=0.2), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
# 测试集:只需要缩放和归一化test_transforms = transforms.Compose([ transforms.Resize((IMG_SIZE, IMG_SIZE)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
# ========== 3. 加载数据 ==========train_dataset = datasets.ImageFolder( os.path.join(DATA_DIR, 'train'), transform=train_transforms)test_dataset = datasets.ImageFolder( os.path.join(DATA_DIR, 'test'), transform=test_transforms)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
# 自动获取类名映射class_names = train_dataset.classesprint(f'类别数量: {len(class_names)}, 类名: {class_names}')
# ========== 4. 构建模型 ==========# 使用预训练的 ResNet18,只改最后的全连接层model = models.resnet18(weights='DEFAULT')num_features = model.fc.in_featuresmodel.fc = nn.Linear(num_features, NUM_CLASSES)model = model.to(device)
# ========== 5. 损失函数和优化器 ==========criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
# ========== 6. 训练循环 ==========train_losses = []test_accuracies = []
for epoch in range(EPOCHS): # --- 训练 --- model.train() running_loss = 0.0 for images, labels in train_loader: images, labels = images.to(device), labels.to(device)
optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) loss.backward() optimizer.step()
running_loss += loss.item()
avg_loss = running_loss / len(train_loader) train_losses.append(avg_loss)
# --- 测试 --- model.eval() correct = 0 total = 0 with torch.no_grad(): for images, labels in test_loader: images, labels = images.to(device), labels.to(device) outputs = model(images) _, predicted = torch.max(outputs, 1) total += labels.size(0) correct += (predicted == labels).sum().item()
accuracy = 100 * correct / total test_accuracies.append(accuracy)
print(f'Epoch [{epoch+1}/{EPOCHS}], Loss: {avg_loss:.4f}, Test Accuracy: {accuracy:.2f}%')
# ========== 7. 保存模型 ==========os.makedirs('models', exist_ok=True)torch.save({ 'model_state_dict': model.state_dict(), 'class_names': class_names, 'img_size': IMG_SIZE}, MODEL_SAVE_PATH)print(f'模型已保存到 {MODEL_SAVE_PATH}')
# ========== 8. 绘制训练曲线 ==========plt.figure(figsize=(12, 4))plt.subplot(1, 2, 1)plt.plot(train_losses)plt.title('Training Loss')plt.xlabel('Epoch')plt.subplot(1, 2, 2)plt.plot(test_accuracies)plt.title('Test Accuracy')plt.xlabel('Epoch')plt.ylabel('%')plt.tight_layout()plt.savefig('./models/training_curve.png')plt.show()3.3 开始训练
cd "G:\DEMO_Project\demo_assist\api-server\"python train.py正常输出类似:
使用设备: cuda类别数量: 2, 类名: ['organic', 'recyclable']Epoch [1/10], Loss: 0.3456, Test Accuracy: 85.23%...Epoch [10/10], Loss: 0.0213, Test Accuracy: 95.67%模型已保存到 ./models/waste_classifier.pth我们使用 nvidia-smi 确认 GPU 在工作(显存占用约2~4GB)
预计10 个 epoch 大约跑 15-30 分钟,时长取决于数据集大小。
四、使用真实模型替换 Flask API
现在 Flask API 里 /predict 还在返回 “fake_prediction”,用刚训练好的模型替换掉。
4.1 改造 Flask API
在 api-server 文件夹 上找到之前创建的 app.py,完整替换为:
from flask import Flask, request, jsonifyimport torchimport torch.nn as nnfrom torchvision import transforms, modelsfrom PIL import Imageimport ioimport os
app = Flask(__name__)
# ========== 1. 加载模型 ==========MODEL_PATH = 'D:/ai-project/models/waste_classifier.pth'device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 加载保存的模型信息checkpoint = torch.load(MODEL_PATH, map_location=device)class_names = checkpoint['class_names']img_size = checkpoint.get('img_size', 224)
# 重建模型结构并加载权重model = models.resnet18(weights=None)num_features = model.fc.in_featuresmodel.fc = nn.Linear(num_features, len(class_names))model.load_state_dict(checkpoint['model_state_dict'])model = model.to(device)model.eval()
print(f'模型已加载,类别: {class_names}, 设备: {device}')
# ========== 2. 预处理函数 ==========transform = transforms.Compose([ transforms.Resize((img_size, img_size)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
def predict_image(image_bytes): """输入图片字节流,返回预测类名和置信度""" image = Image.open(io.BytesIO(image_bytes)).convert('RGB') image_tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad(): outputs = model(image_tensor) probabilities = torch.nn.functional.softmax(outputs, dim=1) confidence, predicted_idx = torch.max(probabilities, 1)
class_name = class_names[predicted_idx.item()] confidence_val = confidence.item()
return class_name, confidence_val
# ========== 3. API 接口 ==========@app.route('/health')def health(): return jsonify({ "status": "ok", "cuda": torch.cuda.is_available(), "classes": class_names })
@app.route('/predict', methods=['POST'])def predict(): # 方案A:接收图片URL(后端下载) if request.is_json: data = request.get_json() image_url = data.get('image_url', '') if image_url.startswith('http'): import requests response = requests.get(image_url, timeout=10) image_bytes = response.content else: # 本地路径 with open('.' + image_url, 'rb') as f: # 路径前加 . 因为Flask工作目录问题 image_bytes = f.read() class_name, confidence = predict_image(image_bytes) return jsonify({ "class": class_name, "confidence": round(confidence, 4) })
# 方案B:直接上传图片文件 if 'file' in request.files: file = request.files['file'] image_bytes = file.read() class_name, confidence = predict_image(image_bytes) return jsonify({ "class": class_name, "confidence": round(confidence, 4) })
return jsonify({"error": "no image provided"}), 400
if __name__ == '__main__': app.run(host='0.0.0.0', port=5000, debug=False)4.2 本地测试新 API
在 Windows 端的 Powershell 上启动 Flask:
cd "G:\DEMO_Project\demo_assist\api-server\"python app.py用 Windows 端的浏览器访问 http://localhost:5000/health,确认返回类名。
在 服务器端 上用 curl 测试:
curl -X POST http://localhost:5000/predict \ -F "file=@D:/test_organic.jpg"期望返回:
{ "class": "organic", "confidence": 0.9578}五、Spring Boot 后端适配新模型
现在 Flask 返回的真结果格式变了,Spring Boot 后端需要做一点适配。
5.1 新增 DTO 类
原来的预测方法返回 Map<String, Object>,现在需要解析新的字段。
让我们创建专门的 DTO 类来接收。
新建 src/main/java/com/example/demo/dto/AiPredictResult.java:
package com.example.demo.dto;
public class AiPredictResult { private String clazz; // 注意:class 是关键字,用 @JsonProperty private Double confidence;
// Getter / Setter public String getClazz() { return clazz; } public void setClazz(String clazz) { this.clazz = clazz; } public Double getConfidence() { return confidence; } public void setConfidence(Double confidence) { this.confidence = confidence; }}5.2 修改 AiService
修改 AiServce.java 中的调用:
public AiPredictResult predictByFile(MultipartFile file) throws IOException { // 把 MultipartFile 转成字节,发给 Flask // 注意:Spring Boot 在服务器上,笔记本 API 在 localhost:5000(通过 frp) String url = AI_BASE_URL + "/predict";
// 用 RestTemplate 上传文件 HttpHeaders headers = new HttpHeaders(); headers.setContentType(MediaType.MULTIPART_FORM_DATA);
// 先存为临时文件 File tempFile = File.createTempFile("upload_", ".jpg"); file.transferTo(tempFile);
MultiValueMap<String, Object> body = new LinkedMultiValueMap<>(); body.add("file", new FileSystemResource(tempFile));
HttpEntity<MultiValueMap<String, Object>> requestEntity = new HttpEntity<>(body, headers); ResponseEntity<AiPredictResult> response = restTemplate.postForEntity( url, requestEntity, AiPredictResult.class);
// 清理临时文件 tempFile.delete();
return response.getBody();}5.3 修改 Controller
在 ImageController 的 uploadAndRecognize 方法中,调用新的 predictByFile 并保存结果:
@PostMapping("/upload-and-recognize")public Result<ImageRecord> uploadAndRecognize(@RequestParam("file") MultipartFile file) throws IOException { // 保存文件(原有逻辑) Files.createDirectories(Paths.get(uploadDir)); String filename = System.currentTimeMillis() + "_" + file.getOriginalFilename(); File dest = new File(uploadDir + "/" + filename); file.transferTo(dest);
// 调用真实 AI 预测(新逻辑) AiPredictResult aiResult = aiService.predictByFile(file); String resultStr = aiResult.getClazz() + " (置信度: " + String.format("%.2f", aiResult.getConfidence() * 100) + "%)";
// 存数据库 ImageRecord record = new ImageRecord(file.getOriginalFilename(), "/uploads/" + filename); record.setRecognitionResult(resultStr); return Result.success(repository.save(record));}5.4 确保网络穿透正常
确保:
-
笔记本上 Flask API 在跑(端口 5000)
-
笔记本上 frpc 在跑,连上了服务器
-
服务器上 frps 在跑
在服务器上验证:
curl -X POST http://localhost:5000/predict -F "file=@一张图片.jpg"返回真实分类结果即可。
六、全链路验证
-
笔记本:Flask + frpc 运行中
-
服务器:jar 运行中
-
小程序:拍照 → 上传 → 等待 2-3 秒 → 显示 “recyclable (置信度: 95.67%)”
-
历史记录:能看到带有真实识别结果的记录
七、验证清单
-
数据集下载或自建完成,目录结构正确
-
train.py训练完毕,测试准确率 > 85%(分类越少越容易达标) -
模型保存为
.pth文件,class_names正确保存 -
Flask
/predict接口能返回真实类名和置信度 -
Spring Boot 能正确解析 AI 返回结果并存入数据库
-
小程序拍照后显示的不是 “fake_prediction”,而是 “organic(置信度: 93.21%)”