Home > ACA-Code > ToolSimpleKmeans.m

ToolSimpleKmeans

PURPOSE ^

performs kmeans clustering

SYNOPSIS ^

function [clusterIdx, state] = ToolSimpleKmeans(V, K, numMaxIter, prevState)

DESCRIPTION ^

performs kmeans clustering
>
> @param V: features for all train observations (dimension iNumFeatures x iNumObservations)
> @param k: number of clusters
> @param numMaxIter: maximum number of iterations (stop if not converged before)
> @param prevState: internal state that can be stored to continue clustering later
>
> @retval clusterIdx cluster index of each observation
> @retval state internal state (only if needed)
 ======================================================================

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SUBFUNCTIONS ^

SOURCE CODE ^

0001 %performs kmeans clustering
0002 %>
0003 %> @param V: features for all train observations (dimension iNumFeatures x iNumObservations)
0004 %> @param k: number of clusters
0005 %> @param numMaxIter: maximum number of iterations (stop if not converged before)
0006 %> @param prevState: internal state that can be stored to continue clustering later
0007 %>
0008 %> @retval clusterIdx cluster index of each observation
0009 %> @retval state internal state (only if needed)
0010 % ======================================================================
0011 function [clusterIdx, state] = ToolSimpleKmeans(V, K, numMaxIter, prevState)
0012     
0013     if (nargin < 3)
0014         numMaxIter = 1000;
0015     end
0016     if (nargin == 4)
0017         state = prevState;
0018     else
0019         % initialize
0020         % use fixed seed for reproducibility (comment out if needed)
0021         %rng(42);
0022         
0023         % pick random observations as cluster means
0024         state = struct('m', V(:, round(rand(1, K) * (size(V, 2)-1))+1)); 
0025     end
0026     range_V = [min(V, [], 2) max(V, [], 2)];
0027     
0028     % assign observations to clusters
0029     clusterIdx = assignClusterLabels_I(V, state);
0030     
0031     for i=1:numMaxIter
0032         prevState = state;
0033         
0034         % update means
0035         state = computeClusterMeans_I(V, clusterIdx, K);
0036         
0037         % reinitialize empty clusters
0038         state = reinitState_I(state, clusterIdx, K, range_V);
0039         
0040         % assign observations to clusters
0041         clusterIdx = assignClusterLabels_I(V, state);
0042         
0043         % if we have converged, break
0044         if (max(sum(abs(state.m-prevState.m)))==0)
0045             break;
0046         end
0047     end
0048 end
0049 
0050 function [clusterIdx]  = assignClusterLabels_I(V, state)
0051 
0052     K = size(state.m, 2);
0053     
0054     % compute distance to all points
0055     for k = 1:K
0056         D(k, :) = sqrt(sum((repmat(state.m(:, k), 1, size(V, 2))-V).^2, 1));
0057     end
0058     
0059     % assign to closest
0060     [dummy, clusterIdx] = min(D, [], 1);
0061 end
0062 
0063 function [state] = computeClusterMeans_I(V, clusterIdx, K)
0064     m = zeros(size(V, 1), K);
0065     for k = 1:K
0066         if~(isempty(find(clusterIdx==k)))
0067             m(:, k) = mean(V(:,find(clusterIdx==k)), 2);
0068         end
0069     end
0070     state = struct('m',m);
0071 end
0072 
0073 function  [state] = reinitState_I(state, clusterIdx, K, range)
0074     for k = 1:K
0075         if(isempty(find(clusterIdx==k)))
0076             state.m(:, k) = rand(size(state, 1), 1).*(range(:, 2)-range(:, 1)) + range(:, 1);
0077         end
0078     end
0079 end

Generated on Fri 22-Apr-2022 20:59:51 by m2html © 2005