jax study notes[20]

文章目录

  • one neuron in the ANN
  • references

one neuron in the ANN

  1. the design for neuron model of ANN was inspired by neuronbiology,the core struct include input,weight,activation function and output.
  2. the output of a neuron can be simulated with the following mathmatical model.
    y=f(∑i=1nwixi+b) y = f\left(\sum_{i=1}^{n} w_i x_i + b \right) y=f(i=1nwixi+b)
  • xix_ixi:the i-th input from neurons in the previous level or the orignal input data.
  • wiw_iwi:the weight corresponding to the input xix_ixi ,which means the degree of importance.
  • bbb:the Bias to adjust the activiation threshold of neurons
  • f(⋅)f(\cdot)f():the activiation function,which enable the network to learn complex patterns through applying nonlinearity.
  1. the inputted weighted summation,which is linear transformation, be firstly computed.

z=∑i=1nwixi+b z = \sum_{i=1}^{n} w_i x_i + b z=i=1nwixi+b
secondly,that summation will be taken to activiation function f(z)f(z)f(z),through the following nonlinear function in order to similate the complicated function.

  • Sigmoidf(z)=11+e−zf(z) = \frac{1}{1 + e^{-z}}f(z)=1+ez1,that output range from 0 to 1 ,can be apply in probability.
  • ReLUf(z)=max⁡(0,z)f(z) = \max(0, z)f(z)=max(0,z),can settle the matter that the vanishing gradient,to be used in the hidden level widely.
  • Tanhf(z)=tanh⁡(z)f(z) = \tanh(z)f(z)=tanh(z),that output range from -1 to 1,used to centralized data.
  • Softmax:multi-classification output level,the output will be convert to the probability distribution.
  1. to handle the batch data such as a matrix X\mathbf{X}X,the following form for computing will be apply.

y=f(Xw+b) \mathbf{y} = f(\mathbf{X} \mathbf{w} + \mathbf{b}) y=f(Xw+b)

w\mathbf{w}w is weighted vector ,the b\mathbf{b}b is the bias vector.
that computation of matrix multiplication can be accelerated with GPU .
5. the entire process of a neuron’s action can be explained with the following python code using JAX.

import jax
import jax.numpy as jnp
from jax import grad, vmap, jit
import matplotlib.pyplot as plt

# ------------------------------
# 1. 定义神经元模型
# ------------------------------
def neuron(params, x):
    """带激活函数的单个神经元"""
    z = jnp.dot(x, params['w']) + params['b']  # 加权和 + 偏置
    return jax.nn.sigmoid(z)  # Sigmoid激活函数 (可替换为 relu/tanh)

# ------------------------------
# 2. 初始化参数和超参数
# ------------------------------
input_dim = 2  # 输入特征维度
learning_rate = 0.1
epochs = 1000

# 随机初始化权重和偏置
key = jax.random.PRNGKey(42)
params = {
    'w': jax.random.normal(key, (input_dim,)),  # 权重向量
    'b': 0.0  # 偏置
}

# ------------------------------
# 3. 生成合成数据 (OR逻辑门)
# ------------------------------
X = jnp.array([
    [0, 0],
    [0, 1],
    [1, 0],
    [1, 1]
])
y = jnp.array([0, 1, 1, 1])  # OR逻辑门的输出

# ------------------------------
# 4. 定义损失函数和梯度计算
# ------------------------------
@jit  # JIT编译加速
def loss_fn(params, X_batch, y_batch):
    """均方误差损失"""
    predictions = vmap(neuron, in_axes=(None, 0))(params, X_batch)  # 批量预测
    return jnp.mean((predictions - y_batch) ** 2)

compute_grads = grad(loss_fn)  # 自动微分函数

# ------------------------------
# 5. 训练循环
# ------------------------------
loss_history = []

for epoch in range(epochs):
    # 计算梯度和损失
    grads = compute_grads(params, X, y)
    loss = loss_fn(params, X, y)
    loss_history.append(loss)
    
    # 梯度下降更新参数
    params = {
        'w': params['w'] - learning_rate * grads['w'],
        'b': params['b'] - learning_rate * grads['b']
    }
    
    # 每100轮打印进度
    if epoch % 100 == 0:
        print(f"Epoch {epoch}, Loss: {loss:.4f}")

# ------------------------------
# 6. 结果可视化
# ------------------------------
# 绘制损失曲线
plt.plot(loss_history)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training Loss")
plt.show()

# ------------------------------
# 7. 测试预测
# ------------------------------
# 定义批量预测函数
predict = vmap(neuron, in_axes=(None, 0))

# 在训练数据上测试
predictions = predict(params, X)
print("\nPredictions:")
for x, pred in zip(X, predictions):
    print(f"Input: {x}, Output: {pred:.4f} → Predicted class: {int(pred > 0.5)}")

# ------------------------------
# 8. 输出训练后的参数
# ------------------------------
print("\nTrained parameters:")
print(f"weights: {params['w']}")
print(f"bias: {params['b']}")

在这里插入图片描述

references

  1. deepseek
  2. 《神经网络与机器学习》
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

翻译之海

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值