Files
mosh/networktransport.cpp
T
2012-02-02 17:47:54 -05:00

150 lines
4.7 KiB
C++

#include <assert.h>
#include <iostream>
#include "networktransport.hpp"
#include "transportsender.cpp"
using namespace Network;
using namespace std;
template <class MyState, class RemoteState>
Transport<MyState, RemoteState>::Transport( MyState &initial_state, RemoteState &initial_remote,
const char *desired_ip )
: connection( desired_ip ),
sender( &connection, initial_state ),
received_states( 1, TimestampedState<RemoteState>( timestamp(), 0, initial_remote ) ),
last_receiver_state( initial_remote ),
fragments(),
verbose( false )
{
/* server */
}
template <class MyState, class RemoteState>
Transport<MyState, RemoteState>::Transport( MyState &initial_state, RemoteState &initial_remote,
const char *key_str, const char *ip, int port )
: connection( key_str, ip, port ),
sender( &connection, initial_state ),
received_states( 1, TimestampedState<RemoteState>( timestamp(), 0, initial_remote ) ),
last_receiver_state( initial_remote ),
fragments(),
verbose( false )
{
/* client */
}
template <class MyState, class RemoteState>
void Transport<MyState, RemoteState>::recv( void )
{
string s( connection.recv() );
Fragment frag( s );
if ( fragments.add_fragment( frag ) ) { /* complete packet */
Instruction inst = fragments.get_assembly();
if ( inst.protocol_version() != MOSH_PROTOCOL_VERSION ) {
throw NetworkException( "mosh protocol version mismatch", 0 );
}
sender.process_acknowledgment_through( inst.ack_num() );
/* first, make sure we don't already have the new state */
for ( typename list< TimestampedState<RemoteState> >::iterator i = received_states.begin();
i != received_states.end();
i++ ) {
if ( inst.new_num() == i->num ) {
return;
}
}
/* now, make sure we do have the old state */
bool found = 0;
typename list< TimestampedState<RemoteState> >::iterator reference_state = received_states.begin();
while ( reference_state != received_states.end() ) {
if ( inst.old_num() == reference_state->num ) {
found = true;
break;
}
reference_state++;
}
if ( !found ) {
// fprintf( stderr, "Ignoring out-of-order packet. Reference state %d has been discarded or hasn't yet been received.\n", int(inst.old_num) );
return; /* this is security-sensitive and part of how we enforce idempotency */
}
/* apply diff to reference state */
TimestampedState<RemoteState> new_state = *reference_state;
new_state.timestamp = timestamp();
new_state.num = inst.new_num();
if ( !inst.diff().empty() ) {
new_state.state.apply_string( inst.diff() );
}
process_throwaway_until( inst.throwaway_num() );
/* Insert new state in sorted place */
for ( typename list< TimestampedState<RemoteState> >::iterator i = received_states.begin();
i != received_states.end();
i++ ) {
if ( i->num > new_state.num ) {
received_states.insert( i, new_state );
if ( verbose ) {
fprintf( stderr, "[%u] Received OUT-OF-ORDER state %d [ack %d]\n",
(unsigned int)(timestamp() % 100000), (int)new_state.num, (int)inst.ack_num() );
}
return;
}
}
if ( verbose ) {
fprintf( stderr, "[%u] Received state %d [ack %d]\n",
(unsigned int)(timestamp() % 100000), (int)new_state.num, (int)inst.ack_num() );
}
received_states.push_back( new_state );
sender.set_ack_num( received_states.back().num );
if ( !inst.diff().empty() ) {
sender.set_data_ack();
}
}
}
/* The sender uses throwaway_num to tell us the earliest received state that we need to keep around */
template <class MyState, class RemoteState>
void Transport<MyState, RemoteState>::process_throwaway_until( uint64_t throwaway_num )
{
typename list< TimestampedState<RemoteState> >::iterator i = received_states.begin();
while ( i != received_states.end() ) {
typename list< TimestampedState<RemoteState> >::iterator inext = i;
inext++;
if ( i->num < throwaway_num ) {
received_states.erase( i );
}
i = inext;
}
assert( received_states.size() > 0 );
}
template <class MyState, class RemoteState>
string Transport<MyState, RemoteState>::get_remote_diff( void )
{
/* find diff between last receiver state and current remote state, then rationalize states */
string ret( received_states.back().state.diff_from( last_receiver_state ) );
const RemoteState *oldest_receiver_state = &received_states.front().state;
for ( typename list< TimestampedState<RemoteState> >::reverse_iterator i = received_states.rbegin();
i != received_states.rend();
i++ ) {
i->state.subtract( oldest_receiver_state );
}
last_receiver_state = received_states.back().state;
return ret;
}