Home > ACA-Code > ToolGmm.m

ToolGmm

PURPOSE ^

gaussian mixture model

SYNOPSIS ^

function [mu, sigma, state] = ToolGmm(V, k, numMaxIter, prevState)

DESCRIPTION ^

gaussian mixture model
>
> @param FeatureMatrix: features for all train observations (dimension iNumFeatures x iNumObservations)
> @param k: number of gaussians
> @param numMaxIter: maximum number of iterations (stop if not converged before)
> @param prevState: internal state that can be stored to continue clustering later
>
> @retval mu means
> @retval sigma standard deviations
> @retval state result containing internal state (if needed)
 ======================================================================

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SUBFUNCTIONS ^

SOURCE CODE ^

0001 %gaussian mixture model
0002 %>
0003 %> @param FeatureMatrix: features for all train observations (dimension iNumFeatures x iNumObservations)
0004 %> @param k: number of gaussians
0005 %> @param numMaxIter: maximum number of iterations (stop if not converged before)
0006 %> @param prevState: internal state that can be stored to continue clustering later
0007 %>
0008 %> @retval mu means
0009 %> @retval sigma standard deviations
0010 %> @retval state result containing internal state (if needed)
0011 % ======================================================================
0012 function [mu, sigma, state] = ToolGmm(V, k, numMaxIter, prevState)
0013     
0014     if (nargin < 3)
0015         numMaxIter  = 1000;
0016     end
0017     if (nargin == 4)
0018         state = prevState;
0019     else
0020         % initialize state
0021         state = initState_I(V, k);
0022     end
0023     
0024     for j = 1:numMaxIter
0025         prevState = state;
0026         
0027         % compute weighted gaussian
0028         p = computeProb_I(V, state);
0029         
0030         % update clusters
0031         state = updateGaussians_I(V, p, state);
0032          
0033         % if we have converged, break
0034         if (max(sum(abs(state.m-prevState.m))) <= 1e-20)
0035             break;
0036         end
0037     end
0038     
0039     mu = state.m;
0040     sigma = state.sigma;
0041 end
0042 
0043 function [state] = updateGaussians_I(FeatureMatrix,p,state)
0044 
0045     % number of clusters
0046     K = size(state.m, 2);
0047  
0048     % update priors
0049     state.prior = mean(p, 1)';
0050 
0051     for k = 1:K
0052         s = 0;
0053         
0054         % update means
0055         state.m(:, k) = FeatureMatrix * p(:, k) / sum(p(:, k));
0056         
0057         % subtract mean
0058         F = FeatureMatrix - repmat(state.m(:, k), 1, size(FeatureMatrix, 2));
0059         
0060         for n = 1:size(FeatureMatrix, 2)
0061             s = s + p(n, k) * (F(:, n) * F(:, n)');
0062         end
0063         state.sigma(:, :, k) = s / sum(p(:, k));
0064     end
0065 end
0066 
0067 function [p] = computeProb_I(FeatureMatrix, state)
0068 
0069     K = size(state.m, 2);
0070     p = zeros(size(FeatureMatrix, 2), K);
0071     
0072     % for each cluster
0073     for k = 1:K
0074         % subtract mean
0075         F = FeatureMatrix - repmat(state.m(:, k), 1, size(FeatureMatrix, 2));
0076 
0077         % weighted gaussian
0078         p(:, k) = 1 / sqrt((2*pi)^size(F, 1) * det(state.sigma(:, :, k))) *...
0079             exp(-1/2 * sum((F .* (inv(state.sigma(:, :, k)) * F)), 1)');
0080         p(:, k) = state.prior(k) * p(:, k);
0081     end
0082     
0083     % norm over clusters
0084     p = p ./ repmat(sum(p,2),1,K);
0085 end
0086 
0087 function [state] = initState_I(FeatureMatrix,K)
0088 
0089     %init
0090     m       = zeros(size(FeatureMatrix,1), K);
0091     sigma   = zeros(size(FeatureMatrix,1), size(FeatureMatrix,1), K);
0092     prior   = zeros(1,K);
0093 
0094     % pick random points as cluster means
0095     mIdx    = round(rand(1,K)*(size(FeatureMatrix,2)-1))+1;
0096  
0097     % assign means etc.
0098     m       = FeatureMatrix(:,mIdx);
0099     prior   = ones(1,K)/K;
0100     sigma   = repmat(cov(FeatureMatrix'),1,1,K);
0101 
0102     % write initial state
0103     state   = struct('m',m,'sigma',sigma,'prior',prior);
0104 end

Generated on Fri 22-Apr-2022 20:59:51 by m2html © 2005