SOM
基本原理
计算过程
Python实现
(1)SOM类实现
import numpy as np
import matplotlib.pyplot as plt
class SOM:
def __init__(self, grid_size=(10, 10), n_features=2, sigma=1.0, learning_rate=0.5, max_iter=100):
self.grid_size = grid_size
self.n_features = n_features
self.sigma = sigma
self.learning_rate = learning_rate
self.max_iter = max_iter
self.weights = np.random.rand(grid_size[0], grid_size[1], n_features)
def _neighborhood(self, i, j, t):
"""计算邻域函数"""
distances = np.sqrt((np.arange(self.grid_size[0])[:, None] - i)**2 +
(np.arange(self.grid_size[1])[None, :] - j)**2)
return np.exp(-distances**2 / (2 * (self.sigma * (1 - t / self.max_iter))**2))
def train(self, X):
"""训练SOM"""
for t in range(self.max_iter):
# 随机选择一个样本
idx = np.random.randint(0, len(X))
x = X[idx]
# 计算BMU
distances = np.linalg.norm(self.weights - x, axis=2)
bmu_i, bmu_j = np.unravel_index(np.argmin(distances), distances.shape)
# 更新权重
h = self._neighborhood(bmu_i, bmu_j, t)
for i in range(self.grid_size[0]):
for j in range(self.grid_size[1]):
self.weights[i, j] += self.learning_rate * h[i, j] * (x - self.weights[i, j])
# 衰减学习率和邻域半径
self.learning_rate = self.learning_rate * 0.9
self.sigma = self.sigma * 0.9
def predict(self, X):
"""预测每个样本的BMU"""
predictions = []
for x in X:
distances = np.linalg.norm(self.weights - x, axis=2)
bmu_i, bmu_j = np.unravel_index(np.argmin(distances), distances.shape)
predictions.append((bmu_i, bmu_j))
return np.array(predictions)
(2)示例数据
# 生成二维数据(环形分布)
np.random.seed(42)
theta = np.random.uniform(0, 2*np.pi, 100)
r = np.random.uniform(0, 1, 100)
X = np.column_stack([r * np.cos(theta), r * np.sin(theta)])
# 训练SOM
som = SOM(grid_size=(10, 10), n_features=2, sigma=1.0, learning_rate=0.5, max_iter=100)
som.train(X)
# 预测BMU
predictions = som.predict(X)
(3)可视化结果
# 绘制SOM网格
plt.figure(figsize=(8, 8))
for i in range(som.grid_size[0]):
for j in range(som.grid_size[1]):
plt.scatter(i, j, c=som.weights[i, j], cmap='viridis', s=100)
plt.title("SOM Weights")
plt.colorbar()
plt.show()
# 绘制数据点和BMU
plt.figure(figsize=(8, 8))
plt.scatter(X[:, 0], X[:, 1], c='blue', alpha=0.5, label='Data Points')
for idx, (i, j) in enumerate(predictions):
plt.scatter(i, j, c='red', marker='x', s=50)
plt.title("Data Points and BMUs")
plt.legend()
plt.show()