前馈神经网络-参数学习梯度下降法-多分类任务
前馈神经网络 - 参数学习(梯度下降法 - 多分类任务)
之前的博文中,对于前馈神经网络,我们学习过基于二分类任务的参数学习,本文我们来学习两个多分类任务的参数学习的例子,来进一步加深对反向传播算法的理解。
例子1:前馈神经网络在多分类任务中的参数学习(梯度下降法)
以下通过一个 3分类任务 的具体例子,详细说明前馈神经网络如何使用梯度下降法进行参数学习。假设网络结构如下:
- 输入层 :2个神经元(输入特征 x1,x2)
- 隐藏层 :3个神经元(使用 ReLU 激活函数)
- 输出层 :3个神经元(使用 Softmax 激活函数)
1. 网络参数初始化
假设输入样本为 x=[1,0],真实标签为类别 2(标签编码为 one-hot 向量 y=[0,0,1])。
参数定义 :
输入层到隐藏层 :
隐藏层到输出层 :
2. 前向传播
隐藏层计算 :
线性变换:
ReLU 激活:
输出层计算 :
线性变换:
Softmax 激活(转换为概率):
3. 损失计算(交叉熵损失)
真实标签为类别 2(y=[0,0,1]),预测概率为 y^≈[0.31,0.20,0.49]:
4. 反向传播计算梯度
输出层梯度
Softmax + 交叉熵的梯度简化 :
参数梯度 :
隐藏层梯度
误差信号传播 :
ReLU 导数:
上游误差:
逐元素相乘:
参数梯度 :
5. 参数更新(学习率 η=0.1)
输出层参数 :
隐藏层参数 :
6. 验证更新后的预测
更新参数后,重新进行前向传播:
- 隐藏层输出可能更接近真实类别 2,损失 L 应减小。例如,若新的预测概率为 [0.25,0.15,0.60],则损失为 −log(0.60)≈0.511<0.713,表明参数学习有效。
关键总结
Softmax + 交叉熵 :
- 多分类任务的标准组合,梯度计算简化为
。
- 多分类任务的标准组合,梯度计算简化为
ReLU 导数特性 :
- 激活导数为 0 或 1,加速计算并缓解梯度消失问题。
梯度下降步骤 :
- 通过链式法则逐层计算梯度,参数沿负梯度方向更新。
实际应用注意点 :
- 学习率需调参(过大震荡,过小收敛慢)。
- 参数初始化影响收敛(如 Xavier 初始化)。
例子2: 一个简单的多层感知器(MLP)
下面给出一个基于多分类任务的前馈神经网络参数学习过程,展示如何使用梯度下降法(GD)结合反向传播计算梯度,逐步优化参数。我们以一个简单的多层感知器(MLP)来处理三分类问题为例。
1. 网络结构设定
假设我们的任务是将输入样本分为三个类别(类别1、类别2、类别3)。网络结构如下:
- 输入层 :假设输入向量 x∈R^d(例如 d=4)。
- 隐藏层 :设有一层隐藏层,包含 h 个神经元,激活函数使用 ReLU。
- 输出层 :有 3 个神经元,对应三个类别,激活函数采用 Softmax,将输出转换为概率分布。
具体数学描述:
- 隐藏层:
- 输出层:
Softmax 的定义为:
2. 损失函数
对于多分类任务,我们通常采用多类别交叉熵损失函数。假设真实标签 y 使用 one-hot 编码,交叉熵损失为:
3. 具体例子
网络参数设定(示例数值)
假设输入维度 d=4,隐藏层神经元数量 h=3:
- 输入 x = [1.0, 0.5, -1.0, 2.0]^T。
- 隐藏层权重:
隐藏层偏置:
- 输出层权重:
输出层偏置:
假设真实标签为类别3,即 one-hot 编码 y = [0, 0, 1]^T。
前向传播计算
隐藏层计算 :
- 计算
逐个神经元计算:
神经元1:
神经元2:
神经元3:
得到
.
- 通过 ReLU 激活函数计算 a^{(1)}:
输出层计算 :
- 计算
对每个类别计算:
类别1:
类别2:
类别3:
- 通过 Softmax 激活函数计算预测概率:
此时,模型预测概率为:
- 类别1:42.1%,
- 类别2:27.7%,
- 类别3:30.2%。
假设真实标签为类别3,则 one-hot 编码 y = [0,0,1]^T。
4. 损失计算
采用多类别交叉熵损失函数:
由于 y=[0,0,1] ,损失为:
5. 反向传播与参数更新(简要描述)
输出层梯度 :
隐藏层梯度 :
参数更新 : 使用梯度下降法(例如学习率 η),更新各层参数:
经过多次迭代和大量样本训练,网络参数逐渐调整使得损失函数最小化,模型预测准确率不断提升。
总结
利用梯度下降法对前馈神经网络进行参数学习的过程包括:
- 前向传播 :将输入数据通过网络各层计算,得到预测概率。
- 损失计算 :利用多类别交叉熵损失函数衡量预测与真实标签之间的差距。
- 反向传播 :使用链式法则,从输出层到隐藏层逐层计算梯度。
- 参数更新 :依据计算得到的梯度,采用梯度下降(或其变种)更新各层权重和偏置。
通过具体的多分类任务示例(例如一个三类别分类问题),我们可以看到如何从输入、前向传播、损失计算、反向传播到参数更新的整个流程,最终实现神经网络参数的优化和任务性能的提升。