// Copyright 1999, 2002 Robert Buff
// Contact: http://robertbuff.com/uvm
//
// This file is part of Mtg-Book.
//
// Mtg-Book is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published
// by the Free Software Foundation; either version 2 of the License,
// or (at your option) any later version.
//
// Mtg-Book is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Mtg-Book; if not, write to the 
//
// Free Software Foundation, Inc.
// 59 Temple Place, Suite 330
// Boston, MA 02111-1307
// USA

#include "MtgIncl.h"
#include "MtgSocket.h"

#if defined(_WIN32)
    #include <process.h>
    #include <winsock2.h>
#elif defined(__GNUC__)
    #include <sys/socket.h>
    #include <sys/types.h>
    #include <sys/wait.h>
    #include <netinet/in.h>
    #include <arpa/inet.h>
    #include <netdb.h>
#endif

MTG_BEGIN_NAMESPACE


//
//   i n i t
//

void tSocket::init()

{
    m_nFd = 0;
    m_sBuffer = 0;
    m_nSize = 0;
    m_nReceived = 0;
    m_nNextChar = (size_t) -1;

    reset();
}


//
//   t h r e a d S t u b
//

void tSocket::threadStub( tThread* pThread )

{
    MTG_ASSERT( pThread != 0 );

    MTG_TRACE( "Started thread %d\n", pThread->m_nId );
    pThread->m_pServer->run( *pThread->m_pSocket );
    delete pThread->m_pSocket;
    MTG_TRACE( "Finished thread %d\n", pThread->m_nId );
    delete pThread;
}


//
//   g e t M y A d d r
//

tRetCode tSocket::getMyAddr()

{
    MTG_ASSERT( m_nFd > 0 );

    int nSize;
    struct sockaddr_in S;

    nSize = sizeof(S);
    if( getsockname( m_nFd, (struct sockaddr*) &S, &nSize ) != 0 )
        return OPEN_ERROR;

    m_nMyPort = ntohs( S.sin_port );
    m_nMyNetAddr = S.sin_addr.s_addr;

    return OK;
}


//
//   g e t P e e r A d d r
//

tRetCode tSocket::getPeerAddr()

{
    MTG_ASSERT( m_nFd > 0 );

    int nSize;
    struct sockaddr_in S;

    nSize = sizeof(S);
    if( getpeername( m_nFd, (struct sockaddr*) &S, &nSize ) != 0 )
        return OPEN_ERROR;

    m_nPeerPort = ntohs( S.sin_port );
    m_nPeerNetAddr = S.sin_addr.s_addr;

    return OK;
}


//
//   g e t A d d r
//

tRetCode tSocket::getAddr()

{
    tRetCode nRet;

    if( ( nRet = getMyAddr() ) != OK || ( nRet = getPeerAddr() ) != OK )
        return nRet;
    return OK;
}


//
//   t S o c k e t
//

tSocket::tSocket()

{
    init();
}


//
//   t S o c k e t
//

tSocket::tSocket( const tSocket& Sock )

{
    init();
    operator=( Sock );
}


//
//   ~ t S o c k e t
//

tSocket::~tSocket()

{
    reset();
}   


//
//   r e s e t
//

void tSocket::reset()

{
    close();
    setSize( 0 );
    m_nMyNetAddr = 0;
    m_nMyPort = 0;
    m_nPeerNetAddr = 0;
    m_nPeerPort = 0;
    m_nTimeout = 60;
}


//
//   o p e r a t o r =
//

tSocket& tSocket::operator=( const tSocket& Sock )

{
    if( &Sock != this )
        copyFrom( Sock );
    return *this;
}


//
//   c o p y F r o m
//

void tSocket::copyFrom( const tSocket& Sock )

{
    if( &Sock == this )
        return;

    reset();

    m_nFd = Sock.m_nFd;
    m_nMyNetAddr = Sock.m_nMyNetAddr;
    m_nMyPort = Sock.m_nMyPort;
    m_nPeerNetAddr = Sock.m_nPeerNetAddr;
    m_nPeerPort = Sock.m_nPeerPort;
    m_nTimeout = Sock.m_nTimeout;
    m_nReceived = Sock.m_nReceived;
    m_nNextChar = Sock.m_nNextChar;

    setSize( Sock.m_nSize );
}


//
//   s e t T i m e o u t
//

void tSocket::setTimeout( time_t nTimeout )

{
    m_nTimeout = nTimeout;
}


//
//   s e t S i z e
//

void tSocket::setSize( size_t nSize )

{
    if( m_nSize > 0 )
        delete m_sBuffer;

    m_nSize = nSize;
    if( m_nSize > 0 )
        m_sBuffer = new char[m_nSize];
    else
        m_sBuffer = 0;

    m_nReceived = 0;
}


