0001 function [alpha,d,b,activeset,supind,params,story] = SpicyMKL(K,yapp,C,options)
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030
0031
0032
0033
0034
0035
0036
0037
0038
0039
0040
0041
0042
0043
0044
0045
0046
0047
0048
0049
0050
0051
0052
0053
0054
0055 N = length(yapp);
0056 M = size(K,3);
0057
0058 if ~exist('options')
0059 options = [];
0060 end;
0061 if ~isfield(options,'loss')
0062 if length(unique(yapp)) == 2
0063 options.loss = 'logit';
0064 yval = unique(yapp);
0065 yapp(yval==yval(1)) = -1;
0066 yapp(yval==yval(2)) = 1;
0067 else
0068 options.loss = 'square';
0069 end;
0070 fprintf('Please specify options.loss. We set options.loss=%s.\n',options.loss);
0071 else
0072 bb = 0;
0073 loss_names = {'svm','logit','square'};
0074 for i = 1:length(loss_names)
0075 bb = bb | strcmp(options.loss,loss_names{i});
0076 end;
0077 if ~bb
0078 if length(unique(yapp)) == 2
0079 options.loss = 'logit';
0080 yval = unique(yapp);
0081 yapp(yval==yval(1)) = -1;
0082 yapp(yval==yval(2)) = 1;
0083 else
0084 options.loss = 'square';
0085 end;
0086 fprintf('Please specify options.loss correctly. We set options.loss=%s.\n',options.loss);
0087 end;
0088 end;
0089 if ~isfield(options,'tolOuter')
0090 tol = 0.01;
0091 else
0092 tol = options.tolOuter;
0093 end;
0094 if ~isfield(options,'tolInner')
0095 tolInner = tol/10000;
0096 else
0097 tolInner = options.tolInner;
0098 end;
0099 if ~isfield(options,'innerOptMeth')
0100 opt_method = 'Newton';
0101 else
0102 opt_method = options.innerOptMeth;
0103 end;
0104 if ~isfield(options,'outerMaxIter')
0105 OuterMaxIter = 500;
0106 else
0107 OuterMaxIter = options.outerMaxIter;
0108 end;
0109 if ~isfield(options,'innerMaxIter')
0110 InnerMaxIter = 500;
0111 else
0112 InnerMaxIter = options.InnerMaxIter;
0113 end;
0114 if ~isfield(options,'calpha')
0115 calpha = 10;
0116 else
0117 calpha = options.calpha;
0118 end;
0119 if ~isfield(options,'stopIneqViolation')
0120 options.stopIneqViolation = 0;
0121 end;
0122 if ~isfield(options,'stopdualitygap')
0123 options.stopdualitygap = 1;
0124 end;
0125 if ~isfield(options,'display')
0126 options.display=2;
0127 end
0128
0129
0130 story.primalobj = [];
0131 story.mod_dualobj = [];
0132 story.len_active = [];
0133 story.dualitygap = [];
0134
0135 rho = -yapp/2;
0136
0137 mu = zeros(N,M);
0138 cb = 0;
0139
0140 cgamma = ones(M,1)*10;
0141 cgammab = 1;
0142
0143 if strcmp(options.loss,'svm')
0144 numt = 2;
0145 for i=1:numt
0146 ct{i} = ones(N,1)*10;
0147 clambda{i} = zeros(N,1);
0148 end;
0149 end;
0150
0151 ck = Inf;
0152
0153 cbeta = 1/2;
0154
0155 u = rho*ones(1,M) + mu;
0156 [v wj] = normKj(K,u);
0157 oneN = ones(N,1);
0158 oneM = ones(1,M);
0159
0160 if strcmp(opt_method,'BFGS')
0161 Hk = eye(N);
0162 end;
0163
0164 activeset = find(wj > C)';
0165 supind = (1:N)';
0166
0167
0168
0169 for l=1:OuterMaxIter
0170
0171 for step = 1:InnerMaxIter
0172 sumrho = sum(rho);
0173 if isempty(activeset)
0174 activeset = [];
0175 end;
0176 yrho = rho.*yapp;
0177
0178
0179 if strcmp(options.loss,'svm')
0180 [fval,wvec1,wvec2] = funcevalsvm(wj,yrho,sumrho,cgamma,cgammab,cb,clambda,ct,C);
0181 grad = gradSVM(wvec1,wvec2,yapp);
0182 grad = gradAug(v,cgamma,activeset,wj,C,grad,sumrho,cgammab,cb);
0183 elseif strcmp(options.loss,'logit')
0184 fval = funcevallogit(wj,yrho,cgamma,C,sumrho,cgammab,cb);
0185 grad = gradLogit(yrho,yapp);
0186 grad = gradAug(v,cgamma,activeset,wj,C,grad,sumrho,cgammab,cb);
0187 elseif strcmp(options.loss,'square')
0188 fval = funcevalsquare(wj,yapp,rho,cgamma,C,sumrho,cgammab,cb);
0189 grad = gradSquare(rho,yapp);
0190 grad = gradAug(v,cgamma,activeset,wj,C,grad,sumrho,cgammab,cb);
0191 end;
0192
0193
0194 if strcmp(opt_method,'Newton')
0195 if strcmp(options.loss,'svm')
0196 Hessian = HessSVM(ct,yapp,wvec1,wvec2);
0197 Hessian = HessAugMexbias(K,wj,activeset,v,C,cgamma,cgammab,Hessian);
0198 if N < 2000
0199 dk = - (Hessian + eye(N)*1e-8)\grad;
0200 else
0201 dk = - pcg(Hessian + eye(N)*1e-8,grad,1e-4,200);
0202 end;
0203 elseif strcmp(options.loss,'logit')
0204 Hessian = HessLogit(yrho);
0205 Hessian = HessAugMexbias(K,wj,activeset,v,C,cgamma,cgammab,Hessian);
0206 dk = - (Hessian)\grad;
0207
0208 elseif strcmp(options.loss,'square')
0209 Hessian = HessSquare(length(rho));
0210 Hessian = HessAugMexbias(K,wj,activeset,v,C,cgamma,cgammab,Hessian);
0211 dk = - (Hessian)\grad;
0212 end;
0213 graddotd = grad'*dk;
0214 elseif strcmp(opt_method,'BFGS')
0215 dk = - Hk*grad;
0216 graddotd = grad'*dk;
0217 if graddotd >=0 || step == 1
0218 dk = - grad;
0219 Hk = eye(N);
0220 else
0221 deltakBFGS = rho - old_rho_BFGS;
0222 gammakBFGS = grad - old_grad_BFGS;
0223 dgBFGS = (deltakBFGS'*gammakBFGS);
0224 Vk = eye(N) - deltakBFGS*(gammakBFGS/dgBFGS)';
0225 Hk = Vk*Hk*Vk' + deltakBFGS*(deltakBFGS/dgBFGS)';
0226 end;
0227 old_grad_BFGS = grad;
0228 old_rho_BFGS = rho;
0229 end;
0230
0231 if graddotd > 0
0232 dk = - grad;
0233 end;
0234
0235 step_size = 1;
0236 old_fval = fval;
0237 old_rho = rho;
0238 old_wj = wj;
0239 old_u = u; old_v = v;
0240 old_activeset = activeset;
0241 old_yrho = yrho;
0242
0243 rho = old_rho + step_size*dk;
0244 if strcmp(options.loss,'logit')
0245 yrho = rho.*yapp;
0246 if any(yrho <= -1 | yrho >= 0)
0247 yd = yapp.*dk;
0248 ss = min([(-1-old_yrho(yd<0))./yd(yd<0);-(old_yrho(yd>0)./(yd(yd>0)))])*0.99;
0249 step_size = min(step_size,ss);
0250 rho = old_rho + step_size*dk;
0251 end;
0252 end;
0253 u = rho*oneM + mu;
0254 [v wj] = normKj(K,u);
0255
0256 activeset = find(wj>C)';
0257 if isempty(activeset)
0258 activeset = [];
0259 end;
0260
0261
0262 sumrho = sum(rho);
0263
0264 if strcmp(options.loss,'svm')
0265 fval = funcevalsvm(wj,rho.*yapp,sumrho,cgamma,cgammab,cb,clambda,ct,C);
0266 elseif strcmp(options.loss,'logit')
0267 fval = funcevallogit(wj,rho.*yapp,cgamma,C,sumrho,cgammab,cb);
0268 elseif strcmp(options.loss,'square')
0269 fval = funcevalsquare(wj,yapp,rho,cgamma,C,sum(rho),cgammab,cb);
0270 end;
0271
0272 dir = rho - old_rho;
0273
0274 tmp_activeset = union(activeset,old_activeset);
0275
0276 actdif = setdiff(tmp_activeset,activeset);
0277 old_u(:,actdif) = old_rho*ones(1,length(actdif))+mu(:,actdif);
0278 [old_v(:,actdif) old_wj(actdif)] = normKj(K,old_u,actdif);
0279
0280 dirnorm = 0*wj;
0281 dirdotu = 0*wj;
0282 dirnorm(tmp_activeset) = dir'*(v(:,tmp_activeset)-old_v(:,tmp_activeset));
0283 dirdotu(tmp_activeset) = dir'*old_v(:,tmp_activeset);
0284 steplen = 1;
0285
0286 while fval > old_fval + steplen*0.1*(dir'*grad)
0287 steplen = steplen/2;
0288 rho = old_rho + dir*steplen;
0289 wj = 0*wj;
0290
0291 wj(tmp_activeset) = sqrt(max(0,old_wj(tmp_activeset).^2 + 2*steplen*dirdotu(tmp_activeset) + steplen^2*(dirnorm(tmp_activeset))));
0292
0293 if strcmp(options.loss,'svm')
0294 fval = funcevalsvm(wj,rho.*yapp,sum(rho),cgamma,cgammab,cb,clambda,ct,C);
0295 elseif strcmp(options.loss,'logit')
0296 fval = funcevallogit(wj,rho.*yapp,cgamma,C,sum(rho),cgammab,cb);
0297 elseif strcmp(options.loss,'square')
0298 fval = funcevalsquare(wj,yapp,rho,cgamma,C,sum(rho),cgammab,cb);
0299 end;
0300 end;
0301
0302
0303 if steplen ~= 1
0304 activeset = find(wj > C)';
0305 u = rho*oneM + mu;
0306 v(:,activeset) = normKj(K,u,activeset);
0307 end;
0308
0309 if ~isreal(fval)
0310 norm(v)
0311 end;
0312 if options.display>=3
0313 fprintf(' [%d] fval:%g steplen:%g\n',step,fval,steplen);
0314 end
0315 if norm(old_rho - rho)/norm(old_rho) <= tolInner
0316 break;
0317 end;
0318 end;
0319
0320
0321 sumrho = sum(rho);
0322 activeset = find(wj > C)';
0323 if isempty(activeset)
0324 activeset = [];
0325 end;
0326 yrho = yapp.*rho;
0327
0328
0329
0330 uu = u;
0331 if ~isempty(activeset)
0332 uu(:,activeset) = u(:,activeset).*(oneN*min(1,C./wj(activeset))');
0333 end;
0334 hresid = sqrt(sum((uu-rho*oneM).^2,1));
0335 maxgap = 0;
0336 if strcmp(options.loss,'svm')
0337 work = abs(max( - 1 - yrho, - clambda{1}./ct{1}));
0338 ctI{1} = find( work > cbeta*ck );
0339 maxgap = max(max(work),maxgap);
0340 work = abs(max( yrho, - clambda{2}./ct{2}));
0341 ctI{2} = find( work > cbeta*ck );
0342 maxgap = max(max(work),maxgap);
0343 end;
0344 work = abs(sumrho);
0345 b_cb = ( work > cbeta*ck);
0346 maxgap = max(work,maxgap);
0347
0348 I3 = find(hresid > cbeta*ck);
0349 maxgap = max(max(abs(hresid)),maxgap);
0350
0351 if options.display>=2 && options.stopIneqViolation
0352 fprintf('[[%d]] maxgap:%g ck:%g ',l,maxgap,ck);
0353 end
0354
0355
0356 if maxgap <= ck
0357 ck = maxgap;
0358 if strcmp(options.loss,'svm')
0359 clambda{1} = max(clambda{1} + ct{1}.*(-yrho-1),0);
0360 clambda{2} = max(clambda{2} + ct{2}.*(yrho),0);
0361 end;
0362 cb = cb + cgammab*sumrho;
0363
0364 mu = mu*0;
0365 if ~isempty(activeset)
0366 mu(:,activeset) = u(:,activeset).*(oneN*max((wj(activeset) - C)./wj(activeset),0)');
0367 end;
0368
0369
0370 if ck < tol && options.stopIneqViolation
0371 break;
0372 end;
0373
0374 end;
0375
0376 allind = 1:N;
0377 if strcmp(options.loss,'svm')
0378 for i=1:2
0379 ct{i}(ctI{i}) = calpha * ct{i}(ctI{i});
0380 ctIcomp = setdiff(allind,ctI{i});
0381 ct{i}(ctIcomp) = calpha + ct{i}(ctIcomp);
0382 end;
0383 end;
0384 if b_cb
0385 cgammab = calpha * cgammab;
0386 else
0387 cgammab = calpha + cgammab;
0388 end;
0389 cgamma(I3) = calpha * cgamma(I3);
0390 mu(:,I3) = mu(:,I3)/calpha;
0391 cI3 = setdiff(1:M,I3);
0392 cgamma(cI3) = calpha + cgamma(cI3);
0393 mu(:,cI3) = mu(:,cI3).*(oneN*((cgamma(cI3)-calpha)./cgamma(cI3))');
0394
0395 u = rho*oneM + mu;
0396 [v wj] = normKj(K,u);
0397 activeset = find(wj>C)';
0398
0399
0400 [vmu muwj] = normKj(K,mu,activeset);
0401 alpha = -mu.*(oneN*cgamma');
0402 [vv rhowj] = normKj(K,rho*oneM,activeset);
0403 if ~isempty(activeset)
0404 modrho = rho*min(1,C/max(rhowj));
0405 else
0406 modrho = rho;
0407 end;
0408 modrho = modrho - oneN*sum(modrho)/N;
0409 aa = zeros(N,1);
0410 if strcmp(options.loss,'svm')
0411 supind = find(clambda{2}==0);
0412 end;
0413 for j=activeset
0414 aa = aa - (K(:,supind,j)*alpha(supind,j));
0415 end;
0416 aa = aa + cb;
0417 if strcmp(options.loss,'svm')
0418 aa = sum(max(aa.*yapp,-1) + 1);
0419
0420
0421 modyrho = yapp.*modrho;
0422 mod_dualobj = -losssvm(modyrho);
0423 elseif strcmp(options.loss,'logit')
0424 aa = sum(log(1+exp(aa.*yapp)));
0425
0426
0427 modyrho = yapp.*modrho;
0428 modyrho = min(max(modyrho,-0.9999999),0.00000001);
0429 mod_dualobj = - losslogit(modyrho);
0430 elseif strcmp(options.loss,'square')
0431 aa = sum((aa + yapp).^2)*0.5;
0432
0433
0434 mod_dualobj = -losssquare(yapp,modrho);
0435 end;
0436
0437 if ~isempty(activeset)
0438 primalobj = aa + C*sum(muwj.*cgamma(activeset));
0439 else
0440 primalobj = aa;
0441 end;
0442
0443 story.primalobj = [story.primalobj primalobj];
0444 story.mod_dualobj = [story.mod_dualobj mod_dualobj];
0445 story.len_active = [story.len_active length(activeset)];
0446 dualitygap = (abs(primalobj - mod_dualobj)/abs(primalobj));
0447 story.dualitygap = [story.dualitygap dualitygap];
0448
0449
0450 if options.display>=2
0451 if ~options.stopIneqViolation
0452 fprintf('[[%d]] ',l);
0453 end;
0454 fprintf('primal:%g dual:%g duality_gap:%g\n',primalobj,mod_dualobj,dualitygap);
0455 end
0456
0457 if options.stopdualitygap && dualitygap < tol
0458 break;
0459 end;
0460 end;
0461
0462
0463 if options.display>=2
0464 fprintf('\n');
0465 end;
0466
0467 params.cgamma = cgamma;
0468 params.cgammab = cgammab;
0469 params.wj = wj;
0470 params.C = C;
0471 params.loss = options.loss;
0472 if strcmp(options.loss,'svm')
0473 params.ct = ct;
0474 params.clambda = clambda;
0475 end;
0476
0477 alpha = mu.*(oneN*cgamma');
0478 d = zeros(1,M);
0479 for j=activeset
0480 d(j) = (K(:,:,j)*alpha(:,j))'*(K(:,:,j)*rho)/norm(K(:,:,j)*rho)^2;
0481 end;
0482 alpha = -rho*sum(d);
0483 d = d/sum(d);
0484
0485 if strcmp(options.loss,'svm')
0486 supind = find(clambda{2}==0);
0487
0488 else
0489 supind = (1:N)';
0490
0491 end;
0492 b = - cb;
0493
0494
0495
0496 function grad = gradAug(v,cgamma,activeset,wj,C,grad,sumrho,cgammab,cb)
0497 if ~isempty(activeset)
0498 for j = activeset
0499 grad = grad + v(:,j)*(cgamma(j)*(wj(j)-C)/wj(j));
0500 end;
0501 end;
0502 if nargin >=7
0503 grad = grad + (cgammab*sumrho+cb);
0504 end;
0505
0506
0507 function [HessBarrier] = HessSVM(ct,yapp,wvec1,wvec2)
0508 HessBarrier = zeros(length(yapp));
0509 ind1 = (wvec1>0); ind2 = (wvec2>0);
0510 HessBarrier(ind1,ind1) = diag(ct{1}(ind1));
0511 HessBarrier(ind2,ind2) = HessBarrier(ind2,ind2) + diag(ct{2}(ind2));
0512
0513 function grad = gradSVM(wvec1,wvec2,yapp)
0514 grad = ((-wvec1 + wvec2).*yapp) + yapp;
0515
0516 function [val,wvec1,wvec2] = funcevalsvm(wj,yrho,sumrho,cgamma,cgammab,cb,clambda,ct,C)
0517
0518 wvec1 = max(0,clambda{1} + ct{1}.*(-1 - yrho));
0519 wvec2 = max(0,clambda{2} + ct{2}.*(yrho));
0520 val = losssvm(yrho) + (cgamma'*(max(wj-C,0).^2))/2;
0521 val = val + sum((wvec1.^2 - clambda{1}.^2)./ct{1})/2 + sum((wvec2.^2 - clambda{2}.^2)./ct{2})/2;
0522 val = val + cgammab*sumrho^2/2 + cb*sumrho;
0523
0524 function val = losssvm(yrho)
0525 val = sum(yrho);
0526
0527
0528
0529 function Hessian = HessLogit(yrho)
0530 Hessian = diag(1./(-yrho.*(1+yrho)));
0531
0532 function grad = gradLogit(yrho,yapp)
0533 grad = yapp.*log((1+yrho)./(-yrho));
0534
0535 function [val] = funcevallogit(wj,yrho,cgamma,C, sumrho,cgammab,cb)
0536 if any(yrho>0 | yrho<-1)
0537 val = inf;
0538 return;
0539 end
0540
0541 val = losslogit(yrho) + (cgamma'*(max(wj-C,0).^2))/2;
0542
0543 if nargin >=5
0544 val = val + cgammab*sumrho^2/2 + cb*sumrho;
0545 end;
0546
0547 function val = losslogit(yrho)
0548 val = sum((1+yrho).*log(1+yrho)-yrho.*log(-yrho));
0549
0550
0551 function [val] = funcevalsquare(wj,yapp,rho,cgamma,C,sumrho,cgammab,cb)
0552 val = losssquare(yapp,rho) + (cgamma'*(max(wj-C,0).^2))/2;
0553
0554 if nargin >=6
0555 val = val + cgammab*sumrho^2/2 + cb*sumrho;
0556 end;
0557
0558 function val = losssquare(yapp,rho)
0559 val = 0.5*(rho'*rho + 2*rho'*yapp);
0560
0561 function Hessian = HessSquare(N)
0562 Hessian = eye(N);
0563
0564 function grad = gradSquare(rho,yapp)
0565 grad = rho + yapp;
0566
0567