Separate modules by subdirectory

This commit is contained in:
Keith Winstein
2012-02-06 18:26:45 -05:00
parent 7e56af8fcd
commit 38c9e99882
58 changed files with 79 additions and 16 deletions
+6
View File
@@ -0,0 +1,6 @@
AM_CPPFLAGS = -I$(srcdir)/../util -I$(srcdir)/../crypto -I$(builddir)/../protobufs
AM_CXXFLAGS = --std=c++0x -pedantic -Werror -Wall -Wextra -Weffc++ -fno-default-inline -pipe
noinst_LIBRARIES = libmoshnetwork.a
libmoshnetwork_a_SOURCES = network.cc network.h networktransport.cc networktransport.h transportfragment.cc transportfragment.h transportsender.cc transportsender.h transportstate.h
+359
View File
@@ -0,0 +1,359 @@
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <assert.h>
#include <endian.h>
#include "dos_assert.h"
#include "network.h"
#include "crypto.h"
using namespace std;
using namespace Network;
using namespace Crypto;
const uint64_t DIRECTION_MASK = uint64_t(1) << 63;
const uint64_t SEQUENCE_MASK = uint64_t(-1) ^ DIRECTION_MASK;
/* Read in packet from coded string */
Packet::Packet( string coded_packet, Session *session )
: seq( -1 ),
direction( TO_SERVER ),
timestamp( -1 ),
timestamp_reply( -1 ),
payload()
{
Message message = session->decrypt( coded_packet );
direction = (message.nonce.val() & DIRECTION_MASK) ? TO_CLIENT : TO_SERVER;
seq = message.nonce.val() & SEQUENCE_MASK;
dos_assert( message.text.size() >= 2 * sizeof( uint16_t ) );
uint16_t *data = (uint16_t *)message.text.data();
timestamp = be16toh( data[ 0 ] );
timestamp_reply = be16toh( data[ 1 ] );
payload = string( message.text.begin() + 2 * sizeof( uint16_t ), message.text.end() );
}
/* Output coded string from packet */
string Packet::tostring( Session *session )
{
uint64_t direction_seq = (uint64_t( direction == TO_CLIENT ) << 63) | (seq & SEQUENCE_MASK);
uint16_t ts_net[ 2 ] = { htobe16( timestamp ), htobe16( timestamp_reply ) };
string timestamps = string( (char *)ts_net, 2 * sizeof( uint16_t ) );
return session->encrypt( Message( Nonce( direction_seq ), timestamps + payload ) );
}
Packet Connection::new_packet( string &s_payload )
{
uint16_t outgoing_timestamp_reply = -1;
uint64_t now = timestamp();
if ( now - saved_timestamp_received_at < 1000 ) { /* we have a recent received timestamp */
/* send "corrected" timestamp advanced by how long we held it */
outgoing_timestamp_reply = saved_timestamp + (now - saved_timestamp_received_at);
saved_timestamp = -1;
saved_timestamp_received_at = 0;
}
Packet p( next_seq++, direction, timestamp16(), outgoing_timestamp_reply, s_payload );
return p;
}
void Connection::setup( void )
{
/* create socket */
sock = socket( AF_INET, SOCK_DGRAM, 0 );
if ( sock < 0 ) {
throw NetworkException( "socket", errno );
}
/* Enable path MTU discovery */
char flag = IP_PMTUDISC_DO;
socklen_t optlen = sizeof( flag );
if ( setsockopt( sock, IPPROTO_IP, IP_MTU_DISCOVER, &flag, optlen ) < 0 ) {
throw NetworkException( "setsockopt", errno );
}
}
Connection::Connection( const char *desired_ip ) /* server */
: sock( -1 ),
remote_addr(),
server( true ),
attached( false ),
MTU( SEND_MTU ),
key(),
session( key ),
direction( TO_CLIENT ),
next_seq( 0 ),
saved_timestamp( -1 ),
saved_timestamp_received_at( 0 ),
expected_receiver_seq( 0 ),
RTT_hit( false ),
SRTT( 1000 ),
RTTVAR( 500 )
{
setup();
/* Attempt to bind free local port, with
address client used to connect to us.
This usage does not seem to be endorsed by POSIX. */
struct sockaddr_in local_addr;
local_addr.sin_family = AF_INET;
local_addr.sin_port = htons( 0 );
if ( desired_ip
&& inet_aton( desired_ip, &local_addr.sin_addr )
&& (bind( sock, (sockaddr *)&local_addr, sizeof( local_addr ) ) == 0) ) {
return;
}
if ( desired_ip ) {
fprintf( stderr, "Could not bind to desired local address %s.\n", desired_ip );
}
/* Could not bind to that IP (maybe we are behind NAT).
Try again with any IP. */
local_addr.sin_addr.s_addr = INADDR_ANY;
if ( bind( sock, (sockaddr *)&local_addr, sizeof( local_addr ) ) < 0 ) {
throw NetworkException( "bind", errno );
}
}
Connection::Connection( const char *key_str, const char *ip, int port ) /* client */
: sock( -1 ),
remote_addr(),
server( false ),
attached( false ),
MTU( SEND_MTU ),
key( key_str ),
session( key ),
direction( TO_SERVER ),
next_seq( 0 ),
saved_timestamp( -1 ),
saved_timestamp_received_at( 0 ),
expected_receiver_seq( 0 ),
RTT_hit( false ),
SRTT( 1000 ),
RTTVAR( 500 )
{
setup();
/* associate socket with remote host and port */
remote_addr.sin_family = AF_INET;
remote_addr.sin_port = htons( port );
if ( !inet_aton( ip, &remote_addr.sin_addr ) ) {
int saved_errno = errno;
char buffer[ 2048 ];
snprintf( buffer, 2048, "Bad IP address (%s)", ip );
throw NetworkException( buffer, saved_errno );
}
attached = true;
}
void Connection::send( string s )
{
assert( attached );
Packet px = new_packet( s );
string p = px.tostring( &session );
ssize_t bytes_sent = sendto( sock, p.data(), p.size(), 0,
(sockaddr *)&remote_addr, sizeof( remote_addr ) );
if ( (bytes_sent < 0) && (errno == EMSGSIZE) ) {
update_MTU();
throw NetworkException( "Path MTU Discovery", EMSGSIZE );
} else if ( bytes_sent == static_cast<int>( p.size() ) ) {
return;
} else {
throw NetworkException( "sendto", errno );
}
}
string Connection::recv( void )
{
struct sockaddr_in packet_remote_addr;
char buf[ RECEIVE_MTU ];
socklen_t addrlen = sizeof( packet_remote_addr );
ssize_t received_len = recvfrom( sock, buf, RECEIVE_MTU, 0, (sockaddr *)&packet_remote_addr, &addrlen );
if ( received_len < 0 ) {
throw NetworkException( "recvfrom", errno );
}
if ( received_len > RECEIVE_MTU ) {
char buffer[ 2048 ];
snprintf( buffer, 2048, "Received oversize datagram (size %d) and limit is %d\n",
static_cast<int>( received_len ), RECEIVE_MTU );
throw NetworkException( buffer, errno );
}
Packet p( string( buf, received_len ), &session );
dos_assert( p.direction == (server ? TO_SERVER : TO_CLIENT) ); /* prevent malicious playback to sender */
if ( p.seq >= expected_receiver_seq ) { /* don't use out-of-order packets for timestamp or targeting */
expected_receiver_seq = p.seq + 1; /* this is security-sensitive because a replay attack could otherwise
screw up the timestamp and targeting */
if ( p.timestamp != uint16_t(-1) ) {
saved_timestamp = p.timestamp;
saved_timestamp_received_at = timestamp();
}
if ( p.timestamp_reply != uint16_t(-1) ) {
uint16_t now = timestamp16();
double R = timestamp_diff( now, p.timestamp_reply );
if ( R < 5000 ) { /* ignore large values, e.g. server was Ctrl-Zed */
if ( !RTT_hit ) { /* first measurement */
SRTT = R;
RTTVAR = R / 2;
RTT_hit = true;
} else {
const double alpha = 1.0 / 8.0;
const double beta = 1.0 / 4.0;
RTTVAR = (1 - beta) * RTTVAR + ( beta * fabs( SRTT - R ) );
SRTT = (1 - alpha) * SRTT + ( alpha * R );
}
}
}
/* auto-adjust to remote host */
attached = true;
if ( (remote_addr.sin_addr.s_addr != packet_remote_addr.sin_addr.s_addr)
|| (remote_addr.sin_port != packet_remote_addr.sin_port) ) {
remote_addr = packet_remote_addr;
if ( server ) {
fprintf( stderr, "Server now attached to client at %s:%d\n",
inet_ntoa( remote_addr.sin_addr ),
ntohs( remote_addr.sin_port ) );
}
}
}
return p.payload; /* we do return out-of-order or duplicated packets to caller */
}
int Connection::port( void ) const
{
struct sockaddr_in local_addr;
socklen_t addrlen = sizeof( local_addr );
if ( getsockname( sock, (sockaddr *)&local_addr, &addrlen ) < 0 ) {
throw NetworkException( "getsockname", errno );
}
return ntohs( local_addr.sin_port );
}
uint64_t Network::timestamp( void )
{
struct timespec tp;
if ( clock_gettime( CLOCK_MONOTONIC, &tp ) < 0 ) {
throw NetworkException( "clock_gettime", errno );
}
uint64_t millis = tp.tv_nsec / 1000000;
millis += uint64_t( tp.tv_sec ) * 1000;
return millis;
}
uint16_t Network::timestamp16( void )
{
uint16_t ts = timestamp() % 65536;
if ( ts == uint16_t(-1) ) {
ts++;
}
return ts;
}
uint16_t Network::timestamp_diff( uint16_t tsnew, uint16_t tsold )
{
int diff = tsnew - tsold;
if ( diff < 0 ) {
diff += 65536;
}
assert( diff >= 0 );
assert( diff <= 65535 );
return diff;
}
uint64_t Connection::timeout( void ) const
{
uint64_t RTO = lrint( ceil( SRTT + 4 * RTTVAR ) );
if ( RTO < MIN_RTO ) {
RTO = MIN_RTO;
} else if ( RTO > MAX_RTO ) {
RTO = MAX_RTO;
}
return RTO;
}
class Socket {
public:
int fd;
Socket( int domain, int type, int protocol )
: fd( socket( domain, type, protocol ) )
{
if ( fd < 0 ) {
throw NetworkException( "socket", errno );
}
}
~Socket()
{
if ( close( fd ) < 0 ) {
throw NetworkException( "close", errno );
}
}
};
void Connection::update_MTU( void )
{
if ( !attached ) {
return;
}
/* We don't want to use our main socket because we don't want to have to connect it */
Socket path_MTU_socket( AF_INET, SOCK_DGRAM, 0 );
/* Connect socket so we can retrieve path MTU */
if ( connect( path_MTU_socket.fd, (sockaddr *)&remote_addr, sizeof( remote_addr ) ) < 0 ) {
throw NetworkException( "connect", errno );
}
int PMTU;
socklen_t optlen = sizeof( PMTU );
if ( getsockopt( path_MTU_socket.fd, IPPROTO_IP, IP_MTU, &PMTU, &optlen ) < 0 ) {
throw NetworkException( "getsockopt", errno );
}
if ( optlen != sizeof( PMTU ) ) {
throw NetworkException( "Error getting path MTU", errno );
}
MTU = min( PMTU, int(SEND_MTU) ); /* need cast to compile without optimization! XXX */
}
+105
View File
@@ -0,0 +1,105 @@
#ifndef NETWORK_HPP
#define NETWORK_HPP
#include <stdint.h>
#include <deque>
#include <sys/socket.h>
#include <netinet/in.h>
#include <string>
#include <math.h>
#include "crypto.h"
using namespace std;
using namespace Crypto;
namespace Network {
static const unsigned int MOSH_PROTOCOL_VERSION = 1;
uint64_t timestamp( void );
uint16_t timestamp16( void );
uint16_t timestamp_diff( uint16_t tsnew, uint16_t tsold );
class NetworkException {
public:
string function;
int the_errno;
NetworkException( string s_function, int s_errno ) : function( s_function ), the_errno( s_errno ) {}
};
enum Direction {
TO_SERVER = 0,
TO_CLIENT = 1
};
class Packet {
public:
uint64_t seq;
Direction direction;
uint16_t timestamp, timestamp_reply;
string payload;
Packet( uint64_t s_seq, Direction s_direction,
uint16_t s_timestamp, uint16_t s_timestamp_reply, string s_payload )
: seq( s_seq ), direction( s_direction ),
timestamp( s_timestamp ), timestamp_reply( s_timestamp_reply ), payload( s_payload )
{}
Packet( string coded_packet, Session *session );
string tostring( Session *session );
};
class Connection {
private:
static const int RECEIVE_MTU = 2048;
static const int SEND_MTU = 1400;
static const uint64_t MIN_RTO = 50; /* ms */
static const uint64_t MAX_RTO = 1000; /* ms */
int sock;
struct sockaddr_in remote_addr;
bool server;
bool attached;
int MTU;
Base64Key key;
Session session;
void setup( void );
Direction direction;
uint64_t next_seq;
uint16_t saved_timestamp;
uint64_t saved_timestamp_received_at;
uint64_t expected_receiver_seq;
bool RTT_hit;
double SRTT;
double RTTVAR;
Packet new_packet( string &s_payload );
void update_MTU( void );
public:
Connection( const char *desired_ip ); /* server */
Connection( const char *key_str, const char *ip, int port ); /* client */
void send( string s );
string recv( void );
int fd( void ) const { return sock; }
int get_MTU( void ) const { return MTU; }
int port( void ) const;
string get_key( void ) const { return key.printable_key(); }
bool get_attached( void ) const { return attached; }
uint64_t timeout( void ) const;
double get_SRTT( void ) const { return SRTT; }
};
}
#endif
+154
View File
@@ -0,0 +1,154 @@
#include <assert.h>
#include <iostream>
#include "networktransport.h"
#include "transportsender.cc"
using namespace Network;
using namespace std;
template <class MyState, class RemoteState>
Transport<MyState, RemoteState>::Transport( MyState &initial_state, RemoteState &initial_remote,
const char *desired_ip )
: connection( desired_ip ),
sender( &connection, initial_state ),
received_states( 1, TimestampedState<RemoteState>( timestamp(), 0, initial_remote ) ),
last_receiver_state( initial_remote ),
sent_state_late_acked( 0 ),
fragments(),
verbose( false )
{
/* server */
}
template <class MyState, class RemoteState>
Transport<MyState, RemoteState>::Transport( MyState &initial_state, RemoteState &initial_remote,
const char *key_str, const char *ip, int port )
: connection( key_str, ip, port ),
sender( &connection, initial_state ),
received_states( 1, TimestampedState<RemoteState>( timestamp(), 0, initial_remote ) ),
last_receiver_state( initial_remote ),
sent_state_late_acked( 0 ),
fragments(),
verbose( false )
{
/* client */
}
template <class MyState, class RemoteState>
void Transport<MyState, RemoteState>::recv( void )
{
string s( connection.recv() );
Fragment frag( s );
if ( fragments.add_fragment( frag ) ) { /* complete packet */
Instruction inst = fragments.get_assembly();
if ( inst.protocol_version() != MOSH_PROTOCOL_VERSION ) {
throw NetworkException( "mosh protocol version mismatch", 0 );
}
sender.process_acknowledgment_through( inst.ack_num() );
if ( inst.late_ack_num() > sent_state_late_acked ) {
sent_state_late_acked = inst.late_ack_num();
}
/* first, make sure we don't already have the new state */
for ( typename list< TimestampedState<RemoteState> >::iterator i = received_states.begin();
i != received_states.end();
i++ ) {
if ( inst.new_num() == i->num ) {
return;
}
}
/* now, make sure we do have the old state */
bool found = 0;
typename list< TimestampedState<RemoteState> >::iterator reference_state = received_states.begin();
while ( reference_state != received_states.end() ) {
if ( inst.old_num() == reference_state->num ) {
found = true;
break;
}
reference_state++;
}
if ( !found ) {
// fprintf( stderr, "Ignoring out-of-order packet. Reference state %d has been discarded or hasn't yet been received.\n", int(inst.old_num) );
return; /* this is security-sensitive and part of how we enforce idempotency */
}
/* apply diff to reference state */
TimestampedState<RemoteState> new_state = *reference_state;
new_state.timestamp = timestamp();
new_state.num = inst.new_num();
if ( !inst.diff().empty() ) {
new_state.state.apply_string( inst.diff() );
}
process_throwaway_until( inst.throwaway_num() );
/* Insert new state in sorted place */
for ( typename list< TimestampedState<RemoteState> >::iterator i = received_states.begin();
i != received_states.end();
i++ ) {
if ( i->num > new_state.num ) {
received_states.insert( i, new_state );
if ( verbose ) {
fprintf( stderr, "[%u] Received OUT-OF-ORDER state %d [ack %d]\n",
(unsigned int)(timestamp() % 100000), (int)new_state.num, (int)inst.ack_num() );
}
return;
}
}
if ( verbose ) {
fprintf( stderr, "[%u] Received state %d [ack %d]\n",
(unsigned int)(timestamp() % 100000), (int)new_state.num, (int)inst.ack_num() );
}
received_states.push_back( new_state );
sender.set_ack_num( received_states.back().num );
if ( !inst.diff().empty() ) {
sender.set_data_ack();
}
}
}
/* The sender uses throwaway_num to tell us the earliest received state that we need to keep around */
template <class MyState, class RemoteState>
void Transport<MyState, RemoteState>::process_throwaway_until( uint64_t throwaway_num )
{
typename list< TimestampedState<RemoteState> >::iterator i = received_states.begin();
while ( i != received_states.end() ) {
typename list< TimestampedState<RemoteState> >::iterator inext = i;
inext++;
if ( i->num < throwaway_num ) {
received_states.erase( i );
}
i = inext;
}
assert( received_states.size() > 0 );
}
template <class MyState, class RemoteState>
string Transport<MyState, RemoteState>::get_remote_diff( void )
{
/* find diff between last receiver state and current remote state, then rationalize states */
string ret( received_states.back().state.diff_from( last_receiver_state ) );
const RemoteState *oldest_receiver_state = &received_states.front().state;
for ( typename list< TimestampedState<RemoteState> >::reverse_iterator i = received_states.rbegin();
i != received_states.rend();
i++ ) {
i->state.subtract( oldest_receiver_state );
}
last_receiver_state = received_states.back().state;
return ret;
}
+89
View File
@@ -0,0 +1,89 @@
#ifndef NETWORK_TRANSPORT_HPP
#define NETWORK_TRANSPORT_HPP
#include <string>
#include <signal.h>
#include <time.h>
#include <list>
#include <vector>
#include "network.h"
#include "transportsender.h"
#include "transportfragment.h"
using namespace std;
namespace Network {
template <class MyState, class RemoteState>
class Transport
{
private:
/* the underlying, encrypted network connection */
Connection connection;
/* sender side */
TransportSender<MyState> sender;
/* helper methods for recv() */
void process_throwaway_until( uint64_t throwaway_num );
/* simple receiver */
list< TimestampedState<RemoteState> > received_states;
RemoteState last_receiver_state; /* the state we were in when user last queried state */
uint64_t sent_state_late_acked;
FragmentAssembly fragments;
bool verbose;
public:
Transport( MyState &initial_state, RemoteState &initial_remote, const char *desired_ip );
Transport( MyState &initial_state, RemoteState &initial_remote,
const char *key_str, const char *ip, int port );
/* Send data or an ack if necessary. */
void tick( void ) { sender.tick(); }
/* Returns the number of ms to wait until next possible event. */
int wait_time( void ) { return sender.wait_time(); }
/* Blocks waiting for a packet. */
void recv( void );
/* Find diff between last receiver state and current remote state, then rationalize states. */
string get_remote_diff( void );
/* Shut down other side of connection. */
/* Illegal to change current_state after this. */
void start_shutdown( void ) { sender.start_shutdown(); }
bool shutdown_in_progress( void ) const { return sender.get_shutdown_in_progress(); }
bool shutdown_acknowledged( void ) const { return sender.get_shutdown_acknowledged(); }
bool shutdown_ack_timed_out( void ) const { return sender.shutdown_ack_timed_out(); }
bool attached( void ) const { return connection.get_attached(); }
/* Other side has requested shutdown and we have sent one ACK */
bool counterparty_shutdown_ack_sent( void ) const { return sender.get_counterparty_shutdown_acknowledged(); }
int port( void ) const { return connection.port(); }
string get_key( void ) const { return connection.get_key(); }
MyState &get_current_state( void ) { return sender.get_current_state(); }
void set_current_state( const MyState &x ) { sender.set_current_state( x ); }
uint64_t get_remote_state_num( void ) const { return received_states.back().num; }
const TimestampedState<RemoteState> & get_latest_remote_state( void ) const { return received_states.back(); }
int fd( void ) const { return connection.fd(); }
void set_verbose( void ) { sender.set_verbose(); verbose = true; }
void set_send_delay( int new_delay ) { sender.set_send_delay( new_delay ); }
uint64_t get_sent_state_acked( void ) const { return sender.get_sent_state_acked(); }
uint64_t get_sent_state_last( void ) const { return sender.get_sent_state_last(); }
uint64_t get_sent_state_late_acked( void ) const { return sent_state_late_acked; }
unsigned int send_interval( void ) const { return sender.send_interval(); }
};
}
#endif
+162
View File
@@ -0,0 +1,162 @@
#include <endian.h>
#include <assert.h>
#include "transportfragment.h"
#include "transportinstruction.pb.h"
using namespace Network;
using namespace TransportBuffers;
static string network_order_string( uint16_t host_order )
{
uint16_t net_int = htobe16( host_order );
return string( (char *)&net_int, sizeof( net_int ) );
}
static string network_order_string( uint64_t host_order )
{
uint64_t net_int = htobe64( host_order );
return string( (char *)&net_int, sizeof( net_int ) );
}
string Fragment::tostring( void )
{
assert( initialized );
string ret;
ret += network_order_string( id );
assert( !( fragment_num & 0x8000 ) ); /* effective limit on size of a terminal screen change or buffered user input */
uint16_t combined_fragment_num = ( final << 15 ) | fragment_num;
ret += network_order_string( combined_fragment_num );
assert( ret.size() == frag_header_len );
ret += contents;
return ret;
}
Fragment::Fragment( string &x )
: id( -1 ), fragment_num( -1 ), final( false ), initialized( true ),
contents( x.begin() + frag_header_len, x.end() )
{
assert( x.size() >= frag_header_len );
uint64_t *data64 = (uint64_t *)x.data();
uint16_t *data16 = (uint16_t *)x.data();
id = be64toh( data64[ 0 ] );
fragment_num = be16toh( data16[ 4 ] );
final = ( fragment_num & 0x8000 ) >> 15;
fragment_num &= 0x7FFF;
}
bool FragmentAssembly::add_fragment( Fragment &frag )
{
/* see if this is a totally new packet */
if ( current_id != frag.id ) {
fragments.clear();
fragments.resize( frag.fragment_num + 1 );
fragments.at( frag.fragment_num ) = frag;
fragments_arrived = 1;
fragments_total = -1; /* unknown */
current_id = frag.id;
} else { /* not a new packet */
/* see if we already have this fragment */
if ( (fragments.size() > frag.fragment_num)
&& (fragments.at( frag.fragment_num ).initialized) ) {
/* make sure new version is same as what we already have */
assert( fragments.at( frag.fragment_num ) == frag );
} else {
if ( (int)fragments.size() < frag.fragment_num + 1 ) {
fragments.resize( frag.fragment_num + 1 );
}
fragments.at( frag.fragment_num ) = frag;
fragments_arrived++;
}
}
if ( frag.final ) {
fragments_total = frag.fragment_num + 1;
assert( (int)fragments.size() <= fragments_total );
fragments.resize( fragments_total );
}
if ( fragments_total != -1 ) {
assert( fragments_arrived <= fragments_total );
}
/* see if we're done */
return ( fragments_arrived == fragments_total );
}
Instruction FragmentAssembly::get_assembly( void )
{
assert( fragments_arrived == fragments_total );
string encoded;
for ( int i = 0; i < fragments_total; i++ ) {
assert( fragments.at( i ).initialized );
encoded += fragments.at( i ).contents;
}
Instruction ret;
assert( ret.ParseFromString( encoded ) );
fragments.clear();
fragments_arrived = 0;
fragments_total = -1;
return ret;
}
bool Fragment::operator==( const Fragment &x )
{
return ( id == x.id ) && ( fragment_num == x.fragment_num ) && ( final == x.final )
&& ( initialized == x.initialized ) && ( contents == x.contents );
}
vector<Fragment> Fragmenter::make_fragments( Instruction &inst, int MTU )
{
if ( (inst.old_num() != last_instruction.old_num())
|| (inst.new_num() != last_instruction.new_num())
|| (inst.ack_num() != last_instruction.ack_num())
|| (inst.throwaway_num() != last_instruction.throwaway_num())
|| (inst.late_ack_num() != last_instruction.late_ack_num())
|| (inst.protocol_version() != last_instruction.protocol_version())
|| (last_MTU != MTU) ) {
next_instruction_id++;
}
if ( (inst.old_num() == last_instruction.old_num())
&& (inst.new_num() == last_instruction.new_num()) ) {
assert( inst.diff() == last_instruction.diff() );
}
last_instruction = inst;
last_MTU = MTU;
string payload = inst.SerializeAsString();
uint16_t fragment_num = 0;
vector<Fragment> ret;
while ( !payload.empty() ) {
string this_fragment;
bool final = false;
if ( int( payload.size() + HEADER_LEN ) > MTU ) {
this_fragment = string( payload.begin(), payload.begin() + MTU - HEADER_LEN );
payload = string( payload.begin() + MTU - HEADER_LEN, payload.end() );
} else {
this_fragment = payload;
payload.clear();
final = true;
}
ret.push_back( Fragment( next_instruction_id, fragment_num++, final, this_fragment ) );
}
return ret;
}
+78
View File
@@ -0,0 +1,78 @@
#ifndef TRANSPORT_FRAGMENT_HPP
#define TRANSPORT_FRAGMENT_HPP
#include <stdint.h>
#include <vector>
#include <string>
#include "transportinstruction.pb.h"
using namespace std;
using namespace TransportBuffers;
namespace Network {
static const int HEADER_LEN = 66;
class Fragment
{
private:
static const size_t frag_header_len = sizeof( uint64_t ) + sizeof( uint16_t );
public:
uint64_t id;
uint16_t fragment_num;
bool final;
bool initialized;
string contents;
Fragment()
: id( -1 ), fragment_num( -1 ), final( false ), initialized( false ), contents()
{}
Fragment( uint64_t s_id, uint16_t s_fragment_num, bool s_final, string s_contents )
: id( s_id ), fragment_num( s_fragment_num ), final( s_final ), initialized( true ),
contents( s_contents )
{}
Fragment( string &x );
string tostring( void );
bool operator==( const Fragment &x );
};
class FragmentAssembly
{
private:
vector<Fragment> fragments;
uint64_t current_id;
int fragments_arrived, fragments_total;
public:
FragmentAssembly() : fragments(), current_id( -1 ), fragments_arrived( 0 ), fragments_total( -1 ) {}
bool add_fragment( Fragment &inst );
Instruction get_assembly( void );
};
class Fragmenter
{
private:
uint64_t next_instruction_id;
Instruction last_instruction;
int last_MTU;
public:
Fragmenter() : next_instruction_id( 0 ), last_instruction(), last_MTU( -1 )
{
last_instruction.set_old_num( -1 );
last_instruction.set_new_num( -1 );
}
vector<Fragment> make_fragments( Instruction &inst, int MTU );
uint64_t last_ack_sent( void ) const { return last_instruction.ack_num(); }
};
}
#endif
+301
View File
@@ -0,0 +1,301 @@
#include <algorithm>
#include <list>
#include "transportsender.h"
#include "transportfragment.h"
using namespace Network;
template <class MyState>
TransportSender<MyState>::TransportSender( Connection *s_connection, MyState &initial_state )
: connection( s_connection ),
current_state( initial_state ),
sent_states( 1, TimestampedState<MyState>( timestamp(), 0, initial_state ) ),
assumed_receiver_state( sent_states.begin() ),
fragmenter(),
next_ack_time( timestamp() ),
next_send_time( timestamp() ),
verbose( false ),
shutdown_in_progress( false ),
shutdown_tries( 0 ),
ack_num( 0 ),
pending_data_ack( false ),
ack_timestamp( 0 ),
ack_history(),
SEND_MINDELAY( 15 )
{
}
/* Try to send roughly two frames per RTT, bounded by limits on frame rate */
template <class MyState>
unsigned int TransportSender<MyState>::send_interval( void ) const
{
int SEND_INTERVAL = lrint( ceil( connection->get_SRTT() / 2.0 ) );
if ( SEND_INTERVAL < SEND_INTERVAL_MIN ) {
SEND_INTERVAL = SEND_INTERVAL_MIN;
} else if ( SEND_INTERVAL > SEND_INTERVAL_MAX ) {
SEND_INTERVAL = SEND_INTERVAL_MAX;
}
return SEND_INTERVAL;
}
/* How many ms can the caller wait before we will have an event (empty ack or next frame)? */
template <class MyState>
int TransportSender<MyState>::wait_time( void )
{
if ( pending_data_ack && (next_ack_time > timestamp() + ACK_DELAY) ) {
next_ack_time = timestamp() + ACK_DELAY;
}
if ( !(current_state == sent_states.back().state) ) { /* pending data to send */
if ( next_send_time > timestamp() + SEND_MINDELAY ) {
next_send_time = timestamp() + SEND_MINDELAY;
}
if ( next_send_time < sent_states.back().timestamp + send_interval() ) {
next_send_time = sent_states.back().timestamp + send_interval();
}
}
/* speed up shutdown sequence */
if ( shutdown_in_progress || (ack_num == uint64_t(-1)) ) {
next_ack_time = sent_states.back().timestamp + send_interval();
}
uint64_t next_wakeup = next_ack_time;
if ( next_send_time < next_wakeup ) {
next_wakeup = next_send_time;
}
if ( !connection->get_attached() ) {
return -1;
}
if ( next_wakeup > timestamp() ) {
return next_wakeup - timestamp();
} else {
return 0;
}
}
/* Send data or an empty ack if necessary */
template <class MyState>
void TransportSender<MyState>::tick( void )
{
wait_time();
if ( !connection->get_attached() ) {
return;
}
if ( (timestamp() < next_ack_time)
&& (timestamp() < next_send_time) ) {
return;
}
/* Determine if a new diff or empty ack needs to be sent */
/* Update assumed receiver state */
update_assumed_receiver_state();
/* Cut out common prefix of all states */
rationalize_states();
string diff = current_state.diff_from( assumed_receiver_state->state );
if ( diff.empty() && (timestamp() >= next_ack_time) ) {
send_empty_ack();
return;
}
if ( !diff.empty() && ( (timestamp() >= next_send_time)
|| (timestamp() >= next_ack_time) ) ) {
/* Send diffs or ack */
send_to_receiver( diff );
}
}
template <class MyState>
void TransportSender<MyState>::send_empty_ack( void )
{
assert ( timestamp() >= next_ack_time );
uint64_t new_num = sent_states.back().num + 1;
/* special case for shutdown sequence */
if ( shutdown_in_progress ) {
new_num = uint64_t( -1 );
}
// sent_states.push_back( TimestampedState<MyState>( sent_states.back().timestamp, new_num, current_state ) );
add_sent_state( sent_states.back().timestamp, new_num, current_state );
send_in_fragments( "", new_num );
next_ack_time = timestamp() + ACK_INTERVAL;
}
template <class MyState>
void TransportSender<MyState>::add_sent_state( uint64_t the_timestamp, uint64_t num, MyState &state )
{
sent_states.push_back( TimestampedState<MyState>( the_timestamp, num, state ) );
if ( sent_states.size() > 32 ) { /* limit on state queue */
auto last = sent_states.end();
for ( int i = 0; i < 16; i++ ) { last--; }
sent_states.erase( last ); /* erase state from middle of queue */
}
}
template <class MyState>
void TransportSender<MyState>::send_to_receiver( string diff )
{
uint64_t new_num;
if ( current_state == sent_states.back().state ) { /* previously sent */
new_num = sent_states.back().num;
} else { /* new state */
new_num = sent_states.back().num + 1;
}
/* special case for shutdown sequence */
if ( shutdown_in_progress ) {
new_num = uint64_t( -1 );
}
if ( new_num == sent_states.back().num ) {
sent_states.back().timestamp = timestamp();
} else {
add_sent_state( timestamp(), new_num, current_state );
}
send_in_fragments( diff, new_num ); // Can throw NetworkException
/* successfully sent, probably */
/* ("probably" because the FIRST size-exceeded datagram doesn't get an error) */
assumed_receiver_state = sent_states.end();
assumed_receiver_state--;
next_ack_time = timestamp() + ACK_INTERVAL;
next_send_time = uint64_t(-1);
}
template <class MyState>
void TransportSender<MyState>::update_assumed_receiver_state( void )
{
uint64_t now = timestamp();
/* start from what is known and give benefit of the doubt to unacknowledged states
transmitted recently enough ago */
assumed_receiver_state = sent_states.begin();
typename list< TimestampedState<MyState> >::iterator i = sent_states.begin();
i++;
while ( i != sent_states.end() ) {
assert( now >= i->timestamp );
if ( uint64_t(now - i->timestamp) < connection->timeout() + ACK_DELAY ) {
assumed_receiver_state = i;
} else {
return;
}
i++;
}
}
template <class MyState>
void TransportSender<MyState>::rationalize_states( void )
{
const MyState * known_receiver_state = &sent_states.front().state;
current_state.subtract( known_receiver_state );
for ( typename list< TimestampedState<MyState> >::reverse_iterator i = sent_states.rbegin();
i != sent_states.rend();
i++ ) {
i->state.subtract( known_receiver_state );
}
}
template <class MyState>
void TransportSender<MyState>::send_in_fragments( string diff, uint64_t new_num )
{
Instruction inst;
uint64_t now = timestamp();
inst.set_protocol_version( MOSH_PROTOCOL_VERSION );
inst.set_old_num( assumed_receiver_state->num );
inst.set_new_num( new_num );
inst.set_ack_num( ack_num );
inst.set_throwaway_num( sent_states.front().num );
inst.set_late_ack_num( get_late_ack( now ) );
inst.set_diff( diff );
if ( new_num == uint64_t(-1) ) {
shutdown_tries++;
}
vector<Fragment> fragments = fragmenter.make_fragments( inst, connection->get_MTU() );
for ( auto i = fragments.begin(); i != fragments.end(); i++ ) {
connection->send( i->tostring() );
if ( verbose ) {
fprintf( stderr, "[%u] Sent [%d=>%d] id %d, frag %d ack=%d, late_ack=%d, throwaway=%d, len=%d, frame rate=%.2f, timeout=%d, srtt=%.1f age=%llu\n",
(unsigned int)(timestamp() % 100000), (int)inst.old_num(), (int)inst.new_num(), (int)i->id, (int)i->fragment_num,
(int)inst.ack_num(), (int)inst.late_ack_num(), (int)inst.throwaway_num(), (int)i->contents.size(),
1000.0 / (double)send_interval(),
(int)connection->timeout(), connection->get_SRTT(),
(long long)(now - ack_timestamp) );
}
}
pending_data_ack = false;
}
template <class MyState>
void TransportSender<MyState>::process_acknowledgment_through( uint64_t ack_num )
{
/* Ignore ack if we have culled the state it's acknowledging */
if ( sent_states.end() != find_if( sent_states.begin(), sent_states.end(),
[&]( const TimestampedState<MyState> &x ) { return x.num == ack_num; } ) ) {
sent_states.remove_if( [&]( const TimestampedState<MyState> &x ) { return x.num < ack_num; } );
}
assert( !sent_states.empty() );
}
/* give up on getting acknowledgement for shutdown */
template <class MyState>
bool TransportSender<MyState>::shutdown_ack_timed_out( void ) const
{
return shutdown_tries >= SHUTDOWN_RETRIES;
}
/* Executed upon entry to new receiver state */
template <class MyState>
void TransportSender<MyState>::set_ack_num( uint64_t s_ack_num )
{
ack_num = s_ack_num;
ack_timestamp = timestamp();
ack_history.push_back( make_pair( ack_num, ack_timestamp ) );
}
/* The "late" ack is for the input state that has had enough time on the host to have been echoed */
template <class MyState>
uint64_t TransportSender<MyState>::get_late_ack( uint64_t now )
{
uint64_t newest_echo_ack = 0;
for ( auto i = ack_history.begin(); i != ack_history.end(); i++ ) {
if ( i->second < now - ECHO_TIMEOUT ) {
newest_echo_ack = i->first;
}
}
ack_history.remove_if( [&]( const pair<uint64_t, uint64_t> &x ) { return x.first < newest_echo_ack; } );
return newest_echo_ack;
}
+116
View File
@@ -0,0 +1,116 @@
#ifndef TRANSPORT_SENDER_HPP
#define TRANSPORT_SENDER_HPP
#include <string>
#include <list>
#include "network.h"
#include "transportinstruction.pb.h"
#include "transportstate.h"
#include "transportfragment.h"
using namespace std;
using namespace TransportBuffers;
namespace Network {
template <class MyState>
class TransportSender
{
private:
/* timing parameters */
static const int SEND_INTERVAL_MIN = 20; /* ms between frames */
static const int SEND_INTERVAL_MAX = 250; /* ms between frames */
static const int ACK_INTERVAL = 3000; /* ms between empty acks */
static const int ACK_DELAY = 100; /* ms before delayed ack */
static const int SHUTDOWN_RETRIES = 3; /* number of shutdown packets to send before giving up */
static const int ECHO_TIMEOUT = 50; /* for late ack */
/* helper methods for tick() */
void update_assumed_receiver_state( void );
void rationalize_states( void );
void send_to_receiver( string diff );
void send_empty_ack( void );
void send_in_fragments( string diff, uint64_t new_num );
void add_sent_state( uint64_t the_timestamp, uint64_t num, MyState &state );
/* state of sender */
Connection *connection;
MyState current_state;
list< TimestampedState<MyState> > sent_states;
/* first element: known, acknowledged receiver state */
/* last element: last sent state */
/* somewhere in the middle: the assumed state of the receiver */
typename list< TimestampedState<MyState> >::iterator assumed_receiver_state;
/* for fragment creation */
Fragmenter fragmenter;
/* timing state */
uint64_t next_ack_time;
uint64_t next_send_time;
bool verbose;
bool shutdown_in_progress;
int shutdown_tries;
/* information about receiver state */
uint64_t ack_num;
bool pending_data_ack;
uint64_t ack_timestamp;
list< pair<uint64_t, uint64_t> > ack_history;
uint64_t get_late_ack( uint64_t now ); /* calculate delayed "echo" acknowledgment */
unsigned int SEND_MINDELAY; /* ms to collect all input */
public:
/* constructor */
TransportSender( Connection *s_connection, MyState &initial_state );
/* Send data or an ack if necessary */
void tick( void );
/* Returns the number of ms to wait until next possible event. */
int wait_time( void );
/* Executed upon receipt of ack */
void process_acknowledgment_through( uint64_t ack_num );
/* Executed upon entry to new receiver state */
void set_ack_num( uint64_t s_ack_num );
/* Accelerate reply ack */
void set_data_ack( void ) { pending_data_ack = true; }
/* Starts shutdown sequence */
void start_shutdown( void ) { shutdown_in_progress = true; }
/* Misc. getters and setters */
/* Cannot modify current_state while shutdown in progress */
MyState &get_current_state( void ) { assert( !shutdown_in_progress ); return current_state; }
void set_current_state( const MyState &x ) { assert( !shutdown_in_progress ); current_state = x; }
void set_verbose( void ) { verbose = true; }
bool get_shutdown_in_progress( void ) const { return shutdown_in_progress; }
bool get_shutdown_acknowledged( void ) const { return sent_states.front().num == uint64_t(-1); }
bool get_counterparty_shutdown_acknowledged( void ) const { return fragmenter.last_ack_sent() == uint64_t(-1); }
uint64_t get_sent_state_acked( void ) const { return sent_states.front().num; }
uint64_t get_sent_state_last( void ) const { return sent_states.back().num; }
bool shutdown_ack_timed_out( void ) const;
void set_send_delay( int new_delay ) { SEND_MINDELAY = new_delay; }
unsigned int send_interval( void ) const;
/* nonexistent methods to satisfy -Weffc++ */
TransportSender( const TransportSender &x );
TransportSender & operator=( const TransportSender &x );
};
}
#endif
+19
View File
@@ -0,0 +1,19 @@
#ifndef TRANSPORT_STATE_HPP
#define TRANSPORT_STATE_HPP
namespace Network {
template <class State>
class TimestampedState
{
public:
uint64_t timestamp;
uint64_t num;
State state;
TimestampedState( uint64_t s_timestamp, uint64_t s_num, State &s_state )
: timestamp( s_timestamp ), num( s_num ), state( s_state )
{}
};
}
#endif