シグマ(和)とパイ(積)が混在しているので対数尤度関数の偏微分が容易に計算できない。
(パイだけだとログを取るとシグマに転換するので簡単に計算できる)
https://github.com/shohei/mnist
うまく尤度が計算できてない
やりかけのコード
function classify_by_em_algorithm clear all; close all; global img; global img2; open_and_read_mnist(); binarize(); global img3; prepare_serialized_image(); global uk; init_generator(); global pk; init_pickup_probability(); global gamma_nk; global count; count=0; for idx=1:20 compute_generation_probability(); recreate_generator(); end disp 'comutation done'; save_result(); function open_and_read_mnist fid=fopen('train-images-idx3-ubyte','r','b') magic_number = fread(fid,1,'int32') number_of_items = fread(fid, 1, 'int32') number_of_rows = fread(fid,1,'int32') number_of_columns = fread(fid,1,'int32') img = fread(fid, [28*28 60000],'uint8'); img = reshape(img,28,28,60000); for idx=1:60000 img2(:,:,idx) = uint8(img(:,:,idx)'); end end function binarize thresh = 10; img2(img2<=thresh)=1; img2(img2>thresh)=0; end function prepare_serialized_image img3 = reshape(img2,[28*28 60000]); end function init_generator for k=1:10 u = []; for jdx=1:784 u(end+1) = round(rand(1)); end uk(:,end+1)=u; end uk = reshape(uk,[28*28 10]); end function init_pickup_probability for idx=1:10 pk(end+1)=1/10; end end function compute_generation_probability for k=1:10 puk = []; for n=1:60000 puk(k,n)=1; for pixel=1:28*28 puk(k,n) = puk(k,n) * uk(pixel,k)^img3(pixel,n) * (1-uk(pixel,k))^(1-img3(pixel,n)); end end denominator = 0; denominator = denominator + pk(k)*puk(k,n); gamma_nk(k,n) = pk(k)*puk(k,n)/denominator; disp( sprintf( 'cycle %d done!', k ) ); end end function recreate_generator numerator = 0; denominator = 0; for k=1:10 for n=1:60000 numerator = numerator + gamma_nk(k,n)*img3(:,n); denominator = denominator + gamma_nk(k,n); end uk(:,k) = numerator / denominator; pk(k) = denominator/60000; end count = count+1; disp( sprintf( '** Epoch %d finished. **', count ) ); end function save_result save('result.mat','pk','uk'); end end
分類実行
function do_classify clear all; close all; load('result.mat','pk','uk'); global img; global img2; open_and_read_mnist(); binarize(); global img3; prepare_serialized_image(); classify_image(img3(:,1)); % classify_image(img3(:,2)); % classify_image(img3(:,3)); % classify_image(img3(:,4)); % classify_image(img3(:,5)); % classify_image(img3(:,6)); disp 'stop here.'; function open_and_read_mnist fid=fopen('train-images-idx3-ubyte','r','b') magic_number = fread(fid,1,'int32') number_of_items = fread(fid, 1, 'int32') number_of_rows = fread(fid,1,'int32') number_of_columns = fread(fid,1,'int32') img = fread(fid, [28*28 60000],'uint8'); img = reshape(img,28,28,60000); for idx=1:60000 img2(:,:,idx) = uint8(img(:,:,idx)'); end end function binarize thresh = 10; img2(img2<=thresh)=1; img2(img2>thresh)=0; end function prepare_serialized_image img3 = reshape(img2,[28*28 60000]); end function classify_image(input) for k=1:10 puk = []; puk(k)=1; for pixel=1:28*28 puk(k) = puk(k) * uk(pixel,k)^input(pixel) * (1-uk(pixel,k))^(1-input(pixel)); end denominator = 0; denominator = denominator + pk(k)*puk(k); gamma_nk(k) = pk(k)*puk(k)/denominator; end [maximum,index]=max(gamma_nk(:)); disp( sprintf( 'input number classified as %d !', index ) ); end end