0001 function [W, bias, z, status]=lrds_dual(X, Y, lambda, varargin)
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030
0031
0032
0033
0034
0035
0036
0037
0038
0039
0040
0041
0042
0043
0044 try
0045 opt = propertylist2struct(varargin{:});
0046 catch
0047 if nargin>3
0048 opt = varargin{1};
0049 else
0050 opt = [];
0051 end
0052 end
0053
0054 opt = setDefaults(opt, struct('tol', 1e-6, ...
0055 'tolX', 1e-6, ...
0056 'tmul', 20, ...
0057 'maxiter', 1000, ...
0058 'display', 'iter'));
0059
0060 if ~isnumeric(opt.display)
0061 opt.display = find(strcmp(opt.display,...
0062 {'none','final','iter','every','all'}))-1;
0063 end
0064
0065
0066 if ndims(X)==3 & size(X,1)==size(X,2)
0067 [C,Cd, n]=size(X);
0068 Xf=reshape(X, [C*C, n]);
0069 elseif ndims(X)==3 & size(X,1)
0070 [Cd,C, n]=size(X);
0071 Xf=shiftdim(X);
0072 X =reshape(X,[C,C,n]);
0073 else
0074 [CC,n]=size(X);
0075 C=sqrt(CC);
0076 Xf = X;
0077 X = reshape(X,[C,C,n]);
0078 end
0079
0080 if n~=length(Y)
0081 error('Sample size mismatch!');
0082 end
0083
0084 Y=shiftdim(Y);
0085
0086 if isfield(opt,'alpha')
0087 alpha = opt.alpha;
0088 else
0089 alpha = zeros(n,1);
0090 alpha(Y>0) = min(lambda*0.01,1)/sum(Y>0);
0091 alpha(Y<0) = min(lambda*0.01,1)/sum(Y<0);
0092 end
0093
0094 cc = 0;
0095 display_line_search = opt.display==4;
0096
0097 if isfield(opt,'t')
0098 t = opt.t;
0099 else
0100 t = 2*(C+n)/(n*log(2));
0101 end
0102
0103 cc0 = 0;
0104 time0 = cputime;
0105 time00 = time0;
0106
0107 while cc<opt.maxiter
0108 while cc<opt.maxiter
0109 cc = cc + 1;
0110
0111 A = reshape(Xf*(alpha.*Y), [C,C]); A=(A+A')/2;
0112
0113 R1 = chol(lambda*eye(C)-A);
0114 R2 = chol(lambda*eye(C)+A);
0115
0116
0117 [loss, gl, hl]=lossDual(alpha);
0118
0119
0120
0121
0122 SX1 = zeros(C*C,n);
0123 SX2 = zeros(C*C,n);
0124 gS = zeros(n,1);
0125 for i=1:n
0126 D1 = R1'\(Y(i)*X(:,:,i))/R1;
0127 D2 = R2'\(Y(i)*X(:,:,i))/R2;
0128 SX1(:,i)=reshape(D1, [C*C,1]);
0129 SX2(:,i)=reshape(D2, [C*C,1]);
0130 gS(i) = sum(diag(D1-D2));
0131 end
0132
0133 H1 = SX1'*SX1;
0134 H2 = SX2'*SX2;
0135
0136
0137 g = gl+((2*alpha-1)./(alpha.*(1-alpha)) + gS)/t;
0138
0139
0140 Hd = hl + (alpha.^(-2)+(1-alpha).^(-2))/t;
0141 Hr = (H1+H2)/t;
0142 H = diag(Hd)+Hr;
0143
0144 HIg = H\g;
0145 HIy = H\Y;
0146
0147
0148 nu = - (Y'*HIg)/(Y'*HIy);
0149
0150
0151 delta = -H\(g+Y*nu);
0152
0153 alpha0 = alpha;
0154
0155
0156 Sd0 = reshape(Xf*(delta.*Y), [C,C]); Sd0=(Sd0+Sd0')/2;
0157 Sd1 = -eig(R1'\Sd0/R1);
0158 Sd2 = eig(R2'\Sd0/R2);
0159
0160 [s, dloss] = lineSearch(alpha, delta, t, Sd1, Sd2, Y*nu, opt.tolX/max(abs(delta)./abs(alpha)), display_line_search);
0161
0162
0163
0164 alpha = alpha0 + s*delta;
0165
0166 A = reshape(Xf*(alpha.*Y), [C,C]); A=(A+A')/2;
0167
0168
0169
0170
0171 RR = chol(lambda^2*eye(C)-A*A');
0172
0173 W = 2*(RR\((RR')\A))/t;
0174 W = (W+W')/2;
0175 trQ = 2*lambda*trace(((RR')\eye(C))/RR)/t;
0176
0177
0178
0179
0180
0181
0182
0183
0184 beta1= 1./(t*alpha);
0185 beta2= 1./(t*(1-alpha));
0186
0187
0188 bias = nu;
0189
0190 z=Y.*(reshape(W,[1,C*C])*Xf+bias)'-beta1+beta2;
0191
0192
0193 loss_prim = lossPrime(z)+sum(beta2)+lambda*trQ;
0194
0195
0196 loss=lossDual(alpha);
0197
0198
0199 gap(cc) = loss_prim - (-loss);
0200
0201
0202 obj(cc) = loss +1/t*(-2*sum(log(diag(RR)))...
0203 -sum(log(alpha))-sum(log(1-alpha)));
0204
0205
0206 gg(cc) = max(abs(g+Y*nu));
0207
0208
0209
0210 if 0
0211 [Va, Da]=eig(A);
0212 da=diag(Da);
0213 lmW = 2*da./(t*(lambda^2-da.^2));
0214
0215 W0 = Va*diag(lmW)*Va';
0216 trQ0 = sum(2*lambda./(t*(lambda^2-da.^2)));
0217 z0=Y.*(reshape(W0,[1,C*C])*Xf+bias)'-beta1+beta2;
0218 loss_prim0 = lossPrime(z0)+sum(beta2)+lambda*trQ0;
0219
0220 obj0 = loss +1/t*(-sum(log(lambda-da))-sum(log(lambda+da))...
0221 -sum(log(alpha))-sum(log(1-alpha)));
0222
0223 fprintf('!!! |W-W0|=%g dz=%g, dtrQ=%g dloss_prim=%g dobj=%g\n',...
0224 max(abs(rangeof(W-W0))),...
0225 max(abs(rangeof(z-z0))),...
0226 trQ-trQ0,...
0227 loss_prim-loss_prim0,...
0228 obj(cc)-obj0);
0229 end
0230
0231
0232 if opt.display>=3
0233 fprintf('[%d] t=%g gap=%g(>=%g) gg=%g y*alpha=%g nu=%g Hmin=%g s=%g obj=%g',...
0234 cc, t, gap(cc), 2*(C+n)/t, gg(cc), Y'*alpha, nu, min(eig(H)), s, obj(cc));
0235
0236 if cc>1
0237 fprintf(' dloss=(%g/%g)\n', obj(cc)-obj(cc-1), dloss);
0238 else
0239 fprintf('\n');
0240 end
0241 end
0242
0243
0244 if gg(cc)<opt.tol | ((gg(cc)<opt.tol*min(100,gap(cc)/opt.tol) | gap(cc)<opt.tol) & s* ...
0245 max(abs(delta)./abs(alpha))<opt.tolX);
0246
0247 tlap = cputime-time0;
0248
0249 if opt.display>=2
0250 fprintf('t=%g: gap=%g gg=%g nsteps=%d time=%g\n',...
0251 t,gap(cc),gg(cc),cc-cc0,tlap);
0252 end
0253 cc0 = cc;
0254 time0 = cputime;
0255
0256 break;
0257 end
0258 end
0259 if gap(cc)<opt.tol
0260 break;
0261 else
0262 t = t*opt.tmul;
0263 end
0264 end
0265
0266
0267 status = struct('opt',opt,...
0268 'niter',cc,...
0269 't',t,...
0270 'gap',gap,...
0271 'obj',obj,...
0272 'beta1',beta1,...
0273 'beta2',beta2,...
0274 'alpha', alpha,...
0275 'time', cputime-time00);
0276
0277
0278
0279 if opt.display>0
0280 fprintf('[%d] gap=%g total time=%g\n', cc, gap(end),cputime-time00);
0281 end
0282
0283
0284 function loss = lossPrime(z)
0285 z1 = z(z<0);
0286 z2 = z(z>=0);
0287 loss = sum(log(exp(z1)+1))-sum(z1)+sum(log(1+exp(-z2)));
0288
0289 loss0=sum(log(1+exp(-z)));
0290
0291 if ~isinf(loss0) & abs(loss-loss0)>1e-9
0292 error;
0293 end
0294
0295
0296 function [loss, g, h] = lossDual(alpha)
0297
0298 ix = alpha~=0 & alpha~=1;
0299
0300 loss = zeros(size(alpha));
0301 g = zeros(size(alpha));
0302 h = zeros(size(alpha));
0303
0304 loss(~ix) = 0;
0305 loss(ix) = alpha(ix).*log(alpha(ix)) + (1-alpha(ix)).*log(1-alpha(ix));
0306
0307 g(~ix)= nan;
0308 g(ix) = log(alpha(ix)./(1-alpha(ix)));
0309
0310 h(~ix)= nan;
0311 h(ix) = 1./alpha(ix) + 1./(1-alpha(ix));
0312
0313 loss = sum(loss);
0314
0315
0316 function [s,dloss] = lineSearch(alpha,delta,t,Sd1,Sd2,gnu,tolX,display);
0317 snew = 1;
0318 s1 = 0;
0319 s2 = nan;
0320 s_best = 0;
0321 alpha0 = alpha;
0322 loss0 = lossDual(alpha0);
0323
0324 dloss_best = 0;
0325 cc = 1;
0326
0327
0328 while 1
0329 s = snew;
0330 cc = cc +1;
0331
0332 if display
0333 if s_best > s
0334 fprintf(' %02d: s1=%.2f s=%.2f * s2=%.2f',cc,s1,s,s2);
0335 elseif s_best < s
0336 fprintf(' %02d: s1=%.2f * s=%.2f s2=%.2f',cc,s1,s,s2);
0337 else
0338 fprintf(' %02d: _s1=%.2f s=%.2f s2=%.2f',cc,s1,s,s2);
0339 end
0340 end
0341
0342 lm1 = 1+s*Sd1;
0343 lm2 = 1+s*Sd2;
0344
0345 alpha = alpha0 + s*delta;
0346
0347 isfeas = 0;
0348 if any(lm1<=0) | any(lm2<=0) | any(alpha<=0) | any(alpha>=1)
0349 ss = '!';
0350 s2 = min(s, s2);
0351 snew = max((s1+s2)/2,s/2);
0352 else
0353
0354 isfeas = 1;
0355 dloss = lossDual(alpha)-loss0...
0356 +1/t*(-sum(log(lm1))-sum(log(lm2))...
0357 -sum(log(1+s*delta./alpha0))-sum(log(1-s*delta./(1-alpha0))))...
0358 +s*gnu'*delta;
0359
0360
0361 if dloss < dloss_best
0362 ss = '-';
0363 dloss_best = dloss;
0364 if s_best<s
0365 s1 = s_best;
0366 elseif s<s_best
0367 s2 = s_best;
0368 end
0369
0370 s_best = s;
0371
0372 else
0373 ss = '+';
0374 if s_best<s
0375 s2 = s;
0376 elseif s<s_best
0377 s1 = s;
0378 end
0379 end
0380
0381 if isnan(s2)
0382 snew = s*2;
0383 else
0384 r = 0.5 + rand(1)*0.1-0.05;
0385 snew = s1*r+s2*(1-r);
0386 end
0387
0388
0389
0390 end
0391
0392 if display
0393 fprintf(' dloss_best=%g (%s)\n',dloss_best, ss);
0394 end
0395
0396 if (isfeas & s2-s1<0.01)
0397 break;
0398 end
0399 if (isfeas & isnan(s1) & s<tolX)
0400 break;
0401 end
0402 end
0403
0404
0405 function obj = objectiveLocal(alpha0, delta, x, Xf, Y, lambda, t)
0406
0407
0408
0409
0410
0411
0412
0413
0414
0415
0416 sz = size(x); sz(1)=1;
0417
0418 obj = squeeze(zeros(sz));
0419
0420 C = sqrt(size(Xf,1));
0421
0422 for i=1:prod(size(obj))
0423 alpha = alpha0 + delta*x(:,i);
0424 A = reshape(Xf*(alpha.*Y), [C,C]); A=(A+A')/2;
0425 S1 = lambda*eye(C)-A;
0426 S2 = lambda*eye(C)+A;
0427
0428
0429 obj(i) = lossDual(alpha) +1/t*(-sum(log(eig(S1)))-sum(log(eig(S2)))-sum(log(alpha))-sum(log(1-alpha)));
0430
0431 if any(alpha<0) | any(alpha>1) | any(eig(S1)<0) | any(eig(S2)<0)
0432 obj(i)=nan;
0433 end
0434 end
0435