#include "stdafx.h"
#include "MMSEFilter.h"

using namespace std;

void MMSEFilter::processImage(Image8bpp *pNoisyImage, Image8bpp *pProcessedImage)
{
  // sanity check
  if (pNoisyImage->getHeight()!=pProcessedImage->getHeight() ||
	  pNoisyImage->getWidth()!=pProcessedImage->getWidth())
    throw runtime_error("invalid args: images must be same size!");

  int nr=pNoisyImage->getHeight(), 
	  nc=pNoisyImage->getWidth(), 
	  stride=pNoisyImage->getStride();
  float N = (float)m_kernelSize*m_kernelSize;

  // adaptive filtering occurs in 2 steps, first compute local stats
  // and then filter the image based on those states: so here we setup
  // buffers to store these descriptive statistics.
  int nStats = (nr-2*m_khw)*(nc-2*m_khw);
  vector<float> localMeans(nStats);
  // localVariances allocated via Intel library function because later
  // on we're going to use ipps functions to estimate noise power.
  Ipp32f *localVariances = (Ipp32f *)ippsMalloc_32f(nStats);

  // compute localized statistics within the interior of the image
  vector<float>::iterator pLocalMeans = localMeans.begin();
  Ipp32f *pLocalVariances = localVariances;
  for (int iRow=m_khw; iRow<nr-m_khw; ++iRow) {
	Ipp8u *pIn  = pNoisyImage->getPixels() + (iRow-m_khw)*stride;
	for (int iCol=m_khw; iCol<nc-m_khw; ++iCol, pIn++) {

	  // compute local mean and local variance
	  unsigned int sum = 0, sumSquared = 0;
	  for (int iw=0; iw<m_kernelSize; ++iw) {
        for (int ih=0; ih<m_kernelSize; ++ih) {
		  sum += pIn[iw+ih*stride];
		  sumSquared += pIn[iw+ih*stride]*pIn[iw+ih*stride];
		}
	  }
	  float mean = (float)sum/N,
			avgSumSquared = (float)sumSquared/N;
	  Ipp32f variance = avgSumSquared - mean*mean;
	  *pLocalMeans++ = mean;
	  *pLocalVariances++ = variance;

	} // end (for each column)
  } // end (for each row)

  // estimate noise power based on local variances
  Ipp32f noiseVariance;
  if (ippStsNoErr != ippsMean_32f(localVariances, nStats, &noiseVariance, ippAlgHintAccurate))
	  throw runtime_error("Failed to estimate noise variance!");

  // and now filter the image
  pLocalMeans = localMeans.begin();
  pLocalVariances = localVariances;
  for (int iRow=m_khw; iRow<nr-m_khw; ++iRow) {
	Ipp8u *pIn  = pNoisyImage->getPixels() + (iRow-m_khw)*stride,
	      *pOut = pProcessedImage->getPixels() + (iRow)*stride + m_khw;
	for (int iCol=m_khw; iCol<nc-m_khw; ++iCol, pIn++, pOut++) {
	  float localMean = *pLocalMeans++,
		    localVar  = *pLocalVariances++;
	  // This is the heart of the MMSE adaptive filter:
	  // Linearly interpolate between the local mean and the input pixel.
	  // If the local variance is roughly the same as the noise variance,
	  // give greater weight to the mean else tilt towards the input
	  float alpha = noiseVariance/localVar;
	  // pin alpha such that if the noise variance dominates output
	  // pixel contribution comes entirely from the local mean
	  if (alpha > 1.0) 
		  alpha = 1.0;
	  float outPixel = (1-alpha)*(*pIn) + alpha*localMean;

	  // handle out-of-range pixels
	  if (outPixel < 0.f)
		  outPixel = localMean;
	  else if (outPixel > 255.f)
		  outPixel = localMean;
	  *pOut = (Ipp8u)outPixel+0.5; // clamp to 8 bit range
	} // end (for each column)
  } // end (for each row)

  ippsFree(localVariances);
}