//
//   c o n n e c t
//

tRetCode tSocket::connect( unsigned long nNetAddr, int nPort )

{
    MTG_ASSERT( m_nFd <= 0 );

    struct sockaddr_in S;
    tRetCode nRet;

#if defined(_WIN32)
    memset( &S, 0, sizeof(S) );
#else
    bzero( (char*) &S, sizeof(S) );
#endif

    S.sin_family = AF_INET;
    S.sin_port = htons( nPort );
    S.sin_addr.s_addr = nNetAddr;

    if( ( m_nFd = socket( AF_INET, SOCK_STREAM, 0 ) ) < 0 ) {
        m_nFd = 0;
        return OPEN_ERROR;
    }

#if defined(_WIN32)
    struct linger L;

    L.l_onoff = 1;
    L.l_linger = 2;
    if( setsockopt( m_nFd, SOL_SOCKET, SO_LINGER, 
            (const char*) &L, sizeof(L) ) != 0 ) {
        closesocket( m_nFd );
        m_nFd = 0;
        return OPEN_ERROR;
    }
#endif

    if( ::connect( m_nFd, (struct sockaddr*) &S, sizeof(S) ) < 0 ) {
        close();
        return CONNECT_ERROR;
    }

    if( ( nRet = getAddr() ) != OK ) {
        close();
        return nRet;
    }

    return OK;
}


//
//   c o n n e c t
//

tRetCode tSocket::connect( const char* sHost, int nPort )

{
    tRetCode nRet;
    unsigned long nNetAddr;

    if( ( nRet = host2Addr( sHost, nNetAddr ) ) != OK )
        return nRet;
    return connect( nNetAddr, nPort );
}


//
//   l i s t e n
//

tRetCode tSocket::listen( int nPort, int nOpenConnections )

{
    MTG_ASSERT( m_nFd <= 0 );

    tRetCode nRet;

    struct sockaddr_in S;

    m_nMyPort = nPort;
    m_nMyNetAddr = htonl( INADDR_ANY );

#if defined(_WIN32)
    memset( &S, 0, sizeof(S) );
#else
    bzero( (char*) &S, sizeof(S) );   
#endif

    S.sin_family = AF_INET;
    S.sin_port = htons( m_nMyPort );
    S.sin_addr.s_addr = m_nMyNetAddr;

    if( ( m_nFd = socket( AF_INET, SOCK_STREAM, 0 ) ) < 0 ) {
        m_nFd = 0;
        return OPEN_ERROR;
    }

    if( bind( m_nFd, (struct sockaddr*) &S, sizeof(S) ) < 0 ) {
        close();
        return BIND_ERROR;
    }

    if( ::listen( m_nFd, nOpenConnections ) < 0 ) {
        close();
        return LISTEN_ERROR;
    }

    if( ( nRet = getMyAddr() ) != OK ) {
        close();
        return nRet;
    }

    return OK;
}


//
//   s i n g l e S e r v e r A c c e p t
//

tRetCode tSocket::singleServerAccept( tSocket& Child )

{
    MTG_ASSERT( m_nFd > 0 );

    int nNewFd, nSize;
    struct sockaddr_in S;
    tRetCode nRet;

    nSize = sizeof(S);
    if( ( nNewFd = accept( m_nFd, (struct sockaddr*) &S, &nSize ) ) <= 0 )
        return ACCEPT_ERROR;

    Child.m_nFd = nNewFd;
    if( ( nRet = Child.getAddr() ) != OK ) {
        Child.close();
        return nRet;
    }
  
    return OK;
}


//
//   s i n g l e S e r v e r A c c e p t
//

tRetCode tSocket::singleServerAccept( tTask& Server )

{
    tRetCode nRet;
    tSocket Child;

    if( ( nRet = singleServerAccept( Child ) ) != OK )
        return nRet;

    Server.run( Child );  // Server has to close child.
    return OK;
}


//
//   m u l t i S e r v e r A c c e p t
//

tRetCode tSocket::multiServerAccept( tTask& Server )

{
    int nPid;
    return multiServerAccept( Server, nPid );
}


//
//   m u l t i S e r v e r A c c e p t
//

tRetCode tSocket::multiServerAccept( tTask& Server, int& nPid )

