adaboost算法java_AdaBoost装袋提升算法

本文详细介绍了AdaBoost算法,它是基于bagging算法的一种迭代方法,通过赋予错误分类样本更高的权重来逐步提高分类准确性。文章通过实例展示了如何利用AdaBoost构建组合分类器,并给出了Java代码实现,最后解释了为何要增加错误分类样本的权重以达到提升整体分类性能的目的。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

介绍

在介绍AdaBoost算法之前,需要了解一个类似的算法,装袋算法(bagging),bagging是一种提高分类准确率的算法,通过给定组合投票的方式,获得最优解。比如你生病了,去n个医院看了n个医生,每个医生给你开了药方,最后的结果中,哪个药方的出现的次数多,那就说明这个药方就越有可能性是最由解,这个很好理解。而bagging算法就是这个思想。

算法原理

而AdaBoost算法的核心思想还是基于bagging算法,但是他又一点点的改进,上面的每个医生的投票结果都是一样的,说明地位平等,如果在这里加上一个权重,大城市的医生权重高点,小县城的医生权重低,这样通过最终计算权重和的方式,会更加的合理,这就是AdaBoost算法。AdaBoost算法是一种迭代算法,只有最终分类误差率小于阈值算法才能停止,针对同一训练集数据训练不同的分类器,我们称弱分类器,最后按照权重和的形式组合起来,构成一个组合分类器,就是一个强分类器了。算法的只要过程:

1、对D训练集数据训练处一个分类器Ci

2、通过分类器Ci对数据进行分类,计算此时误差率

3、把上步骤中的分错的数据的权重提高,分对的权重降低,以此凸显了分错的数据。为什么这么做呢,后面会做出解释。

完整的adaboost算法如下

45e930afc38d55555abbbeffa3b81814.png

最后的sign函数是符号函数,如果最后的值为正,则分为+1类,否则即使-1类。

我们举个例子代入上面的过程,这样能够更好的理解。

adaboost的实现过程:

wp-display-data.php?filename=image00212953420561302596151.png&type=image%2Fpng&width=368&height=320

图中,“+”和“-”分别表示两种类别,在这个过程中,我们使用水平或者垂直的直线作为分类器,来进行分类。

第一步:

wp-display-data.php?filename=image003129534212413025961911302602682.png&type=image%2Fpng&width=835&height=310

根据分类的正确率,得到一个新的样本分布D2­,一个子分类器h1

其中划圈的样本表示被分错的。在右边的途中,比较大的“+”表示对该样本做了加权。

算法最开始给了一个均匀分布 D 。所以h1 里的每个点的值是0.1。ok,当划分后,有三个点划分错了,根据算法误差表达式

381845f32e92a5268dc30e52d619c9c0.png得到

误差为分错了的三个点的值之和,所以ɛ1=(0.1+0.1+0.1)=0.3,而ɑ1 根据表达式 的可以算出来为0.42. 然后就根据算法 把分错的点权值变大。如此迭代,最终完成adaboost算法。

第二步:

wp-display-data.php?filename=image00412953421681302596228.png&type=image%2Fpng&width=837&height=310

根据分类的正确率,得到一个新的样本分布D3,一个子分类器h2

第三步:

wp-display-data.php?filename=image005129534221113025962651302602796.png&type=image%2Fpng&width=483&height=326

得到一个子分类器h3

整合所有子分类器:

wp-display-data.php?filename=image00612953422551302596303.png&type=image%2Fpng&width=626&height=382

因此可以得到整合的结果,从结果中看,及时简单的分类器,组合起来也能获得很好的分类效果,在例子中所有的。后面的代码实现时,举出的也是这个例子,可以做对比,这里有一点比较重要,就是点的权重经过大小变化之后,需要进行归一化,确保总和为1.0,这个容易遗忘。

算法的代码实现

输入测试数据,与上图的例子相对应(数据格式:x坐标 y坐标 已分类结果):

1 5 1

2 3 1

3 1 -1

4 5 -1

5 6 1

6 4 -1

6 7 1

7 6 1

8 7 -1

8 2 -1

Point.java

package DataMining_AdaBoost;

/**

* 坐标点类

*

* @author lyq

*

*/

public class Point {

// 坐标点x坐标

private int x;

// 坐标点y坐标

private int y;

// 坐标点的分类类别

private int classType;

//如果此节点被划错,他的误差率,不能用个数除以总数,因为不同坐标点的权重不一定相等

private double probably;

public Point(int x, int y, int classType){

this.x = x;

this.y = y;

this.classType = classType;

}

public Point(String x, String y, String classType){

this.x = Integer.parseInt(x);

this.y = Integer.parseInt(y);

this.classType = Integer.parseInt(classType);

}

public int getX() {

return x;

}

public void setX(int x) {

this.x = x;

}

public int getY() {

return y;

}

public void setY(int y) {

this.y = y;

}

public int getClassType() {

return classType;

}

public void setClassType(int classType) {

this.classType = classType;

}

public double getProbably() {

return probably;

}

public void setProbably(double probably) {

this.probably = probably;

}

}AdaBoost.java

