首页 > 代码库 > ZCA白化变换推导——Learning Multiple Layers of Features from Tiny Images

ZCA白化变换推导——Learning Multiple Layers of Features from Tiny Images

参考文献:Learning Multiple Layers of Features from Tiny Images:附录

设数据集 X 的维数为 d×n ,且已经中心化

则协方差矩阵为

1/(n-1)*X*X‘

我们想让这n个d维向量中任意两维都不相关,则假定去相关矩阵为W

Y = W*X

为了使W达到去相关的目的,Y*Y‘必须是对角阵,可以进一步约束Y满足

Y * Y’ = (n - 1) I

再对W矩阵加限制(主要是方便下面的推导)

W = W‘

Y * Y’ = (n-1) I

W*X*X‘*W‘ = (n-1) I

W ‘ * W * X * X‘ * W = (n-1) * W‘ = (n-1) * W

所以W^2 * X * X‘ = (n-1) I

W = sqrt(n-1) * (X * X‘)^(-1/2)

而X * X‘ 是对称半正定,所以可以分解为 P*D*P’,其中D是对角阵,P是正交阵,(X * X‘)^(a) = P * D^(a) * P

所以W = P * D^(-1/2) * P‘;


Matlab代码

testZCA.m

clear;clc
patches = [];
tx = imread('test.jpg');%load('pcaData.txt','-ascii');
tx = double(tx);
x = zeros(size(tx, 1) * size(tx, 2), size(tx, 3));
tx = tx(:);
for i = 1 : size(x,2)
    x(:, i) = tx(1 + size(x, 1) * (i - 1) : size(x,1) * i);
end
patches = x';
%for i = 1 : size(x,2);
%     im = imread(strcat('train\', num2str(i), '.png'));
%     im = reshape(im, [1, 32*32*3]);
%     im = double(im);
%     % centralize
%     im(1:32*32) = im(1:32*32) - mean(im(1:32*32));
%     im(32*32+1:2*32*32) = im(32*32+1:2*32*32) - mean(im(32*32+1:2*32*32));
%     im(2*32*32+1:3*32*32) = im(2*32*32+1:3*32*32) - mean(im(2*32*32+1:3*32*32));
%     
%     patches = [patches, im'];
%end
y = ZCA_whitening(patches);
y = y * y';
y = ZCA_normalize(y);
covx = 1/1000 * patches * patches';
covx = ZCA_normalize(covx);
hold on
subplot(1,2,1)
imshow(uint8(covx))
subplot(1,2,2)
imshow(uint8(y));


function y = ZCA_normalize(x)
[row, col] = size(x);
tx = [];
for i = 1 : row
    tx = [tx, x(i,:)];
end
tx = tx - min(tx);
tx = tx / max(tx) * 255;
y = zeros(row, col);
for i = 1 : row
    y(i, :) = tx((i - 1) * col + 1 : i * col);
end
end

%% 假定每一个d维数据是0均值的,那么这n个d维数据的矩阵X(d*n)的协方差矩阵为
%% covX = 1/(n-1)*X*X'
%% 为了消除维数之间的相关性,做变换W,得到Y,即Y = W*X
%% 下面求‘去相关矩阵W'.由于W消除了相关性,所以Y*Y'是对角阵,故令W满足
%% Y*Y' = (n-1)*I
%% 由于满足条件的W很多,那么不妨设W=W'
%% 之后的推导比较自然,比较难想到的就是将X*X'做正交分解为P*D*P'

function y = ZCAwhitening(x)
[dim, n] = size(x);
[P, D] = schur(x * x');
w = sqrt(n-1) * P * D^(-1/2) * P';
y = w * x;
end
效果图——数据集是取自Kaggle上的Tiny Image Classification的比赛