#include "Coverage.h"

enum ECoverageTags { kCondenseTag, kRemoteValueTag };
const byte kBackGround = 255;

MemoryManager* PointSetCoverage::fMemMgr = NULL;

#ifdef USE_MPI
MPI_Comm PointSetCoverage::fComm = 0;
#endif

/*
*************************************************************************
*									*
* PointSetCoverage     					        *
*									*
*************************************************************************
*/

PointSetCoverage::PointSetCoverage() { 
  fValue=NULL; fGhostSize=1; fNObjects=0; fPSet = NULL; memset(fInfo,0,8);
#ifdef USE_MPI
  fReq = NULL;
	fNReq = 0;
  fStatus = NULL;
  fComm = 0;
#endif
}

PointSetCoverage::PointSetCoverage( DistributedGrid* pset, int gSize, int memoryIndex ) { 
  memset(fInfo,0,8);
  fValue=NULL; 
  fPSet = NULL;  
#ifdef USE_MPI
  fReq = NULL;
	fNReq = 0;
  fStatus = NULL;
	fComm = 0;
#endif
  Setup(pset,gSize,memoryIndex); 
}


int PointSetCoverage::Setup( PointSetCoverage& cov, int memoryIndex ) {
  if( Allocated() ) return 0;
  fInfo[kBaseType] = cov.fInfo[kBaseType];
  fMemoryIndex = memoryIndex;
  cov.fPSet->CheckGhostSize(fGhostSize = cov.fGhostSize);
  fPSet = cov.fPSet;
  fTag = cov.fTag;
  fNObjects = cov.fNObjects;

  ShallowCopy( fMemMgr->GetVector( fNObjects, *this, &cov ) );

  return 1;
}

int PointSetCoverage::Setup( DistributedGrid* pset, int gSize, int memoryIndex ) {
  if( Allocated() ) return 0;
  fInfo[kBaseType] = pset->GetInfo(DGrid::kBaseType);
  fMemoryIndex = memoryIndex;
  pset->CheckGhostSize(fGhostSize = gSize);
  if(fMemMgr == NULL) fMemMgr = new MemoryManager();
  fPSet = pset;
  fTag = pset->GetCoverageTag();
  fNObjects = pset->Size();
#ifdef USE_MPI
  fReq = NULL;
	fNReq = 0;
  fStatus = NULL;
	fComm = 0;
#endif

  ShallowCopy( fMemMgr->GetVector( fNObjects, *this ) );

  return 1;
}

int PointSetCoverage::Setup( DistributedGrid* pset, PointSetCoverage& cov, int gSize, int memoryIndex ) {
  if( Allocated() ) return 0;
  fInfo[kBaseType] = pset->GetInfo(DGrid::kBaseType);
  fMemoryIndex = memoryIndex;
  pset->CheckGhostSize(fGhostSize = gSize);
  fPSet = pset;
  fTag = pset->GetCoverageTag();
  fNObjects = pset->Size();

  ShallowCopy( fMemMgr->GetVector( fNObjects, *this, &cov ) );

  return 1;
}