package DataMining_AdaBoost;

import java.io.BufferedReader;

import java.io.File;

import java.io.FileReader;

import java.io.IOException;

import java.text.MessageFormat;

import java.util.ArrayList;

import java.util.HashMap;

import java.util.Map;

/**

* AdaBoost提升算法工具类

*

* @author lyq

*

*/

public class AdaBoostTool {

// 分类的类别,程序默认为正类1和负类-1

public static final int CLASS_POSITIVE = 1;

public static final int CLASS_NEGTIVE = -1;

// 事先假设的3个分类器(理论上应该重新对数据集进行训练得到)

public static final String CLASSIFICATION1 = "X=2.5";

public static final String CLASSIFICATION2 = "X=7.5";

public static final String CLASSIFICATION3 = "Y=5.5";

// 分类器组

public static final String[] ClASSIFICATION = new String[] {

CLASSIFICATION1, CLASSIFICATION2, CLASSIFICATION3 };

// 分类权重组

private double[] CLASSIFICATION_WEIGHT;

// 测试数据文件地址

private String filePath;

// 误差率阈值

private double errorValue;

// 所有的数据点

private ArrayList totalPoint;

public AdaBoostTool(String filePath, double errorValue) {

this.filePath = filePath;

this.errorValue = errorValue;

readDataFile();

}

/**

* 从文件中读取数据

*/

private void readDataFile() {

File file = new File(filePath);

ArrayList dataArray = new ArrayList();

try {

BufferedReader in = new BufferedReader(new FileReader(file));

String str;

String[] tempArray;

while ((str = in.readLine()) != null) {

tempArray = str.split(" ");

dataArray.add(tempArray);

}

in.close();

} catch (IOException e) {

e.getStackTrace();

}

Point temp;

totalPoint = new ArrayList<>();

for (String[] array : dataArray) {

temp = new Point(array[0], array[1], array[2]);

temp.setProbably(1.0 / dataArray.size());

totalPoint.add(temp);

}

}

/**

* 根据当前的误差值算出所得的权重

*

* @param errorValue

* 当前划分的坐标点误差率

* @return

*/

private double calculateWeight(double errorValue) {

double alpha = 0;

double temp = 0;

temp = (1 - errorValue) / errorValue;

alpha = 0.5 * Math.log(temp);

return alpha;

}

/**

* 计算当前划分的误差率

*

* @param pointMap

* 划分之后的点集

* @param weight

* 本次划分得到的分类器权重

* @return

*/

private double calculateErrorValue(

HashMap> pointMap) {

double resultValue = 0;

double temp = 0;

double weight = 0;

int tempClassType;

ArrayList pList;

for (Map.Entry entry : pointMap.entrySet()) {

tempClassType = (int) entry.getKey();

pList = (ArrayList) entry.getValue();

for (Point p : pList) {

temp = p.getProbably();

// 如果划分类型不相等,代表划错了

if (tempClassType != p.getClassType()) {

resultValue += temp;

}

}

}

weight = calculateWeight(resultValue);

for (Map.Entry entry : pointMap.entrySet()) {

tempClassType = (int) entry.getKey();

pList = (ArrayList) entry.getValue();

for (Point p : pList) {

temp = p.getProbably();

// 如果划分类型不相等,代表划错了

if (tempClassType != p.getClassType()) {

// 划错的点的权重比例变大

temp *= Math.exp(weight);

p.setProbably(temp);

} else {

// 划对的点的权重比减小

temp *= Math.exp(-weight);

p.setProbably(temp);

}

}

}

// 如果误差率没有小于阈值,继续处理

dataNormalized();

return resultValue;

}

/**

* 概率做归一化处理

*/

private void dataNormalized() {

double sumProbably = 0;

double temp = 0;

for (Point p : totalPoint) {

sumProbably += p.getProbably();

}

// 归一化处理

for (Point p : totalPoint) {

temp = p.getProbably();

p.setProbably(temp / sumProbably);

}

}

/**

* 用AdaBoost算法得到的组合分类器对数据进行分类

*

*/

public void adaBoostClassify() {

double value = 0;

Point p;

calculateWeightArray();

for (int i = 0; i < ClASSIFICATION.length; i++) {

System.out.println(MessageFormat.format("分类器{0}权重为:{1}", (i+1), CLASSIFICATION_WEIGHT[i]));

}

for (int j = 0; j < totalPoint.size(); j++) {

p = totalPoint.get(j);

value = 0;

for (int i = 0; i < ClASSIFICATION.length; i++) {

value += 1.0 * classifyData(ClASSIFICATION[i], p)

* CLASSIFICATION_WEIGHT[i];

}

//进行符号判断

if (value > 0) {

System.out

.println(MessageFormat.format(

"点({0}, {1})的组合分类结果为:1,该点的实际分类为{2}", p.getX(), p.getY(),

p.getClassType()));

} else {

System.out.println(MessageFormat.format(

"点({0}, {1})的组合分类结果为:-1,该点的实际分类为{2}", p.getX(), p.getY(),

p.getClassType()));

}

}

}

/**

* 计算分类器权重数组

*/

private void calculateWeightArray() {

int tempClassType = 0;

double errorValue = 0;

ArrayList posPointList;

ArrayList negPointList;

HashMap> mapList;

CLASSIFICATION_WEIGHT = new double[ClASSIFICATION.length];

for (int i = 0; i < CLASSIFICATION_WEIGHT.length; i++) {

mapList = new HashMap<>();

posPointList = new ArrayList<>();

negPointList = new ArrayList<>();

for (Point p : totalPoint) {

tempClassType = classifyData(ClASSIFICATION[i], p);

if (tempClassType == CLASS_POSITIVE) {

posPointList.add(p);

} else {

negPointList.add(p);

}

}

mapList.put(CLASS_POSITIVE, posPointList);

mapList.put(CLASS_NEGTIVE, negPointList);

if (i == 0) {

// 最开始的各个点的权重一样,所以传入0,使得e的0次方等于1

errorValue = calculateErrorValue(mapList);

} else {

// 每次把上次计算所得的权重代入,进行概率的扩大或缩小

errorValue = calculateErrorValue(mapList);

}

// 计算当前分类器的所得权重

CLASSIFICATION_WEIGHT[i] = calculateWeight(errorValue);

}

}

/**

* 用各个子分类器进行分类

*

* @param classification

* 分类器名称

* @param p

* 待划分坐标点

* @return

*/

private int classifyData(String classification, Point p) {

// 分割线所属坐标轴

String position;

// 分割线的值

double value = 0;

double posProbably = 0;

double negProbably = 0;

// 划分是否是大于一边的划分

boolean isLarger = false;

String[] array;

ArrayList pList = new ArrayList<>();

array = classification.split("=");

position = array[0];

value = Double.parseDouble(array[1]);

if (position.equals("X")) {

if (p.getX() > value) {

isLarger = true;

}

// 将训练数据中所有属于这边的点加入

for (Point point : totalPoint) {

if (isLarger && point.getX() > value) {

pList.add(point);

} else if (!isLarger && point.getX() < value) {

pList.add(point);

}

}

} else if (position.equals("Y")) {

if (p.getY() > value) {

isLarger = true;

}

// 将训练数据中所有属于这边的点加入

for (Point point : totalPoint) {

if (isLarger && point.getY() > value) {

pList.add(point);

} else if (!isLarger && point.getY() < value) {

pList.add(point);

}

}

}

for (Point p2 : pList) {

if (p2.getClassType() == CLASS_POSITIVE) {

posProbably++;

} else {

negProbably++;

}

}

//分类按正负类数量进行划分

if (posProbably > negProbably) {

return CLASS_POSITIVE;

} else {

return CLASS_NEGTIVE;

}

}

}调用类Client.java:

