Home > ACA-Code > ToolSimpleNmf.m

ToolSimpleNmf

PURPOSE ^

computes nmf (implementation inspired by

SYNOPSIS ^

function [W, H, err] = ToolSimpleNmf(X, iRank, iMaxIteration, fSparsity)

DESCRIPTION ^

computes nmf (implementation inspired by
https://github.com/cwu307/NmfDrumToolbox/blob/master/src/PfNmf.m)
>
> @param X: non-negative matrix to factorize (usually ifreq x iObservations)
> @param iRank: nmf rank
> @param iMaxIteration: maximum number of iterations (default: 300)
> @param fSparsity: sparsity weight (default: 0)
>
> @retval W dictionary matrix
> @retval H activation matrix
> @retval err loss function result
 ======================================================================

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SUBFUNCTIONS ^

SOURCE CODE ^

0001 %computes nmf (implementation inspired by
0002 %https://github.com/cwu307/NmfDrumToolbox/blob/master/src/PfNmf.m)
0003 %>
0004 %> @param X: non-negative matrix to factorize (usually ifreq x iObservations)
0005 %> @param iRank: nmf rank
0006 %> @param iMaxIteration: maximum number of iterations (default: 300)
0007 %> @param fSparsity: sparsity weight (default: 0)
0008 %>
0009 %> @retval W dictionary matrix
0010 %> @retval H activation matrix
0011 %> @retval err loss function result
0012 % ======================================================================
0013 function [W, H, err] = ToolSimpleNmf(X, iRank, iMaxIteration, fSparsity)
0014 
0015     if nargin < 4
0016         fSparsity = 0;
0017     end
0018     if nargin < 3
0019         iMaxIteration = 300;
0020     end
0021     %rng(42);
0022     
0023     % avoid zero input
0024     X = X + realmin;
0025 
0026     % initialization
0027     [iFreq, iFrames] = size(X);
0028     err = zeros(1, iMaxIteration);
0029     bUpdateW = true;
0030     bUpdateH = true;
0031 
0032     W = rand(iFreq, iRank);
0033     H = rand(iRank, iFrames);
0034 
0035     % normalize W / H matrix
0036     for i = 1:iRank
0037         W(:, i) = W(:, i) ./ (norm(W(:, i), 1));
0038     end
0039 
0040     count = 0;
0041     rep   = ones(iFreq, iFrames);
0042 
0043     % iteration
0044     while (count < iMaxIteration)  
0045     
0046         % current estimate
0047         X_hat = W * H; 
0048  
0049         % update
0050         if bUpdateH
0051             H = H .* (W'* (X./X_hat)) ./ (W'*rep);
0052         end
0053         if bUpdateW
0054             W = W .* ((X./X_hat)*H') ./ (rep*H');
0055         end
0056     
0057         %normalize
0058         for i = 1:iRank
0059             W(:, i) = W(:, i) ./ (norm(W(:, i), 1));
0060         end
0061        
0062         %calculate variation between iterations
0063         count = count + 1;
0064         err(count) = KlDivergence_I(X, (W*H)) + fSparsity * norm(H, 1);
0065     
0066         if (count >=2)               
0067             if (abs(err(count) - err(count - 1)) / ...
0068                     (err(1) - err(count) + realmin)) < 0.001
0069                 break;
0070             end
0071         end   
0072     end
0073     err = err(1:count);
0074 end
0075 
0076 function [D] = KlDivergence_I(p, q)
0077     D = sum(sum( p.*( log(p + realmin) - log(q + realmin)) - p + q ));
0078 end

Generated on Fri 22-Apr-2022 20:59:51 by m2html © 2005