import sys
import subprocess
import zipfile
import pkg_resources
import requests
# 检查并安装缺失的依赖
required = {
'torch', 'torchvision', 'numpy', 'matplotlib',
'tqdm', 'requests', 'pillow', 'scikit-learn', 'pyqt5', 'torchsummary' # 添加torchsummary
}
installed = {pkg.key for pkg in pkg_resources.working_set}
missing = required - installed
if missing:
print(f"安装缺失的依赖: {', '.join(missing)}")
python = sys.executable
subprocess.check_call([python, '-m', 'pip', 'install', *missing])
# 现在导入其他模块
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import datasets, transforms, models
import numpy as np
import matplotlib.pyplot as plt
import os
import shutil
from PIL import Image
from tqdm import tqdm
import matplotlib
from matplotlib import font_manager
import json
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
# PyQt5相关导入
from PyQt5.QtWidgets import (QApplication, QWidget, QVBoxLayout, QHBoxLayout, QPushButton,
QLabel, QScrollArea, QFileDialog, QMessageBox, QTextEdit)
from PyQt5.QtGui import QPixmap
from PyQt5.QtCore import Qt, QObject, pyqtSignal
import threading
import time
# 导入torchsummary
from torchsummary import summary
# 设置中文字体支持
try:
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
except:
try:
font_url = "https://github.com/googlefonts/noto-cjk/raw/main/Sans/OTF/SimplifiedChinese/NotoSansSC-Regular.otf"
font_path = "NotoSansSC-Regular.otf"
if not os.path.exists(font_path):
response = requests.get(font_url)
with open(font_path, 'wb') as f:
f.write(response.content)
font_prop = font_manager.FontProperties(fname=font_path)
plt.rcParams['font.family'] = font_prop.get_name()
except:
print("警告: 无法设置中文字体")
matplotlib.use('Agg')
# 第二部分:下载并设置数据集
def download_and_extract_dataset():
base_dir = "data"
data_path = os.path.join(base_dir, "dogs-vs-cats")
train_folder = os.path.join(data_path, 'train')
test_folder = os.path.join(data_path, 'test')
os.makedirs(train_folder, exist_ok=True)
os.makedirs(test_folder, exist_ok=True)
# 检查数据集是否完整
cat_files = [f for f in os.listdir(train_folder) if f.startswith('cat')]
dog_files = [f for f in os.listdir(train_folder) if f.startswith('dog')]
if len(cat_files) > 1000 and len(dog_files) > 1000:
print("数据集已存在,跳过下载")
return
print("正在下载数据集...")
dataset_url = "https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_5340.zip"
try:
zip_path = os.path.join(base_dir, "catsdogs.zip")
# 下载文件
if not os.path.exists(zip_path):
response = requests.get(dataset_url, stream=True)
total_size = int(response.headers.get('content-length', 0))
with open(zip_path, 'wb') as f, tqdm(
desc="下载进度",
total=total_size,
unit='B',
unit_scale=True,
unit_divisor=1024,
) as bar:
for data in response.iter_content(chunk_size=1024):
size = f.write(data)
bar.update(size)
print("下载完成,正在解压...")
# 解压文件
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(base_dir)
print("数据集解压完成!")
# 移动文件
extracted_dir = os.path.join(base_dir, "PetImages")
# 移动猫图片
cat_source = os.path.join(extracted_dir, "Cat")
for file in os.listdir(cat_source):
src = os.path.join(cat_source, file)
dst = os.path.join(train_folder, f"cat.{file}")
if os.path.exists(src) and not os.path.exists(dst):
shutil.move(src, dst)
# 移动狗图片
dog_source = os.path.join(extracted_dir, "Dog")
for file in os.listdir(dog_source):
src = os.path.join(dog_source, file)
dst = os.path.join(train_folder, f"dog.{file}")
if os.path.exists(src) and not os.path.exists(dst):
shutil.move(src, dst)
# 创建测试集(从训练集中抽取20%)
train_files = os.listdir(train_folder)
np.random.seed(42)
test_files = np.random.choice(train_files, size=int(len(train_files) * 0.2), replace=False)
for file in test_files:
src = os.path.join(train_folder, file)
dst = os.path.join(test_folder, file)
if os.path.exists(src) and not os.path.exists(dst):
shutil.move(src, dst)
# 清理临时文件
if os.path.exists(extracted_dir):
shutil.rmtree(extracted_dir)
if os.path.exists(zip_path):
os.remove(zip_path)
print(
f"数据集设置完成!训练集: {len(os.listdir(train_folder))} 张图片, 测试集: {len(os.listdir(test_folder))} 张图片")
except Exception as e:
print(f"下载或设置数据集时出错: {str(e)}")
print("请手动下载数据集并解压到 data/dogs-vs-cats 目录")
print("下载地址: https://www.microsoft.com/en-us/download/details.aspx?id=54765")
# 下载并解压数据集
download_and_extract_dataset()
# 第三部分:自定义数据集
class DogsVSCats(Dataset):
def __init__(self, data_dir, transform=None):
self.image_paths = []
self.labels = []
for file in os.listdir(data_dir):
if file.lower().endswith(('.png', '.jpg', '.jpeg')):
img_path = os.path.join(data_dir, file)
try:
# 验证图片完整性
with Image.open(img_path) as img:
img.verify()
self.image_paths.append(img_path)
# 根据文件名设置标签
if file.startswith('cat'):
self.labels.append(0)
elif file.startswith('dog'):
self.labels.append(1)
else:
# 对于无法识别的文件,默认设为猫
self.labels.append(0)
except (IOError, SyntaxError) as e:
print(f"跳过损坏图片: {img_path} - {str(e)}")
if not self.image_paths:
print(f"错误: 在 {data_dir} 中没有找到有效图片!")
for i in range(10):
img_path = os.path.join(data_dir, f"example_{i}.jpg")
img = Image.new('RGB', (224, 224), color=(i * 25, i * 25, i * 25))
img.save(img_path)
self.image_paths.append(img_path)
self.labels.append(0 if i % 2 == 0 else 1)
print(f"已创建 {len(self.image_paths)} 个示例图片")
self.transform = transform or transforms.Compose([
transforms.Resize((150, 150)), # 修改为150x150以匹配CNN输入
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
try:
image = Image.open(self.image_paths[idx]).convert('RGB')
except Exception as e:
print(f"无法加载图片: {self.image_paths[idx]}, 使用占位符 - {str(e)}")
image = Image.new('RGB', (150, 150), color=(100, 100, 100))
image = self.transform(image)
label = torch.tensor(self.labels[idx], dtype=torch.long)
return image, label
# 第六部分:定义自定义CNN模型(添加额外的Dropout层)
class CatDogCNN(nn.Module):
def __init__(self):
super(CatDogCNN, self).__init__()
# 卷积层1: 输入3通道(RGB), 输出32通道, 卷积核3x3
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
# 卷积层2: 输入32通道, 输出64通道, 卷积核3x3
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
# 卷积层3: 输入64通道, 输出128通道, 卷积核3x3
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
# 卷积层4: 输入128通道, 输出256通道, 卷积核3x3
self.conv4 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
# 最大池化层
self.pool = nn.MaxPool2d(2, 2)
# 全连接层
self.fc1 = nn.Linear(256 * 9 * 9, 512) # 输入尺寸计算: 150 -> 75 -> 37 -> 18 -> 9
self.fc2 = nn.Linear(512, 2) # 输出2个类别 (猫和狗)
# Dropout防止过拟合(添加额外的Dropout层)
self.dropout1 = nn.Dropout(0.5) # 第一个Dropout层
self.dropout2 = nn.Dropout(0.5) # 新添加的第二个Dropout层
def forward(self, x):
# 卷积层1 + ReLU + 池化
x = self.pool(F.relu(self.conv1(x)))
# 卷积层2 + ReLU + 池化
x = self.pool(F.relu(self.conv2(x)))
# 卷积层3 + ReLU + 池化
x = self.pool(F.relu(self.conv3(x)))
# 卷积层4 + ReLU + 池化
x = self.pool(F.relu(self.conv4(x)))
# 展平特征图
x = x.view(-1, 256 * 9 * 9)
# 全连接层 + Dropout
x = self.dropout1(F.relu(self.fc1(x)))
# 添加第二个Dropout层
x = self.dropout2(x)
# 输出层
x = self.fc2(x)
return x
# 第七部分:模型训练和可视化
class Trainer:
def __init__(self, model, train_loader, val_loader):
self.train_loader = train_loader
self.val_loader = val_loader
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {self.device}")
self.model = model.to(self.device)
self.optimizer = optim.Adam(self.model.parameters(), lr=0.001)
self.criterion = nn.CrossEntropyLoss()
# 使用兼容性更好的调度器设置(移除了 verbose 参数)
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer, mode='max', factor=0.1, patience=2)
# 记录指标
self.train_losses = []
self.train_accuracies = []
self.val_losses = []
self.val_accuracies = []
def train(self, num_epochs):
best_accuracy = 0.0
for epoch in range(num_epochs):
# 训练阶段
self.model.train()
running_loss = 0.0
correct = 0
total = 0
train_bar = tqdm(self.train_loader, desc=f"Epoch {epoch + 1}/{num_epochs} [训练]")
for images, labels in train_bar:
images, labels = images.to(self.device), labels.to(self.device)
self.optimizer.zero_grad()
outputs = self.model(images)
loss = self.criterion(outputs, labels)
loss.backward()
self.optimizer.step()
running_loss += loss.item() * images.size(0)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
train_loss = running_loss / total
train_acc = correct / total
train_bar.set_postfix(loss=train_loss, acc=train_acc)
# 计算训练指标
epoch_train_loss = running_loss / total
epoch_train_acc = correct / total
self.train_losses.append(epoch_train_loss)
self.train_accuracies.append(epoch_train_acc)
# 验证阶段
val_loss, val_acc = self.validate()
self.val_losses.append(val_loss)
self.val_accuracies.append(val_acc)
# 更新学习率
self.scheduler.step(val_acc)
# 保存最佳模型
if val_acc > best_accuracy:
best_accuracy = val_acc
torch.save(self.model.state_dict(), 'best_cnn_model.pth')
print(f"保存最佳模型,验证准确率: {best_accuracy:.4f}")
# 打印epoch结果
print(f"Epoch {epoch + 1}/{num_epochs} | "
f"训练损失: {epoch_train_loss:.4f} | 训练准确率: {epoch_train_acc:.4f} | "
f"验证损失: {val_loss:.4f} | 验证准确率: {val_acc:.4f}")
# 训练完成后可视化结果
self.visualize_training_results()
def validate(self):
self.model.eval()
running_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
val_bar = tqdm(self.val_loader, desc="[验证]")
for images, labels in val_bar:
images, labels = images.to(self.device), labels.to(self.device)
outputs = self.model(images)
loss = self.criterion(outputs, labels)
running_loss += loss.item() * images.size(0)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
val_loss = running_loss / total
val_acc = correct / total
val_bar.set_postfix(loss=val_loss, acc=val_acc)
return running_loss / total, correct / total
def visualize_training_results(self):
"""可视化训练和验证的准确率与损失"""
epochs = range(1, len(self.train_accuracies) + 1)
# 创建准确率图表
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.plot(epochs, self.train_accuracies, 'bo-', label='训练准确率')
plt.plot(epochs, self.val_accuracies, 'ro-', label='验证准确率')
plt.title('训练和验证准确率')
plt.xlabel('Epoch')
plt.ylabel('准确率')
plt.legend()
plt.grid(True)
# 创建损失图表
plt.subplot(1, 2, 2)
plt.plot(epochs, self.train_losses, 'bo-', label='训练损失')
plt.plot(epochs, self.val_losses, 'ro-', label='验证损失')
plt.title('训练和验证损失')
plt.xlabel('Epoch')
plt.ylabel('损失')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig('training_visualization.png')
print("训练结果可视化图表已保存为 training_visualization.png")
# 单独保存准确率图表
plt.figure(figsize=(8, 6))
plt.plot(epochs, self.train_accuracies, 'bo-', label='训练准确率')
plt.plot(epochs, self.val_accuracies, 'ro-', label='验证准确率')
plt.title('训练和验证准确率')
plt.xlabel('Epoch')
plt.ylabel('准确率')
plt.legend()
plt.grid(True)
plt.savefig('accuracy_curve.png')
print("准确率曲线已保存为 accuracy_curve.png")
# 单独保存损失图表
plt.figure(figsize=(8, 6))
plt.plot(epochs, self.train_losses, 'bo-', label='训练损失')
plt.plot(epochs, self.val_losses, 'ro-', label='验证损失')
plt.title('训练和验证损失')
plt.xlabel('Epoch')
plt.ylabel('损失')
plt.legend()
plt.grid(True)
plt.savefig('loss_curve.png')
print("损失曲线已保存为 loss_curve.png")
# 保存训练结果
results = {
'epochs': list(epochs),
'train_losses': self.train_losses,
'train_accuracies': self.train_accuracies,
'val_losses': self.val_losses,
'val_accuracies': self.val_accuracies
}
with open('training_results.json', 'w') as f:
json.dump(results, f)
print("训练结果已保存为 training_results.json")
# 图像处理类
class ImageProcessor(QObject):
result_signal = pyqtSignal(str, str) # 信号:filename, result
def __init__(self, model, device, filename):
super().__init__()
self.model = model
self.device = device
self.filename = filename
self.transform = transforms.Compose([
transforms.Resize((150, 150)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def process_image(self):
try:
# 加载图像
image = Image.open(self.filename).convert('RGB')
image_tensor = self.transform(image).unsqueeze(0).to(self.device)
# 模型预测
self.model.eval()
with torch.no_grad():
output = self.model(image_tensor)
probabilities = F.softmax(output, dim=1)
_, predicted = torch.max(output, 1)
# 获取猫和狗的置信度
cat_prob = probabilities[0][0].item()
dog_prob = probabilities[0][1].item()
# 确定结果和置信度
result = "猫" if predicted.item() == 0 else "狗"
confidence = cat_prob if result == "猫" else dog_prob
# 格式化输出结果
formatted_result = f"{result} ({confidence * 100:.1f}%置信度)"
self.result_signal.emit(self.filename, formatted_result)
except Exception as e:
self.result_signal.emit(self.filename, f"处理错误: {str(e)}")
# 主应用窗口
class CatDogClassifierApp(QWidget):
def __init__(self, model, device):
super().__init__()
self.setWindowTitle("猫狗识别系统")
self.setGeometry(100, 100, 1000, 700)
self.model = model
self.device = device
self.initUI()
self.image_processors = []
def initUI(self):
# 主布局
main_layout = QVBoxLayout()
# 标题
title = QLabel("猫狗识别系统")
title.setAlignment(Qt.AlignCenter)
title.setStyleSheet("font-size: 24px; font-weight: bold; margin: 10px;")
main_layout.addWidget(title)
# 按钮区域
button_layout = QHBoxLayout()
self.upload_button = QPushButton("上传图像")
self.upload_button.setStyleSheet("font-size: 16px; padding: 10px;")
self.upload_button.clicked.connect(self.uploadImage)
button_layout.addWidget(self.upload_button)
self.batch_process_button = QPushButton("批量处理")
self.batch_process_button.setStyleSheet("font-size: 16px; padding: 10px;")
self.batch_process_button.clicked.connect(self.batchProcess)
button_layout.addWidget(self.batch_process_button)
self.clear_button = QPushButton("清除所有")
self.clear_button.setStyleSheet("font-size: 16px; padding: 10px;")
self.clear_button.clicked.connect(self.clearAll)
button_layout.addWidget(self.clear_button)
self.results_button = QPushButton("查看训练结果")
self.results_button.setStyleSheet("font-size: 16px; padding: 10px;")
self.results_button.clicked.connect(self.showTrainingResults)
button_layout.addWidget(self.results_button)
# 添加查看模型结构按钮
self.model_summary_button = QPushButton("查看模型结构")
self.model_summary_button.setStyleSheet("font-size: 16px; padding: 10px;")
self.model_summary_button.clicked.connect(self.showModelSummary)
button_layout.addWidget(self.model_summary_button)
main_layout.addLayout(button_layout)
# 状态标签
self.status_label = QLabel("就绪")
self.status_label.setStyleSheet("font-size: 14px; color: #666; margin: 5px;")
main_layout.addWidget(self.status_label)
# 图像预览区域
self.preview_area = QScrollArea()
self.preview_area.setWidgetResizable(True)
self.preview_area.setStyleSheet("background-color: #f0f0f0;")
self.preview_widget = QWidget()
self.preview_layout = QHBoxLayout()
self.preview_layout.setAlignment(Qt.AlignTop | Qt.AlignLeft)
self.preview_widget.setLayout(self.preview_layout)
self.preview_area.setWidget(self.preview_widget)
main_layout.addWidget(self.preview_area)
# 底部信息
info_label = QLabel("基于卷积神经网络(CNN)的猫狗识别系统 | 支持上传单张或多张图片")
info_label.setAlignment(Qt.AlignCenter)
info_label.setStyleSheet("font-size: 12px; color: #888; margin: 10px;")
main_layout.addWidget(info_label)
self.setLayout(main_layout)
def uploadImage(self):
self.status_label.setText("正在选择图像...")
filename, _ = QFileDialog.getOpenFileName(
self,
"选择图像",
"",
"图像文件 (*.png *.jpg *.jpeg)"
)
if filename:
self.status_label.setText(f"正在处理: {os.path.basename(filename)}")
self.displayImage(filename)
def batchProcess(self):
self.status_label.setText("正在选择多张图像...")
filenames, _ = QFileDialog.getOpenFileNames(
self,
"选择多张图像",
"",
"图像文件 (*.png *.jpg *.jpeg)"
)
if filenames:
self.status_label.setText(f"正在批量处理 {len(filenames)} 张图像...")
for filename in filenames:
self.displayImage(filename)
def displayImage(self, filename):
if not os.path.isfile(filename):
QMessageBox.warning(self, "警告", "文件路径不安全或文件不存在")
self.status_label.setText("错误: 文件不存在")
return
# 检查是否已存在相同文件
for i in reversed(range(self.preview_layout.count())):
item = self.preview_layout.itemAt(i)
if item.widget() and item.widget().objectName().startswith(f"container_{filename}"):
widget_to_remove = item.widget()
self.preview_layout.removeWidget(widget_to_remove)
widget_to_remove.deleteLater()
# 创建图像容器
container = QWidget()
container.setObjectName(f"container_{filename}")
container.setStyleSheet("""
background-color: white;
border: 1px solid #ddd;
border-radius: 5px;
padding: 10px;
margin: 5px;
""")
container.setFixedSize(300, 350)
container_layout = QVBoxLayout(container)
container_layout.setContentsMargins(5, 5, 5, 5)
container_layout.setSpacing(5)
# 显示文件名
filename_label = QLabel(os.path.basename(filename))
filename_label.setStyleSheet("font-size: 12px; color: #555;")
filename_label.setAlignment(Qt.AlignCenter)
container_layout.addWidget(filename_label)
# 图像预览
pixmap = QPixmap(filename)
if pixmap.width() > 280 or pixmap.height() > 200:
pixmap = pixmap.scaled(280, 200, Qt.KeepAspectRatio, Qt.SmoothTransformation)
preview_label = QLabel(container)
preview_label.setPixmap(pixmap)
preview_label.setAlignment(Qt.AlignCenter)
preview_label.setFixedSize(280, 200)
preview_label.setStyleSheet("border: 1px solid #eee;")
container_layout.addWidget(preview_label)
# 结果标签
result_label = QLabel("识别中...", container)
result_label.setObjectName(f"result_{filename}")
result_label.setAlignment(Qt.AlignCenter)
result_label.setStyleSheet("font-size: 16px; font-weight: bold; padding: 5px;")
container_layout.addWidget(result_label)
# 删除按钮
delete_button = QPushButton("删除", container)
delete_button.setObjectName(f"button_{filename}")
delete_button.setStyleSheet("""
QPushButton {
background-color: #ff6b6b;
color: white;
border: none;
border-radius: 3px;
padding: 5px;
}
QPushButton:hover {
background-color: #ff5252;
}
""")
delete_button.clicked.connect(lambda _, fn=filename: self.deleteImage(fn))
container_layout.addWidget(delete_button)
# 添加到预览区域
self.preview_layout.addWidget(container)
# 创建并启动图像处理线程
processor = ImageProcessor(self.model, self.device, filename)
processor.result_signal.connect(self.updateUIWithResult)
threading.Thread(target=processor.process_image).start()
self.image_processors.append(processor)
# 限制最大处理数量
if self.preview_layout.count() > 20:
QMessageBox.warning(self, "警告", "最多只能同时处理20张图像")
self.image_processors.clear()
def deleteImage(self, filename):
container_name = f"container_{filename}"
container = self.findChild(QWidget, container_name)
if container:
self.preview_layout.removeWidget(container)
container.deleteLater()
self.status_label.setText(f"已删除: {os.path.basename(filename)}")
def updateUIWithResult(self, filename, result):
container = self.findChild(QWidget, f"container_{filename}")
if container:
result_label = container.findChild(QLabel, f"result_{filename}")
if result_label:
# 根据结果设置颜色
if "猫" in result:
result_label.setStyleSheet("color: #1a73e8; font-size: 16px; font-weight: bold;")
elif "狗" in result:
result_label.setStyleSheet("color: #e91e63; font-size: 16px; font-weight: bold;")
else:
result_label.setStyleSheet("color: #f57c00; font-size: 16px; font-weight: bold;")
result_label.setText(result)
self.status_label.setText(f"完成识别: {os.path.basename(filename)} -> {result}")
def clearAll(self):
# 删除所有图像容器
while self.preview_layout.count():
item = self.preview_layout.takeAt(0)
widget = item.widget()
if widget is not None:
widget.deleteLater()
self.image_processors = []
self.status_label.setText("已清除所有图像")
def showTrainingResults(self):
"""显示训练结果可视化图表"""
if not os.path.exists('training_visualization.png'):
QMessageBox.information(self, "提示", "训练结果可视化图表尚未生成")
return
try:
# 创建结果展示窗口
results_window = QWidget()
results_window.setWindowTitle("训练结果可视化")
results_window.setGeometry(200, 200, 1200, 800)
layout = QVBoxLayout()
# 标题
title = QLabel("模型训练结果可视化")
title.setStyleSheet("font-size: 20px; font-weight: bold; margin: 10px;")
title.setAlignment(Qt.AlignCenter)
layout.addWidget(title)
# 综合图表
layout.addWidget(QLabel("训练和验证准确率/损失:"))
pixmap1 = QPixmap('training_visualization.png')
label1 = QLabel()
label1.setPixmap(pixmap1.scaled(1000, 500, Qt.KeepAspectRatio, Qt.SmoothTransformation))
layout.addWidget(label1)
# 水平布局用于两个图表
h_layout = QHBoxLayout()
# 准确率图表
vbox1 = QVBoxLayout()
vbox1.addWidget(QLabel("准确率曲线:"))
pixmap2 = QPixmap('accuracy_curve.png')
label2 = QLabel()
label2.setPixmap(pixmap2.scaled(450, 350, Qt.KeepAspectRatio, Qt.SmoothTransformation))
vbox1.addWidget(label2)
h_layout.addLayout(vbox1)
# 损失图表
vbox2 = QVBoxLayout()
vbox2.addWidget(QLabel("损失曲线:"))
pixmap3 = QPixmap('loss_curve.png')
label3 = QLabel()
label3.setPixmap(pixmap3.scaled(450, 350, Qt.KeepAspectRatio, Qt.SmoothTransformation))
vbox2.addWidget(label3)
h_layout.addLayout(vbox2)
layout.addLayout(h_layout)
# 关闭按钮
close_button = QPushButton("关闭")
close_button.setStyleSheet("font-size: 16px; padding: 8px;")
close_button.clicked.connect(results_window.close)
layout.addWidget(close_button, alignment=Qt.AlignCenter)
results_window.setLayout(layout)
results_window.show()
except Exception as e:
QMessageBox.critical(self, "错误", f"加载训练结果时出错: {str(e)}")
def showModelSummary(self):
"""显示模型结构摘要"""
# 创建摘要展示窗口
summary_window = QWidget()
summary_window.setWindowTitle("模型结构摘要")
summary_window.setGeometry(200, 200, 800, 600)
layout = QVBoxLayout()
# 标题
title = QLabel("模型各层参数状况")
title.setStyleSheet("font-size: 20px; font-weight: bold; margin: 10px;")
title.setAlignment(Qt.AlignCenter)
layout.addWidget(title)
# 创建文本编辑框显示摘要
summary_text = QTextEdit()
summary_text.setReadOnly(True)
summary_text.setStyleSheet("font-family: monospace; font-size: 12px;")
# 获取模型摘要
try:
# 使用StringIO捕获summary的输出
from io import StringIO
import sys
# 重定向标准输出
original_stdout = sys.stdout
sys.stdout = StringIO()
# 生成模型摘要
summary(self.model, input_size=(3, 150, 150), device=self.device.type)
# 获取捕获的输出
summary_output = sys.stdout.getvalue()
# 恢复标准输出
sys.stdout = original_stdout
# 显示摘要
summary_text.setPlainText(summary_output)
except Exception as e:
summary_text.setPlainText(f"生成模型摘要时出错: {str(e)}")
layout.addWidget(summary_text)
# 关闭按钮
close_button = QPushButton("关闭")
close_button.setStyleSheet("font-size: 16px; padding: 8px;")
close_button.clicked.connect(summary_window.close)
layout.addWidget(close_button, alignment=Qt.AlignCenter)
summary_window.setLayout(layout)
summary_window.show()
# 程序入口点
if __name__ == "__main__":
# 设置数据集路径
data_path = os.path.join("data", "dogs-vs-cats")
train_folder = os.path.join(data_path, 'train')
test_folder = os.path.join(data_path, 'test')
# 检查是否已有训练好的模型
model_path = "catdog_cnn_model_with_extra_dropout.pth" # 修改模型名称以反映更改
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
# 创建模型实例(使用添加了额外Dropout层的新模型)
model = CatDogCNN()
if os.path.exists(model_path):
print("加载已训练的模型...")
model.load_state_dict(torch.load(model_path, map_location=device))
model = model.to(device)
model.eval()
print("模型加载完成")
else:
print("未找到训练好的模型,开始训练新模型...")
# 创建完整训练集和测试集(使用数据增强)
# 训练集使用增强后的transform
train_transform = transforms.Compose([
transforms.RandomRotation(15), # 随机旋转15度
transforms.RandomHorizontalFlip(), # 随机水平翻转
transforms.Resize((150, 150)),
transforms.ColorJitter(brightness=0.2, contrast=0.2), # 随机调整亮度和对比度
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 验证集和测试集使用基础transform(不需要增强)
base_transform = transforms.Compose([
transforms.Resize((150, 150)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
full_train_dataset = DogsVSCats(train_folder, transform=train_transform)
test_dataset = DogsVSCats(test_folder, transform=base_transform)
# 划分训练集和验证集 (80% 训练, 20% 验证)
train_size = int(0.8 * len(full_train_dataset))
val_size = len(full_train_dataset) - train_size
gen = torch.Generator().manual_seed(42)
train_dataset, val_dataset = random_split(
full_train_dataset,
[train_size, val_size],
generator=gen
)
# 创建数据加载器
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
# 训练模型
trainer = Trainer(model, train_loader, val_loader)
num_epochs = 15
print(f"开始训练(带额外Dropout层和数据增强),共 {num_epochs} 个epoch...")
trainer.train(num_epochs)
# 保存最终模型
torch.save(model.state_dict(), model_path)
print(f"模型已保存为 {model_path}")
# 输出模型各层的参数状况
print("\n模型各层参数状况:")
summary(model, input_size=(3, 150, 150), device=device.type)
# 启动应用程序
app = QApplication(sys.argv)
window = CatDogClassifierApp(model, device)
window.show()
sys.exit(app.exec_())对此代码进行优化
最新发布