{
    MTG_ASSERT( m_nFd > 0 );

    int nNewFd, nSize;
    struct sockaddr_in S;
    tRetCode nRet;

    nSize = sizeof(S);
    if( ( nNewFd = accept( m_nFd, (struct sockaddr*) &S, &nSize ) ) <= 0 )
        return ACCEPT_ERROR;

    tSocket* pChild = new tSocket;

    pChild->m_nFd = nNewFd;
    if( ( nRet = pChild->getAddr() ) != OK ) {
        delete pChild;
        return nRet;
    }

    tThread* pThread = new tThread;

    pThread->m_nId = rand();
    pThread->m_pServer = &Server;
    pThread->m_pSocket = pChild;

#if defined(_WIN32)

    nPid = 0;
    if( _beginthread( (void (*)( void* )) threadStub, 0, pThread ) == -1 ) {
        delete pChild;
        delete pThread;
        return FORK_ERROR;
    }
    return OK;

#else

    if( ( nPid = fork() ) < 0 ) {
        delete pChild;
        delete pThread;
        return FORK_ERROR;
    }
    else
        if( nPid == 0 ) {
                // child
            close();
            threadStub( pThread );
            exit( 0 );
        }

        // parent
    delete pChild;
    delete pThread;
 
        // take care of zombie processes

    while( waitpid( -1, 0, WNOHANG ) > 0 );
    return OK;

#endif
}


//
//   c l o s e
//

void tSocket::close()

{
    if( m_nFd > 0 ) {
#if defined(_WIN32)
        closesocket( m_nFd );
#else 
        ::close( m_nFd );
#endif

        m_nFd = 0;
        m_nReceived = 0;
        m_nNextChar = (size_t) -1;
    }
}


//
//   r e a d P e n d i n g
//

tRetCode tSocket::readPending()

{
    return readPending( (time_t) -1 );
}


//
//   r e a d P e n d i n g
//

tRetCode tSocket::readPending( time_t nTimeout )

{
    MTG_ASSERT( m_nFd > 0 );

    fd_set Readset;
    struct timeval Tout, *pTout;

    if( nTimeout == (time_t) -1 ) {
        pTout = 0;
    }
    else {
        Tout.tv_sec = nTimeout;
        Tout.tv_usec = 0;
        pTout = &Tout;
    }
    
    FD_ZERO( &Readset );
    FD_SET( (unsigned int) m_nFd, &Readset );

    switch( select( m_nFd + 1, &Readset, NULL, NULL, pTout ) ) {
        case 1 : return OK;
        case 0 : return TIMEOUT;
    }

    return SELECT_ERROR;
}


//
//   r e a d P e n d i n g
//

tRetCode tSocket::readPending( tSocket& Sock1, tSocket& Sock2, int& nWho )

{
    return readPending( Sock1, Sock2, nWho, (time_t) -1 );
}


//   r e a d P e n d i n g
//

tRetCode tSocket::readPending( tSocket& Sock1, tSocket& Sock2, int& nWho,
    time_t nTimeout )

{
    MTG_ASSERT( Sock1.m_nFd > 0 && Sock2.m_nFd > 0 );

    struct timeval Tout, *pTout;

    if( nTimeout == (time_t) -1 ) {
        pTout = 0;
    }
    else {
        Tout.tv_sec = nTimeout;
        Tout.tv_usec = 0;
        pTout = &Tout;
    }

    int nFd = Sock1.m_nFd;
    if( nFd < Sock2.m_nFd )
        nFd = Sock2.m_nFd;

    fd_set Set;

    FD_ZERO( &Set );
    FD_SET( (unsigned) Sock1.m_nFd, &Set );
    FD_SET( (unsigned) Sock2.m_nFd, &Set );

    nWho = 0;
    int n = select( nFd, &Set, 0, 0, pTout );

    if( n < 0 || n > 2 )
        return SELECT_ERROR;

    if( FD_ISSET( Sock1.m_nFd, &Set ) ) {
        if( FD_ISSET( Sock2.m_nFd, &Set ) )
            nWho = 3;
        else
            nWho = 1;
    }
    else {
        if( FD_ISSET( Sock2.m_nFd, &Set ) )
            nWho = 2;
    }

    return nWho ? OK : TIMEOUT;
}


//
//   r e a d
//

tRetCode tSocket::read( char* sBuffer, size_t& nSize )

{
    MTG_ASSERT( m_nFd > 0 );    

    int n;
    fd_set Readset;
    struct timeval Tout, *pTout;

    if( m_nTimeout == (time_t) -1 ) {
        pTout = 0;
    }
    else {
        Tout.tv_sec = m_nTimeout;
        Tout.tv_usec = 0;
        pTout = &Tout;
    }

    FD_ZERO( &Readset );
    FD_SET( (unsigned int) m_nFd, &Readset );

    switch( select( m_nFd + 1, &Readset, NULL, NULL, pTout ) ) {
        case 1 :
#if defined(_WIN32)
            n = ::recv( m_nFd, sBuffer, nSize, 0 );
#else
            n = ::read( m_nFd, sBuffer, nSize );
#endif
            if( n < 0 ) {
                nSize = 0;
                return READ_ERROR;
            }
            nSize = (size_t) n;
            break;

        case 0 :
            nSize = 0;
            return TIMEOUT;

        default :
            nSize = 0;
            return SELECT_ERROR;
    }

    return ( nSize == 0 ) ? END_OF_FILE : OK;
}


