目录

漫话机器学习系列125.普拉托变换Platt-Scaling

【漫话机器学习系列】125.普拉托变换(Platt Scaling)

https://i-blog.csdnimg.cn/direct/366a01e781cd4a598fee01ae6bccd67d.png

普拉托变换(Platt Scaling)详解

1. 什么是普拉托变换(Platt Scaling)?

普拉托变换(Platt Scaling)是一种用于二分类支持向量机(SVM)的后处理方法,它的作用是将 SVM 的输出分数转换为概率值。由于 SVM 本身并不直接输出概率,而是输出一个决策值(即 f(x)),因此我们需要使用某种方法将这个决策值映射到 [0,1] 范围的概率值。普拉托变换正是通过拟合一个 Sigmoid 函数 来完成这个任务。


2. 线性支持向量机的输出问题

在 SVM 训练完成后,它会为每个输入样本 x 计算一个 决策函数值

https://latex.csdn.net/eq?f%28x%29%20%3D%20w%5ET%20x%20+%20b

这个值表示样本相对于决策边界的距离。一般来说:

  • 如果 f(x) > 0,则分类为正类( y = 1 )。
  • 如果 f(x) < 0,则分类为负类( y = 0 )。

但这个值只是一个距离,而不是概率。例如,f(x) = 2 和 f(x) = 100 都表示正类,但哪个更有可能是正类呢?这个问题就是 SVM 无法直接提供概率输出的原因。因此,我们需要一个方法将 SVM 的输出转换为概率。


3. 普拉托变换的基本思想

普拉托变换的核心思想是 使用 Sigmoid 函数拟合 SVM 的输出,使其变成概率

https://latex.csdn.net/eq?P%28y%3D1%7Cx%29%20%3D%20%5Cfrac%7B1%7D%7B1%20&plus;%20e%5E%7BA%20f%28x%29%20&plus;%20B%7D%7D

其中:

  • f(x) 是 SVM 计算得到的决策值(即得分)。
  • A 和 B 是两个需要拟合的参数。

换句话说,我们希望通过训练找到 A 和 B ,使得 SVM 的输出 f(x) 能够很好地拟合样本的真实概率分布。


4. 如何训练 A 和 B?

训练 A 和 B 的过程如下:

  1. 在数据集上训练 SVM ,获取所有样本的 得分 f(x) 及其真实类别 y。
  2. 使用交叉验证 构建一个 逻辑回归模型 ,以 f(x) 为输入,y 为输出,拟合 Sigmoid 函数 的参数A 和 B。
  3. 计算损失函数
    • 普拉托变换使用 最大似然估计(MLE) 进行参数拟合。

    • 具体来说,最小化以下 交叉熵损失

      https://latex.csdn.net/eq?v%5Csum_%7Bi%3D1%7D%5E%7BN%7D%20%5Cleft%5B%20y_i%20%5Clog%20P%28y%3D1%7Cx_i%29%20&plus;%20%281%20-%20y_i%29%20%5Clog%20%281%20-%20P%28y%3D1%7Cx_i%29%29%20%5Cright%5Di%3D1

    • 这个公式与 逻辑回归的损失函数 一样,因此可以使用标准的优化方法(如梯度下降)来求解 A 和 B。


5. 为什么要使用普拉托变换?

普拉托变换是 SVM 最常用 的概率校准方法,原因如下:

  • 简单高效 :只需要额外训练一个小的逻辑回归模型(拟合 Sigmoid),计算开销较小。
  • 适用于 SVM 及其他分类器 :虽然最初是为 SVM 提出的,但 Platt Scaling 也可用于其他分类器(如神经网络、随机森林等)。
  • 提高模型可解释性 :许多应用(如医学、金融等)需要输出概率,而不仅仅是分类结果。

6. Platt Scaling 的局限性

尽管普拉托变换在很多场景下都能有效校准概率输出,但它也有一些局限:

  1. 对样本不均衡较敏感 :如果数据集的类别不均衡,拟合的 Sigmoid 可能会偏向多数类。
  2. 假设数据符合 Sigmoid 分布 :Platt Scaling 假设 SVM 输出的决策值可以用 Sigmoid 函数拟合,但在某些复杂分布下,可能效果不好。
  3. 需要额外的数据 :Platt Scaling 需要使用交叉验证数据来训练 A 和 B,这可能会浪费一部分训练数据。

如果数据集较大、类别不均衡或者 Sigmoid 拟合效果不好,通常可以考虑 Isotonic Regression(等温回归) 作为替代方法。


7. 代码实现(Python 示例)

在实际应用中,Scikit-learn 提供了 Platt Scaling 的实现,示例如下:

from sklearn.svm import SVC
from sklearn.calibration import CalibratedClassifierCV
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split

# 生成数据
X, y = make_classification(n_samples=1000, n_features=20, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# 训练 SVM(不带概率)
svm = SVC(kernel='linear', probability=False)
svm.fit(X_train, y_train)

# 使用 Platt Scaling 校准概率
svm_calibrated = CalibratedClassifierCV(svm, method='sigmoid', cv=5)
svm_calibrated.fit(X_train, y_train)

# 预测概率
probabilities = svm_calibrated.predict_proba(X_test)

print("前 5 个样本的预测概率:")
print(probabilities[:5])

运行结果

 5 个样本的预测概率
[[0.39216467 0.60783533]
 [0.19236017 0.80763983]
 [0.51329367 0.48670633]
 [0.24076157 0.75923843]
 [0.0697871  0.9302129 ]]

8. 结论

普拉托变换(Platt Scaling)是一种 简单有效的概率校准方法 ,特别适用于 SVM 这种不直接输出概率的分类器。它通过 拟合 Sigmoid 函数 ,将分类得分映射到 [0,1] 范围,使得输出更具可解释性。

在实际应用中,我们可以:

  • 需要概率输出的任务 (如医疗诊断、金融风控)中使用 Platt Scaling。
  • 如果数据 类别不均衡 ,可以尝试 Isotonic Regression 作为替代方案。

希望这篇文章能够帮助你更好地理解 Platt Scaling 的原理和应用!