; -----------------------------------------------------------------
; Library        NN
; -----------------------------------------------------------------
;   Macro        nn2
; -----------------------------------------------------------------
;   Description  More complex classification network.
; -----------------------------------------------------------------
  setsize(640,480)
  library ("plot")
  library ("nn")
;
  randomize(0)
  n  = 100
  xt = normal(n,2)+#(-1,-1)' | normal(n,2)+#(+1,-2)'
  xt = xt | normal(n,2)+#(+4, 0)' | normal(n,2)+#(+1,+1)'
;
  color  = string("red",1:3*n) | string("blue",1:n)
  symbol = string("circle",1:3*n) | string("triangle",1:n)
  xt     = setmask(xt, color, symbol)
  plot(xt)
  xl="x1"
  yl="x2"
  tl="Training Data Set"
  setgopt(plotdisplay,1,1,"title",tl,"xlabel",xl,"ylabel",yl)
;
  yt = (matrix(3*n)-1)|matrix(n)
  w  = matrix(4*n)
  param = 1
  net = nnrnet(xt,yt,w,3)
  nnrinfo(net)
;
  x = normal(n,2)+#(-1,-1)' | normal(n,2)+#(+1,-2)'
  x = x | normal(n,2)+#(+4, 0)' | normal(n,2)+#(+1,+1)'
  pred  = nnrpredict(x, net)
  prob  = pred.result
;
  y  = (matrix(3*n)-1)|matrix(n) ; true 
  yp = prob > 0.5                ; predicted
  misc = paf(1:4*n,y!=yp)        ; misclassified
  good = paf(1:4*n,y==yp)        ; correctly classified
  nm = rows(misc)
  sm = string("fill",1:nm)+symbol[misc]
  xm = setmask(x[misc],color[misc],sm,"huge")
  xg = setmask(x[good],color[good],symbol[good])
;
  pm = 100*nm/(4*n)                ; percentage of misclassified
  spm = string("%1.2f",pm)+"%"
  Network = createdisplay(1,1)
  show(Network,1,1,xg,xm)
  tl="Network: misclassified = "+spm
  setgopt(Network,1,1,"title",tl,"xlabel",xl,"ylabel",yl)
;
  mu0 = mean(xt[1:3*n])
  mu1 = mean(xt[3*n+1:4*n])
  mu  = (mu0+mu1)/2
  lin = inv(cov(xt))*(mu0-mu1)'
;
  y  = (matrix(3*n)-1)|matrix(n) ; true 
  yp = (x-mu)*lin<=0             ; predicted
  misc = paf(1:4*n,y!=yp)        ; misclassified
  good = paf(1:4*n,y==yp)        ; correctly classified
  nm = rows(misc)
  sm = string("fill",1:nm)+symbol[misc]
  xm = setmask(x[misc],color[misc],sm,"huge")
  xg = setmask(x[good],color[good],symbol[good])
  x  = setmask(x, color, symbol)
;
  pm = 100*nm/(4*n)            ; percentage of misclassified
  spm = string("%1.2f",pm)+"%"
  Discrim = createdisplay(1,1)
  show(Discrim,1,1,xg,xm)
  tl="Linear: misclassified = "+spm
  setgopt(Discrim,1,1,"title",tl,"xlabel",xl,"ylabel",yl)
;
  print(plotdisplay,"nndata2.ps")
  print(Network,"nnnet2.ps")
  print(Discrim,"nndis2.ps")
  