//
//   r e a d
//

tRetCode tSocket::read()

{
    MTG_ASSERT( m_nFd > 0 );    

    tRetCode nRet;

    if( m_nSize == 0 )
        setSize();

    if( m_nNextChar != (size_t) -1 && m_nNextChar < m_nReceived ) {
        m_nReceived -= m_nNextChar;
        memmove( m_sBuffer, m_sBuffer + m_nNextChar, m_nReceived );
    }
    else {
        m_nReceived = 0;
    }
    m_nNextChar = (size_t) -1;

    size_t nSize = m_nSize - m_nReceived;
    if( ( nRet = read( m_sBuffer + m_nReceived, nSize ) ) != OK )
        return nRet;
    m_nReceived += nSize;

    return OK;
}


//
//   w r i t e
//

tRetCode tSocket::write( const char* sString )

{
    return write( sString, strlen( sString ) );
}


//
//   w r i t e
//

tRetCode tSocket::write( const char* sBuffer, size_t nSize )

{
    MTG_ASSERT( m_nFd > 0 );

    int n;

    while( nSize > 0 ) {

#if defined(_WIN32)
        n = ::send( m_nFd, sBuffer, nSize, 0 );
        if( n == SOCKET_ERROR )
            n = -1;
#else
        n = ::write( m_nFd, sBuffer, nSize );
#endif

        if( n < 0 )
            return WRITE_ERROR;
        sBuffer += n;
        nSize -= n;
    }

    return OK;
}


//
//   r e a d C h a r
//

tRetCode tSocket::readChar( char& c )

{
    tRetCode nRet;

    if( m_nNextChar == (size_t) -1 ) {
        if( ( nRet = read() ) != OK )
            return nRet;
        m_nNextChar = 0;
    }
    else
        if( m_nNextChar >= m_nReceived ) {
            if( ( nRet = read() ) != OK ) {
                m_nNextChar = (size_t) -1;
                return nRet;
            }
            m_nNextChar = 0;
        }

    MTG_ASSERT( m_nReceived > 0 );

    c = m_sBuffer[m_nNextChar++];
    return OK;
}


//
//   u n r e a d C h a r
//

void tSocket::unreadChar()

{
    if( m_nNextChar > 0 )
        --m_nNextChar;
}


//
//   g e t M y H o s t N a m e
//

const char* tSocket::getMyHostName() const

{
    const char *sHost;

    if( addr2DecDot( m_nMyNetAddr, sHost ) != OK )
        return 0;
    return sHost;
}


//
//   g e t P e e r H o s t N a m e
//

const char* tSocket::getPeerHostName() const

{
    const char *sHost;

    if( addr2DecDot( m_nPeerNetAddr, sHost ) != OK )
        return 0;
    return sHost;
}


//
//   g e t L o c a l A d d r
//

tRetCode tSocket::getLocalAddr( unsigned long& nNetAddr )

{
    char sName[256];

    if( gethostname( sName, sizeof(sName) ) != 0 || 
        host2Addr( sName, nNetAddr ) != OK ) {
        nNetAddr = 0;
        return NOT_FOUND;
    }
    return OK;
}


//
//   h o s t 2 A d d r
//

tRetCode tSocket::host2Addr( const char* sHost, unsigned long& nNetAddr )

{
    struct hostent* H;

    if( ( nNetAddr = inet_addr( sHost ) ) == INADDR_NONE ) {
        if( ( H = gethostbyname( sHost ) ) == 0 )
            return NOT_FOUND;
        nNetAddr = ( (struct in_addr*) H->h_addr )->s_addr;
    }

    return OK;
}


//
//   a d d r 2 D e c D o t
//

tRetCode tSocket::addr2DecDot( unsigned long nNetAddr, const char *&sHost )

{
    struct in_addr A;

    A.s_addr = nNetAddr;
    sHost = inet_ntoa( A );
    return OK;
}


//
//   s t a r t u p
//

tRetCode tSocket::startup()

{
#if defined(_WIN32)
    WSADATA wsaData; 

    if( WSAStartup( MAKEWORD( 1, 1 ), &wsaData ) != 0 )
        return WINSOCK_ERROR;

    if( LOBYTE( wsaData.wVersion ) != 1 || 
        HIBYTE( wsaData.wVersion ) != 1 ) { 
        WSACleanup(); 
        return WINSOCK_ERROR;
    } 
#endif

    return OK;
}


//
//   c l e a n u p
//

void tSocket::cleanup()

{
#if defined(_WIN32)
    WSACleanup();
#endif
}

MTG_END_NAMESPACE
