-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathboost_train.m
59 lines (52 loc) · 1.71 KB
/
boost_train.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
function [model, accuracy] = boost_train(conf, Y, D, weak_learners, hyps)
% Train boost classifier
% [model, accuracy] = boost_train(conf, X, Y, D, weak_learners, hyps)
%
% INPUT
% conf: configuration
% Y: training labels, with shape (N x 1)
% D: domain ids, 0: source, 1: target
% weak_learners: pre-trained weak learners
% hyps: pre-computed hypothesis of weak learners
% OUTPUT
% model: the trained boosted model
% accuracy: training accuracy
model = struct;
model.algorithm = conf.alg_name;
T = conf.num_boostIter; % number of iterations
model.weak_learners = weak_learners;
% Initial weight of samples
wi = init_weights(D, conf.mult_target);
% Set the parameter for boosting function
boost_Para = struct(...
'NAME_ALGORITHM',conf.algorthmId,... % 0: adaboost, 1: tradaboost, 2: d-tradaboost
'MAX_ITERATION',T);
fprintf('\nStart boosting training ...\n');
th = tic();
if conf.use_mex
% for mex interface:
[alphas best_t] = mex_boosting(boost_Para, Y, D, hyps, wi);
th = toc(th);
fprintf('\nBoosting mex took %.4f seconds\n', th);
else
% for matlab interface:
[alphas best_t] = mat_boosting(boost_Para, Y, D, hyps, wi);
th = toc(th);
fprintf('\nBoosting matlab took %.4f seconds\n', th);
end
best_t = best_t(best_t > 0);
alphas = alphas(alphas > 0);
assert(length(alphas) == length(best_t));
assert(~isempty(alphas));
if conf.algorthmId > 0
ind = min(round(T/2), length(alphas));
alphas = alphas(ind:end);
best_t = best_t(ind:end);
end
model.alphas = alphas;
model.best_t = best_t;
% Evaluate the strong classifier
y_ensemble = hyps(:, best_t) * alphas';
y_ensemble = (y_ensemble > 0)*2 - 1;
accuracy = mean(y_ensemble == Y);
end