int PointSetCoverage::SetupCommunication(int do_buffer) {

  static int MPIBuffersize  = 0;
  int buff_size, nrmax = fPSet->NLinks()*2;
#ifdef USE_MPI
  if( fComm == 0 )  { 
		if( Env::GetComm(&fComm) ) { 
			gFatal("Error in PointSetCoverage"); 
		}
    SME_Errors_dump_to_debugger( fComm, Env::Program(), Env::Display() ); 
  }
  MPI_Datatype cellMap;
  fReq = new MPI_Request[nrmax];
  fStatus = new MPI_Status[nrmax];
  for(int ip=0; ip<gNProc; ip++) {
    cellMap = fPSet->GetGhostIndexMap( fGhostSize, ip, kL_Input );
    if( cellMap != MPI_DATATYPE_NULL ) {     // post recieve
			MPI_Type_size(cellMap,&buff_size);
			if( buff_size > 0 ) {
				MPI_Recv_init(fValue,1,cellMap,ip,fTag,fComm,fReq+(fNReq++));
				if(gDebug) {
					sprintf(gMsgStr,"\nPosted Receive Request, index(%d), Proc(%d), data size = %d, ghostSize = %d",fNReq-1,ip,buff_size,fGhostSize);
					gPrintLog(); 
				}
			}
    }
    cellMap = fPSet->GetGhostIndexMap( fGhostSize, ip, kL_Output );
    if( cellMap != MPI_DATATYPE_NULL ) {     // post send
			MPI_Type_size(cellMap,&buff_size);
			if( buff_size > 0 ) {
				if( do_buffer ) {
					MPIBuffersize += buff_size; 
					Env::SetMPIBufferSize( 4*MPIBuffersize, 8*MPIBuffersize );
					MPI_Bsend_init(fValue,1,cellMap,ip,fTag,fComm,fReq+(fNReq++));
				} else {
					MPI_Send_init(fValue,1,cellMap,ip,fTag,fComm,fReq+(fNReq++));
				}
				if(gDebug) {
					sprintf(gMsgStr,"\nPosted Send Request, index(%d), Proc(%d), data size = %d, ghostSize = %d",fNReq-1,ip,buff_size,fGhostSize);
					gPrintLog();
				}
			}
    }
  }
#endif
	return 0;
}

int PointSetCoverage::StartLinkEdges() {
  int rv = 0;
#ifdef USE_MPI
  if( fReq == NULL ) SetupCommunication();
  rv = MPI_Startall(fNReq,fReq);
#endif
  return rv;
}

int PointSetCoverage::LinkEdges( const char* name ) {
#ifdef USE_MPI
	Env::StartCommTimer();
  StartLinkEdges();
  return CompleteLinkEdges( name );
	Env::StopCommTimer();
#endif
	return 0;
}

int PointSetCoverage::CompleteLinkEdges( const char* name ) {
  int rv = 0;
#ifdef USE_MPI
  if( fReq  ) {
		if(gDebug>2) {
			sprintf(gMsgStr,"Link Edges %s: Waiting for completion of %d MPI requests.", name, fNReq );
			gPrintScreen(gMsgStr,TRUE);
		}
    rv = MPI_Waitall(fNReq,fReq,fStatus);      
    if(gDebug>2) {
      int count;
			sprintf(gMsgStr,"Link Edges %s Done", name );
			gPrintScreen(gMsgStr,TRUE);
      for(int i=0; i<fNReq; i++) {
				MPI_Get_count( fStatus+i, MPI_FLOAT, &count );
				sprintf(gMsgStr,"\nLinkMsg Completed: source(%d), tag(%d), nElem(%d): ",
					fStatus[i].MPI_SOURCE, fStatus[i].MPI_TAG, count );
				gPrintLog();
      }
    }
  }
#endif
  return rv;
}

void PointSetCoverage::SubSet( byte iset, float value ) {
  for( Pix p = fPSet->first(); p; fPSet->next(p) ) {
    byte test = fPSet->subset_contains( p, iset );
    if( test ) SetValue(p,value);
  }
}

PointSetCoverage&  PointSetCoverage::operator = (PointSetCoverage& rhs)
{
	fill(0.0);
  if (&rhs != this) {
    DistributedGrid* rPSet = rhs.PointSet();
    if( fPSet->Scale() <= rPSet->Scale() ) {
      for (Pix i = fPSet->first(); i; fPSet->next(i)) {
				Pix rp = rhs.MapPixFast( i, *fPSet);
				if(rp) SetValue(i,rhs(rp));
      }
    } else {
      for (Pix i = fPSet->first(); i; fPSet->next(i)) {
				OrderedPoint& pt = (OrderedPoint&) (*fPSet)(i);
				OrderedPoint* rpt = (OrderedPoint*) rhs.GetPoint(pt);
				if(rpt) SetValue( i, rhs.AggregateValue( *rpt, fPSet->Scale() ) );
      }
    }
  }
  return *this;
}


