function K = denoise_daub4_subband_adaptive(I, N, type, mult)
%
% DENOISE_DAUB4_SUBBAND_ADAPTIVE   wavelet denoising with D4 wavelet
%    K = DENOISE_DAUB4_SUBBAND_ADAPTIVE(I, N, type, mult)
%    K = DENOISE_DAUB4_SUBBAND_ADAPTIVE(I, N, type)
%    K = DENOISE_DAUB4_SUBBAND_ADAPTIVE(I, N)
%
%    K: denoised image
%    I: input image
%    N: # wavelet levels
%    type: 'H' for hard threshold, 'S' for soft threshold
%    mult: scalar multiplier
%
% NOTE: if not specified, hard thresholding with a multiplier of 1.0 is
%       assumed. The standard deviation of the HH subband is used as an
%       estimate of the noise variance, which is then multiplied by
%       mult to derive a threshold for denoising.
%

I = double(I);

% default to mult = 1
if nargin<4, mult = 1;, end

% default to hard thresholding
if nargin<3, type='H';, end

% sanity checks omitted (check for double, square matrix, etc.)
[nrows ncols] = size(I);

% N-level D4 forward wavelet transform on the image I
num_levels = N;
J = I;
for level = 1:num_levels
    sz = nrows/2^(level-1); % assumes # rows = # cols
    H = daub4_basis(sz);
    P = permutation_matrix(sz);
    J(1:sz,1:sz) = P*H*J(1:sz,1:sz)*H'*P';
end

% subband-adaptive denoising step
for level = 1:num_levels
    sz = nrows/2^level; % assumes # rows = # cols
    HH = J(sz+1:sz*2, sz+1:sz*2);
    % calculate threshold based on current HH subband,
    % using standard deviation estimate of noise variance
    lambda = mult*std2(HH);
    % threshold HH, LH, and HL subbands
    LH = J(1:sz, sz+1:sz*2);
    HL = J(sz+1:sz*2, 1:sz);
    if type == 'H' % hard threshold
        HH = hard_thresh(HH, lambda);
        LH = hard_thresh(LH, lambda);
        HL = hard_thresh(HL, lambda);
    elseif type == 'S' % soft threshold
        HH = soft_thresh(HH, lambda);
        LH = soft_thresh(LH, lambda);
        HL = soft_thresh(HL, lambda);        
    end
    % finally, replace existing detail coefficients
    J(sz+1:sz*2, sz+1:sz*2) = HH;
    J(1:sz, sz+1:sz*2) = LH;
    J(sz+1:sz*2, 1:sz) = HL;
end

% N-level D4 inverse wavelet transform
K = J;
sz = nrows/2^num_levels;
for level = 1:num_levels
    sz = sz*2;
    H = daub4_basis(sz);
    P = permutation_matrix(sz);
    K(1:sz,1:sz) = inv(P*H)*K(1:sz,1:sz)*inv(H'*P');   
end

function thresholded = hard_thresh(detail_coefs, T)

% simple, "keep or kill": anything above T keep and
% anything below T set to zero.
thresholded = (abs(detail_coefs) > T) .* detail_coefs;

function thresholded = soft_thresh(detail_coefs, T)

% In the text we define the soft threshold operator as so:
%
%            x-T, if x >= T
% Tsoft(x) = 0,   if x < T
%            x+T, if x <= -T
%
% An equivalent means of stating the above is to set to zero 
% any coefficient below T (in absolute value), and
%
% Tsoft(x) = sgn(x)(|x|-T), if |x|>T 

absd = abs(detail_coefs);
thresholded = sign(detail_coefs).*(absd >= T).*(absd - T); 