/**

* AdaBoost提升算法调用类

* @author lyq

*

*/

public class Client {

public static void main(String[] agrs){

String filePath = "C:\\Users\\lyq\\Desktop\\icon\\input.txt";

//误差率阈值

double errorValue = 0.2;

AdaBoostTool tool = new AdaBoostTool(filePath, errorValue);

tool.adaBoostClassify();

}

}

输出结果:

分类器1权重为:0.424

分类器2权重为:0.65

分类器3权重为:0.923

点(1, 5)的组合分类结果为:1,该点的实际分类为1

点(2, 3)的组合分类结果为:1,该点的实际分类为1

点(3, 1)的组合分类结果为:-1,该点的实际分类为-1

点(4, 5)的组合分类结果为:-1,该点的实际分类为-1

点(5, 6)的组合分类结果为:1,该点的实际分类为1

点(6, 4)的组合分类结果为:-1,该点的实际分类为-1

点(6, 7)的组合分类结果为:1,该点的实际分类为1

点(7, 6)的组合分类结果为:1,该点的实际分类为1

点(8, 7)的组合分类结果为:-1,该点的实际分类为-1

点(8, 2)的组合分类结果为:-1,该点的实际分类为-1

我们可以看到,如果3个分类单独分类,都没有百分百分对,而尽管组合结果之后,全部分类正确。

我对AdaBoost算法的理解

到了算法的末尾,有必要解释一下每次分类自后需要把错的点的权重增大,正确的减少的理由了,加入上次分类之后,(1,5)已经分错了,如果这次又分错,由于上次的权重已经提升,所以误差率更大,则代入公式ln(1-误差率/误差率)所得的权重越小,也就是说,如果同个数据,你分类的次数越多,你的权重越小,所以这就造成整体好的分类器的权重会越大,内部就会同时有各种权重的分类器,形成了一种互补的结果,如果好的分类器结果分错

,可以由若干弱一点的分类器进行弥补。

AdaBoost算法的应用

可以运用在诸如特征识别,二分类的一些应用上,与单个模型相比,组合的形式能显著的提高准确率。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值