float PointSetCoverage::RemoteValue( Pix p, int recvProc ) {    // p should be non-zero on only one proc!
#ifdef USE_MPI
  MPI_Request request;
  MPI_Status status;
  float value;
  if( p ) {
    value = Value(p);
    if( gIProc == recvProc ) return value;
    else MPI_Isend( &value, 1, MPI_FLOAT, recvProc, kRemoteValueTag, fComm, &request );  
  }
  else if( gIProc == recvProc) {    
    MPI_Irecv( &value, 1, MPI_FLOAT, MPI_ANY_SOURCE, kRemoteValueTag, fComm, &request ); 
    MPI_Wait(&request,&status); 
  }
  return value;
#else
  return Value(p);
#endif
}

float PointSetCoverage::Reduce ( OrderedPointSet& set, EReduceOp op, int rootProc ) {
  float ftmp = 0.0; 
  int fp=1;
  for (Pix i = set.first(); i; set.next(i)) {
    const OrderedPoint& point = set.GetPoint(i);
    float value = Value(point);
    switch(op) {
    case kR_Max:
      if(fp) { ftmp = value; fp = 0; }
      else { ftmp = ( ftmp < value ) ? value : ftmp; }
      break;
    case kR_Min:
      if(fp) { ftmp = value; fp = 0; }
      else { ftmp = ( ftmp > value ) ? value : ftmp; }
      break;
    case kR_Sum:
      ftmp+= value;
      break;
    case kR_Ave:
      ftmp+= value;
    }
  }
#ifdef USE_MPI
  if( fComm == 0 )  { 
		if( Env::GetComm(&fComm) ) { 
			gFatal("Error in PointSetCoverage"); 
		}
    SME_Errors_dump_to_debugger(fComm, Env::Program(), Env::Display() ); 
  }
  float rv;
  switch(op) {
  case kR_Max:
    MPI_Reduce(&ftmp,&rv,1,MPI_FLOAT,MPI_MAX,rootProc,fComm);
    break;
  case kR_Min:
    MPI_Reduce(&ftmp,&rv,1,MPI_FLOAT,MPI_MIN,rootProc,fComm);
    break;
  case kR_Sum:
    MPI_Reduce(&ftmp,&rv,1,MPI_FLOAT,MPI_SUM,rootProc,fComm);
    break;
  case kR_Ave:
    int ctmp, count = set.length();
    MPI_Reduce(&ftmp,&rv,1,MPI_FLOAT,MPI_SUM,rootProc,fComm);
    MPI_Reduce(&count,&ctmp,1,MPI_INT,MPI_SUM,rootProc,fComm);
    if(gIProc == rootProc) { rv /= ctmp; }
  }
  return rv;
#else
  return ftmp;
#endif
}

float  PointSetCoverage::AggregateValue( const OrderedPoint& point, int scale ) {
  return Value(point);
}

void PointSetCoverage::Condense( byte*& cData, int targetProc, int bsize, EDataLayout layout, float s, float o )
{
#ifdef USE_MPI
  bsize = Util::iclamp(bsize,0,sizeof(float));
  int dLen = fNObjects*bsize;
  byte* cDataTmp = new byte[dLen];
  if( dLen > 0 ) {
    if( bsize < sizeof(float) ) {	
      unsigned long itmp, imax = Util::pow2(bsize*8) - 1; 
      float ftmp;
      byte* dPtr = cDataTmp;
      for(int i=0; i<fNObjects; i++)  {
				itmp = (unsigned long) Util::fclamp(((fValue[i]-o)/s),0.0,imax);
				Util::enc_Nb(&dPtr,itmp,bsize);
      }
    } else {
      for(int i=0; i<fNObjects; i++)  memcpy(cDataTmp+i*bsize,fValue+i,bsize);
    }
  }
  if( gIProc == targetProc ) {
    MPI_Status status; int flag, source, count, msgCnt=1;
    cData = MergeData(cData,cDataTmp,targetProc,bsize,fNObjects);
    while(1) {
      MPI_Iprobe(MPI_ANY_SOURCE,fTag,fComm,&flag,&status);
      if( flag == true ) {
				MPI_Get_count(&status,MPI_CHAR,&count);
				source = status.MPI_SOURCE;
				byte dTemp[count];
				MPI_Recv(dTemp,count,MPI_CHAR,source,fTag,fComm,&status);
				cData = MergeData(cData,dTemp,source,bsize,count/bsize);
				if(++msgCnt == gNProc) break;
			}
    }
  } else {
      MPI_Send(cDataTmp,dLen,MPI_CHAR,targetProc,fTag,fComm);
  }
  delete[] cDataTmp; 
  if( layout == kCompressed ) Compress(cData,bsize);
#endif  
}


