Home > dal > objdalgl.m

objdalgl

PURPOSE ^

objdalgl - objective function of DAL with grouped L1 regularization

SYNOPSIS ^

function varargout=objdalgl(aa, info, prob, ww, uu, A, B, lambda, eta)

DESCRIPTION ^

 objdalgl - objective function of DAL with grouped L1 regularization

 Copyright(c) 2009 Ryota Tomioka
 This software is distributed under the MIT license. See license.txt

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 % objdalgl - objective function of DAL with grouped L1 regularization
0002 %
0003 % Copyright(c) 2009 Ryota Tomioka
0004 % This software is distributed under the MIT license. See license.txt
0005 
0006 function varargout=objdalgl(aa, info, prob, ww, uu, A, B, lambda, eta)
0007 
0008 
0009 vv   = ww/eta+A'*aa;
0010 [vsth,ss] = fevals(prob.softth, vv, lambda, info);
0011 nm=ss+lambda; %% only correct for active components
0012 I = find(ss>0);
0013 
0014 if nargout<=3
0015   [floss, gloss, hmin]=feval(prob.floss.d,aa, prob.floss.args{:});
0016 else
0017   [floss, gloss, hloss,hmin]=feval(prob.floss.d, aa, prob.floss.args{:});
0018 end
0019 
0020 m = size(gloss,1);
0021 
0022 fval = floss+0.5*eta*sum(ss.^2);
0023 
0024 if ~isempty(uu)
0025   u1   = uu/eta+B'*aa;
0026   fval = fval + 0.5*eta*sum(u1.^2);
0027 end
0028 
0029 varargout{1}=fval;
0030 
0031 if nargout<=2
0032   varargout{2}=info;
0033 else
0034   gg  = gloss+eta*A*vsth;
0035   soc = sum((vsth-ww/eta).^2);
0036   if ~isempty(uu)
0037     gg  = gg+eta*B*u1;
0038     soc = soc+sum((B'*aa).^2);
0039   end
0040 
0041   if soc>0
0042     info.ginfo = norm(gg)/(sqrt(eta*hmin*soc));
0043   else
0044     info.ginfo = inf;
0045   end
0046   varargout{2} = gg;
0047 
0048   if nargout==3
0049     varargout{3} = info;
0050   else
0051     switch(info.solver)
0052       case 'cg'
0053        prec = hloss;
0054        varargout{3} = struct('blks',info.blks,'hloss',hloss,...
0055                              'I',I,'vv',vv,'nm',nm,'prec',prec,'lambda',lambda,'B',B);
0056      case 'nt'
0057       H=hloss;
0058       if length(I)>0
0059         AF = zeros(m, sum(info.blks(I)));
0060         ix0=0;
0061         for kk=1:length(I)
0062           jj=I(kk);
0063           J=sum(info.blks(1:jj-1))+(1:info.blks(jj));
0064           vn=vv(J)/nm(jj);
0065           ff=sqrt(1-lambda/nm(jj));
0066 
0067           Iout=ix0+(1:info.blks(jj));
0068           ix0=Iout(end);
0069           AF(:,Iout)=ff*A(:,J)+(1-ff)*A(:,J)*vn*vn';
0070         end
0071         H = H + eta*AF*AF';
0072       end
0073       if ~isempty(uu)
0074         H = H+eta*B*B';
0075       end
0076 
0077       varargout{3} = H;
0078      case 'ntsv'
0079       H=hloss;
0080       for kk=1:length(I)
0081         jj=I(kk);
0082         J=sum(info.blks(1:jj-1))+(1:info.blks(jj));
0083         vn=vv(:,J)/nm(jj);
0084         ff=sqrt(1-lambda/nm(jj));
0085 
0086         AF=ff*A(:,J)+(1-ff)*A(:,J)*vn*vn';
0087         H = H+eta*AF*AF';
0088       end
0089       if ~isempty(uu)
0090         H = H+eta*B*B';
0091       end
0092       varargout{3} = H;
0093     end % end switch(info.solver)
0094     varargout{4} = info;
0095   end
0096 end
0097

Generated on Sat 22-Aug-2009 22:15:36 by m2html © 2003