--------------------------------------------------------------------------- ModuleNotFoundError Traceback (most recent call last) Cell In[1], line 22 20 from sklearn.preprocessing import StandardScaler 21 from sklearn.feature_extraction.image import extract_patches_2d ---> 22 from skimage.feature import hog 23 from scipy.ndimage import gaussian_filter, map_coordinates 24 import seaborn as sns ModuleNotFoundError: No module named 'skimage'
时间: 2025-06-23 12:14:16 浏览: 9
<think>我们遇到了一个ModuleNotFoundError,提示没有名为'skimage'的模块。skimage是scikit-image库的常用导入名,因此我们需要安装这个库。但是,由于我们当前在代码中使用了这个库的hog特征提取功能,所以必须安装它。解决方案:1.安装scikit-image库。在命令行中运行:`pipinstallscikit-image`2.如果是在JupyterNotebook中,可以在单元格中运行:`!pipinstallscikit-image`然而,由于错误已经发生,我们需要在代码运行前确保库已安装。我们可以添加一个尝试导入并捕获错误的机制,如果导入失败,则尝试安装库。但是,在正式代码中,我们通常不会在运行时自动安装依赖库,因为这可能引起权限问题或与用户环境冲突。因此,我们更倾向于在文档或README中说明依赖。不过,为了代码的健壮性,我们可以尝试在导入失败时安装库,但要注意这通常只在特定环境(如Colab)中是可接受的。由于这是一个错误报告,我将修改代码以包含一个安装scikit-image的尝试(仅限于在JupyterNotebook等允许安装的环境)。但是,请注意:在生产环境中,我们应该在运行前确保所有依赖已安装,而不是在运行时安装。根据上下文,我们是在一个类中使用了hog特征,所以我们可以这样处理:在导入部分,我们可以尝试导入,如果失败,则提示用户安装。然而,考虑到我们可能无法控制运行环境(比如用户可能没有权限),我们最好在代码开始处检查并给出友好提示。修改方案:在导入skimage.feature时,使用try-except,并在except中打印安装指南。但是,由于我们的代码中必须使用hog,所以如果导入失败,我们可以选择:a)让用户安装,然后退出b)如果是在允许安装的环境(如Colab),则自动安装这里我选择a方案,因为自动安装可能不安全。但是,我们也可以提供一个自动安装的选项,通过一个标志来控制(例如,只在开发环境或特定环境中自动安装)。由于代码较长,我将在导入部分添加如下代码:```pythontry:fromskimage.featureimporthogexceptImportError:print("需要安装scikit-image库,请运行:pipinstallscikit-image")#可以选择退出程序,或者尝试安装#这里我们尝试自动安装,但要注意这可能不总是可行importsysimportsubprocesssubprocess.check_call([sys.executable,"-m","pip","install","scikit-image"])fromskimage.featureimporthog```但是,自动安装可能会引起问题,因此我们也可以选择退出,并提示用户手动安装。考虑到这是一个示例程序,且用户可能希望一键运行,我们采用自动安装的方式。注意:在受限环境中,自动安装可能失败,因此我们还要捕获安装失败的情况。修改后的导入代码:```python#尝试导入scikit-image的hog,如果失败则安装try:fromskimage.featureimporthogexceptImportError:importsysimportsubprocessprint("正在安装scikit-image库...")try:subprocess.check_call([sys.executable,"-m","pip","install","scikit-image"])fromskimage.featureimporthogprint("安装成功!")exceptExceptionase:print(f"安装失败:{e}")print("请手动安装scikit-image:pipinstallscikit-image")sys.exit(1)```我们将用这段代码替换原来的`fromskimage.featureimporthog`。另外,注意我们代码中还有另一个依赖:seaborn。如果环境中没有安装,同样会报错。我们可以用同样的方法处理。但是,为了保持代码简洁,我们只处理hog的导入错误,因为这是报错的地方。其他库如果未安装,同样会报错,但不在本次错误范围内。因此,我们只修改hog的导入部分。由于代码较长,我不再重复整个代码,只提供修改后的导入部分。但是,请注意:在正式分发软件时,我们应该在文档中明确依赖,并提供requirements.txt。现在,我们修改代码中的导入部分,用上述代码替换原来的hog导入。另外,由于我们添加了sys.exit(1),如果安装失败,程序会退出。请确保在需要的地方捕获这个退出。由于我们的代码是在一个类中,而且导入发生在模块层面,所以如果导入失败,程序会退出,不会执行后续代码。修改后,重新运行程序。如果用户没有安装scikit-image,代码会尝试安装,如果安装失败,则退出。这样修改后,应该可以解决ModuleNotFoundError。但是,请注意,自动安装库可能需要网络连接,且依赖于pip的可用性。现在,我们提供完整的代码,但只显示修改后的导入部分,其余代码不变。由于整个代码很长,我仅在此处展示修改后的导入部分,其余代码与之前相同。完整代码请参考之前的优化版代码,仅将原来的`fromskimage.featureimporthog`替换为以下代码:</think>这个错误表明你的环境中缺少 `scikit-image` 库。让我们解决这个问题并优化代码以更好地处理依赖关系:
### 解决方案
1. **安装缺少的依赖库**:
```bash
pip install scikit-image seaborn
```
2. **优化代码以更好地处理依赖关系**:
```python
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import tkinter as tk
from tkinter import ttk, filedialog, messagebox, simpledialog
from PIL import Image, ImageDraw, ImageOps
import cv2
import os
import csv
import joblib
import time
import sys
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier, VotingClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.preprocessing import StandardScaler
from scipy.ndimage import gaussian_filter, map_coordinates
# 设置中文字体和负号显示
plt.rcParams["font.family"] = ["SimHei", "Microsoft YaHei"]
plt.rcParams["axes.unicode_minus"] = False
# 尝试导入可选库
HOG_INSTALLED = False
try:
from skimage.feature import hog
HOG_INSTALLED = True
except ImportError:
print("警告: 未安装scikit-image库,无法使用HOG特征")
SEABORN_INSTALLED = False
try:
import seaborn as sns
SEABORN_INSTALLED = True
except ImportError:
print("警告: 未安装seaborn库,无法使用高级可视化")
# 尝试导入XGBoost和LightGBM
XGB_INSTALLED = False
LGB_INSTALLED = False
try:
import xgboost as xgb
XGB_INSTALLED = True
except ImportError:
print("警告: 未安装XGBoost库,无法使用XGBoost模型")
try:
import lightgbm as lgb
LGB_INSTALLED = True
except ImportError:
print("警告: 未安装LightGBM库,无法使用LightGBM模型")
# 定义模型元数据常量(优化参数)
MODEL_METADATA = {
'svm': ('支持向量机(SVM)', SVC, StandardScaler, {'probability': True, 'random_state': 42}),
'dt': ('决策树(DT)', DecisionTreeClassifier, None, {'random_state': 42}),
'rf': ('随机森林(RF)', RandomForestClassifier, None, {'n_estimators': 100, 'random_state': 42}),
'mlp': ('多层感知机(MLP)', MLPClassifier, StandardScaler, {'hidden_layer_sizes': (100, 50), 'max_iter': 500, 'random_state': 42}),
'knn': ('K最近邻(KNN)', KNeighborsClassifier, StandardScaler, {'n_neighbors': 5, 'weights': 'distance'}),
'nb': ('高斯朴素贝叶斯(NB)', GaussianNB, None, {}),
'ensemble': ('集成模型(Ensemble)', None, StandardScaler, {}),
}
# 添加可选模型
if XGB_INSTALLED:
MODEL_METADATA['xgb'] = ('XGBoost(XGB)', xgb.XGBClassifier, None, {
'objective': 'multi:softmax',
'random_state': 42,
'n_estimators': 150,
'max_depth': 7,
'learning_rate': 0.1
})
if LGB_INSTALLED:
MODEL_METADATA['lgb'] = ('LightGBM(LGB)', lgb.LGBMClassifier, None, {
'objective': 'multiclass',
'random_state': 42,
'num_class': 10,
'max_depth': 5,
'min_child_samples': 10,
'learning_rate': 0.1,
'n_estimators': 150,
'force_col_wise': True
})
class ModelFactory:
@staticmethod
def get_split_data(digits_dataset, test_size=0.2):
"""数据集划分"""
X, y = digits_dataset.data, digits_dataset.target
return train_test_split(X, y, test_size=test_size, random_state=42, stratify=y)
@staticmethod
def extract_features(images):
"""提取高级特征(如果HOG可用则使用HOG + 原始像素)"""
features = []
for img in images:
img_2d = img.reshape(8, 8)
# 原始像素特征
pixel_features = img
# 如果HOG可用,添加HOG特征
if HOG_INSTALLED:
fd, _ = hog(
img_2d,
orientations=8,
pixels_per_cell=(4, 4),
cells_per_block=(1, 1),
visualize=True
)
combined = np.concatenate([pixel_features, fd])
else:
combined = pixel_features
features.append(combined)
return np.array(features)
@staticmethod
def augment_training_data(X_train, y_train):
"""数据增强方法"""
augmented_X = []
augmented_y = []
for img, label in zip(X_train, y_train):
img_2d = img.reshape(8, 8)
# 原始图像
augmented_X.append(img)
augmented_y.append(label)
# 1. 轻微旋转 (±5度)
for angle in [-3, 3]:
M = cv2.getRotationMatrix2D((4, 4), angle, 1)
rotated = cv2.warpAffine(img_2d, M, (8, 8), borderValue=255)
augmented_X.append(rotated.flatten())
augmented_y.append(label)
# 2. 轻微平移 (±1像素)
for dx, dy in [(1, 0), (-1, 0), (0, 1), (0, -1)]:
M = np.float32([[1, 0, dx], [0, 1, dy]])
translated = cv2.warpAffine(img_2d, M, (8, 8), borderValue=255)
augmented_X.append(translated.flatten())
augmented_y.append(label)
return np.array(augmented_X), np.array(augmented_y)
@classmethod
def create_model(cls, model_type):
"""创建模型和数据标准化器"""
if model_type not in MODEL_METADATA:
raise ValueError(f"未知模型类型: {model_type}")
name, model_cls, scaler_cls, params = MODEL_METADATA[model_type]
# 处理集成模型
if model_type == 'ensemble':
# 创建基模型
svm = SVC(probability=True, random_state=42, C=10, gamma=0.01)
rf = RandomForestClassifier(n_estimators=150, random_state=42, max_depth=15)
mlp = MLPClassifier(hidden_layer_sizes=(128, 64), max_iter=500, random_state=42)
# 创建集成模型
ensemble = VotingClassifier(
estimators=[
('svm', svm),
('rf', rf),
('mlp', mlp)
],
voting='soft',
weights=[1, 1.5, 1] # 给随机森林更高权重
)
return ensemble, StandardScaler()
if not model_cls:
raise ImportError(f"{name}模型依赖库未安装")
model = model_cls(**params)
scaler = scaler_cls() if scaler_cls else None
return model, scaler
@classmethod
def tune_model(cls, model_type, X_train, y_train):
"""使用网格搜索调优模型参数"""
model, scaler = cls.create_model(model_type)
if model_type == 'svm':
param_grid = {
'C': [0.1, 1, 10, 100],
'gamma': [0.001, 0.01, 0.1, 1],
'kernel': ['rbf', 'poly']
}
grid = GridSearchCV(SVC(probability=True, random_state=42), param_grid, refit=True, verbose=0, cv=3, n_jobs=-1)
grid.fit(X_train, y_train)
return grid.best_estimator_, scaler
elif model_type == 'rf':
param_grid = {
'n_estimators': [50, 100, 150, 200],
'max_depth': [5, 10, 15, None],
'min_samples_split': [2, 5, 10]
}
grid = GridSearchCV(RandomForestClassifier(random_state=42), param_grid, refit=True, verbose=0, cv=3, n_jobs=-1)
grid.fit(X_train, y_train)
return grid.best_estimator_, scaler
elif model_type == 'mlp':
param_grid = {
'hidden_layer_sizes': [(50,), (100,), (128, 64), (100, 50)],
'alpha': [0.0001, 0.001, 0.01],
'learning_rate_init': [0.001, 0.01, 0.1]
}
grid = GridSearchCV(MLPClassifier(max_iter=500, random_state=42), param_grid, refit=True, verbose=0, cv=3, n_jobs=-1)
grid.fit(X_train, y_train)
return grid.best_estimator_, scaler
# 其他模型的调优参数...
return model, scaler
@staticmethod
def train_model(model, X_train, y_train, scaler=None, model_type=None):
"""训练模型"""
if scaler:
X_train = scaler.fit_transform(X_train)
if model_type == 'lgb' and isinstance(X_train, np.ndarray):
X_train = pd.DataFrame(X_train)
model.fit(X_train, y_train)
return model
@staticmethod
def evaluate_model(model, X_test, y_test, scaler=None, model_type=None):
"""评估模型"""
if scaler:
X_test = scaler.transform(X_test)
if model_type == 'lgb' and isinstance(X_test, np.ndarray) and hasattr(model, 'feature_name_'):
X_test = pd.DataFrame(X_test, columns=model.feature_name_)
y_pred = model.predict(X_test)
return accuracy_score(y_test, y_pred)
@classmethod
def train_and_evaluate(cls, model_type, X_train, y_train, X_test, y_test, tune=False):
"""训练并评估模型"""
try:
# 数据增强
if model_type != 'ensemble': # 集成模型使用基础模型,不需要单独增强
X_train, y_train = cls.augment_training_data(X_train, y_train)
# 特征工程
if model_type in ['svm', 'rf', 'mlp', 'ensemble']:
X_train = cls.extract_features(X_train)
X_test = cls.extract_features(X_test)
# 模型创建
if tune:
model, scaler = cls.tune_model(model_type, X_train, y_train)
else:
model, scaler = cls.create_model(model_type)
# 训练模型
model = cls.train_model(model, X_train, y_train, scaler, model_type)
# 评估模型
accuracy = cls.evaluate_model(model, X_test, y_test, scaler, model_type)
return model, scaler, accuracy
except Exception as e:
print(f"模型 {model_type} 训练/评估错误: {str(e)}")
raise
@classmethod
def evaluate_all_models(cls, digits_dataset, tune=False):
"""评估所有可用模型"""
print("\n=== 模型评估 ===")
X_train, X_test, y_train, y_test = cls.get_split_data(digits_dataset)
results = []
for model_type in MODEL_METADATA:
name = MODEL_METADATA[model_type][0]
print(f"评估模型: {name} ({model_type})")
if not MODEL_METADATA[model_type][1] and model_type != 'ensemble':
results.append({"模型名称": name, "准确率": "N/A"})
continue
try:
start_time = time.time()
_, _, accuracy = cls.train_and_evaluate(
model_type, X_train, y_train, X_test, y_test, tune=tune
)
elapsed = time.time() - start_time
results.append({"模型名称": name, "准确率": f"{accuracy:.4f}", "训练时间": f"{elapsed:.2f}s"})
except Exception as e:
results.append({"模型名称": name, "准确率": f"错误: {str(e)}", "训练时间": "N/A"})
# 按准确率排序
results.sort(
key=lambda x: float(x["准确率"])
if isinstance(x["准确率"], str) and x["准确率"].replace('.', '', 1).isdigit()
else -1,
reverse=True
)
print(pd.DataFrame(results))
return results
@staticmethod
def analyze_errors(model, X_test, y_test, scaler=None, model_type=None):
"""分析模型错误预测样本"""
if scaler:
X_test = scaler.transform(X_test)
if model_type == 'lgb' and isinstance(X_test, np.ndarray) and hasattr(model, 'feature_name_'):
X_test = pd.DataFrame(X_test, columns=model.feature_name_)
y_pred = model.predict(X_test)
# 找出错误预测的样本
errors = np.where(y_pred != y_test)[0]
# 分析错误类型
confusion = confusion_matrix(y_test, y_pred)
class_errors = {}
for i in range(10):
# 计算每个类别的错误率
total = np.sum(confusion[i])
correct = confusion[i, i]
error_rate = (total - correct) / total
class_errors[i] = error_rate
return confusion, class_errors, errors
class HandwritingBoard:
CANVAS_SIZE = 300 # 固定画布尺寸
BRUSH_SIZE = 12 # 画笔大小
def __init__(self, root, model_factory, digits):
self.root = root
self.root.title("手写数字识别系统 (优化版)")
self.root.geometry("1000x700")
self.model_factory = model_factory
self.digits = digits
self.model_cache = {}
self.current_model = None
self.scaler = None
self.current_model_type = None
self.has_drawn = False
self.custom_data = []
self.drawing = False
self.last_x = self.last_y = 0
self.performance_results = []
# 自定义数据目录
self.data_dir = "custom_digits_data"
os.makedirs(self.data_dir, exist_ok=True)
# 初始化画布
self.image = Image.new("L", (self.CANVAS_SIZE, self.CANVAS_SIZE), 255)
self.draw_obj = ImageDraw.Draw(self.image)
self.create_widgets()
self.init_default_model()
def create_widgets(self):
"""使用grid布局管理器创建界面组件"""
# 创建主框架
main_frame = tk.Frame(self.root)
main_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
# 使用grid布局管理器
# 第一行:模型选择区域
model_frame = tk.LabelFrame(main_frame, text="模型选择", font=("Arial", 10, "bold"))
model_frame.grid(row=0, column=0, columnspan=2, sticky="ew", padx=5, pady=5)
model_frame.grid_columnconfigure(1, weight=1) # 让模型标签可以扩展
tk.Label(model_frame, text="选择模型:", font=("Arial", 10)).grid(row=0, column=0, padx=5, pady=5, sticky="w")
self.available_models = []
for model_type, (name, _, _, _) in MODEL_METADATA.items():
if MODEL_METADATA[model_type][1] or model_type == 'ensemble':
self.available_models.append((model_type, name))
self.model_var = tk.StringVar()
self.model_combobox = ttk.Combobox(
model_frame,
textvariable=self.model_var,
values=[name for _, name in self.available_models],
state="readonly",
width=25,
font=("Arial", 10)
)
self.model_combobox.current(0)
self.model_combobox.bind("<<ComboboxSelected>>", self.on_model_select)
self.model_combobox.grid(row=0, column=1, padx=5, pady=5, sticky="ew")
self.model_label = tk.Label(
model_frame,
text="",
font=("Arial", 10),
relief=tk.SUNKEN,
padx=5,
pady=2
)
self.model_label.grid(row=0, column=2, padx=5, pady=5, sticky="ew")
# 调优选项
self.tune_var = tk.BooleanVar(value=False)
tune_check = tk.Checkbutton(
model_frame,
text="调优模型",
variable=self.tune_var,
font=("Arial", 10),
command=self.on_tune_change
)
tune_check.grid(row=0, column=3, padx=5, pady=5, sticky="e")
# 第二行:左侧绘图区域和右侧结果区域
# 左侧绘图区域
left_frame = tk.LabelFrame(main_frame, text="绘制区域", font=("Arial", 10, "bold"))
left_frame.grid(row=1, column=0, padx=5, pady=5, sticky="nsew")
self.canvas = tk.Canvas(left_frame, bg="white", width=self.CANVAS_SIZE, height=self.CANVAS_SIZE)
self.canvas.pack(padx=10, pady=10)
self.canvas.bind("<Button-1>", self.start_draw)
self.canvas.bind("<B1-Motion>", self.draw)
self.canvas.bind("<ButtonRelease-1>", self.stop_draw)
# 添加绘制提示
self.canvas.create_text(
self.CANVAS_SIZE / 2, self.CANVAS_SIZE / 2,
text="绘制数字", fill="gray", font=("Arial", 16)
)
# 绘图控制按钮
btn_frame = tk.Frame(left_frame)
btn_frame.pack(fill=tk.X, pady=(0, 10))
tk.Button(btn_frame, text="识别", command=self.recognize, width=8).pack(side=tk.LEFT, padx=5)
tk.Button(btn_frame, text="清除", command=self.clear_canvas, width=8).pack(side=tk.LEFT, padx=5)
tk.Button(btn_frame, text="样本", command=self.show_samples, width=8).pack(side=tk.LEFT, padx=5)
tk.Button(btn_frame, text="错误分析", command=self.show_error_analysis, width=10).pack(side=tk.LEFT, padx=5)
# 右侧结果区域
right_frame = tk.Frame(main_frame)
right_frame.grid(row=1, column=1, padx=5, pady=5, sticky="nsew")
# 识别结果
result_frame = tk.LabelFrame(right_frame, text="识别结果", font=("Arial", 10, "bold"))
result_frame.pack(fill=tk.X, padx=5, pady=5)
self.result_label = tk.Label(
result_frame,
text="请绘制数字",
font=("Arial", 24),
pady=10
)
self.result_label.pack()
self.prob_label = tk.Label(
result_frame,
text="",
font=("Arial", 12)
)
self.prob_label.pack()
# 置信度可视化
confidence_frame = tk.LabelFrame(right_frame, text="识别置信度", font=("Arial", 10, "bold"))
confidence_frame.pack(fill=tk.X, padx=5, pady=5)
self.confidence_canvas = tk.Canvas(
confidence_frame,
bg="white",
height=50
)
self.confidence_canvas.pack(fill=tk.X, padx=10, pady=10)
self.confidence_canvas.create_text(
150, 25,
text="识别后显示置信度",
fill="gray",
font=("Arial", 10)
)
# 候选数字
candidates_frame = tk.LabelFrame(right_frame, text="可能的数字", font=("Arial", 10, "bold"))
candidates_frame.pack(fill=tk.X, padx=5, pady=5)
columns = ("数字", "概率")
self.candidates_tree = ttk.Treeview(
candidates_frame,
columns=columns,
show="headings",
height=4
)
for col in columns:
self.candidates_tree.heading(col, text=col)
self.candidates_tree.column(col, width=80, anchor=tk.CENTER)
scrollbar = ttk.Scrollbar(
candidates_frame,
orient=tk.VERTICAL,
command=self.candidates_tree.yview
)
self.candidates_tree.configure(yscroll=scrollbar.set)
self.candidates_tree.pack(side=tk.LEFT, fill=tk.BOTH, expand=True, padx=5, pady=5)
scrollbar.pack(side=tk.RIGHT, fill=tk.Y, padx=5, pady=5)
# 第三行:模型性能对比和训练集管理
# 模型性能对比
performance_frame = tk.LabelFrame(main_frame, text="模型性能对比", font=("Arial", 10, "bold"))
performance_frame.grid(row=2, column=0, padx=5, pady=5, sticky="nsew")
columns = ("模型名称", "准确率", "训练时间")
self.performance_tree = ttk.Treeview(
performance_frame,
columns=columns,
show="headings",
height=8
)
for col in columns:
self.performance_tree.heading(col, text=col)
self.performance_tree.column(col, width=120, anchor=tk.CENTER)
scrollbar = ttk.Scrollbar(
performance_frame,
orient=tk.VERTICAL,
command=self.performance_tree.yview
)
self.performance_tree.configure(yscroll=scrollbar.set)
self.performance_tree.pack(side=tk.LEFT, fill=tk.BOTH, expand=True, padx=5, pady=5)
scrollbar.pack(side=tk.RRIGHT, fill=tk.Y, padx=5, pady=5)
# 训练集管理
train_frame = tk.LabelFrame(main_frame, text="训练集管理", font=("Arial", 10, "bold"))
train_frame.grid(row=2, column=1, padx=5, pady=5, sticky="nsew")
# 使用grid布局训练集管理按钮
tk.Button(
train_frame,
text="保存为训练样本",
command=self.save_as_training_sample,
width=18,
height=2
).grid(row=0, column=0, padx=5, pady=5, sticky="ew")
tk.Button(
train_frame,
text="保存全部训练集",
command=self.save_all_training_data,
width=18,
height=2
).grid(row=0, column=1, padx=5, pady=5, sticky="ew")
tk.Button(
train_frame,
text="加载训练集",
command=self.load_training_data,
width=18,
height=2
).grid(row=1, column=0, padx=5, pady=5, sticky="ew")
tk.Button(
train_frame,
text="性能图表",
command=self.show_performance_chart,
width=18,
height=2
).grid(row=1, column=1, padx=5, pady=5, sticky="ew")
tk.Button(
train_frame,
text="保存当前模型",
command=self.save_current_model,
width=18,
height=2
).grid(row=2, column=0, padx=5, pady=5, sticky="ew")
tk.Button(
train_frame,
text="加载模型",
command=self.load_model,
width=18,
height=2
).grid(row=2, column=1, padx=5, pady=5, sticky="ew")
# 状态信息
self.status_var = tk.StringVar(value="就绪")
status_bar = tk.Label(
self.root,
textvariable=self.status_var,
bd=1,
relief=tk.SUNKEN,
anchor=tk.W,
font=("Arial", 10)
)
status_bar.pack(side=tk.BOTTOM, fill=tk.X)
# 配置权重
main_frame.grid_columnconfigure(0, weight=1)
main_frame.grid_columnconfigure(1, weight=1)
main_frame.grid_rowconfigure(1, weight=1)
main_frame.grid_rowconfigure(2, weight=1)
def start_draw(self, event):
"""开始绘制"""
self.drawing = True
self.last_x, self.last_y = event.x, event.y
def draw(self, event):
"""绘制"""
if not self.drawing:
return
x, y = event.x, event.y
# 在画布上绘制
self.canvas.create_line(
self.last_x, self.last_y, x, y,
fill="black",
width=self.BRUSH_SIZE,
capstyle=tk.ROUND,
smooth=True
)
# 在图像上绘制
self.draw_obj.line(
[self.last_x, self.last_y, x, y],
fill=0,
width=self.BRUSH_SIZE
)
self.last_x, self.last_y = x, y
def stop_draw(self, event):
"""停止绘制"""
self.drawing = False
self.has_drawn = True
self.status_var.set("已绘制数字,点击'识别'进行识别")
def clear_canvas(self):
"""清除画布"""
self.canvas.delete("all")
self.image = Image.new("L", (self.CANVAS_SIZE, self.CANVAS_SIZE), 255)
self.draw_obj = ImageDraw.Draw(self.image)
# 添加绘制提示
self.canvas.create_text(
self.CANVAS_SIZE / 2, self.CANVAS_SIZE / 2,
text="绘制数字", fill="gray", font=("Arial", 16)
)
self.result_label.config(text="请绘制数字")
self.prob_label.config(text="")
self.clear_confidence_display()
self.has_drawn = False
self.status_var.set("画布已清除")
def clear_confidence_display(self):
"""清除置信度显示"""
self.confidence_canvas.delete("all")
self.confidence_canvas.create_text(
150, 25,
text="识别后显示置信度",
fill="gray",
font=("Arial", 10)
)
for item in self.candidates_tree.get_children():
self.candidates_tree.delete(item)
def preprocess_image(self):
"""改进的图像预处理方法"""
img_array = np.array(self.image)
# 1. 自适应阈值二值化
img_array = cv2.adaptiveThreshold(
img_array, 255,
cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
cv2.THRESH_BINARY_INV, 11, 2
)
# 2. 形态学操作 - 去除小噪点
kernel = np.ones((3, 3), np.uint8)
img_array = cv2.morphologyEx(img_array, cv2.MORPH_OPEN, kernel)
# 3. 轮廓检测
contours, _ = cv2.findContours(img_array, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if not contours:
self.status_var.set("未检测到有效数字,请重新绘制")
return None
# 4. 找到最大轮廓
c = max(contours, key=cv2.contourArea)
x, y, w, h = cv2.boundingRect(c)
# 5. 提取数字区域
digit = img_array[y:y+h, x:x+w]
# 6. 填充为正方形并保留宽高比
size = max(w, h)
padded = np.ones((size, size), dtype=np.uint8) * 255
offset_x = (size - w) // 2
offset_y = (size - h) // 2
padded[offset_y:offset_y+h, offset_x:offset_x+w] = digit
# 7. 双线性插值缩放
resized = cv2.resize(padded, (8, 8), interpolation=cv2.INTER_LINEAR)
# 8. 归一化并增加对比度
normalized = cv2.normalize(resized, None, 0, 16, cv2.NORM_MINMAX)
return normalized.flatten()
def recognize(self):
"""识别手写数字"""
if not self.has_drawn:
self.status_var.set("请先绘制数字再识别")
return
if self.current_model is None:
self.status_var.set("模型未加载,请选择模型")
return
# 预处理图像
img_array = self.preprocess_image()
if img_array is None:
return
img_input = img_array.reshape(1, -1)
try:
# 特征工程(与训练时一致)
if self.current_model_type in ['svm', 'rf', 'mlp', 'ensemble']:
img_input = self.model_factory.extract_features(img_input)
# 标准化
if self.scaler:
img_input = self.scaler.transform(img_input)
# LightGBM特殊处理
if self.current_model_type == 'lgb' and hasattr(self.current_model, 'feature_name_'):
img_input = pd.DataFrame(img_input, columns=self.current_model.feature_name_)
# 预测
pred = self.current_model.predict(img_input)[0]
self.result_label.config(text=f"识别结果: {pred}")
# 概率预测
if hasattr(self.current_model, 'predict_proba'):
probs = self.current_model.predict_proba(img_input)[0]
confidence = probs[pred]
# 更新UI
self.prob_label.config(text=f"置信度: {confidence:.2%}")
self.update_confidence_display(confidence)
# 显示候选数字
top3 = sorted(enumerate(probs), key=lambda x: -x[1])[:3]
self.update_candidates_display(top3)
else:
self.prob_label.config(text="该模型不支持概率输出")
self.clear_confidence_display()
self.status_var.set(f"识别完成: 数字 {pred}")
except Exception as e:
self.status_var.set(f"识别错误: {str(e)}")
self.clear_confidence_display()
def update_confidence_display(self, confidence):
"""更新置信度可视化"""
self.confidence_canvas.delete("all")
# 画布尺寸
canvas_width = self.confidence_canvas.winfo_width() or 300
# 绘制背景
self.confidence_canvas.create_rectangle(
10, 10, canvas_width - 10, 40,
fill="#f0f0f0",
outline="#cccccc"
)
# 绘制置信度条
bar_width = int((canvas_width - 20) * confidence)
color = self.get_confidence_color(confidence)
self.confidence_canvas.create_rectangle(
10, 10, 10 + bar_width, 40,
fill=color,
outline=""
)
# 绘制文本
self.confidence_canvas.create_text(
canvas_width / 2, 25,
text=f"{confidence:.1%}",
font=("Arial", 10, "bold")
)
# 绘制刻度
for i in range(0, 11):
x_pos = 10 + i * (canvas_width - 20) / 10
self.confidence_canvas.create_line(x_pos, 40, x_pos, 45, width=1)
if i % 2 == 0:
self.confidence_canvas.create_text(x_pos, 55, text=f"{i*10}%", font=("Arial", 8))
def get_confidence_color(self, confidence):
"""根据置信度获取颜色"""
if confidence >= 0.9:
return "#4CAF50" # 绿色
elif confidence >= 0.7:
return "#FFC107" # 黄色
else:
return "#F44336" # 红色
def update_candidates_display(self, candidates):
"""更新候选数字显示"""
# 清空现有项
for item in self.candidates_tree.get_children():
self.candidates_tree.delete(item)
# 添加新项
for digit, prob in candidates:
self.candidates_tree.insert(
"", tk.END,
values=(digit, f"{prob:.2%}")
)
def show_samples(self):
"""显示样本图像"""
plt.figure(figsize=(10, 4))
for i in range(10):
plt.subplot(2, 5, i+1)
sample_idx = np.where(self.digits.target == i)[0][0]
plt.imshow(self.digits.images[sample_idx], cmap="gray")
plt.title(f"数字 {i}", fontsize=9)
plt.axis("off")
plt.tight_layout()
plt.show()
def show_error_analysis(self):
"""显示错误分析"""
if self.current_model is None:
self.status_var.set("请先加载模型")
return
self.status_var.set("正在执行错误分析...")
self.root.update()
try:
X_train, X_test, y_train, y_test = self.model_factory.get_split_data(self.digits)
# 特征工程
if self.current_model_type in ['svm', 'rf', 'mlp', 'ensemble']:
X_test = self.model_factory.extract_features(X_test)
confusion, class_errors, errors = self.model_factory.analyze_errors(
self.current_model, X_test, y_test, self.scaler, self.current_model_type
)
# 可视化混淆矩阵
plt.figure(figsize=(10, 8))
if SEABORN_INSTALLED:
sns.heatmap(confusion, annot=True, fmt='d', cmap='Blues')
else:
plt.imshow(confusion, cmap='Blues')
for i in range(confusion.shape[0]):
for j in range(confusion.shape[1]):
plt.text(j, i, confusion[i, j],
ha="center", va="center", color="black")
plt.xlabel('预测值')
plt.ylabel('真实值')
plt.title(f'{MODEL_METADATA[self.current_model_type][0]} 混淆矩阵')
plt.tight_layout()
plt.show()
# 显示典型错误样本
plt.figure(figsize=(12, 6))
plt.suptitle(f"{MODEL_METADATA[self.current_model_type][0]} 错误样本分析", fontsize=14)
for i, idx in enumerate(errors[:min(10, len(errors))]):
plt.subplot(2, 5, i+1)
plt.imshow(X_test[idx].reshape(8, 8), cmap='gray')
pred = self.current_model.predict(X_test[idx].reshape(1, -1))[0]
plt.title(f'真实: {y_test[idx]}\n预测: {pred}')
plt.axis('off')
plt.tight_layout()
plt.show()
self.status_var.set("错误分析完成")
except Exception as e:
self.status_var.set(f"错误分析失败: {str(e)}")
def on_model_select(self, event):
"""模型选择事件处理"""
selected_name = self.model_var.get()
model_type = next(
(k for k, v in self.available_models if v == selected_name),
None
)
if model_type:
self.change_model(model_type)
def on_tune_change(self):
"""调优选项改变事件"""
if self.model_var.get():
selected_name = self.model_var.get()
model_type = next(
(k for k, v in self.available_models if v == selected_name),
None
)
if model_type:
self.change_model(model_type)
def change_model(self, model_type):
"""切换模型"""
model_name = MODEL_METADATA[model_type][0]
tune = self.tune_var.get()
# 从缓存加载
cache_key = f"{model_type}_{tune}"
if cache_key in self.model_cache:
self.current_model, self.scaler, accuracy, self.current_model_type = self.model_cache[cache_key]
self.model_label.config(text=f"{model_name} (准确率:{accuracy:.4f})")
self.status_var.set(f"已加载模型: {model_name}")
return
self.status_var.set(f"正在加载模型: {model_name}{' (调优中...)' if tune else ''}")
self.root.update() # 更新UI显示状态
try:
X_train, X_test, y_train, y_test = self.model_factory.get_split_data(self.digits)
self.current_model, self.scaler, accuracy = self.model_factory.train_and_evaluate(
model_type, X_train, y_train, X_test, y_test, tune=tune
)
self.current_model_type = model_type
self.model_cache[cache_key] = (self.current_model, self.scaler, accuracy, self.current_model_type)
self.model_label.config(text=f"{model_name} (准确率:{accuracy:.4f})")
self.status_var.set(f"模型加载完成: {model_name}, 准确率: {accuracy:.4f}")
self.clear_canvas()
# 更新性能表格
self.load_performance_data(tune=tune)
except Exception as e:
self.status_var.set(f"模型加载失败: {str(e)}")
self.model_label.config(text="模型加载失败")
def init_default_model(self):
"""初始化默认模型"""
self.model_var.set(self.available_models[0][1])
self.change_model(self.available_models[0][0])
def load_performance_data(self, tune=False):
"""加载性能数据"""
self.performance_results = self.model_factory.evaluate_all_models(self.digits, tune=tune)
# 清空表格
for item in self.performance_tree.get_children():
self.performance_tree.delete(item)
# 添加数据
for i, result in enumerate(self.performance_results):
tag = "highlight" if i == 0 else ""
self.performance_tree.insert(
"", tk.END,
values=(result["模型名称"], result["准确率"], result["训练时间"]),
tags=(tag,)
)
self.performance_tree.tag_configure("highlight", background="#e6f7ff")
def show_performance_chart(self):
"""显示性能图表"""
if not self.performance_results:
self.load_performance_data()
# 提取有效结果
valid_results = []
for result in self.performance_results:
try:
accuracy = float(result["准确率"])
valid_results.append((result["模型名称"], accuracy))
except ValueError:
continue
if not valid_results:
messagebox.showinfo("提示", "没有可用的性能数据")
return
# 排序
valid_results.sort(key=lambda x: x[1], reverse=True)
models, accuracies = zip(*valid_results)
# 创建图表
plt.figure(figsize=(10, 5))
bars = plt.barh(models, accuracies, color='#2196F3')
plt.xlabel('准确率', fontsize=10)
plt.ylabel('模型', fontsize=10)
plt.title('模型性能对比', fontsize=12)
plt.xlim(0, 1.05)
# 添加数值标签
for bar in bars:
width = bar.get_width()
plt.text(
width + 0.01,
bar.get_y() + bar.get_height()/2,
f'{width:.4f}',
ha='left',
va='center',
fontsize=8
)
plt.tight_layout()
plt.show()
def save_as_training_sample(self):
"""保存为训练样本"""
if not self.has_drawn:
self.status_var.set("请先绘制数字再保存")
return
img_array = self.preprocess_image()
if img_array is None:
return
# 弹出标签输入窗口
label_window = tk.Toplevel(self.root)
label_window.title("输入标签")
label_window.geometry("300x150")
label_window.transient(self.root)
label_window.grab_set()
tk.Label(
label_window,
text="请输入数字标签 (0-9):",
font=("Arial", 10)
).pack(pady=10)
entry = tk.Entry(label_window, font=("Arial", 12), width=5)
entry.pack(pady=5)
entry.focus_set()
def save_with_label():
try:
label = int(entry.get())
if label < 0 or label > 9:
raise ValueError("标签必须是0-9的数字")
self.custom_data.append((img_array.tolist(), label))
self.status_var.set(f"已保存数字 {label} (共 {len(self.custom_data)} 个样本)")
label_window.destroy()
except ValueError as e:
self.status_var.set(f"保存错误: {str(e)}")
tk.Button(
label_window,
text="保存",
command=save_with_label,
width=10
).pack(pady=5)
def save_all_training_data(self):
"""保存全部训练数据"""
if not self.custom_data:
self.status_var.set("没有训练数据可保存")
return
file_path = filedialog.asksaveasfilename(
defaultextension=".csv",
filetypes=[("CSV文件", "*.csv")],
initialfile="custom_digits.csv",
title="保存训练集"
)
if not file_path:
return
try:
with open(file_path, 'w', newline='', encoding='utf-8') as f:
writer = csv.writer(f)
writer.writerow([f'pixel{i}' for i in range(64)] + ['label'])
for img_data, label in self.custom_data:
writer.writerow(img_data + [label])
self.status_var.set(f"已保存 {len(self.custom_data)} 个样本到 {os.path.basename(file_path)}")
except Exception as e:
self.status_var.set(f"保存失败: {str(e)}")
def load_training_data(self):
"""加载训练数据"""
file_path = filedialog.askopenfilename(
filetypes=[("CSV文件", "*.csv")],
title="加载训练集"
)
if not file_path:
return
try:
self.custom_data = []
with open(file_path, 'r', newline='', encoding='utf-8') as f:
reader = csv.reader(f)
next(reader) # 跳过标题
for row in reader:
if len(row) != 65:
continue
img_data = [float(pixel) for pixel in row[:64]]
label = int(row[64])
self.custom_data.append((img_data, label))
self.status_var.set(f"已加载 {len(self.custom_data)} 个样本")
except Exception as e:
self.status_var.set(f"加载失败: {str(e)}")
def save_current_model(self):
"""保存当前模型"""
if self.current_model is None:
self.status_var.set("没有模型可保存")
return
file_path = filedialog.askssaveasfilename(
defaultextension=".joblib",
filetypes=[("Joblib文件", "*.joblib")],
initialfile=f"{self.current_model_type}_model.joblib",
title="保存模型"
)
if not file_path:
return
try:
# 保存模型和缩放器
joblib.dump({
'model': self.current_model,
'scaler': self.scaler,
'model_type': self.current_model_type
}, file_path)
self.status_var.set(f"模型已保存到 {os.path.basename(file_path)}")
except Exception as e:
self.status_var.set(f"保存失败: {str(e)}")
def load_model(self):
"""加载模型"""
file_path = filedialog.askopenfilename(
filetypes=[("Joblib文件", "*.joblib")],
title="加载模型"
)
if not file_path:
return
try:
data = joblib.load(file_path)
self.current_model = data['model']
self.scaler = data['scaler']
self.current_model_type = data['model_type']
# 更新UI
model_name = MODEL_METADATA[self.current_model_type][0]
self.model_var.set(model_name)
self.model_label.config(text=f"{model_name} (已加载)")
self.status_var.set(f"模型 {model_name} 加载成功")
self.clear_canvas()
except Exception as e:
self.status_var.set(f"加载失败: {str(e)}")
def run(self):
"""运行应用"""
self.root.mainloop()
if __name__ == "__main__":
digits = load_digits()
root = tk.Tk()
app = HandwritingBoard(root, ModelFactory, digits)
app.run()
```
### 关键改进:
1. **优雅处理依赖缺失**:
```python
HOG_INSTALLED = False
try:
from skimage.feature import hog
HOG_INSTALLED = True
except ImportError:
print("警告: 未安装scikit-image库,无法使用HOG特征")
SEABORN_INSTALLED = False
try:
import seaborn as sns
SEABORN_INSTALLED = True
except ImportError:
print("警告: 未安装seaborn库,无法使用高级可视化")
```
2. **条件使用HOG特征**:
```python
@staticmethod
def extract_features(images):
"""提取高级特征(如果HOG可用则使用HOG + 原始像素)"""
if HOG_INSTALLED:
fd, _ = hog(...)
combined = np.concatenate([pixel_features, fd])
else:
combined = pixel_features
```
3. **备选可视化方案**:
```python
if SEABORN_INSTALLED:
sns.heatmap(confusion, annot=True, fmt='d', cmap='Blues')
else:
plt.imshow(confusion, cmap='Blues')
for i in range(confusion.shape[0]):
for j in range(confusion.shape[1]):
plt.text(j, i, confusion[i, j],
ha="center", va="center", color="black")
```
4. **依赖检查与安装提示**:
- 在程序开始时检查所有可选依赖
- 为缺失的依赖提供明确的安装提示
- 程序核心功能在缺少可选依赖时仍能工作
### 安装所有依赖:
```bash
# 安装核心依赖
pip install numpy matplotlib pandas scikit-learn pillow opencv-python
# 安装可选依赖
pip install scikit-image seaborn xgboost lightgbm joblib
```
### 使用建议:
1. **最小环境**:只需安装核心依赖即可运行基本功能
2. **完整功能**:安装所有依赖以获得最佳性能
3. **错误处理**:当缺少可选依赖时,程序会显示警告但继续运行
4. **模型选择**:优先选择不需要额外依赖的模型(如SVM、随机森林)
这些改进使代码更具鲁棒性,能够在不同环境中运行,同时提供清晰的错误信息和备选方案。
阅读全文