# -*- coding: utf-8 -*-
"""
Created on Fri Jan 23 14:22:38 2026
@author: 1103596
"""
# -*- coding: utf-8 -*-
"""
✅ CNN特征可视化系统 · Spyder 兼容安全版(纯4空格缩进|无Tab|免报错)
✅ 功能完整:自动识别L/T/十字结构 → 生成带箭头GIF → 输出教学PPTX
✅ 已修复:
• IndentationError(全手工4空格对齐)
• AssertionError帧尺寸不一致(PIL强制resize到480x480)
• object类型.npy拦截|(8,8)校验|PNG尺寸归一化
"""
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import os
from datetime import datetime
import imageio.v2 as imageio
from pptx import Presentation
from pptx.util import Inches, Pt
from pptx.dml.color import RGBColor
from pptx.enum.text import PP_ALIGN
from scipy.ndimage import zoom
from PIL import Image
# -----------------------------
# 🛠️ 配置 & 路径
# -----------------------------
SAVE_DIR = "."
INPUT_FILE = "my_layout.npy"
FULL_INPUT = os.path.join(SAVE_DIR, INPUT_FILE)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
OUTPUT_GIF = f"cnn_activation_path_{timestamp}.gif"
OUTPUT_PPTX = f"cnn_tutorial_{timestamp}.pptx"
PATH_GIF = os.path.join(SAVE_DIR, OUTPUT_GIF)
PATH_PPTX = os.path.join(SAVE_DIR, OUTPUT_PPTX)
# -----------------------------
# 🔍 安全加载用户版图(四重防护)
# -----------------------------
def load_layout():
if not os.path.exists(FULL_INPUT):
print(f"⚠️ 未找到 '{INPUT_FILE}',正在生成示例 L 形结构...")
ex = np.zeros((8, 8), dtype=np.float32)
ex[6, 6:8] = 1
ex[6:8, 6] = 1
np.save(FULL_INPUT, ex)
print(f"✅ 已创建示例文件:{FULL_INPUT}")
return ex
try:
arr = np.load(FULL_INPUT)
except Exception as e:
raise RuntimeError(f"❌ 读取 {INPUT_FILE} 失败:{e}")
if arr.dtype == object:
raise ValueError(f"❌ {INPUT_FILE} 是 object 类型数组,请用 np.array([[...]]) 正确保存!")
if arr.shape != (8, 8):
raise ValueError(f"❌ {INPUT_FILE} 形状必须是 (8, 8),实际为 {arr.shape}")
if not np.issubdtype(arr.dtype, np.number):
raise TypeError(f"❌ {INPUT_FILE} 数据类型必须为数字,当前为 {arr.dtype}")
if not np.all(np.isin(arr, [0, 1])):
bad = np.unique(arr[(arr != 0) & (arr != 1)])
raise ValueError(f"❌ {INPUT_FILE} 含非法值 {bad.tolist()},仅允许 0 或 1")
arr = arr.astype(np.float32)
arr = (arr > 0).astype(np.float32)
print(f"✅ 成功加载用户版图:{INPUT_FILE}({arr.shape}, {arr.dtype})")
return arr
# -----------------------------
# 🧩 自动识别几何结构(L/T/十字/线/点)
# -----------------------------
def detect_shape(arr):
from collections import Counter
ones = np.argwhere(arr == 1)
if len(ones) == 0:
return "empty", 1.0, {}
if len(ones) == 1:
return "point", 1.0, {"center": tuple(ones[0])}
rows, cols = ones[:, 0], ones[:, 1]
if len(np.unique(rows)) == 1 and len(cols) >= 3:
return "horizontal_line", min(1.0, 0.7 + 0.3 * (len(cols) / 8)), {"row": rows[0]}
if len(np.unique(cols)) == 1 and len(rows) >= 3:
return "vertical_line", min(1.0, 0.7 + 0.3 * (len(rows) / 8)), {"col": cols[0]}
def get_connected_components(binary):
labeled = np.zeros_like(binary, dtype=int)
label = 1
for r in range(8):
for c in range(8):
if binary[r, c] and labeled[r, c] == 0:
stack = [(r, c)]
labeled[r, c] = label
while stack:
cr, cc = stack.pop()
for dr, dc in [(-1,0),(1,0),(0,-1),(0,1),(-1,-1),(-1,1),(1,-1),(1,1)]:
nr, nc = cr + dr, cc + dc
if 0 <= nr < 8 and 0 <= nc < 8 and binary[nr, nc] and labeled[nr, nc] == 0:
labeled[nr, nc] = label
stack.append((nr, nc))
label += 1
return labeled, label - 1
labeled, n_cc = get_connected_components(arr)
if n_cc > 1:
return "multiple_shapes", 0.5, {"components": n_cc}
center_r, center_c = int(round(np.mean(rows))), int(round(np.mean(cols)))
center_r = np.clip(center_r, 0, 7)
center_c = np.clip(center_c, 0, 7)
neighbors = [
(center_r-1, center_c),
(center_r+1, center_c),
(center_r, center_c-1),
(center_r, center_c+1),
]
nb_vals = [arr[r, c] if 0<=r<8 and 0<=c<8 else 0 for r, c in neighbors]
up, down, left, right = nb_vals
if arr[center_r, center_c] == 1 and up and down and left and right:
return "cross", 0.95, {"center": (center_r, center_c)}
if arr[center_r, center_c] == 1 and sum(nb_vals) == 3:
missing = ["up","down","left","right"][nb_vals.index(0)]
return "t_shape", 0.9, {"center": (center_r, center_c), "missing": missing}
corners = []
for r, c in ones:
if r < 7 and c < 7 and arr[r, c+1] and arr[r+1, c]:
corners.append((r, c))
if len(corners) >= 1:
row_span = rows.max() - rows.min() + 1
col_span = cols.max() - cols.min() + 1
if max(row_span, col_span) >= 4 and min(row_span, col_span) in [2, 3]:
return "l_shape", 0.85, {"corner": corners[0], "span": (row_span, col_span)}
return "irregular", 0.6, {"pixel_count": len(ones)}
# -----------------------------
# ⚙️ 卷积核定义
# -----------------------------
kernels1 = torch.tensor([
[[[0, 0, 0], [1, 1, 1], [0, 0, 0]]],
[[[0, 1, 0], [0, 1, 0], [0, 1, 0]]],
[[[0, 0, 0], [0, 1, 1], [0, 1, 0]]]
], dtype=torch.float32)
kernels2 = torch.zeros(2, 3, 3, 3, dtype=torch.float32)
kernels2[0, 2] = torch.tensor([[0, 0, 0],
[0, 1, 1],
[0, 1, 0]])
kernels2[1] = torch.tensor([[[0.5]], [[0.3]], [[0.2]]])
# -----------------------------
# 🔼 主流程
# -----------------------------
try:
layout_np = load_layout()
except Exception as e:
print("\n🛑 程序终止:", str(e))
exit(1)
shape_name, confidence, meta = detect_shape(layout_np)
print(f"🔍 自动识别结构:{shape_name.upper()}(置信度 {confidence:.2f})")
if meta:
print(" 📌 细节:", meta)
layout = torch.tensor(layout_np, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
x1 = F.relu(F.conv2d(layout, kernels1, padding=1))
x2 = F.relu(F.conv2d(x1, kernels2, padding=1))
data = [
layout_np,
x1[0, 0].detach().numpy(),
x1[0, 1].detach().numpy(),
x1[0, 2].detach().numpy(),
x2[0, 0].detach().numpy(),
x2[0, 1].detach().numpy()
]
for i in range(len(data)):
if data[i].shape != (8, 8):
data[i] = zoom(data[i], (8/data[i].shape[0], 8/data[i].shape[1]), order=1)
vmax_global = max(d.max() for d in data[1:]) if len(data) > 1 else 1
# -----------------------------
# 🎨 绘图工具函数
# -----------------------------
def add_annotations(ax, im_array, k=3):
idxs = np.argsort(im_array.flat)[::-1][:k]
coords = [(i // 8, i % 8) for i in idxs]
for r, c in coords:
ax.plot(c, r, 'o', color='red', markersize=8, markeredgecolor='white', markeredgewidth=1.2)
ax.text(c, r, f'({c},{r})\n{im_array[r,c]:.2f}',
ha='center', va='center', fontsize=7, color='white', fontweight='bold',
bbox=dict(boxstyle='round,pad=0.2', facecolor='black', alpha=0.7))
def draw_arrow(ax, start, end, label="", color="gold"):
sy, sx = start
ey, ex = end
ax.plot(sx, sy, 'o', color=color, markersize=6, alpha=0.9)
ax.plot(ex, ey, 's', color=color, markersize=8, fillstyle='full',
markeredgecolor='white', markeredgewidth=1.2, alpha=0.9)
ax.annotate('', xy=(ex, ey), xytext=(sx, sy),
arrowprops=dict(arrowstyle='->', color=color, lw=2.5,
connectionstyle="arc3,rad=0.1", alpha=0.9))
dx = 0.3 if ex >= sx else -0.3
dy = 0.3 if ey >= sy else -0.3
ax.text(ex + dx, ey + dy, label,
fontsize=8, color=color, fontweight='bold',
bbox=dict(boxstyle='round,pad=0.2', facecolor='black', alpha=0.6))
# -----------------------------
# 🎞️ 生成 GIF 帧(✅ Spyder 安全版:无缩进错误|PIL 强制 resize)
# -----------------------------
frames = []
temp_dir = os.path.join(SAVE_DIR, "_tmp_cnn_viz")
os.makedirs(temp_dir, exist_ok=True)
scenes = [
{
"title": f"🎯 第0步:原始版图输入 —— {shape_name.upper().replace('_', ' ')}",
"data_idx": None,
"subtitle": f"用户提供的 8×8 结构({len(np.argwhere(layout_np==1))} 个像素)|AI判定:{shape_name.upper().replace('_', ' ')}"
},
{"title": "🔍 第1层:横向边缘检测完成 ✔", "data_idx": 1, "subtitle": "水平线核响应最强位置"},
{"title": "🔍 第1层:纵向边缘检测完成 ✔", "data_idx": 2, "subtitle": "竖直线核响应最强位置"},
{"title": "🔥 第1层:L形拐角被成功检测到!⚠️", "data_idx": 3, "subtitle": "L核对右下角结构高度敏感"},
{"title": "🔁 第2层:再次确认L结构存在 ✅", "data_idx": 4, "subtitle": "第二层强化关键区域响应"},
{"title": "📊 第2层:多特征融合输出最终得分", "data_idx": 5, "subtitle": "加权融合生成最终决策分数"},
]
for i, s in enumerate(scenes):
if s["data_idx"] is not None:
arr = data[s["data_idx"]]
idx = np.argmax(arr)
r, c = idx // 8, idx % 8
val = arr[r, c]
color = ["gold", "cyan", "limegreen", "magenta", "orange"][min(i-1, 4)]
s["arrow"] = {"start": (r, c), "end": (r, c), "label": f"→ {val:.1f}", "color": color}
else:
s["arrow"] = None
for i, s in enumerate(scenes):
fig, ax = plt.subplots(figsize=(7, 7))
if s["data_idx"] is not None:
im_arr = data[s["data_idx"]]
cmap = "hot" if s["data_idx"] > 0 else "gray"
vmin, vmax = (0, 1) if s["data_idx"] == 0 else (0, vmax_global)
ax.imshow(im_arr, cmap=cmap, vmin=vmin, vmax=vmax, interpolation='none')
add_annotations(ax, im_arr, k=3)
if s["data_idx"] > 0:
plt.colorbar(plt.cm.ScalarMappable(plt.cm.colors.Normalize(vmin, vmax), cmap=cmap),
ax=ax, shrink=0.8, pad=0.02)
else:
ax.imshow(data[0], cmap="gray", vmin=0, vmax=1, interpolation='none')
ax.set_xticks(np.arange(8))
ax.set_yticks(np.arange(8))
ax.tick_params(length=0, labelsize=8)
ax.grid(True, color='limegreen', linewidth=0.5, alpha=0.6)
ax.text(0.5, 1.05, s["title"], transform=ax.transAxes, fontsize=14, ha='center', va='center',
bbox=dict(boxstyle='round,pad=0.5', facecolor='wheat', alpha=0.9), fontweight='bold')
ax.text(0.5, -0.12, s["subtitle"], transform=ax.transAxes, fontsize=10, ha='center', va='top',
color='darkslategray', style='italic')
if s["arrow"]:
a = s["arrow"]
draw_arrow(ax, a["start"], a["end"], a["label"], a["color"])
plt.tight_layout(rect=[0, 0, 1, 0.90])
fp = os.path.join(temp_dir, f"frame_{i:03d}.png")
plt.savefig(fp, dpi=150, bbox_inches='tight')
plt.close(fig)
# ✅ 【Spyder 安全核心】—— 强制 resize 到 480x480(彻底解决尺寸不一致)
img = imageio.imread(fp)
img_resized = np.array(Image.fromarray(img).resize((480, 480), Image.Resampling.LANCZOS))
frames.append(img_resized)
imageio.mimsave(PATH_GIF, frames, duration=1000, loop=0)
print(f"🎉 GIF 已生成:{os.path.abspath(PATH_GIF)}")
# -----------------------------
# 📄 生成 PPTX
# -----------------------------
prs = Presentation()
for i, s in enumerate(scenes):
slide = prs.slides.add_slide(prs.slide_layouts[1])
title = slide.shapes.title
title.text = s["title"]
title.text_frame.paragraphs[0].font.size = Pt(24)
title.text_frame.paragraphs[0].font.bold = True
left, top, width, height = Inches(0.5), Inches(1.5), Inches(6), Inches(4.5)
fp = os.path.join(temp_dir, f"frame_{i:03d}.png")
slide.shapes.add_picture(fp, left, top, width, height)
tx = slide.shapes.add_textbox(Inches(6.7), Inches(1.5), Inches(3.0), Inches(4.5))
tf = tx.text_frame
p = tf.add_paragraph()
p.text = s["subtitle"]
p.font.size = Pt(14)
p = tf.add_paragraph()
p.text = "🔧 技术备注:" + {
0: "输入:8×8 二值图像",
1: "卷积核:[0,1,1,1,0] 检测水平连续段",
2: "卷积核:[0;1;1;1;0] 检测垂直连续段",
3: "L核:匹配 ┘ 类型结构",
4: "第二层使用相同L核进行再确认",
5: "加权融合:Horiz×0.5 + Vert×0.3 + L×0.2"
}[i]
p.font.size = Pt(11)
p.font.color.rgb = RGBColor(100, 100, 100)
foot = slide.shapes.add_textbox(Inches(0.5), Inches(6.5), Inches(9), Inches(0.3))
ftf = foot.text_frame
ftf.text = f"Slide {i+1}/{len(scenes)} | CNN Feature Visualization | {datetime.now().year}"
ftf.paragraphs[0].font.size = Pt(9)
ftf.paragraphs[0].font.color.rgb = RGBColor(150, 150, 150)
ftf.paragraphs[0].alignment = PP_ALIGN.LEFT
prs.save(PATH_PPTX)
print(f"📄 PPTX 已生成:{os.path.abspath(PATH_PPTX)}")
print("\n✅ 全部完成!您现在拥有:")
print(f" • 动态 GIF:{os.path.basename(PATH_GIF)}(480×480,带箭头路径)")
print(f" • 教学 PPTX:{os.path.basename(PATH_PPTX)}(6页,图文并茂)")
print("\n💡 下一步建议:")
print(" • 修改 my_layout.npy 后重新运行,观察不同结构的响应差异")
print(" • 将 PPTX 导出为 PDF 或视频用于分享")
# === 错误捕获增强(防闪退)===
import sys
import traceback
sys.excepthook = lambda *args: (traceback.print_exception(*args), input("❌ 按回车键退出..."))