package javanet.udp;

import java.io.*;
import java.net.*;
import java.text.*;
import java.util.*;

/**
 * Reliable DatagramSocket. Extends DatagramSocket, and adds
 * the 'sendReceive' method, for reliably sending a DatagramPacket
 * and getting its response, subject to an adaptive timeout and exponential backoff.
 *<p>
 * The underlying assumption here is that the server acknowledges
 * every DatagramPacket with another DatagramPacket.
 * Whether the response is appropriate to the transmission, e.g. indicates success
 * or failure, is up to the calling application.
 *<p>
 * The socket maintains a transmission sequence number which a server is expected
 * to duplicate in the reply. This uses up 8 bytes of the datagram.
 *<p>
 * N.B.: as with all datagram servers, the server at the other end of this
 * must be idempotent: it must be able to cope with retransmissions of the same
 * request without introducing errors into the application domain,
 * e.g. not crediting a bank account twice. To accomplish this the server
 * must keep track of sequence numbers already processed and acknowledge but
 * ignore duplicates.
 *
 * From 'C' code in W. Richard Stevens: <u>Unix Network Programming</u>
 * Vol I (Prentice Hall 1998) pp. 542ff;
 * this also appears in a mathematically and algorithmically
 * more complicated form in the first edition (Prentice Hall 1990), pp. 407ff.
 *
 * See also Van Jacobson, <u>Congestion Avoidance and Control</u>,
 * in Computer Communications Review, vol. 18. no. 4, pp. 314-329,
 * Proceedings of the ACM SIGCOMM '88 Workshop, August, 1988,
 * which "describes modifications made to the 4.3BSD TCP/IP software that provide:
 * slow start, improved round-trip timing, and congestion avoidance".
 *
 * @author Van Jacobson (original paper)
 * @author W. Richard Stevens (C version)
 * @author Esmond Pitt, esmond.pitt@rmiproxy.com (Java version)
 *
 * @version $Revision: 2 $
 */
