Home > dal > dal.m

dal

PURPOSE ^

dal - dual augmented Lagrangian method for sparse learaning/reconstruction

SYNOPSIS ^

function [xx, uu, status]=dal(prob, ww0, uu0, A, B, lambda, varargin)

DESCRIPTION ^

 dal - dual augmented Lagrangian method for sparse learaning/reconstruction

 Overview:
  Solves the following optimization problem
   xx = argmin f(x) + lambda*c(x)
  where f is a user specified (convex, smooth) loss function and c
  is a measure of sparsity (currently L1 or grouped L1)

 Syntax:
  [ww, uu, status] = dal(prob, ww0, uu0, A, B, lambda, <opt>)

 Inputs:
  prob   : structure that contains the following fields:
   .obj      : DAL objective function
   .floss    : structure with three fields (p: primal loss, d: dual loss, args: arguments to the loss functions)
   .fspec    : function handle to the regularizer spectrum function 
               (absolute values for L1, vector of norms for grouped L1, etc.)
   .dnorm    : function handle to the conjugate of the regularizer function
               (max(abs(x)) for L1, max(norms) for grouped L1, etc.)
   .softth   : soft threshold function
   .mm       : number of samples (scalar)
   .nn       : number of unknown variables (scalar)
   .ll       : lower constraint for the Lagrangian multipliers (mm x 1)
   .uu       : upper constraint for the Lagrangian multipliers (mm x 1)
   .Ac       : inequality constraint Ac*aa<=bc for the LMs (pp x mm)
   .bc       :                                             (pp x  1)
   .info     : auxiliary variables for the objective function
   .stopcond : function handle for the stopping condition
   .hessMult : function handle to the Hessian product function (H*x)
   .softth : function handle to the "soft threshold" function
  ww0    : initial solution ((nn x 1) or (ns x nc) with ns*nc=nn)
  uu0    : initial unregularized component (nu x 1)
  A      : design matrix (mm x nn)
  B      : design matrix for the unregularized component (mm x nu)
  lambda : regularization constant (scalar)
  <opt>  : list of 'fieldname1', value1, 'filedname2', value2, ...
   aa        : initial Lagrangian multiplier [mm,1] (default zero(mm,1))
   tol       : tolerance (default 1e-3)
   maxiter   : maximum number of outer iterations (default 100)
   eta       : initial barrier parameter (default 1)
   eps       : initial internal tolerance parameter (default 1e-4)
   eta_multp : multiplying factor for eta (default 2)
   eps_multp : multiplying factor for eps (default 0.5)
   solver    : internal solver. Can be either:
               'nt'   : Newton method with cholesky factorization (default)
               'ntsv' : Newton method saves memory (slightly slower)
               'cg'   : Newton method with PCG
               'qn'   : Quasi-Newton method
   display   : display level (0: none, 1: only the last, 2: every
               outer iteration, (default) 3: every inner iteration)
   iter      : output the value of ww at each iteration 
               (boolean, default 0)
 Outputs:
  ww     : the final solution
  uu     : the final unregularized component
  status : various status values

 Reference:
 "Dual Augmented Lagrangian Method for Efficient Sparse Reconstruction"
 Ryota Tomioka and Masashi Sugiyama
 http://arxiv.org/abs/0904.0584
 
 Copyright(c) 2009 Ryota Tomioka
 This software is distributed under the MIT license. See license.txt

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 % dal - dual augmented Lagrangian method for sparse learaning/reconstruction
0002 %
0003 % Overview:
0004 %  Solves the following optimization problem
0005 %   xx = argmin f(x) + lambda*c(x)
0006 %  where f is a user specified (convex, smooth) loss function and c
0007 %  is a measure of sparsity (currently L1 or grouped L1)
0008 %
0009 % Syntax:
0010 %  [ww, uu, status] = dal(prob, ww0, uu0, A, B, lambda, <opt>)
0011 %
0012 % Inputs:
0013 %  prob   : structure that contains the following fields:
0014 %   .obj      : DAL objective function
0015 %   .floss    : structure with three fields (p: primal loss, d: dual loss, args: arguments to the loss functions)
0016 %   .fspec    : function handle to the regularizer spectrum function
0017 %               (absolute values for L1, vector of norms for grouped L1, etc.)
0018 %   .dnorm    : function handle to the conjugate of the regularizer function
0019 %               (max(abs(x)) for L1, max(norms) for grouped L1, etc.)
0020 %   .softth   : soft threshold function
0021 %   .mm       : number of samples (scalar)
0022 %   .nn       : number of unknown variables (scalar)
0023 %   .ll       : lower constraint for the Lagrangian multipliers (mm x 1)
0024 %   .uu       : upper constraint for the Lagrangian multipliers (mm x 1)
0025 %   .Ac       : inequality constraint Ac*aa<=bc for the LMs (pp x mm)
0026 %   .bc       :                                             (pp x  1)
0027 %   .info     : auxiliary variables for the objective function
0028 %   .stopcond : function handle for the stopping condition
0029 %   .hessMult : function handle to the Hessian product function (H*x)
0030 %   .softth : function handle to the "soft threshold" function
0031 %  ww0    : initial solution ((nn x 1) or (ns x nc) with ns*nc=nn)
0032 %  uu0    : initial unregularized component (nu x 1)
0033 %  A      : design matrix (mm x nn)
0034 %  B      : design matrix for the unregularized component (mm x nu)
0035 %  lambda : regularization constant (scalar)
0036 %  <opt>  : list of 'fieldname1', value1, 'filedname2', value2, ...
0037 %   aa        : initial Lagrangian multiplier [mm,1] (default zero(mm,1))
0038 %   tol       : tolerance (default 1e-3)
0039 %   maxiter   : maximum number of outer iterations (default 100)
0040 %   eta       : initial barrier parameter (default 1)
0041 %   eps       : initial internal tolerance parameter (default 1e-4)
0042 %   eta_multp : multiplying factor for eta (default 2)
0043 %   eps_multp : multiplying factor for eps (default 0.5)
0044 %   solver    : internal solver. Can be either:
0045 %               'nt'   : Newton method with cholesky factorization (default)
0046 %               'ntsv' : Newton method saves memory (slightly slower)
0047 %               'cg'   : Newton method with PCG
0048 %               'qn'   : Quasi-Newton method
0049 %   display   : display level (0: none, 1: only the last, 2: every
0050 %               outer iteration, (default) 3: every inner iteration)
0051 %   iter      : output the value of ww at each iteration
0052 %               (boolean, default 0)
0053 % Outputs:
0054 %  ww     : the final solution
0055 %  uu     : the final unregularized component
0056 %  status : various status values
0057 %
0058 % Reference:
0059 % "Dual Augmented Lagrangian Method for Efficient Sparse Reconstruction"
0060 % Ryota Tomioka and Masashi Sugiyama
0061 % http://arxiv.org/abs/0904.0584
0062 %
0063 % Copyright(c) 2009 Ryota Tomioka
0064 % This software is distributed under the MIT license. See license.txt
0065  
0066  
0067 
0068 function [xx, uu, status]=dal(prob, ww0, uu0, A, B, lambda, varargin)
0069 
0070 opt=propertylist2struct(varargin{:});
0071 opt=set_defaults(opt, 'aa', [],...
0072                       'tol', 1e-3, ...
0073                       'iter', 0, ...
0074                       'maxiter', 100,...
0075                       'eta', 1,...
0076                       'eps', 1, ...
0077                       'eps_multp', 0.99,...
0078                       'eta_multp', 2, ...
0079                       'solver', 'nt', ...
0080                       'display',2);
0081 
0082 
0083 prob=set_defaults(prob, 'll', -inf*ones(prob.mm,1), ...
0084                         'uu', inf*ones(prob.mm,1), ...
0085                         'Ac', [], ...
0086                         'bc', [], ...
0087                         'info', [], ...
0088                         'finddir', []);
0089 
0090 if opt.display>0
0091   if ~isempty(uu0)
0092     nuu = length(uu0);
0093     vstr=sprintf('%d+%d',prob.nn,nuu);
0094   else
0095     vstr=sprintf('%d',prob.nn);
0096   end
0097   
0098   lstr=func2str(prob.floss.p); lstr=lstr(6:end-1);
0099   fprintf(['DAL ver0.98d\n#samples=%d #variables=%s lambda=%g ' ...
0100              'loss=%s solver=%s\n'],prob.mm, vstr, lambda, lstr, ...
0101             opt.solver);
0102 end
0103 
0104 
0105 if opt.iter
0106   xx  = [[ww0(:); uu0(:)], ones(length(ww0(:))+length(uu0(:)),opt.maxiter-1)*nan];
0107 end
0108 
0109 res    = nan*ones(1,opt.maxiter);
0110 fval   = nan*ones(1,opt.maxiter);
0111 etaout = nan*ones(1,opt.maxiter);
0112 time   = nan*ones(1,opt.maxiter);
0113 pred   = nan*ones(1,opt.maxiter);
0114 
0115 
0116 time0=cputime;
0117 ww   = ww0;
0118 uu   = uu0;
0119 gtmp = zeros(size(ww));
0120 if isempty(opt.aa)
0121   aa = zeros(prob.mm,1);
0122 else
0123   aa = opt.aa;
0124 end
0125 
0126 eta  = opt.eta;
0127 epsl = opt.eps;
0128 info = prob.info;
0129 info.solver=opt.solver;
0130 for ii=1:opt.maxiter-1
0131   etaout(ii)=eta;
0132   time(ii)=cputime-time0;
0133 
0134   %% Evaluate objective and Check stopping condition
0135   [ret,fval(ii),spec,res(ii)]=feval(prob.stopcond, ww, uu, aa, opt.tol, prob, A, B, lambda);
0136 
0137   %% Display
0138   if opt.display>1 || opt.display>0 && ret~=0
0139     if ii>1
0140       fval1 = fval(ii-1)-pred(ii-1);
0141     else
0142       fval1 = nan;
0143     end
0144     nnz = full(sum(spec>0));
0145     fprintf('[[%d]] fval=%g (pred=%g) #(xx~=0)=%d res=%g\n', ii, fval(ii), ...
0146             fval1,...
0147             nnz, res(ii));
0148   end
0149 
0150   if ret~=0
0151     break;
0152   end
0153 
0154   %% Save the original dual variable for daltv2d
0155   info.aa0 = aa;
0156   
0157   %% Solve minimization with respect to aa
0158   % fun  = @(aa,info)prob.obj(aa, prob, ww, uu, A, AT, B, lambda, eta, info);
0159   args = {prob,ww,uu,A,B,lambda,eta};
0160   switch(opt.solver)
0161    case {'nt','ntsv'}
0162     [aa,dfval,dgg,stat] = newton(prob.obj, aa, prob.ll, prob.uu, prob.Ac, ...
0163                                  prob.bc, epsl, prob.finddir, info, opt.display>2, args{:});
0164    case 'cg'
0165     funh = {prob.hessMult,A,eta};
0166     fh = {prob.obj, funh};
0167     [aa,dfval,dgg,stat] = newton(fh, aa, prob.ll, prob.uu, prob.Ac, ...
0168                                  prob.bc, epsl, prob.finddir, info, opt.display>2, args{:});
0169    case 'qn'
0170     optlbfgs=struct('epsginfo',epsl,'display',opt.display-1);
0171     [aa,stat]=lbfgs(prob.obj,aa,prob.ll,prob.uu,prob.Ac,prob.bc,info,optlbfgs,args{:});
0172    case 'fminunc'
0173     optfm=optimset('LargeScale','on','GradObj','on','Hessian', ...
0174                    'on','TolFun',1e-16,'TolX',0,'MaxIter',1000,'display','iter');
0175     [aa,fvalin,exitflag]=fminunc(@(xx)objdall1fminunc(xx,prob,ww, ...
0176                                                       uu,A,B,lambda,eta,epsl), aa, optfm);
0177     stat.info=info;
0178     stat.ret=exitflag~=1;
0179    otherwise
0180     error('Unknown method [%s]',opt.solver);
0181   end
0182   info=stat.info;
0183 
0184 
0185   if isfield(prob,'Aeq')
0186     gtmp(:) = [A', prob.Aeq']*aa;
0187   else    
0188     gtmp(:) = A'*aa;
0189   end
0190   
0191   ww1     = fevals(prob.softth, ww+eta*gtmp,eta*lambda,info);
0192 
0193   %% Predicted decrease in the objective
0194   %pred(ii) = norm(ww1(:)-ww(:))^2/(2*eta);
0195   %if ~isempty(uu)
0196   %pred(ii)  = pred(ii) + 0.5*eta*norm(B'*aa)^2;
0197   %end
0198   
0199   %% Update primal variable
0200   ww      = ww1;
0201   if ~isempty(uu)
0202     if isfield(prob,'Aeq')
0203       uu  = uu+eta*B'*aa(1:end-prob.meq);
0204     else
0205       uu  = uu+eta*B'*aa;
0206     end
0207   end
0208 
0209  
0210   %% Update barrier parameter eta and tolerance parameter epsl
0211   eta     = eta*opt.eta_multp^(stat.ret==0);
0212   epsl    = epsl*opt.eps_multp^(stat.ret==0);
0213   if opt.iter
0214     xx(:,ii+1)=[ww(:);uu(:)];
0215   end
0216 end
0217 
0218 res(ii+1:end)=[];
0219 fval(ii+1:end)=[];
0220 time(ii+1:end)=[];
0221 etaout(ii+1:end)=[];
0222 pred(ii+1:end)=[];
0223 if opt.iter
0224   xx(:,ii+1:end)=[];
0225 else
0226   xx = ww;
0227 end
0228 
0229 
0230 status=struct('aa', aa,...
0231               'niter',length(res),...
0232               'eta', etaout,...
0233               'pred', pred,...
0234               'time', time,...
0235               'res', res,...
0236               'opt', opt, ...
0237               'info', info,...
0238               'fval', fval);

Generated on Sat 22-Aug-2009 22:15:36 by m2html © 2003