diff --git a/keystroke.cpp b/keystroke.cpp index 18566d2..687cfca 100644 --- a/keystroke.cpp +++ b/keystroke.cpp @@ -12,7 +12,7 @@ void KeyStroke::subtract( KeyStroke * const prefix ) } } -string KeyStroke::diff_from( KeyStroke const & existing, int length_limit ) +string KeyStroke::diff_from( KeyStroke const & existing ) { string ret; @@ -26,8 +26,7 @@ string KeyStroke::diff_from( KeyStroke const & existing, int length_limit ) my_it++; } - while ( (my_it != user_bytes.end()) - && ( (length_limit < 0) ? true : (int(ret.size()) < length_limit) ) ) { + while ( my_it != user_bytes.end() ) { ret += string( &( *my_it ), 1 ); my_it++; } diff --git a/keystroke.hpp b/keystroke.hpp index 2836827..26a4aee 100644 --- a/keystroke.hpp +++ b/keystroke.hpp @@ -18,7 +18,7 @@ public: /* interface for Network::Transport */ void subtract( KeyStroke * const prefix ); - string diff_from( KeyStroke const & existing, int length_limit ); + string diff_from( KeyStroke const & existing ); void apply_string( string diff ); bool operator==( KeyStroke const &x ) const { return user_bytes == x.user_bytes; } }; diff --git a/networkinstruction.cpp b/networkinstruction.cpp index 971a4bd..b1dff22 100644 --- a/networkinstruction.cpp +++ b/networkinstruction.cpp @@ -11,6 +11,12 @@ static string network_order_string( uint64_t host_order ) return string( (char *)&net_int, sizeof( net_int ) ); } +static string network_order_string( uint16_t host_order ) +{ + uint16_t net_int = htobe16( host_order ); + return string( (char *)&net_int, sizeof( net_int ) ); +} + string Instruction::tostring( void ) { string ret; @@ -19,20 +25,92 @@ string Instruction::tostring( void ) ret += network_order_string( new_num ); ret += network_order_string( ack_num ); ret += network_order_string( throwaway_num ); + ret += network_order_string( fragment_num ); + + assert( ret.size() == inst_header_len ); + ret += diff; return ret; } Instruction::Instruction( string &x ) - : old_num( -1 ), new_num( -1 ), ack_num( -1 ), throwaway_num( -1 ), diff() + : old_num( -1 ), new_num( -1 ), ack_num( -1 ), throwaway_num( -1 ), fragment_num( -1 ), diff() { - assert( x.size() >= 4 * sizeof( uint64_t ) ); + assert( x.size() >= inst_header_len ); uint64_t *data = (uint64_t *)x.data(); + uint16_t *data16 = (uint16_t *)x.data(); old_num = be64toh( data[ 0 ] ); new_num = be64toh( data[ 1 ] ); ack_num = be64toh( data[ 2 ] ); throwaway_num = be64toh( data[ 3 ] ); + fragment_num = be16toh( data16[ 16 ] ); - diff = string( x.begin() + 4 * sizeof( uint64_t ), x.end() ); + diff = string( x.begin() + inst_header_len, x.end() ); +} + +bool FragmentAssembly::same_template( Instruction &a, Instruction &b ) +{ + return ( a.old_num == b.old_num ) && ( a.new_num == b.new_num ) && ( a.ack_num == b.ack_num ) + && ( a.throwaway_num == b.throwaway_num ); +} + +bool FragmentAssembly::add_fragment( Instruction &inst ) +{ + /* decode fragment num */ + bool last_fragment = inst.fragment_num > 32767; + uint16_t real_fragment_num = inst.fragment_num; + if ( last_fragment ) { + real_fragment_num -= 32768; + } + + /* see if this is a totally new packet */ + if ( !same_template( inst, current_template ) ) { + fragments.clear(); + current_template = inst; + fragments.resize( real_fragment_num + 1 ); + fragments[ real_fragment_num ] = inst; + fragments_arrived = 1; + fragments_total = -1; + } else { /* not a new packet */ + /* see if we already have this fragment */ + if ( fragments[ real_fragment_num ].old_num != uint64_t(-1) ) { + assert( fragments[ real_fragment_num ] == inst ); + } else { + if ( fragments_total == -1 ) { + fragments.resize( real_fragment_num + 1 ); + } + fragments.at( real_fragment_num ) = inst; + fragments_arrived++; + } + } + + if ( last_fragment ) { + fragments_total = real_fragment_num + 1; + fragments.resize( fragments_total ); + } + + /* see if we're done */ + return ( fragments_arrived == fragments_total ); +} + +Instruction FragmentAssembly::get_assembly( void ) +{ + assert( fragments_arrived == fragments_total ); + + Instruction ret( current_template ); + ret.diff = ""; + + for ( int i = 0; i < fragments_total; i++ ) { + ret.diff += fragments[ i ].diff; + } + + return ret; +} + +bool Instruction::operator==( const Instruction &x ) +{ + return ( old_num == x.old_num ) && ( new_num == x.new_num ) + && ( ack_num == x.ack_num ) && ( throwaway_num == x.throwaway_num ) + && ( fragment_num == x.fragment_num ) && ( diff == x.diff ); } diff --git a/networktransport.cpp b/networktransport.cpp index 54ab17c..c3ae5dc 100644 --- a/networktransport.cpp +++ b/networktransport.cpp @@ -14,7 +14,8 @@ Transport::Transport( MyState &initial_state, RemoteState sent_states( 1, TimestampedState( timestamp(), 0, initial_state ) ), assumed_receiver_state( sent_states.begin() ), received_states( 1, TimestampedState( timestamp(), 0, initial_remote ) ), - last_receiver_state( initial_remote ) + last_receiver_state( initial_remote ), + fragments() { /* server */ } @@ -28,7 +29,8 @@ Transport::Transport( MyState &initial_state, RemoteState sent_states( 1, TimestampedState( timestamp(), 0, initial_state ) ), assumed_receiver_state( sent_states.begin() ), received_states( 1, TimestampedState( timestamp(), 0, initial_remote ) ), - last_receiver_state( initial_remote ) + last_receiver_state( initial_remote ), + fragments() { /* client */ } @@ -61,34 +63,24 @@ int Transport::tick( void ) template void Transport::send_to_receiver( void ) { - /* We don't want to assume that this sequence of diffs will - necessarily bring the receiver to the _actual_ current - state. That requires perfect round-trip stability of the diff - mechanism -- stronger than we need (and probably too fragile). - Instead, we produce the full diff, unlimited by MTU, between - the assumed receiver state and current state, and apply that - diff to produce a target receiver state. Then we produce a - sequence of diffs (this time limited by MTU) that bring us to - that state. */ - if ( !connection.get_attached() ) { return; } - MyState target_receiver_state( assumed_receiver_state->state ); - target_receiver_state.apply_string( current_state.diff_from( target_receiver_state, -1 ) ); - - if ( assumed_receiver_state->state == target_receiver_state ) { + if ( (assumed_receiver_state->num == sent_states.back().num) + && (sent_states.back().state == current_state) ) { /* send empty ack */ if ( (!connection.pending_timestamp()) && (timestamp() - sent_states.back().timestamp < int64_t( ACK_INTERVAL )) ) { return; } + Instruction inst( assumed_receiver_state->num, assumed_receiver_state->num, received_states.back().num, sent_states.front().num, + 32768, "" ); string s = inst.tostring(); connection.send( s, false ); @@ -97,68 +89,43 @@ void Transport::send_to_receiver( void ) return; } - int tries = 0; - while ( !(assumed_receiver_state->state == target_receiver_state) ) { - if ( tries++ > 1024 ) { - fprintf( stderr, "BUG: Convergence limit exceeded.\n" ); - exit( 1 ); + string diff = current_state.diff_from( assumed_receiver_state->state ); + + uint64_t new_num; + if ( current_state == sent_states.back().state ) { /* previously sent */ + new_num = sent_states.back().num; + } else { /* new state */ + new_num = sent_states.back().num + 1; + } + + bool done = false; + int MTU_tries = 0; + while ( !done ) { + MTU_tries++; + + if ( MTU_tries > 20 ) { + fprintf( stderr, "Error, could not send fragments after 20 tries (MTU = %d).\n", + connection.get_MTU() ); } - Instruction inst( assumed_receiver_state->num, - -1, - received_states.back().num, - sent_states.front().num, - current_state.diff_from( assumed_receiver_state->state, - connection.get_MTU() - HEADER_LEN ) ); - MyState new_state = assumed_receiver_state->state; - new_state.apply_string( inst.diff ); - - /* Find the right "new_num" for this instruction. */ - /* Has this state previously been sent? */ - /* should replace with hash table if this becomes demanding */ - typename list< TimestampedState >::iterator previously_sent = sent_states.begin(); - while ( ( previously_sent != sent_states.end() ) - && ( !(previously_sent->state == new_state) ) ) { - previously_sent++; - } - - /* Reusing state numbers is only for intermediate states */ - /* If this is the final diff in a sequence, make sure it does get the highest - state number (even if we've retread to previously-seen ground ) */ - /* This will force the client to update to this state */ - if ( new_state == target_receiver_state ) { - if ( new_state == sent_states.back().state ) { - previously_sent = sent_states.end(); - previously_sent--; - } else { - previously_sent = sent_states.end(); - } - } - - if ( previously_sent == sent_states.end() ) { /* not previously sent */ - inst.new_num = sent_states.back().num + 1; - sent_states.push_back( TimestampedState( timestamp(), inst.new_num, new_state ) ); - previously_sent = sent_states.end(); - previously_sent--; - } else { - inst.new_num = previously_sent->num; - previously_sent->timestamp = timestamp(); - } - - /* send instruction */ - string s = inst.tostring(); - try { - fprintf( stderr, "Sent instruction (timeout %d, queues %d/%d) from %d => %d (terminal %d): %s\r\n", connection.timeout(), (int)sent_states.size(), (int)received_states.size(), int(inst.old_num), int(inst.new_num), int(sent_states.back().num), inst.diff.c_str() ); - connection.send( s ); + send_in_fragments( diff, new_num ); + done = true; } catch ( MTUException m ) { - continue; + done = false; } - - /* successfully sent, probably */ - /* ("probably" because the FIRST size-exceeded datagram doesn't get an error) */ - assumed_receiver_state = previously_sent; } + + if ( current_state == sent_states.back().state ) { + sent_states.back().timestamp = timestamp(); + } else { + sent_states.push_back( TimestampedState( timestamp(), new_num, current_state ) ); + } + + /* successfully sent, probably */ + /* ("probably" because the FIRST size-exceeded datagram doesn't get an error) */ + assumed_receiver_state = sent_states.end(); + assumed_receiver_state--; } template @@ -199,60 +166,58 @@ template void Transport::recv( void ) { string s( connection.recv() ); - Instruction inst( s ); - - process_acknowledgment_through( inst.ack_num ); + Instruction frag( s ); - /* first, make sure we don't already have the new state */ - for ( typename list< TimestampedState >::iterator i = received_states.begin(); - i != received_states.end(); - i++ ) { - if ( inst.new_num == i->num ) { - i->timestamp = timestamp(); + if ( fragments.add_fragment( frag ) ) { /* complete packet */ + Instruction inst = fragments.get_assembly(); + + process_acknowledgment_through( inst.ack_num ); + + /* first, make sure we don't already have the new state */ + for ( typename list< TimestampedState >::iterator i = received_states.begin(); + i != received_states.end(); + i++ ) { + if ( inst.new_num == i->num ) { + i->timestamp = timestamp(); + return; + } + } + + /* now, make sure we do have the old state */ + bool found = 0; + typename list< TimestampedState >::iterator reference_state = received_states.begin(); + while ( reference_state != received_states.end() ) { + if ( inst.old_num == reference_state->num ) { + found = true; + break; + } + reference_state++; + } + + if ( !found ) { + // fprintf( stderr, "Ignoring out-of-order packet. Reference state %d has been discarded or hasn't yet been received.\n", int(inst.old_num) ); return; } - } - - /* now, make sure we do have the old state */ - bool found = 0; - typename list< TimestampedState >::iterator reference_state = received_states.begin(); - while ( reference_state != received_states.end() ) { - if ( inst.old_num == reference_state->num ) { - found = true; - break; - } - reference_state++; - } - - if ( !found ) { - // fprintf( stderr, "Ignoring out-of-order packet. Reference state %d has been discarded or hasn't yet been received.\n", int(inst.old_num) ); - /* There may be some benefit to storing these diffs until they can be used later, - but my guess is that the benefit is slim -- the diffs are likely to be small enough - that the entire diff will usually fit in one datagram, and by the time of retransmission - the target state will be different anyway. */ - return; - } - - /* apply diff to reference state */ - TimestampedState new_state = *reference_state; - new_state.timestamp = timestamp(); - new_state.num = inst.new_num; - new_state.state.apply_string( inst.diff ); - - if ( new_state.num > received_states.back().num ) { + + /* apply diff to reference state */ + TimestampedState new_state = *reference_state; + new_state.timestamp = timestamp(); + new_state.num = inst.new_num; + new_state.state.apply_string( inst.diff ); + process_throwaway_until( inst.throwaway_num ); - } - /* Insert new state in sorted place */ - for ( typename list< TimestampedState >::iterator i = received_states.begin(); - i != received_states.end(); - i++ ) { - if ( i->num > new_state.num ) { - received_states.insert( i, new_state ); - return; + /* Insert new state in sorted place */ + for ( typename list< TimestampedState >::iterator i = received_states.begin(); + i != received_states.end(); + i++ ) { + if ( i->num > new_state.num ) { + received_states.insert( i, new_state ); + return; + } } + received_states.push_back( new_state ); } - received_states.push_back( new_state ); } template @@ -275,12 +240,12 @@ void Transport::process_acknowledgment_through( uint64_t a template void Transport::process_throwaway_until( uint64_t throwaway_num ) { - typename list< TimestampedState >::iterator i = received_states.begin(); + typename list< TimestampedState >::iterator i = received_states.begin(); while ( i != received_states.end() ) { - typename list< TimestampedState >::iterator inext = i; + typename list< TimestampedState >::iterator inext = i; inext++; if ( i->num < throwaway_num ) { - sent_states.erase( i ); + received_states.erase( i ); } i = inext; } @@ -293,7 +258,7 @@ string Transport::get_remote_diff( void ) { /* find diff between last receiver state and current remote state, then rationalize states */ - string ret( received_states.back().state.diff_from( last_receiver_state, -1 ) ); + string ret( received_states.back().state.diff_from( last_receiver_state ) ); MyState * const oldest_receiver_state = &received_states.front().state; @@ -307,3 +272,33 @@ string Transport::get_remote_diff( void ) return ret; } + +template +void Transport::send_in_fragments( string diff, uint64_t new_num ) +{ + uint16_t fragment_num = 0; + + while ( !diff.empty() ) { + string this_fragment; + + assert( fragment_num <= 32767 ); + + if ( int( diff.size() + HEADER_LEN ) > connection.get_MTU() ) { + this_fragment = string( diff.begin(), diff.begin() + connection.get_MTU() - HEADER_LEN ); + diff = string( diff.begin() + connection.get_MTU() - HEADER_LEN, diff.end() ); + } else { + this_fragment = diff; + diff.clear(); + fragment_num += 32768; /* last fragment */ + } + + Instruction inst( assumed_receiver_state->num, + new_num, + received_states.back().num, + sent_states.front().num, + fragment_num++, + this_fragment ); + string s = inst.tostring(); + connection.send( s ); + } +} diff --git a/networktransport.hpp b/networktransport.hpp index 28a1244..2ab827c 100644 --- a/networktransport.hpp +++ b/networktransport.hpp @@ -5,6 +5,7 @@ #include #include #include +#include #include "network.hpp" @@ -13,22 +14,48 @@ using namespace std; namespace Network { class Instruction { + private: + static const size_t inst_header_len = 4 * sizeof( uint64_t ) + 1 * sizeof( uint16_t ); + public: uint64_t old_num, new_num; uint64_t ack_num; uint64_t throwaway_num; + uint16_t fragment_num; string diff; + Instruction() : old_num( -1 ), new_num( -1 ), ack_num( -1 ), throwaway_num( -1 ), fragment_num( -1 ), + diff( "" ) + {} + Instruction( uint64_t s_old_num, uint64_t s_new_num, - uint64_t s_ack_num, uint64_t s_throwaway_num, string s_diff ) + uint64_t s_ack_num, uint64_t s_throwaway_num, uint16_t s_fragment_num, + string s_diff ) : old_num( s_old_num ), new_num( s_new_num ), - ack_num( s_ack_num ), throwaway_num( s_throwaway_num ), diff( s_diff ) + ack_num( s_ack_num ), throwaway_num( s_throwaway_num ), fragment_num( s_fragment_num ), + diff( s_diff ) {} Instruction( string &x ); string tostring( void ); + + bool operator==( const Instruction &x ); + }; + + class FragmentAssembly + { + private: + vector fragments; + Instruction current_template; + int fragments_arrived, fragments_total; + + public: + FragmentAssembly() : fragments(), current_template(), fragments_arrived( 0 ), fragments_total( -1 ) {} + static bool same_template( Instruction &a, Instruction &b ); + bool add_fragment( Instruction &inst ); + Instruction get_assembly( void ); }; template @@ -56,6 +83,7 @@ namespace Network { void update_assumed_receiver_state( void ); void rationalize_states( void ); void send_to_receiver( void ); + void send_in_fragments( string diff, uint64_t new_num ); /* helper methods for recv() */ void process_acknowledgment_through( uint64_t ack_num ); @@ -78,6 +106,8 @@ namespace Network { list< TimestampedState > received_states; MyState last_receiver_state; /* the state we were in when user last queried state */ + FragmentAssembly fragments; + public: Transport( MyState &initial_state, RemoteState &initial_remote ); Transport( MyState &initial_state, RemoteState &initial_remote, diff --git a/templates.cpp b/templates.cpp index c2d8575..1c0c930 100644 --- a/templates.cpp +++ b/templates.cpp @@ -25,5 +25,6 @@ template class vector; template class map; template class vector; template class deque; +template class vector; template class Network::Transport;