public class ReliableDatagramSocket
	extends DatagramSocket
{
	/** Minimum retransmit timeout value in seconds	*/
	public static final int	RXTMIN	= 1;	// Stevens has 2 seconds
	/**
	 * Maximum retransmit timeout value in seconds
	 * Stevens had 120; this seems pretty high. 64 agrees with the 4.3BSD TCP
	 * implementation - EJP
	 */
	public static final int	RXTMAX	= 64;
	/** Maximum retransmissions per datagram	*/
	public static final int	MAX_RETRANSMISSIONS	= 3;	// Stevens has 4: suggest 3 - EJP
	/**
	 * Exponential backoff factor.
	 * Stevens and Van Jacobson and 4.3BSD and everybody else
	 * have 2 here.
	 * In fact any multiplier > 1 also gives exponential backoff.
	 * You could use e.g. the square root of 2 which would
	 * slow down twice as slowly as 2, if you see what I mean. EJP
	 */
	public static final double	EXP_BACKOFF_FACTOR	= 2.0;	// Math.sqrt(2.0);

	/**
	 * Stevens states that "one of these [is] required per socket being timed".
	 *<p>
	 * In fact one is required for each connection, i.e. each
	 * source/destination socket pair; otherwise RTT data from a conversion with
	 * server A is used irrelevantly if a subsequent conversation is started
	 * with server B via the same local DatagramSocket.
	 *<p>
	 * I've overcome the contradiction by making the 'sendReceive' method synchronized,
	 * so that you can't execute more than one send/receive exchange at a time per DatagramSocket.
	 *<p>
	 * You could also require that the DatagramSocket be connected to the destination.
	 * You might think that this slightly defeats the purpose of using UDP, which is partly
	 * to avoid the 3-way connect and 4-way close handshakes, but in fact connect()
	 * on a UDP socket is not a network operation at all, it just conditions the local API.
	 * Java DatagramSockets carry this to such an extent that they don't even call
	 * the underlying connect(), they just mimic its actions at the Java level.
	 */
	static class RoundTripTimer
	{
		// Initial values from Steven's init() method.
		/** most recent round-trip time in seconds	*/
		float	roundTripTime		= 0.0f;
		/** smoothed round-trip time in seconds	*/
		float	smoothedTripTime	= 0.0f;
		/** smoothed mean deviation, seconds	*/
		float	deviation			= 0.75f;
		/** # times retransmitted: 0,1,2,....	*/
		short	retransmissions		= 0;
		/** current retransmit timeout in seconds	*/
		float	currentTimeout		= minmax(calculateRTO());
/*				
		protected void	finalize()
			throws Throwable
		{
			try
			{
//				System.out.println("Finalizing RTT "+this);
			}
			catch (Throwable t)
			{
			}
			super.finalize();
		}
*/		
		private int	calculateRTO()
		{
			return (int)(smoothedTripTime+4.0*deviation);
		}
		
		private float	minmax(float rto)
		{
/*
			if (rto < RXTMIN)
				rto = RXTMIN;
			else
			if (rto > RXTMAX)
				rto = RXTMAX;
			return rto;
*/
			return Math.min(Math.max(rto,RXTMIN),RXTMAX);
		}
		
		/**
		 * Called to initialize everything for a given 'connection'
		 * The difference between init() and newPacket() is that the
		 * former knows nothing about the connection, while the latter
		 * makes use of previous RTT information for an existing connection.
		 */
		void	init()
		{
			roundTripTime		= 0.0f;
			smoothedTripTime	= 0.0f;
			deviation			= 0.75f;
			currentTimeout		= minmax(calculateRTO());
		}
		
		/**
		 * Called before each new packet is transmitted on a 'connection'.
		 * Initializes retransmit counter to 0.
		 * The difference between init() and newPacket() is that the
		 * former knows nothing about the connection, while the latter
		 * makes use of previous RTT information for an existing connection.
		 */
		void	newPacket()
		{
			retransmissions = 0;
		}
		
		/**
		 * Called before each packet is either transmitted or retransmitted.
		 * Calculates the timeout value for the packet and starts the
		 * timer to calculate the round-trip time (RTT).
		 *
		 * @return timeout value for this round trip
		 */
		float	start()
		{
			return currentTimeout/*+0.5*/;
		}
		
		/**
		 * Called after a packet has been received.
		 * @param ms Milliseconds since start() was called.
		 * @return timeout value for this round trip
		 */
		void	stop(long ms)
		{
			/*
			 * A response was received.
			 * Stop the timer and update the appropriate values based on
			 * this packet's RTT. We calculate the RTT, the update the
			 * smoothed RT and the RTT variance (deviation).
			 * This function should be called right after a successful
			 * receive or a timeout.
			 *
			 * Calculate the round-trip time (RTT) for this packet.
			 */
			roundTripTime = ms/1000;
			/*
			 * Update our estimators of RTT and mean deviation of RTT.
			 * See Jacobson's SIGCOMM '88 paper, Appendix A, for the details.
			 * This appendix also contains a fixed-point, integer implementation
			 * (that is actually used in all the post-4.3 TCP code).
			 * We'll use floating point here for simplicity.
			 *
			 * First
			 *		delta = (rtt - old_srtt) = difference between this measured value
			 *							and current estimator.
			 * and
			 *		new_srtt = old_srtt*7/8 + rtt/8;
			 * Then
			 *		new_srtt = old_srtt + delta/8;
			 *
			 * Also
			 *		new_rttdev = old_rttdev + (|delta| - old_rttdev)/4.
			 *
			 * The above comments are from Stevens. The second
			 * and third formulas (assignments to new_srtt) express
			 * a tautology rather than an algorithm: the first of the
			 * two specifies the algorithm and the second gives the implementation.
			 *	EJP February 1999.
			 */
			double	delta = roundTripTime-smoothedTripTime;
			smoothedTripTime += delta/8.0;
			deviation += (Math.abs(delta)-deviation)/4.0;	/* |delta|	*/
			currentTimeout = minmax(calculateRTO());
		}
		
		/**
		 * Called after a timeout has occurred.
		 * Tells you whether to retransmit or give up.
		 *
		 * @return true iff it's time to give up.
		 */
		boolean	timeout()
		{
			currentTimeout *= EXP_BACKOFF_FACTOR;	// next RTO
			retransmissions++;
//			System.out.println("timeout "+currentTimeout+" "+retransmissions);
			return retransmissions > MAX_RETRANSMISSIONS;
		}
		
		/**
		 * Display debugging information.
		 */
		static final MessageFormat	format = new MessageFormat
			(
			// Format floats
			"rtt={0,number,##0.###} srtt={1,number,##0.###} deviation={2,number,#0.###} rto={3,number,##0.###}"
			);

		public String	toString()
		{
			Object[]	args =
			{
				new Float(roundTripTime),
				new Float(smoothedTripTime),
				new Float(deviation),
				new Float(currentTimeout),
			};
//			return format.format(args);
			return "rtt="+roundTripTime+" srtt="+smoothedTripTime+" dev="+deviation+" rto="+currentTimeout;
		}
		
	}
	
	/*package*/ RoundTripTimer	roundTripTimer = new RoundTripTimer();
	private boolean	reinit	= false;	// because we have already initialized
	private long	sendSequenceNo	= 0;
	private long	recvSequenceNo	= 0;
	
	/**
	 * Default constructor: use default port and interface (ANY)
	 * @exception SocketException can't create the socket
	 */
	public ReliableDatagramSocket()
		throws SocketException
	{
		super();
	}
	
	/**
	 * @param port Local port: use default interface (ANY)
	 * @exception SocketException can't create the socket
	 */
	public ReliableDatagramSocket(int port)
		throws SocketException
	{
		super(port);
//		System.out.println(new Date()+": "+getClass().getName()+" constructed");
	}

	/**
	 * @param port Local port
	 * @param localAddr local interface address to use
	 * @exception SocketException can't create the socket
	 */
	public ReliableDatagramSocket(int port,InetAddress localAddr)
		throws SocketException
	{
		super(port,localAddr);
		init();
	}
/*	
	protected void	finalize()
		throws Throwable
	{
		try
		{
//			System.out.println(getClass().getName()+" finalizing: "+getLocalAddress()+":"+getPort());
		}
		catch (Throwable t)
		{
		}
		super.finalize();
	}
*/		
	/**
	 * Override for connect(). Calls DatagramSocket.connect()
	 * and also initializes statistics for the connection.
	 *
	 * @param dest Destination address
	 * @param port Destination port
	 */
	public void	connect(InetAddress dest,int port)
	{
		super.connect(dest,port);
		init();
	}
		
	private void	init()
	{
		this.roundTripTimer = new RoundTripTimer();
	}
	
	/**
	 * Send and receive reliably, retrying adaptively with
	 * exponential backoff until the response is received
	 * or timeout occurs.
	 *
	 * @param sendPacket data to send
	 * @param recvPacket data to receive
	 * @exception IOException on most any error
	 */
	public synchronized void	sendReceive(DatagramPacket sendPacket,DatagramPacket recvPacket)
		throws IOException
	{
		// only initialize first time or after timeout
		if (reinit)
		{
			init();
			reinit = false;
		}
		roundTripTimer.newPacket();
		long	start = System.currentTimeMillis();
		long	seqno = getSendSequenceNo();
//		System.out.println(new Date()+": +sendReceive seqno="+seqno);
		for (;;)
		{
			// always use same seqno while retrying
			setSendSequenceNo(seqno);
//			System.out.print(new Date()+" S");
			send(sendPacket);	// may throw
			int	timeout = (int)(roundTripTimer.start()*1000.0+0.5);
			// Adjust timeout for time already elapsed
			long	soTimeoutStart = System.currentTimeMillis();
//			System.out.println(roundTripTimer);
			try
			{
				for (;;)
				{
					int	soTimeout = timeout-(int)(System.currentTimeMillis()-soTimeoutStart);
					setSoTimeout(soTimeout);
//					System.out.println(new Date()+": r timeout="+soTimeout);
					receive(recvPacket);
					long	recvSeqNo = getRecvSequenceNo();
//					System.out.println(new Date()+": R "+recvSeqNo);
					if (recvSeqNo == seqno)
						break;
//					System.out.println(new Date()+": "+getClass().getName()+" expected "+seqno+" got "+recvSeqNo);
				}
				// Got the correct reply: exit the retransmit loop
				break;
			}
			catch (InterruptedIOException e)
			{
//				System.out.println(getClass().getName()+": "+e);
				// timeout - redo?
//				System.out.println(new Date()+": I "+e);
				if (roundTripTimer.timeout())
				{
//					System.out.println(new Date()+": roundTripTimer.timeout "+e);
					reinit = true;
					throw e;
				}
				Object[]	args =
				{
					new Long(seqno),
				};
				if (false)
				{
					System.out.println
						(
						MessageFormat.format(new Date()+": "+getClass().getName()+".sendRecv timeout: retry packet #{0,number,###0} ",args)
						+roundTripTimer
						);
				}
			}
			// may throw other SocketException or IOException
		}	// end retransmit loop
		long	ms = System.currentTimeMillis()-start;
		roundTripTimer.stop(ms);	// stop timer, calculate new RTT values
	}
	
	/**
	 * @return the last sent sequence number
	 */
	private long		getSendSequenceNo()						{ return sendSequenceNo;	}

	/**
	 * Set the next send sequence number. USed by servers to set the reply seqno.
	 * @param sendSequenceNo Next sequence number to send.
	 */
	public void	setSendSequenceNo(long sendSequenceNo)	{ this.sendSequenceNo = sendSequenceNo;	}

	/**
	 * @return the last received sequence number; used by servers to obtain the seqno for the reply.
	 */
	public long		getRecvSequenceNo()						{ return recvSequenceNo;	}
	
	public void	receive(DatagramPacket packet)
		throws IOException
	{
		super.receive(packet);
		ByteArrayInputStream	bais = new ByteArrayInputStream(packet.getData(),packet.getOffset(),packet.getLength());
		DataInputStream	dis = new DataInputStream(bais);
		recvSequenceNo = dis.readLong();
		byte[]	buffer = new byte[dis.available()];
		dis.read(buffer);
		packet.setData(buffer,0,buffer.length);
//		System.out.println(this+" recv packet #"+recvSequenceNo);
	}
	
	public void	send(DatagramPacket packet)
		throws IOException
	{
		ByteArrayOutputStream	baos = new ByteArrayOutputStream();
		DataOutputStream		dos = new DataOutputStream(baos);
		// Write the seqno then the original packet data
		dos.writeLong(sendSequenceNo++);
		dos.write(packet.getData(),packet.getOffset(),packet.getLength());
		dos.flush();
		byte[]	data = baos.toByteArray();
		DatagramPacket	localPacket = new DatagramPacket(data,baos.size(),packet.getAddress(),packet.getPort());
		super.send(localPacket);
//		System.out.println(this+" sent packet #"+(sendSequenceNo-1));
	}
}
