Home > demo3 > kmeans2.m

kmeans2

PURPOSE ^

output: final centers

SYNOPSIS ^

function [centers,mincenter,mindist,q2,quality] = kmeans2(data,initcenters,method)

DESCRIPTION ^

 output: final centers
 input: data points and initial centers
 if initcenters is a number k, create k centers and start with these
 otherwise, use centers given as input
 method = 0: unoptimized, using n by k matrix of distances O(nk) space
          1: vectorized, using only O(n+k) space
          2: like 1, in addition using distance inequalities (default)
 (C) Charles Elkan

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function [centers,mincenter,mindist,q2,quality] = kmeans2(data,initcenters,method)
0002 % output: final centers
0003 % input: data points and initial centers
0004 % if initcenters is a number k, create k centers and start with these
0005 % otherwise, use centers given as input
0006 % method = 0: unoptimized, using n by k matrix of distances O(nk) space
0007 %          1: vectorized, using only O(n+k) space
0008 %          2: like 1, in addition using distance inequalities (default)
0009 % (C) Charles Elkan
0010 
0011 tic
0012 if nargin < 3 method = 2; end
0013 [n,dim] = size(data);
0014 
0015 if max(size(initcenters)) == 1
0016     k = initcenters;
0017     [centers, mincenter, mindist, lower, computed] = anchors(mean(data),k,data);
0018     total = computed;
0019     skipestep = 1;
0020 else 
0021     centers = initcenters;
0022     mincenter = zeros(n,1);
0023     total = 0;
0024     skipestep = 0;
0025     [k,dim2] = size(centers);    
0026     if dim ~= dim2 error('dim(data) ~= dim(centers)'); end;
0027 end
0028 
0029 nchanged = n;
0030 iteration = 0;
0031 oldmincenter = zeros(n,1);
0032 
0033 while nchanged > 0
0034     % do one E step, then one M step
0035     computed = 0;
0036     
0037     if method == 0 & ~skipestep
0038         for i = 1:n
0039             for j = 1:k
0040                 distmat(i,j) = calcdist(data(i,:),centers(j,:));
0041             end
0042         end
0043         [mindist,mincenter] = min(distmat,[],2);
0044         computed = k*n;
0045 
0046     elseif (method == 1 | (method == 2 & iteration == 0)) & ~skipestep
0047         mindist = Inf*ones(n,1);
0048         lower = zeros(n,k);
0049         for j = 1:k
0050            jdist = calcdist(data,centers(j,:));
0051            lower(:,j) = jdist;
0052            track = find(jdist < mindist);
0053            mindist(track) = jdist(track);
0054            mincenter(track) = j;
0055         end
0056         computed = k*n;
0057 
0058     elseif method == 2 & ~skipestep 
0059         computed = 0;
0060 
0061 % for each center, nndist is half the distance to the nearest center
0062 % if d(x,center) < nndist then x cannot belong to any other center
0063 % mindist is an upper bound on the distance of each point to its nearest center
0064 
0065         nndist = min(centdist,[],2);
0066 % the following usually is not faster
0067 %        ldist = min(lower,[],2);
0068 %        mobile = find(mindist > max(nndist(mincenter),ldist));
0069         mobile = find(mindist > nndist(mincenter));
0070         
0071 % recompute distances for point i and center j
0072 %       only if j can possibly be the new nearest center
0073 % for speed, the first check has been optimized by modifying centdist
0074 % swapping the order of the checks is slower for data with natural clusters
0075 
0076         mdm = mindist(mobile);
0077         mcm = mincenter(mobile);
0078  
0079         for j = 1:k
0080 % the following is incorrect: for j = unique(mcm)'
0081             track = find(mdm > centdist(mcm,j));
0082             if isempty(track) continue; end
0083             alt = find(mdm(track) > lower(mobile(track),j));          
0084             if isempty(alt) continue; end
0085             track1 = mobile(track(alt));
0086                     
0087 % calculate exact distances to the mincenter
0088 % recalculate separately for each jj to avoid copying too much of data
0089 % redo may be empty, but we don't need to check this
0090             redo = find(~recalculated(track1));
0091             redo = track1(redo);
0092             c = mincenter(redo);
0093             computed = computed + size(redo,1);
0094             for jj = unique(c)'
0095                 rp = redo(find(c == jj));
0096                 udist = calcdist(data(rp,:),centers(jj,:));
0097                 lower(rp,jj) = udist;
0098                 mindist(rp) = udist;
0099             end
0100             recalculated(redo) = 1;
0101             
0102             track2 = find(mindist(track1) > centdist(mincenter(track1),j));
0103             track1 = track1(track2);
0104             if isempty(track1) continue; end
0105            
0106             % calculate exact distances to center j
0107             track4 = find(lower(track1,j) < mindist(track1));
0108             if isempty(track4) continue; end
0109             track5 = track1(track4);
0110             jdist = calcdist(data(track5,:),centers(j,:));
0111             computed = computed + size(track5,1);
0112             lower(track5,j) = jdist;
0113                     
0114             % find which points really are assigned to center j
0115             track2 = find(jdist < mindist(track5));
0116             track3 = track5(track2);
0117             mindist(track3) = jdist(track2);
0118             mincenter(track3) = j;
0119         end % for j=1:k
0120     end % if method
0121       
0122     oldcenters = centers;
0123         
0124 % M step: recalculate the means for each cluster
0125 % if a cluster is empty, its mean is left unchanged
0126 % we minimize computations for clusters with little changed membership
0127     
0128     diff = find(mincenter ~= oldmincenter);
0129     diffj = unique([mincenter(diff);oldmincenter(diff)])';
0130     diffj = diffj(find(diffj > 0));
0131     
0132     if size(diff,1) < n/3 & iteration > 0
0133          for j = diffj
0134             plus = find(mincenter(diff) == j);
0135             minus = find(oldmincenter(diff) == j);
0136             oldpop = pop(j);
0137             pop(j) = pop(j) + size(plus,1) - size(minus,1);
0138             if pop(j) == 0 continue; end
0139             centers(j,:) = (centers(j,:)*oldpop + sum(data(diff(plus),:),1) - sum(data(diff(minus),:),1))/pop(j); 
0140         end
0141     else
0142         for j = diffj
0143             track = find(mincenter == j);
0144             pop(j) = size(track,1);
0145             if pop(j) == 0 continue; end
0146 % it's correct to have mean(data(track,:),1) but this can make answer worse!
0147             centers(j,:) = mean(data(track,:),1);
0148         end
0149     end
0150     
0151     if method == 2
0152         for j = diffj
0153             offset = calcdist(centers(j,:),oldcenters(j,:));
0154             computed = computed + 1;
0155             if offset == 0 continue; end
0156             track = find(mincenter == j);
0157             mindist(track) = mindist(track) + offset;
0158             lower(:,j) = max(lower(:,j) - offset,0);
0159         end
0160 
0161 % compute distance between each pair of centers
0162 % modify centdist to make "find" using it faster
0163         recalculated = zeros(n,1);
0164         realdist = alldist(centers);
0165         centdist = 0.5*realdist + diag(Inf*ones(k,1));
0166         computed = computed + k + k*(k-1)/2;   
0167     end
0168     
0169     nchanged = size(diff,1) + skipestep;
0170     iteration = iteration+1;
0171     skipestep = 0;
0172     oldmincenter = mincenter;
0173 
0174 %   difference = max(max(abs(oldcenters - centers)));
0175 %   [iteration toc nchanged computed size(diffj,2)]
0176     [iteration toc nchanged computed]
0177     total = total + computed;
0178 end % while nchanged > 0
0179 
0180 udist = calcdist(data,centers(mincenter,:));
0181 quality = mean(udist);
0182 q2 = mean(udist.^2);
0183 [iteration toc quality q2 total]

Generated on Sat 22-Aug-2009 22:15:36 by m2html © 2003