void PointSetCoverage::CondenseToMap( TMap2& map, int targetProc )
{
#ifdef USE_MPI
  if( fComm == 0 )  { 
		if( Env::GetComm(&fComm) ) { 
			gFatal("Error in PointSetCoverage"); 
		}
    SME_Errors_dump_to_debugger( fComm, Env::Program(), Env::Display() ); 
  }
  int bsize = Util::iclamp(map.NBytes(),0,sizeof(float));
  int dLen = fNObjects*bsize;
  byte* cDataTmp = new byte[dLen];
  if( dLen > 0 ) {
    if( bsize < sizeof(float) ) {	
      byte* dPtr = cDataTmp;
      for(int i=0; i<fNObjects; i++)  {
				unsigned long itmp = map.Scale(fValue[i]);
				Util::enc_Nb(&dPtr,itmp,bsize);
      }
    } else {
      for(int i=0; i<fNObjects; i++)  memcpy(cDataTmp+i*bsize,fValue+i,bsize);
    }
  }
  if( gIProc == targetProc ) {
		Region2& globalRegion = fPSet->GlobalRegion();
		if( ((Region2&)map) != globalRegion) map.ReAlloc(globalRegion);	
    MPI_Status status; int flag, source, count, msgCnt=1;
    MergeData(map.Data(),cDataTmp,targetProc,bsize,fNObjects);
    while(1) {
      MPI_Iprobe(MPI_ANY_SOURCE,fTag,fComm,&flag,&status);
      if( flag == true ) {
				MPI_Get_count(&status,MPI_CHAR,&count);
				source = status.MPI_SOURCE;
				byte dTemp[count];
				MPI_Recv(dTemp,count,MPI_CHAR,source,fTag,fComm,&status);
				MergeData(map.Data(),dTemp,source,bsize,count/bsize);
				if(++msgCnt == gNProc) break;
			}
    }
  } else {
      MPI_Send(cDataTmp,dLen,MPI_CHAR,targetProc,fTag,fComm);
  }
  delete[] cDataTmp; 

#endif  
}

void PointSetCoverage::Compress( byte* data, int bsize ) {
  Region2& globalRegion = fPSet->GlobalRegion();
  ByteGrid* regionMap = fPSet->RegionMap();
  byte mVal; int pCnt = 0, sIndex, dIndex;   
  for(int ir=globalRegion.lower(0); ir<=globalRegion.upper(0); ir+=globalRegion.increment(0) ) {
    for(int ic=globalRegion.lower(1); ic<=globalRegion.upper(1); ic+=globalRegion.increment(1) ) { 
      mVal = regionMap->BValue(ir,ic,0);
      if( fPSet->ActivePoint(mVal) ) {
				sIndex = bsize*pCnt++;
				dIndex = globalRegion.bindex(ir,ic)*bsize;
				memcpy(data+sIndex,data+dIndex,bsize);
      }
    }
  }
}

// inline void my_memset(byte* dest, byte fillVal, int nLoc) {
//   for(int i=0; i<nLoc; i++) dest[i] = fillVal;
// }

