Adaboost的Matlab实例

Posted by Kriz on 2017-04-18

Adaboost是比较简易的一种机器学习算法,简单来说就是使用多个弱分类器共同作用,生成一个强分类器。训练完成后,多个权值不同弱分类器投票决定最终的分类。很多地方提到“三个臭皮匠,赛过诸葛亮”,这么记也没问题。

据说不容易过拟合,原因未知(感觉网上讲到的都不太靠谱,也没有成型的探讨该问题的论文。有兴趣可以自己了解)。

撸了个Matlab代码作为范例。为了符合题目要求,所以数据比较少,加上只使用了三个弱分类器,所以效果并没有很好,但对于理解和记录已经没问题了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
% Adaboost
% by Kriz
% Input and deal with coordinates
data = [80 144 1; 93 232 1; 136 275 -1; 147 131 -1; 159 69 1; 214 31 1; 214 152 -1; 257 83 1; 307 62 -1; 307 231 -1];
[num, ~] = size(data);
weight = 1 / num;
data = [data weight*ones(num,1) zeros(num,1)];
% Update alpha vector by weak classifiers
alpha_vec = [];
[data, alpha] = weakClassifiers(data, 1, 0, -100);
alpha_vec = [alpha_vec alpha];
[data, alpha] = weakClassifiers(data, 1, 0, -300);
alpha_vec = [alpha_vec alpha];
[data, alpha] = weakClassifiers(data, 0, 1, -130);
alpha_vec = [alpha_vec alpha];
% Adapt classifiers
final_data = strongClassifier(data, alpha_vec);
if sign(final_data(:,5)) == final_data(:,3)
fprintf('Classify succeed!')
else
fprintf('Eww? Something wrong!')
end

其中

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
function [data, alpha] = weakClassifiers(data, a, b, c)
sigma = 0;
[num, ~] = size(data);
for i=1:num
if data(i,3) * (a * data(i,1) + b * data(i,2) + c) > 0
sigma = sigma + data(i,4);
data(i,5) = -1;
else
data(i,5) = 1;
end
end
if sigma < 0.5
alpha = log((1-sigma)/sigma)/2;
% Updating weight process starts
Z = 0;
for i=1:num
Z = Z + data(i,4)*exp(-alpha*data(i,5));
end
% Z is ready
for i=1:num
data(i,4) = data(i,4)*exp(-alpha*data(i,5))/Z;
end
end
end
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
function data = strongClassifier(data, alpha_vec)
% Update classifier
[num, ~] = size(data);
% Make final classification
h1x=100; h2x=300; h3y=130;
for j=1:num
if data(j,1) < h1x
data(j,5) = alpha_vec(1);
else
data(j,5) = -alpha_vec(1);
end
if data(j,1) < h2x
data(j,5) = data(j,5) + alpha_vec(2);
else
data(j,5) = data(j,5) - alpha_vec(2);
end
if data(j,2) < h3y
data(j,5) = data(j,5) + alpha_vec(3);
else
data(j,5) = data(j,5) - alpha_vec(3);
end
end
end

……是的,matlab没有三目运算符这事真的毒性有点大。