#include #include #include #include #include #include #include "dos_assert.hpp" #include "network.hpp" #include "crypto.hpp" using namespace std; using namespace Network; using namespace Crypto; const uint64_t DIRECTION_MASK = uint64_t(1) << 63; const uint64_t SEQUENCE_MASK = uint64_t(-1) ^ DIRECTION_MASK; /* Read in packet from coded string */ Packet::Packet( string coded_packet, Session *session ) : seq( -1 ), direction( TO_SERVER ), timestamp( -1 ), timestamp_reply( -1 ), payload() { Message message = session->decrypt( coded_packet ); direction = (message.nonce.val() & DIRECTION_MASK) ? TO_CLIENT : TO_SERVER; seq = message.nonce.val() & SEQUENCE_MASK; assert( message.text.size() >= 2 * sizeof( uint16_t ) ); uint16_t *data = (uint16_t *)message.text.data(); timestamp = be16toh( data[ 0 ] ); timestamp_reply = be16toh( data[ 1 ] ); payload = string( message.text.begin() + 2 * sizeof( uint16_t ), message.text.end() ); } /* Output coded string from packet */ string Packet::tostring( Session *session ) { uint64_t direction_seq = (uint64_t( direction == TO_CLIENT ) << 63) | (seq & SEQUENCE_MASK); uint16_t ts_net[ 2 ] = { htobe16( timestamp ), htobe16( timestamp_reply ) }; string timestamps = string( (char *)ts_net, 2 * sizeof( uint16_t ) ); return session->encrypt( Message( Nonce( direction_seq ), timestamps + payload ) ); } Packet Connection::new_packet( string &s_payload ) { uint16_t outgoing_timestamp_reply = -1; uint64_t now = timestamp(); if ( now - saved_timestamp_received_at < 1000 ) { /* we have a recent received timestamp */ /* send "corrected" timestamp advanced by how long we held it */ outgoing_timestamp_reply = saved_timestamp + (now - saved_timestamp_received_at); saved_timestamp = -1; saved_timestamp_received_at = 0; } Packet p( next_seq++, direction, timestamp16(), outgoing_timestamp_reply, s_payload ); return p; } void Connection::setup( void ) { /* create socket */ sock = socket( AF_INET, SOCK_DGRAM, 0 ); if ( sock < 0 ) { throw NetworkException( "socket", errno ); } /* Enable path MTU discovery */ char flag = IP_PMTUDISC_WANT; socklen_t optlen = sizeof( flag ); if ( setsockopt( sock, IPPROTO_IP, IP_MTU_DISCOVER, &flag, optlen ) < 0 ) { throw NetworkException( "setsockopt", errno ); } } Connection::Connection() /* server */ : sock( -1 ), remote_addr(), server( true ), attached( false ), MTU( SEND_MTU ), key(), session( key ), direction( TO_CLIENT ), next_seq( 0 ), saved_timestamp( -1 ), saved_timestamp_received_at( 0 ), expected_receiver_seq( 0 ), RTT_hit( false ), SRTT( 1000 ), RTTVAR( 500 ) { setup(); /* Bind to free local port. This usage does not seem to be endorsed by POSIX. */ struct sockaddr_in local_addr; local_addr.sin_family = AF_INET; local_addr.sin_port = htons( 0 ); local_addr.sin_addr.s_addr = INADDR_ANY; if ( bind( sock, (sockaddr *)&local_addr, sizeof( local_addr ) ) < 0 ) { throw NetworkException( "bind", errno ); } } Connection::Connection( const char *key_str, const char *ip, int port ) /* client */ : sock( -1 ), remote_addr(), server( false ), attached( false ), MTU( SEND_MTU ), key( key_str ), session( key ), direction( TO_SERVER ), next_seq( 0 ), saved_timestamp( -1 ), saved_timestamp_received_at( 0 ), expected_receiver_seq( 0 ), RTT_hit( false ), SRTT( 1000 ), RTTVAR( 500 ) { setup(); /* associate socket with remote host and port */ remote_addr.sin_family = AF_INET; remote_addr.sin_port = htons( port ); if ( !inet_aton( ip, &remote_addr.sin_addr ) ) { int saved_errno = errno; char buffer[ 2048 ]; snprintf( buffer, 2048, "Bad IP address (%s)", ip ); throw NetworkException( buffer, saved_errno ); } attached = true; } void Connection::send( string s ) { assert( attached ); Packet px = new_packet( s ); string p = px.tostring( &session ); ssize_t bytes_sent = sendto( sock, p.data(), p.size(), 0, (sockaddr *)&remote_addr, sizeof( remote_addr ) ); if ( (bytes_sent < 0) && (errno == EMSGSIZE) ) { update_MTU(); throw NetworkException( "Path MTU Discovery", EMSGSIZE ); } else if ( bytes_sent == static_cast( p.size() ) ) { return; } else { throw NetworkException( "sendto", errno ); } } string Connection::recv( void ) { struct sockaddr_in packet_remote_addr; char buf[ RECEIVE_MTU ]; socklen_t addrlen = sizeof( packet_remote_addr ); ssize_t received_len = recvfrom( sock, buf, RECEIVE_MTU, 0, (sockaddr *)&packet_remote_addr, &addrlen ); if ( received_len < 0 ) { throw NetworkException( "recvfrom", errno ); } if ( received_len > RECEIVE_MTU ) { char buffer[ 2048 ]; snprintf( buffer, 2048, "Received oversize datagram (size %d) and limit is %d\n", static_cast( received_len ), RECEIVE_MTU ); throw NetworkException( buffer, errno ); } Packet p( string( buf, received_len ), &session ); dos_assert( p.direction == (server ? TO_SERVER : TO_CLIENT) ); /* prevent malicious playback to sender */ if ( p.seq >= expected_receiver_seq ) { /* don't use out-of-order packets for timestamp or targeting */ expected_receiver_seq = p.seq + 1; /* this is security-sensitive because a replay attack could otherwise screw up the timestamp and targeting */ if ( p.timestamp != uint16_t(-1) ) { saved_timestamp = p.timestamp; saved_timestamp_received_at = timestamp(); } if ( p.timestamp_reply != uint16_t(-1) ) { uint16_t now = timestamp16(); double R = timestamp_diff( now, p.timestamp_reply ); if ( R < 5000 ) { /* ignore large values, e.g. server was Ctrl-Zed */ if ( !RTT_hit ) { /* first measurement */ SRTT = R; RTTVAR = R / 2; RTT_hit = true; } else { const double alpha = 1.0 / 8.0; const double beta = 1.0 / 4.0; RTTVAR = (1 - beta) * RTTVAR + ( beta * fabs( SRTT - R ) ); SRTT = (1 - alpha) * SRTT + ( alpha * R ); } } } /* server auto-adjusts to client */ if ( server ) { attached = true; if ( (remote_addr.sin_addr.s_addr != packet_remote_addr.sin_addr.s_addr) || (remote_addr.sin_port != packet_remote_addr.sin_port) ) { remote_addr = packet_remote_addr; fprintf( stderr, "Server now attached to client at %s:%d\n", inet_ntoa( remote_addr.sin_addr ), ntohs( remote_addr.sin_port ) ); } } } return p.payload; /* we do return out-of-order or duplicated packets to caller */ } int Connection::port( void ) const { struct sockaddr_in local_addr; socklen_t addrlen = sizeof( local_addr ); if ( getsockname( sock, (sockaddr *)&local_addr, &addrlen ) < 0 ) { throw NetworkException( "getsockname", errno ); } return ntohs( local_addr.sin_port ); } uint64_t Network::timestamp( void ) { struct timespec tp; if ( clock_gettime( CLOCK_MONOTONIC, &tp ) < 0 ) { throw NetworkException( "clock_gettime", errno ); } uint64_t millis = tp.tv_nsec / 1000000; millis += uint64_t( tp.tv_sec ) * 1000; return millis; } uint16_t Network::timestamp16( void ) { uint16_t ts = timestamp() % 65536; if ( ts == uint16_t(-1) ) { ts++; } return ts; } uint16_t Network::timestamp_diff( uint16_t tsnew, uint16_t tsold ) { int diff = tsnew - tsold; if ( diff < 0 ) { diff += 65536; } assert( diff >= 0 ); assert( diff <= 65535 ); return diff; } uint64_t Connection::timeout( void ) const { uint64_t RTO = lrint( ceil( SRTT + 4 * RTTVAR ) ); if ( RTO < MIN_RTO ) { RTO = MIN_RTO; } else if ( RTO > MAX_RTO ) { RTO = MAX_RTO; } return RTO; } class Socket { public: int fd; Socket( int domain, int type, int protocol ) : fd( socket( domain, type, protocol ) ) { if ( fd < 0 ) { throw NetworkException( "socket", errno ); } } ~Socket() { if ( close( fd ) < 0 ) { throw NetworkException( "close", errno ); } } }; void Connection::update_MTU( void ) { if ( !attached ) { return; } /* We don't want to use our main socket because we don't want to have to connect it */ Socket path_MTU_socket( AF_INET, SOCK_DGRAM, 0 ); /* Connect socket so we can retrieve path MTU */ if ( connect( path_MTU_socket.fd, (sockaddr *)&remote_addr, sizeof( remote_addr ) ) < 0 ) { throw NetworkException( "connect", errno ); } int PMTU; socklen_t optlen = sizeof( PMTU ); if ( getsockopt( path_MTU_socket.fd, IPPROTO_IP, IP_MTU, &PMTU, &optlen ) < 0 ) { throw NetworkException( "getsockopt", errno ); } if ( optlen != sizeof( PMTU ) ) { throw NetworkException( "Error getting path MTU", errno ); } MTU = max( PMTU, SEND_MTU ); }