diff --git a/keystroke.cpp b/keystroke.cpp index 8d5b129..18566d2 100644 --- a/keystroke.cpp +++ b/keystroke.cpp @@ -21,6 +21,7 @@ string KeyStroke::diff_from( KeyStroke const & existing, int length_limit ) for ( deque::const_iterator i = existing.user_bytes.begin(); i != existing.user_bytes.end(); i++ ) { + assert( my_it != user_bytes.end() ); assert( *i == *my_it ); my_it++; } diff --git a/network.cpp b/network.cpp index 3e20491..9cee3de 100644 --- a/network.cpp +++ b/network.cpp @@ -20,8 +20,9 @@ Packet::Packet( string coded_packet, Session *session ) { Message message = session->decrypt( coded_packet ); - direction = (message.nonce.val() & 8000000000000000) ? TO_CLIENT : TO_SERVER; + direction = (message.nonce.val() & 0x8000000000000000) ? TO_CLIENT : TO_SERVER; seq = message.nonce.val() & 0x7FFFFFFFFFFFFFFF; + payload = message.text; } @@ -128,6 +129,11 @@ void Connection::send( string &s ) string p = new_packet( s ).tostring( &session ); + /* XXX synthetic packet loss */ + if ( rand() < RAND_MAX / 2 ) { + return; + } + ssize_t bytes_sent = sendto( sock, p.data(), p.size(), 0, (sockaddr *)&remote_addr, sizeof( remote_addr ) ); @@ -163,6 +169,7 @@ string Connection::recv( void ) } Packet p( string( buf, received_len ), &session ); + dos_assert( p.direction == (server ? TO_SERVER : TO_CLIENT) ); /* prevent malicious playback to sender */ /* server auto-adjusts to client */ diff --git a/networkinstruction.cpp b/networkinstruction.cpp index 906488b..971a4bd 100644 --- a/networkinstruction.cpp +++ b/networkinstruction.cpp @@ -1,4 +1,5 @@ #include +#include #include "networktransport.hpp" @@ -17,7 +18,21 @@ string Instruction::tostring( void ) ret += network_order_string( old_num ); ret += network_order_string( new_num ); ret += network_order_string( ack_num ); + ret += network_order_string( throwaway_num ); ret += diff; return ret; } + +Instruction::Instruction( string &x ) + : old_num( -1 ), new_num( -1 ), ack_num( -1 ), throwaway_num( -1 ), diff() +{ + assert( x.size() >= 4 * sizeof( uint64_t ) ); + uint64_t *data = (uint64_t *)x.data(); + old_num = be64toh( data[ 0 ] ); + new_num = be64toh( data[ 1 ] ); + ack_num = be64toh( data[ 2 ] ); + throwaway_num = be64toh( data[ 3 ] ); + + diff = string( x.begin() + 4 * sizeof( uint64_t ), x.end() ); +} diff --git a/networktransport.cpp b/networktransport.cpp index 9e85d36..04e38dc 100644 --- a/networktransport.cpp +++ b/networktransport.cpp @@ -22,20 +22,20 @@ uint64_t Transport::timestamp( void ) } template -Transport::Transport( MyState &initial_state ) +Transport::Transport( MyState &initial_state, RemoteState &initial_remote ) : connection(), server( true ), current_state( initial_state ), sent_states( 1, TimestampedState( timestamp(), 0, initial_state ) ), assumed_receiver_state( sent_states.begin() ), timeout( INITIAL_TIMEOUT ), - highest_state_received( 0 ) + received_states( 1, TimestampedState( timestamp(), 0, initial_remote ) ) { /* server */ } template -Transport::Transport( MyState &initial_state, +Transport::Transport( MyState &initial_state, RemoteState &initial_remote, const char *key_str, const char *ip, int port ) : connection( key_str, ip, port ), server( false ), @@ -43,7 +43,7 @@ Transport::Transport( MyState &initial_state, sent_states( 1, TimestampedState( timestamp(), 0, initial_state ) ), assumed_receiver_state( sent_states.begin() ), timeout( INITIAL_TIMEOUT ), - highest_state_received( 0 ) + received_states( 1, TimestampedState( timestamp(), 0, initial_remote ) ) { /* client */ } @@ -54,10 +54,6 @@ void Transport::tick( void ) /* Update assumed receiver state */ update_assumed_receiver_state(); - fprintf( stderr, "Assumed receiver state: %d/%d\r\n", - int(assumed_receiver_state->num), - int(sent_states.back().num) ); - /* Cut out common prefix of all states */ rationalize_states(); @@ -88,14 +84,13 @@ void Transport::send_to_receiver( void ) /* send empty ack */ Instruction inst( assumed_receiver_state->num, assumed_receiver_state->num, - highest_state_received, + received_states.back().num, + sent_states.front().num, "" ); string s = inst.tostring(); connection.send( s ); assumed_receiver_state->timestamp = timestamp(); - fprintf( stderr, "Empty ack.\r\n" ); - return; } @@ -106,7 +101,10 @@ void Transport::send_to_receiver( void ) exit( 1 ); } - Instruction inst( assumed_receiver_state->num, -1, highest_state_received, + Instruction inst( assumed_receiver_state->num, + -1, + received_states.back().num, + sent_states.front().num, current_state.diff_from( assumed_receiver_state->state, connection.get_MTU() - HEADER_LEN ) ); MyState new_state = assumed_receiver_state->state; @@ -121,6 +119,17 @@ void Transport::send_to_receiver( void ) previously_sent++; } + /* Reusing state numbers is only for intermediate states */ + /* If this is the final diff in a sequence, make sure it does get the highest + state number (even if we've retread to previously-seen ground ) */ + /* This will force the client to update to this state */ + typename list< TimestampedState >::iterator last = sent_states.end(); + last--; + if ( (previously_sent != last) + && (new_state == target_receiver_state) ) { + previously_sent = sent_states.end(); + } + if ( previously_sent == sent_states.end() ) { /* not previously sent */ inst.new_num = sent_states.back().num + 1; sent_states.push_back( TimestampedState( timestamp(), inst.new_num, new_state ) ); @@ -136,11 +145,6 @@ void Transport::send_to_receiver( void ) string s = inst.tostring(); try { - fprintf( stderr, "Sending: " ); - for ( size_t i = 0; i < s.size(); i++ ) { - fprintf( stderr, "%c", s[ i ] ); - } - fprintf( stderr, "\r\n" ); connection.send( s ); } catch ( MTUException m ) { continue; @@ -177,11 +181,101 @@ void Transport::rationalize_states( void ) { MyState * const known_receiver_state = &sent_states.front().state; - for ( typename list< TimestampedState >::iterator i = sent_states.begin(); - i != sent_states.end(); + current_state.subtract( known_receiver_state ); + + for ( typename list< TimestampedState >::reverse_iterator i = sent_states.rbegin(); + i != sent_states.rend(); i++ ) { i->state.subtract( known_receiver_state ); } - - current_state.subtract( known_receiver_state ); +} + +template +void Transport::recv( void ) +{ + string s( connection.recv() ); + Instruction inst( s ); + + process_acknowledgment_through( inst.ack_num ); + // process_throwaway_until( inst.throwaway.num ); + + /* first, make sure we don't already have the new state */ + for ( typename list< TimestampedState >::iterator i = received_states.begin(); + i != received_states.end(); + i++ ) { + if ( inst.new_num == i->num ) { + i->timestamp = timestamp(); + return; + } + } + + /* now, make sure we do have the old state */ + bool found = 0; + typename list< TimestampedState >::iterator reference_state = received_states.begin(); + while ( reference_state != received_states.end() ) { + if ( inst.old_num == reference_state->num ) { + found = true; + break; + } + reference_state++; + } + + if ( !found ) { + fprintf( stderr, "Ignoring out-of-order packet. Reference state %d has been discarded.\n", int(inst.old_num) ); + return; + } + + /* apply diff to reference state */ + TimestampedState new_state = *reference_state; + new_state.timestamp = timestamp(); + new_state.num = inst.new_num; + new_state.state.apply_string( inst.diff ); + + if ( new_state.num > received_states.back().num ) { + process_throwaway_until( inst.throwaway_num ); + } + + /* Insert new state in sorted place */ + for ( typename list< TimestampedState >::iterator i = received_states.begin(); + i != received_states.end(); + i++ ) { + if ( i->num > new_state.num ) { + received_states.insert( i, new_state ); + return; + } + } + received_states.push_back( new_state ); +} + +template +void Transport::process_acknowledgment_through( uint64_t ack_num ) +{ + typename list< TimestampedState >::iterator i = sent_states.begin(); + while ( i != sent_states.end() ) { + typename list< TimestampedState >::iterator inext = i; + inext++; + if ( i->num < ack_num ) { + sent_states.erase( i ); + } + i = inext; + } + + assert( sent_states.size() > 0 ); + assert( sent_states.front().num == ack_num ); +} + +template +void Transport::process_throwaway_until( uint64_t throwaway_num ) +{ + typename list< TimestampedState >::iterator i = received_states.begin(); + while ( i != received_states.end() ) { + typename list< TimestampedState >::iterator inext = i; + inext++; + if ( i->num < throwaway_num ) { + sent_states.erase( i ); + } + i = inext; + } + + assert( received_states.size() > 0 ); } diff --git a/networktransport.hpp b/networktransport.hpp index 48580c5..bb7cbe3 100644 --- a/networktransport.hpp +++ b/networktransport.hpp @@ -16,13 +16,18 @@ namespace Network { public: uint64_t old_num, new_num; uint64_t ack_num; + uint64_t throwaway_num; string diff; - Instruction( uint64_t s_old_num, uint64_t s_new_num, uint64_t s_ack_num, string s_diff ) - : old_num( s_old_num ), new_num( s_new_num ), ack_num( s_ack_num ), diff( s_diff ) + Instruction( uint64_t s_old_num, uint64_t s_new_num, + uint64_t s_ack_num, uint64_t s_throwaway_num, string s_diff ) + : old_num( s_old_num ), new_num( s_new_num ), + ack_num( s_ack_num ), throwaway_num( s_throwaway_num ), diff( s_diff ) {} + Instruction( string &x ); + string tostring( void ); }; @@ -45,13 +50,17 @@ namespace Network { private: static const int INITIAL_TIMEOUT = 1000; /* ms, same as TCP */ static const int SEND_INTERVAL = 20; /* ms between frames */ - static const int HEADER_LEN = 40; + static const int HEADER_LEN = 80; /* helper methods for tick() */ void update_assumed_receiver_state( void ); void rationalize_states( void ); void send_to_receiver( void ); + /* helper methods for recv() */ + void process_acknowledgment_through( uint64_t ack_num ); + void process_throwaway_until( uint64_t throwaway_num ); + Connection connection; bool server; @@ -70,18 +79,25 @@ namespace Network { int timeout; /* simple receiver */ - uint64_t highest_state_received; + list< TimestampedState > received_states; public: - Transport( MyState &initial_state ); - Transport( MyState &initial_state, const char *key_str, const char *ip, int port ); + Transport( MyState &initial_state, RemoteState &initial_remote ); + Transport( MyState &initial_state, RemoteState &initial_remote, + const char *key_str, const char *ip, int port ); void tick( void ); + void recv( void ); + int port( void ) { return connection.port(); } string get_key( void ) { return connection.get_key(); } MyState &get_current_state( void ) { return current_state; } + RemoteState &get_remote_state( void ) { return received_states.back().state; } + uint64_t get_remote_state_num( void ) { return received_states.back().num; } + + int fd( void ) { return connection.fd(); } }; } diff --git a/ntester.cpp b/ntester.cpp index 45e383f..abd3a47 100644 --- a/ntester.cpp +++ b/ntester.cpp @@ -1,9 +1,25 @@ #include #include +#include #include "keystroke.hpp" #include "networktransport.hpp" +bool readable( int fd ) +{ + struct pollfd my_pollfd; + my_pollfd.fd = fd; + my_pollfd.events = POLLIN; + + int num = poll( &my_pollfd, 1, 0 ); + if ( num < 0 ) { + perror( "poll" ); + exit( 1 ); + } + + return my_pollfd.revents & POLLIN; +} + int main( int argc, char *argv[] ) { bool server = true; @@ -11,7 +27,7 @@ int main( int argc, char *argv[] ) char *ip; int port; - KeyStroke user; + KeyStroke me, remote; Network::Transport *n; @@ -24,9 +40,9 @@ int main( int argc, char *argv[] ) ip = argv[ 2 ]; port = atoi( argv[ 3 ] ); - n = new Network::Transport( user, key, ip, port ); + n = new Network::Transport( me, remote, key, ip, port ); } else { - n = new Network::Transport( user ); + n = new Network::Transport( me, remote ); } } catch ( CryptoException e ) { fprintf( stderr, "Fatal error: %s\n", e.text.c_str() ); @@ -36,18 +52,20 @@ int main( int argc, char *argv[] ) fprintf( stderr, "Port bound is %d, key is %s\n", n->port(), n->get_key().c_str() ); if ( server ) { - /* while ( true ) { try { - string s = n->recv(); - printf( "%s", s.c_str() ); - fflush( NULL ); + n->recv(); + n->tick(); + fprintf( stderr, "Num: %d. Contents: ", + (int)n->get_remote_state_num() ); + for ( size_t i = 0; i < n->get_remote_state().user_bytes.size(); i++ ) { + fprintf( stderr, "%c", n->get_remote_state().user_bytes[ i ] ); + } + fprintf( stderr, "\n" ); } catch ( CryptoException e ) { fprintf( stderr, "Cryptographic error: %s\n", e.text.c_str() ); } } - */ - sleep( 6000 ); } else { struct termios saved_termios; struct termios the_termios; @@ -72,10 +90,15 @@ int main( int argc, char *argv[] ) n->get_current_state().key_hit( x ); try { + if ( readable( n->fd() ) ) { + n->recv(); + } n->tick(); } catch ( Network::NetworkException e ) { fprintf( stderr, "%s: %s\r\n", e.function.c_str(), strerror( e.the_errno ) ); break; + } catch ( CryptoException e ) { + fprintf( stderr, "Cryptographic error: %s\n", e.text.c_str() ); } }