byte* PointSetCoverage::MergeData(byte* dest, byte* src, int srcProc, int bsize, int nelem) {
  ByteGrid* regionMap = fPSet->RegionMap();
  Region2& srcRegion = fPSet->Region(srcProc);
  Region2& globalRegion = fPSet->GlobalRegion();
  if(dest==NULL) dest = new byte[globalRegion.nelem()*bsize];
  byte mVal; int pCnt = 0, sIndex, dIndex;   
  for(int ir=srcRegion.lower(0); ir<=srcRegion.upper(0); ir+=srcRegion.increment(0) ) {
    for(int ic=srcRegion.lower(1); ic<=srcRegion.upper(1); ic+=srcRegion.increment(1) ) { 
      mVal = regionMap->BValue(ir,ic,0);
      dIndex = globalRegion.bindex(ir,ic)*bsize;
			if( fPSet->ActivePoint(mVal) ) {
				sIndex = bsize*pCnt;
				memcpy(dest+dIndex,src+sIndex,bsize);
				if( ++pCnt == nelem ) return dest;
			} else {
				memset(dest+dIndex,kBackGround,bsize);
      }
    }
  }
  return dest;
}



PointSetCoverage&  PointSetCoverage::operator=( ByteGrid& map )
{
  for (Pix p = fPSet->first(); p; fPSet->next(p) ) { 
    OrderedPoint& pt = (OrderedPoint&) (*fPSet)(p);
    SetValue( p, map.Value(pt) );
  }
  return *this;
}


void  PointSetCoverage::CopyToMap( TMap2& map, Bool swap, Bool directScale )
{
#ifdef USE_MPI
 CondenseToMap( map, 0 );
#else
  Region2 region(Region());

  if( ((Region2&)map) != region) map.ReAlloc(region);
  OrderedPointSet* pSet = PointSet();
  if( map.NBytes() > 0 ) {
		map.Set((byte)1);
    int y = region.extents(0) - 1;
		for( Pix p = pSet->first(); p; pSet->next(p)) {
			OrderedPoint& point = (OrderedPoint&) (*pSet)(p);
			OrderedPoint pt(point(0), point(1));
			if (swap) pt.elem(0) = y-pt(0);
			map.SetValueWithRescale(&pt,Value(p),directScale);
		}
	} 
#endif
}

	void PointSetCoverage::SetDataValues( float* data, int layOut, int size ) {
		switch(layOut) {
			case PSCserial: {  
				int my_size = ( size > fSize ) ?  fSize : size;
				float *valp = fValue, *val_end = fValue+my_size;
				while( valp <  val_end ) {
					*valp++ = *data++;
				}
			} break;
			case PSCfullGrid:  break;
			case PSCsubGrid:  break;
		}
	
	}
/*
void  PointSetCoverage::CopyToMapRescaled( TMap2& map, float min, float max,
   float umin, float umax, Bool rescale)
{
  Region2 region(Region());
  
  if (((Region2&)map) != region) 
    map.ReAlloc(region);
  OrderedPointSet* pSet = PointSet();
  if( map.NBytes() > 0 ) {

    map.Set('\0');

    int y = region.extents(0) -1;

    float nb = 1.0;
    for (int i = 0; i < map.NBytes(); i++)
      nb *= 256.0;
    nb -= 1.0;
    
    Bool umx = (umin == 0.0 && umax == 0.0);

    if (umx) {
      umin = min;
      umax = max;
    }

    float cf =  nb/(umax - umin);

    for( Pix p = pSet->first(); p; pSet->next(p)) {

      OrderedPoint& point = (OrderedPoint&) (*pSet)(p);
      // TBC swap row/column to x/y ???
      OrderedPoint pt(y-point(0),point(1));

      float v = Value(p);
      if (!umx && (v < umin || v > umax)) 
	v = 0.0;
      else if (rescale) 
	v = (v - umin)*cf;
      map.SetValue(&pt, (long)v);
    }
  } 
}
*/

