Separate modules by subdirectory
This commit is contained in:
@@ -0,0 +1,6 @@
|
||||
AM_CPPFLAGS = -I$(srcdir)/../util -I$(srcdir)/../crypto -I$(builddir)/../protobufs
|
||||
AM_CXXFLAGS = --std=c++0x -pedantic -Werror -Wall -Wextra -Weffc++ -fno-default-inline -pipe
|
||||
|
||||
noinst_LIBRARIES = libmoshnetwork.a
|
||||
|
||||
libmoshnetwork_a_SOURCES = network.cc network.h networktransport.cc networktransport.h transportfragment.cc transportfragment.h transportsender.cc transportsender.h transportstate.h
|
||||
@@ -0,0 +1,359 @@
|
||||
#include <sys/types.h>
|
||||
#include <sys/socket.h>
|
||||
#include <netinet/in.h>
|
||||
#include <arpa/inet.h>
|
||||
#include <assert.h>
|
||||
#include <endian.h>
|
||||
|
||||
#include "dos_assert.h"
|
||||
#include "network.h"
|
||||
#include "crypto.h"
|
||||
|
||||
using namespace std;
|
||||
using namespace Network;
|
||||
using namespace Crypto;
|
||||
|
||||
const uint64_t DIRECTION_MASK = uint64_t(1) << 63;
|
||||
const uint64_t SEQUENCE_MASK = uint64_t(-1) ^ DIRECTION_MASK;
|
||||
|
||||
/* Read in packet from coded string */
|
||||
Packet::Packet( string coded_packet, Session *session )
|
||||
: seq( -1 ),
|
||||
direction( TO_SERVER ),
|
||||
timestamp( -1 ),
|
||||
timestamp_reply( -1 ),
|
||||
payload()
|
||||
{
|
||||
Message message = session->decrypt( coded_packet );
|
||||
|
||||
direction = (message.nonce.val() & DIRECTION_MASK) ? TO_CLIENT : TO_SERVER;
|
||||
seq = message.nonce.val() & SEQUENCE_MASK;
|
||||
|
||||
dos_assert( message.text.size() >= 2 * sizeof( uint16_t ) );
|
||||
|
||||
uint16_t *data = (uint16_t *)message.text.data();
|
||||
timestamp = be16toh( data[ 0 ] );
|
||||
timestamp_reply = be16toh( data[ 1 ] );
|
||||
|
||||
payload = string( message.text.begin() + 2 * sizeof( uint16_t ), message.text.end() );
|
||||
}
|
||||
|
||||
/* Output coded string from packet */
|
||||
string Packet::tostring( Session *session )
|
||||
{
|
||||
uint64_t direction_seq = (uint64_t( direction == TO_CLIENT ) << 63) | (seq & SEQUENCE_MASK);
|
||||
|
||||
uint16_t ts_net[ 2 ] = { htobe16( timestamp ), htobe16( timestamp_reply ) };
|
||||
|
||||
string timestamps = string( (char *)ts_net, 2 * sizeof( uint16_t ) );
|
||||
|
||||
return session->encrypt( Message( Nonce( direction_seq ), timestamps + payload ) );
|
||||
}
|
||||
|
||||
Packet Connection::new_packet( string &s_payload )
|
||||
{
|
||||
uint16_t outgoing_timestamp_reply = -1;
|
||||
|
||||
uint64_t now = timestamp();
|
||||
|
||||
if ( now - saved_timestamp_received_at < 1000 ) { /* we have a recent received timestamp */
|
||||
/* send "corrected" timestamp advanced by how long we held it */
|
||||
outgoing_timestamp_reply = saved_timestamp + (now - saved_timestamp_received_at);
|
||||
saved_timestamp = -1;
|
||||
saved_timestamp_received_at = 0;
|
||||
}
|
||||
|
||||
Packet p( next_seq++, direction, timestamp16(), outgoing_timestamp_reply, s_payload );
|
||||
|
||||
return p;
|
||||
}
|
||||
|
||||
void Connection::setup( void )
|
||||
{
|
||||
/* create socket */
|
||||
sock = socket( AF_INET, SOCK_DGRAM, 0 );
|
||||
if ( sock < 0 ) {
|
||||
throw NetworkException( "socket", errno );
|
||||
}
|
||||
|
||||
/* Enable path MTU discovery */
|
||||
char flag = IP_PMTUDISC_DO;
|
||||
socklen_t optlen = sizeof( flag );
|
||||
if ( setsockopt( sock, IPPROTO_IP, IP_MTU_DISCOVER, &flag, optlen ) < 0 ) {
|
||||
throw NetworkException( "setsockopt", errno );
|
||||
}
|
||||
}
|
||||
|
||||
Connection::Connection( const char *desired_ip ) /* server */
|
||||
: sock( -1 ),
|
||||
remote_addr(),
|
||||
server( true ),
|
||||
attached( false ),
|
||||
MTU( SEND_MTU ),
|
||||
key(),
|
||||
session( key ),
|
||||
direction( TO_CLIENT ),
|
||||
next_seq( 0 ),
|
||||
saved_timestamp( -1 ),
|
||||
saved_timestamp_received_at( 0 ),
|
||||
expected_receiver_seq( 0 ),
|
||||
RTT_hit( false ),
|
||||
SRTT( 1000 ),
|
||||
RTTVAR( 500 )
|
||||
{
|
||||
setup();
|
||||
|
||||
/* Attempt to bind free local port, with
|
||||
address client used to connect to us.
|
||||
|
||||
This usage does not seem to be endorsed by POSIX. */
|
||||
|
||||
struct sockaddr_in local_addr;
|
||||
local_addr.sin_family = AF_INET;
|
||||
local_addr.sin_port = htons( 0 );
|
||||
if ( desired_ip
|
||||
&& inet_aton( desired_ip, &local_addr.sin_addr )
|
||||
&& (bind( sock, (sockaddr *)&local_addr, sizeof( local_addr ) ) == 0) ) {
|
||||
return;
|
||||
}
|
||||
|
||||
if ( desired_ip ) {
|
||||
fprintf( stderr, "Could not bind to desired local address %s.\n", desired_ip );
|
||||
}
|
||||
|
||||
/* Could not bind to that IP (maybe we are behind NAT).
|
||||
Try again with any IP. */
|
||||
local_addr.sin_addr.s_addr = INADDR_ANY;
|
||||
if ( bind( sock, (sockaddr *)&local_addr, sizeof( local_addr ) ) < 0 ) {
|
||||
throw NetworkException( "bind", errno );
|
||||
}
|
||||
}
|
||||
|
||||
Connection::Connection( const char *key_str, const char *ip, int port ) /* client */
|
||||
: sock( -1 ),
|
||||
remote_addr(),
|
||||
server( false ),
|
||||
attached( false ),
|
||||
MTU( SEND_MTU ),
|
||||
key( key_str ),
|
||||
session( key ),
|
||||
direction( TO_SERVER ),
|
||||
next_seq( 0 ),
|
||||
saved_timestamp( -1 ),
|
||||
saved_timestamp_received_at( 0 ),
|
||||
expected_receiver_seq( 0 ),
|
||||
RTT_hit( false ),
|
||||
SRTT( 1000 ),
|
||||
RTTVAR( 500 )
|
||||
{
|
||||
setup();
|
||||
|
||||
/* associate socket with remote host and port */
|
||||
remote_addr.sin_family = AF_INET;
|
||||
remote_addr.sin_port = htons( port );
|
||||
if ( !inet_aton( ip, &remote_addr.sin_addr ) ) {
|
||||
int saved_errno = errno;
|
||||
char buffer[ 2048 ];
|
||||
snprintf( buffer, 2048, "Bad IP address (%s)", ip );
|
||||
throw NetworkException( buffer, saved_errno );
|
||||
}
|
||||
|
||||
attached = true;
|
||||
}
|
||||
|
||||
void Connection::send( string s )
|
||||
{
|
||||
assert( attached );
|
||||
|
||||
Packet px = new_packet( s );
|
||||
|
||||
string p = px.tostring( &session );
|
||||
|
||||
ssize_t bytes_sent = sendto( sock, p.data(), p.size(), 0,
|
||||
(sockaddr *)&remote_addr, sizeof( remote_addr ) );
|
||||
|
||||
if ( (bytes_sent < 0) && (errno == EMSGSIZE) ) {
|
||||
update_MTU();
|
||||
throw NetworkException( "Path MTU Discovery", EMSGSIZE );
|
||||
} else if ( bytes_sent == static_cast<int>( p.size() ) ) {
|
||||
return;
|
||||
} else {
|
||||
throw NetworkException( "sendto", errno );
|
||||
}
|
||||
}
|
||||
|
||||
string Connection::recv( void )
|
||||
{
|
||||
struct sockaddr_in packet_remote_addr;
|
||||
|
||||
char buf[ RECEIVE_MTU ];
|
||||
|
||||
socklen_t addrlen = sizeof( packet_remote_addr );
|
||||
|
||||
ssize_t received_len = recvfrom( sock, buf, RECEIVE_MTU, 0, (sockaddr *)&packet_remote_addr, &addrlen );
|
||||
|
||||
if ( received_len < 0 ) {
|
||||
throw NetworkException( "recvfrom", errno );
|
||||
}
|
||||
|
||||
if ( received_len > RECEIVE_MTU ) {
|
||||
char buffer[ 2048 ];
|
||||
snprintf( buffer, 2048, "Received oversize datagram (size %d) and limit is %d\n",
|
||||
static_cast<int>( received_len ), RECEIVE_MTU );
|
||||
throw NetworkException( buffer, errno );
|
||||
}
|
||||
|
||||
Packet p( string( buf, received_len ), &session );
|
||||
|
||||
dos_assert( p.direction == (server ? TO_SERVER : TO_CLIENT) ); /* prevent malicious playback to sender */
|
||||
|
||||
if ( p.seq >= expected_receiver_seq ) { /* don't use out-of-order packets for timestamp or targeting */
|
||||
expected_receiver_seq = p.seq + 1; /* this is security-sensitive because a replay attack could otherwise
|
||||
screw up the timestamp and targeting */
|
||||
|
||||
if ( p.timestamp != uint16_t(-1) ) {
|
||||
saved_timestamp = p.timestamp;
|
||||
saved_timestamp_received_at = timestamp();
|
||||
}
|
||||
|
||||
if ( p.timestamp_reply != uint16_t(-1) ) {
|
||||
uint16_t now = timestamp16();
|
||||
double R = timestamp_diff( now, p.timestamp_reply );
|
||||
|
||||
if ( R < 5000 ) { /* ignore large values, e.g. server was Ctrl-Zed */
|
||||
if ( !RTT_hit ) { /* first measurement */
|
||||
SRTT = R;
|
||||
RTTVAR = R / 2;
|
||||
RTT_hit = true;
|
||||
} else {
|
||||
const double alpha = 1.0 / 8.0;
|
||||
const double beta = 1.0 / 4.0;
|
||||
|
||||
RTTVAR = (1 - beta) * RTTVAR + ( beta * fabs( SRTT - R ) );
|
||||
SRTT = (1 - alpha) * SRTT + ( alpha * R );
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* auto-adjust to remote host */
|
||||
attached = true;
|
||||
|
||||
if ( (remote_addr.sin_addr.s_addr != packet_remote_addr.sin_addr.s_addr)
|
||||
|| (remote_addr.sin_port != packet_remote_addr.sin_port) ) {
|
||||
remote_addr = packet_remote_addr;
|
||||
if ( server ) {
|
||||
fprintf( stderr, "Server now attached to client at %s:%d\n",
|
||||
inet_ntoa( remote_addr.sin_addr ),
|
||||
ntohs( remote_addr.sin_port ) );
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return p.payload; /* we do return out-of-order or duplicated packets to caller */
|
||||
}
|
||||
|
||||
int Connection::port( void ) const
|
||||
{
|
||||
struct sockaddr_in local_addr;
|
||||
socklen_t addrlen = sizeof( local_addr );
|
||||
|
||||
if ( getsockname( sock, (sockaddr *)&local_addr, &addrlen ) < 0 ) {
|
||||
throw NetworkException( "getsockname", errno );
|
||||
}
|
||||
|
||||
return ntohs( local_addr.sin_port );
|
||||
}
|
||||
|
||||
uint64_t Network::timestamp( void )
|
||||
{
|
||||
struct timespec tp;
|
||||
|
||||
if ( clock_gettime( CLOCK_MONOTONIC, &tp ) < 0 ) {
|
||||
throw NetworkException( "clock_gettime", errno );
|
||||
}
|
||||
|
||||
uint64_t millis = tp.tv_nsec / 1000000;
|
||||
millis += uint64_t( tp.tv_sec ) * 1000;
|
||||
|
||||
return millis;
|
||||
}
|
||||
|
||||
uint16_t Network::timestamp16( void )
|
||||
{
|
||||
uint16_t ts = timestamp() % 65536;
|
||||
if ( ts == uint16_t(-1) ) {
|
||||
ts++;
|
||||
}
|
||||
return ts;
|
||||
}
|
||||
|
||||
uint16_t Network::timestamp_diff( uint16_t tsnew, uint16_t tsold )
|
||||
{
|
||||
int diff = tsnew - tsold;
|
||||
if ( diff < 0 ) {
|
||||
diff += 65536;
|
||||
}
|
||||
|
||||
assert( diff >= 0 );
|
||||
assert( diff <= 65535 );
|
||||
|
||||
return diff;
|
||||
}
|
||||
|
||||
uint64_t Connection::timeout( void ) const
|
||||
{
|
||||
uint64_t RTO = lrint( ceil( SRTT + 4 * RTTVAR ) );
|
||||
if ( RTO < MIN_RTO ) {
|
||||
RTO = MIN_RTO;
|
||||
} else if ( RTO > MAX_RTO ) {
|
||||
RTO = MAX_RTO;
|
||||
}
|
||||
return RTO;
|
||||
}
|
||||
|
||||
class Socket {
|
||||
public:
|
||||
int fd;
|
||||
|
||||
Socket( int domain, int type, int protocol )
|
||||
: fd( socket( domain, type, protocol ) )
|
||||
{
|
||||
if ( fd < 0 ) {
|
||||
throw NetworkException( "socket", errno );
|
||||
}
|
||||
}
|
||||
|
||||
~Socket()
|
||||
{
|
||||
if ( close( fd ) < 0 ) {
|
||||
throw NetworkException( "close", errno );
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
void Connection::update_MTU( void )
|
||||
{
|
||||
if ( !attached ) {
|
||||
return;
|
||||
}
|
||||
|
||||
/* We don't want to use our main socket because we don't want to have to connect it */
|
||||
Socket path_MTU_socket( AF_INET, SOCK_DGRAM, 0 );
|
||||
|
||||
/* Connect socket so we can retrieve path MTU */
|
||||
if ( connect( path_MTU_socket.fd, (sockaddr *)&remote_addr, sizeof( remote_addr ) ) < 0 ) {
|
||||
throw NetworkException( "connect", errno );
|
||||
}
|
||||
|
||||
int PMTU;
|
||||
socklen_t optlen = sizeof( PMTU );
|
||||
if ( getsockopt( path_MTU_socket.fd, IPPROTO_IP, IP_MTU, &PMTU, &optlen ) < 0 ) {
|
||||
throw NetworkException( "getsockopt", errno );
|
||||
}
|
||||
|
||||
if ( optlen != sizeof( PMTU ) ) {
|
||||
throw NetworkException( "Error getting path MTU", errno );
|
||||
}
|
||||
|
||||
MTU = min( PMTU, int(SEND_MTU) ); /* need cast to compile without optimization! XXX */
|
||||
}
|
||||
@@ -0,0 +1,105 @@
|
||||
#ifndef NETWORK_HPP
|
||||
#define NETWORK_HPP
|
||||
|
||||
#include <stdint.h>
|
||||
#include <deque>
|
||||
#include <sys/socket.h>
|
||||
#include <netinet/in.h>
|
||||
#include <string>
|
||||
#include <math.h>
|
||||
|
||||
#include "crypto.h"
|
||||
|
||||
using namespace std;
|
||||
using namespace Crypto;
|
||||
|
||||
namespace Network {
|
||||
static const unsigned int MOSH_PROTOCOL_VERSION = 1;
|
||||
|
||||
uint64_t timestamp( void );
|
||||
uint16_t timestamp16( void );
|
||||
uint16_t timestamp_diff( uint16_t tsnew, uint16_t tsold );
|
||||
|
||||
class NetworkException {
|
||||
public:
|
||||
string function;
|
||||
int the_errno;
|
||||
NetworkException( string s_function, int s_errno ) : function( s_function ), the_errno( s_errno ) {}
|
||||
};
|
||||
|
||||
enum Direction {
|
||||
TO_SERVER = 0,
|
||||
TO_CLIENT = 1
|
||||
};
|
||||
|
||||
class Packet {
|
||||
public:
|
||||
uint64_t seq;
|
||||
Direction direction;
|
||||
uint16_t timestamp, timestamp_reply;
|
||||
string payload;
|
||||
|
||||
Packet( uint64_t s_seq, Direction s_direction,
|
||||
uint16_t s_timestamp, uint16_t s_timestamp_reply, string s_payload )
|
||||
: seq( s_seq ), direction( s_direction ),
|
||||
timestamp( s_timestamp ), timestamp_reply( s_timestamp_reply ), payload( s_payload )
|
||||
{}
|
||||
|
||||
Packet( string coded_packet, Session *session );
|
||||
|
||||
string tostring( Session *session );
|
||||
};
|
||||
|
||||
class Connection {
|
||||
private:
|
||||
static const int RECEIVE_MTU = 2048;
|
||||
static const int SEND_MTU = 1400;
|
||||
static const uint64_t MIN_RTO = 50; /* ms */
|
||||
static const uint64_t MAX_RTO = 1000; /* ms */
|
||||
|
||||
int sock;
|
||||
struct sockaddr_in remote_addr;
|
||||
|
||||
bool server;
|
||||
bool attached;
|
||||
|
||||
int MTU;
|
||||
|
||||
Base64Key key;
|
||||
Session session;
|
||||
|
||||
void setup( void );
|
||||
|
||||
Direction direction;
|
||||
uint64_t next_seq;
|
||||
uint16_t saved_timestamp;
|
||||
uint64_t saved_timestamp_received_at;
|
||||
uint64_t expected_receiver_seq;
|
||||
|
||||
bool RTT_hit;
|
||||
double SRTT;
|
||||
double RTTVAR;
|
||||
|
||||
Packet new_packet( string &s_payload );
|
||||
|
||||
void update_MTU( void );
|
||||
|
||||
public:
|
||||
Connection( const char *desired_ip ); /* server */
|
||||
Connection( const char *key_str, const char *ip, int port ); /* client */
|
||||
|
||||
void send( string s );
|
||||
string recv( void );
|
||||
int fd( void ) const { return sock; }
|
||||
int get_MTU( void ) const { return MTU; }
|
||||
|
||||
int port( void ) const;
|
||||
string get_key( void ) const { return key.printable_key(); }
|
||||
bool get_attached( void ) const { return attached; }
|
||||
|
||||
uint64_t timeout( void ) const;
|
||||
double get_SRTT( void ) const { return SRTT; }
|
||||
};
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,154 @@
|
||||
#include <assert.h>
|
||||
#include <iostream>
|
||||
|
||||
#include "networktransport.h"
|
||||
|
||||
#include "transportsender.cc"
|
||||
|
||||
using namespace Network;
|
||||
using namespace std;
|
||||
|
||||
template <class MyState, class RemoteState>
|
||||
Transport<MyState, RemoteState>::Transport( MyState &initial_state, RemoteState &initial_remote,
|
||||
const char *desired_ip )
|
||||
: connection( desired_ip ),
|
||||
sender( &connection, initial_state ),
|
||||
received_states( 1, TimestampedState<RemoteState>( timestamp(), 0, initial_remote ) ),
|
||||
last_receiver_state( initial_remote ),
|
||||
sent_state_late_acked( 0 ),
|
||||
fragments(),
|
||||
verbose( false )
|
||||
{
|
||||
/* server */
|
||||
}
|
||||
|
||||
template <class MyState, class RemoteState>
|
||||
Transport<MyState, RemoteState>::Transport( MyState &initial_state, RemoteState &initial_remote,
|
||||
const char *key_str, const char *ip, int port )
|
||||
: connection( key_str, ip, port ),
|
||||
sender( &connection, initial_state ),
|
||||
received_states( 1, TimestampedState<RemoteState>( timestamp(), 0, initial_remote ) ),
|
||||
last_receiver_state( initial_remote ),
|
||||
sent_state_late_acked( 0 ),
|
||||
fragments(),
|
||||
verbose( false )
|
||||
{
|
||||
/* client */
|
||||
}
|
||||
|
||||
template <class MyState, class RemoteState>
|
||||
void Transport<MyState, RemoteState>::recv( void )
|
||||
{
|
||||
string s( connection.recv() );
|
||||
Fragment frag( s );
|
||||
|
||||
if ( fragments.add_fragment( frag ) ) { /* complete packet */
|
||||
Instruction inst = fragments.get_assembly();
|
||||
|
||||
if ( inst.protocol_version() != MOSH_PROTOCOL_VERSION ) {
|
||||
throw NetworkException( "mosh protocol version mismatch", 0 );
|
||||
}
|
||||
|
||||
sender.process_acknowledgment_through( inst.ack_num() );
|
||||
if ( inst.late_ack_num() > sent_state_late_acked ) {
|
||||
sent_state_late_acked = inst.late_ack_num();
|
||||
}
|
||||
|
||||
/* first, make sure we don't already have the new state */
|
||||
for ( typename list< TimestampedState<RemoteState> >::iterator i = received_states.begin();
|
||||
i != received_states.end();
|
||||
i++ ) {
|
||||
if ( inst.new_num() == i->num ) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
/* now, make sure we do have the old state */
|
||||
bool found = 0;
|
||||
typename list< TimestampedState<RemoteState> >::iterator reference_state = received_states.begin();
|
||||
while ( reference_state != received_states.end() ) {
|
||||
if ( inst.old_num() == reference_state->num ) {
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
reference_state++;
|
||||
}
|
||||
|
||||
if ( !found ) {
|
||||
// fprintf( stderr, "Ignoring out-of-order packet. Reference state %d has been discarded or hasn't yet been received.\n", int(inst.old_num) );
|
||||
return; /* this is security-sensitive and part of how we enforce idempotency */
|
||||
}
|
||||
|
||||
/* apply diff to reference state */
|
||||
TimestampedState<RemoteState> new_state = *reference_state;
|
||||
new_state.timestamp = timestamp();
|
||||
new_state.num = inst.new_num();
|
||||
|
||||
if ( !inst.diff().empty() ) {
|
||||
new_state.state.apply_string( inst.diff() );
|
||||
}
|
||||
|
||||
process_throwaway_until( inst.throwaway_num() );
|
||||
|
||||
/* Insert new state in sorted place */
|
||||
for ( typename list< TimestampedState<RemoteState> >::iterator i = received_states.begin();
|
||||
i != received_states.end();
|
||||
i++ ) {
|
||||
if ( i->num > new_state.num ) {
|
||||
received_states.insert( i, new_state );
|
||||
if ( verbose ) {
|
||||
fprintf( stderr, "[%u] Received OUT-OF-ORDER state %d [ack %d]\n",
|
||||
(unsigned int)(timestamp() % 100000), (int)new_state.num, (int)inst.ack_num() );
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
if ( verbose ) {
|
||||
fprintf( stderr, "[%u] Received state %d [ack %d]\n",
|
||||
(unsigned int)(timestamp() % 100000), (int)new_state.num, (int)inst.ack_num() );
|
||||
}
|
||||
received_states.push_back( new_state );
|
||||
sender.set_ack_num( received_states.back().num );
|
||||
|
||||
if ( !inst.diff().empty() ) {
|
||||
sender.set_data_ack();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* The sender uses throwaway_num to tell us the earliest received state that we need to keep around */
|
||||
template <class MyState, class RemoteState>
|
||||
void Transport<MyState, RemoteState>::process_throwaway_until( uint64_t throwaway_num )
|
||||
{
|
||||
typename list< TimestampedState<RemoteState> >::iterator i = received_states.begin();
|
||||
while ( i != received_states.end() ) {
|
||||
typename list< TimestampedState<RemoteState> >::iterator inext = i;
|
||||
inext++;
|
||||
if ( i->num < throwaway_num ) {
|
||||
received_states.erase( i );
|
||||
}
|
||||
i = inext;
|
||||
}
|
||||
|
||||
assert( received_states.size() > 0 );
|
||||
}
|
||||
|
||||
template <class MyState, class RemoteState>
|
||||
string Transport<MyState, RemoteState>::get_remote_diff( void )
|
||||
{
|
||||
/* find diff between last receiver state and current remote state, then rationalize states */
|
||||
|
||||
string ret( received_states.back().state.diff_from( last_receiver_state ) );
|
||||
|
||||
const RemoteState *oldest_receiver_state = &received_states.front().state;
|
||||
|
||||
for ( typename list< TimestampedState<RemoteState> >::reverse_iterator i = received_states.rbegin();
|
||||
i != received_states.rend();
|
||||
i++ ) {
|
||||
i->state.subtract( oldest_receiver_state );
|
||||
}
|
||||
|
||||
last_receiver_state = received_states.back().state;
|
||||
|
||||
return ret;
|
||||
}
|
||||
@@ -0,0 +1,89 @@
|
||||
#ifndef NETWORK_TRANSPORT_HPP
|
||||
#define NETWORK_TRANSPORT_HPP
|
||||
|
||||
#include <string>
|
||||
#include <signal.h>
|
||||
#include <time.h>
|
||||
#include <list>
|
||||
#include <vector>
|
||||
|
||||
#include "network.h"
|
||||
#include "transportsender.h"
|
||||
#include "transportfragment.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace Network {
|
||||
template <class MyState, class RemoteState>
|
||||
class Transport
|
||||
{
|
||||
private:
|
||||
/* the underlying, encrypted network connection */
|
||||
Connection connection;
|
||||
|
||||
/* sender side */
|
||||
TransportSender<MyState> sender;
|
||||
|
||||
/* helper methods for recv() */
|
||||
void process_throwaway_until( uint64_t throwaway_num );
|
||||
|
||||
/* simple receiver */
|
||||
list< TimestampedState<RemoteState> > received_states;
|
||||
RemoteState last_receiver_state; /* the state we were in when user last queried state */
|
||||
uint64_t sent_state_late_acked;
|
||||
FragmentAssembly fragments;
|
||||
bool verbose;
|
||||
|
||||
public:
|
||||
Transport( MyState &initial_state, RemoteState &initial_remote, const char *desired_ip );
|
||||
Transport( MyState &initial_state, RemoteState &initial_remote,
|
||||
const char *key_str, const char *ip, int port );
|
||||
|
||||
/* Send data or an ack if necessary. */
|
||||
void tick( void ) { sender.tick(); }
|
||||
|
||||
/* Returns the number of ms to wait until next possible event. */
|
||||
int wait_time( void ) { return sender.wait_time(); }
|
||||
|
||||
/* Blocks waiting for a packet. */
|
||||
void recv( void );
|
||||
|
||||
/* Find diff between last receiver state and current remote state, then rationalize states. */
|
||||
string get_remote_diff( void );
|
||||
|
||||
/* Shut down other side of connection. */
|
||||
/* Illegal to change current_state after this. */
|
||||
void start_shutdown( void ) { sender.start_shutdown(); }
|
||||
bool shutdown_in_progress( void ) const { return sender.get_shutdown_in_progress(); }
|
||||
bool shutdown_acknowledged( void ) const { return sender.get_shutdown_acknowledged(); }
|
||||
bool shutdown_ack_timed_out( void ) const { return sender.shutdown_ack_timed_out(); }
|
||||
bool attached( void ) const { return connection.get_attached(); }
|
||||
|
||||
/* Other side has requested shutdown and we have sent one ACK */
|
||||
bool counterparty_shutdown_ack_sent( void ) const { return sender.get_counterparty_shutdown_acknowledged(); }
|
||||
|
||||
int port( void ) const { return connection.port(); }
|
||||
string get_key( void ) const { return connection.get_key(); }
|
||||
|
||||
MyState &get_current_state( void ) { return sender.get_current_state(); }
|
||||
void set_current_state( const MyState &x ) { sender.set_current_state( x ); }
|
||||
|
||||
uint64_t get_remote_state_num( void ) const { return received_states.back().num; }
|
||||
|
||||
const TimestampedState<RemoteState> & get_latest_remote_state( void ) const { return received_states.back(); }
|
||||
|
||||
int fd( void ) const { return connection.fd(); }
|
||||
|
||||
void set_verbose( void ) { sender.set_verbose(); verbose = true; }
|
||||
|
||||
void set_send_delay( int new_delay ) { sender.set_send_delay( new_delay ); }
|
||||
|
||||
uint64_t get_sent_state_acked( void ) const { return sender.get_sent_state_acked(); }
|
||||
uint64_t get_sent_state_last( void ) const { return sender.get_sent_state_last(); }
|
||||
uint64_t get_sent_state_late_acked( void ) const { return sent_state_late_acked; }
|
||||
|
||||
unsigned int send_interval( void ) const { return sender.send_interval(); }
|
||||
};
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,162 @@
|
||||
#include <endian.h>
|
||||
#include <assert.h>
|
||||
|
||||
#include "transportfragment.h"
|
||||
#include "transportinstruction.pb.h"
|
||||
|
||||
using namespace Network;
|
||||
using namespace TransportBuffers;
|
||||
|
||||
static string network_order_string( uint16_t host_order )
|
||||
{
|
||||
uint16_t net_int = htobe16( host_order );
|
||||
return string( (char *)&net_int, sizeof( net_int ) );
|
||||
}
|
||||
|
||||
static string network_order_string( uint64_t host_order )
|
||||
{
|
||||
uint64_t net_int = htobe64( host_order );
|
||||
return string( (char *)&net_int, sizeof( net_int ) );
|
||||
}
|
||||
|
||||
string Fragment::tostring( void )
|
||||
{
|
||||
assert( initialized );
|
||||
|
||||
string ret;
|
||||
|
||||
ret += network_order_string( id );
|
||||
|
||||
assert( !( fragment_num & 0x8000 ) ); /* effective limit on size of a terminal screen change or buffered user input */
|
||||
uint16_t combined_fragment_num = ( final << 15 ) | fragment_num;
|
||||
ret += network_order_string( combined_fragment_num );
|
||||
|
||||
assert( ret.size() == frag_header_len );
|
||||
|
||||
ret += contents;
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
Fragment::Fragment( string &x )
|
||||
: id( -1 ), fragment_num( -1 ), final( false ), initialized( true ),
|
||||
contents( x.begin() + frag_header_len, x.end() )
|
||||
{
|
||||
assert( x.size() >= frag_header_len );
|
||||
|
||||
uint64_t *data64 = (uint64_t *)x.data();
|
||||
uint16_t *data16 = (uint16_t *)x.data();
|
||||
id = be64toh( data64[ 0 ] );
|
||||
fragment_num = be16toh( data16[ 4 ] );
|
||||
final = ( fragment_num & 0x8000 ) >> 15;
|
||||
fragment_num &= 0x7FFF;
|
||||
}
|
||||
|
||||
bool FragmentAssembly::add_fragment( Fragment &frag )
|
||||
{
|
||||
/* see if this is a totally new packet */
|
||||
if ( current_id != frag.id ) {
|
||||
fragments.clear();
|
||||
fragments.resize( frag.fragment_num + 1 );
|
||||
fragments.at( frag.fragment_num ) = frag;
|
||||
fragments_arrived = 1;
|
||||
fragments_total = -1; /* unknown */
|
||||
current_id = frag.id;
|
||||
} else { /* not a new packet */
|
||||
/* see if we already have this fragment */
|
||||
if ( (fragments.size() > frag.fragment_num)
|
||||
&& (fragments.at( frag.fragment_num ).initialized) ) {
|
||||
/* make sure new version is same as what we already have */
|
||||
assert( fragments.at( frag.fragment_num ) == frag );
|
||||
} else {
|
||||
if ( (int)fragments.size() < frag.fragment_num + 1 ) {
|
||||
fragments.resize( frag.fragment_num + 1 );
|
||||
}
|
||||
fragments.at( frag.fragment_num ) = frag;
|
||||
fragments_arrived++;
|
||||
}
|
||||
}
|
||||
|
||||
if ( frag.final ) {
|
||||
fragments_total = frag.fragment_num + 1;
|
||||
assert( (int)fragments.size() <= fragments_total );
|
||||
fragments.resize( fragments_total );
|
||||
}
|
||||
|
||||
if ( fragments_total != -1 ) {
|
||||
assert( fragments_arrived <= fragments_total );
|
||||
}
|
||||
|
||||
/* see if we're done */
|
||||
return ( fragments_arrived == fragments_total );
|
||||
}
|
||||
|
||||
Instruction FragmentAssembly::get_assembly( void )
|
||||
{
|
||||
assert( fragments_arrived == fragments_total );
|
||||
|
||||
string encoded;
|
||||
|
||||
for ( int i = 0; i < fragments_total; i++ ) {
|
||||
assert( fragments.at( i ).initialized );
|
||||
encoded += fragments.at( i ).contents;
|
||||
}
|
||||
|
||||
Instruction ret;
|
||||
assert( ret.ParseFromString( encoded ) );
|
||||
|
||||
fragments.clear();
|
||||
fragments_arrived = 0;
|
||||
fragments_total = -1;
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
bool Fragment::operator==( const Fragment &x )
|
||||
{
|
||||
return ( id == x.id ) && ( fragment_num == x.fragment_num ) && ( final == x.final )
|
||||
&& ( initialized == x.initialized ) && ( contents == x.contents );
|
||||
}
|
||||
|
||||
vector<Fragment> Fragmenter::make_fragments( Instruction &inst, int MTU )
|
||||
{
|
||||
if ( (inst.old_num() != last_instruction.old_num())
|
||||
|| (inst.new_num() != last_instruction.new_num())
|
||||
|| (inst.ack_num() != last_instruction.ack_num())
|
||||
|| (inst.throwaway_num() != last_instruction.throwaway_num())
|
||||
|| (inst.late_ack_num() != last_instruction.late_ack_num())
|
||||
|| (inst.protocol_version() != last_instruction.protocol_version())
|
||||
|| (last_MTU != MTU) ) {
|
||||
next_instruction_id++;
|
||||
}
|
||||
|
||||
if ( (inst.old_num() == last_instruction.old_num())
|
||||
&& (inst.new_num() == last_instruction.new_num()) ) {
|
||||
assert( inst.diff() == last_instruction.diff() );
|
||||
}
|
||||
|
||||
last_instruction = inst;
|
||||
last_MTU = MTU;
|
||||
|
||||
string payload = inst.SerializeAsString();
|
||||
uint16_t fragment_num = 0;
|
||||
vector<Fragment> ret;
|
||||
|
||||
while ( !payload.empty() ) {
|
||||
string this_fragment;
|
||||
bool final = false;
|
||||
|
||||
if ( int( payload.size() + HEADER_LEN ) > MTU ) {
|
||||
this_fragment = string( payload.begin(), payload.begin() + MTU - HEADER_LEN );
|
||||
payload = string( payload.begin() + MTU - HEADER_LEN, payload.end() );
|
||||
} else {
|
||||
this_fragment = payload;
|
||||
payload.clear();
|
||||
final = true;
|
||||
}
|
||||
|
||||
ret.push_back( Fragment( next_instruction_id, fragment_num++, final, this_fragment ) );
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
@@ -0,0 +1,78 @@
|
||||
#ifndef TRANSPORT_FRAGMENT_HPP
|
||||
#define TRANSPORT_FRAGMENT_HPP
|
||||
|
||||
#include <stdint.h>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "transportinstruction.pb.h"
|
||||
|
||||
using namespace std;
|
||||
using namespace TransportBuffers;
|
||||
|
||||
namespace Network {
|
||||
static const int HEADER_LEN = 66;
|
||||
|
||||
class Fragment
|
||||
{
|
||||
private:
|
||||
static const size_t frag_header_len = sizeof( uint64_t ) + sizeof( uint16_t );
|
||||
|
||||
public:
|
||||
uint64_t id;
|
||||
uint16_t fragment_num;
|
||||
bool final;
|
||||
|
||||
bool initialized;
|
||||
|
||||
string contents;
|
||||
|
||||
Fragment()
|
||||
: id( -1 ), fragment_num( -1 ), final( false ), initialized( false ), contents()
|
||||
{}
|
||||
|
||||
Fragment( uint64_t s_id, uint16_t s_fragment_num, bool s_final, string s_contents )
|
||||
: id( s_id ), fragment_num( s_fragment_num ), final( s_final ), initialized( true ),
|
||||
contents( s_contents )
|
||||
{}
|
||||
|
||||
Fragment( string &x );
|
||||
|
||||
string tostring( void );
|
||||
|
||||
bool operator==( const Fragment &x );
|
||||
};
|
||||
|
||||
class FragmentAssembly
|
||||
{
|
||||
private:
|
||||
vector<Fragment> fragments;
|
||||
uint64_t current_id;
|
||||
int fragments_arrived, fragments_total;
|
||||
|
||||
public:
|
||||
FragmentAssembly() : fragments(), current_id( -1 ), fragments_arrived( 0 ), fragments_total( -1 ) {}
|
||||
bool add_fragment( Fragment &inst );
|
||||
Instruction get_assembly( void );
|
||||
};
|
||||
|
||||
class Fragmenter
|
||||
{
|
||||
private:
|
||||
uint64_t next_instruction_id;
|
||||
Instruction last_instruction;
|
||||
int last_MTU;
|
||||
|
||||
public:
|
||||
Fragmenter() : next_instruction_id( 0 ), last_instruction(), last_MTU( -1 )
|
||||
{
|
||||
last_instruction.set_old_num( -1 );
|
||||
last_instruction.set_new_num( -1 );
|
||||
}
|
||||
vector<Fragment> make_fragments( Instruction &inst, int MTU );
|
||||
uint64_t last_ack_sent( void ) const { return last_instruction.ack_num(); }
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,301 @@
|
||||
#include <algorithm>
|
||||
#include <list>
|
||||
|
||||
#include "transportsender.h"
|
||||
#include "transportfragment.h"
|
||||
|
||||
using namespace Network;
|
||||
|
||||
template <class MyState>
|
||||
TransportSender<MyState>::TransportSender( Connection *s_connection, MyState &initial_state )
|
||||
: connection( s_connection ),
|
||||
current_state( initial_state ),
|
||||
sent_states( 1, TimestampedState<MyState>( timestamp(), 0, initial_state ) ),
|
||||
assumed_receiver_state( sent_states.begin() ),
|
||||
fragmenter(),
|
||||
next_ack_time( timestamp() ),
|
||||
next_send_time( timestamp() ),
|
||||
verbose( false ),
|
||||
shutdown_in_progress( false ),
|
||||
shutdown_tries( 0 ),
|
||||
ack_num( 0 ),
|
||||
pending_data_ack( false ),
|
||||
ack_timestamp( 0 ),
|
||||
ack_history(),
|
||||
SEND_MINDELAY( 15 )
|
||||
{
|
||||
}
|
||||
|
||||
/* Try to send roughly two frames per RTT, bounded by limits on frame rate */
|
||||
template <class MyState>
|
||||
unsigned int TransportSender<MyState>::send_interval( void ) const
|
||||
{
|
||||
int SEND_INTERVAL = lrint( ceil( connection->get_SRTT() / 2.0 ) );
|
||||
if ( SEND_INTERVAL < SEND_INTERVAL_MIN ) {
|
||||
SEND_INTERVAL = SEND_INTERVAL_MIN;
|
||||
} else if ( SEND_INTERVAL > SEND_INTERVAL_MAX ) {
|
||||
SEND_INTERVAL = SEND_INTERVAL_MAX;
|
||||
}
|
||||
|
||||
return SEND_INTERVAL;
|
||||
}
|
||||
|
||||
/* How many ms can the caller wait before we will have an event (empty ack or next frame)? */
|
||||
template <class MyState>
|
||||
int TransportSender<MyState>::wait_time( void )
|
||||
{
|
||||
if ( pending_data_ack && (next_ack_time > timestamp() + ACK_DELAY) ) {
|
||||
next_ack_time = timestamp() + ACK_DELAY;
|
||||
}
|
||||
|
||||
if ( !(current_state == sent_states.back().state) ) { /* pending data to send */
|
||||
if ( next_send_time > timestamp() + SEND_MINDELAY ) {
|
||||
next_send_time = timestamp() + SEND_MINDELAY;
|
||||
}
|
||||
|
||||
if ( next_send_time < sent_states.back().timestamp + send_interval() ) {
|
||||
next_send_time = sent_states.back().timestamp + send_interval();
|
||||
}
|
||||
}
|
||||
|
||||
/* speed up shutdown sequence */
|
||||
if ( shutdown_in_progress || (ack_num == uint64_t(-1)) ) {
|
||||
next_ack_time = sent_states.back().timestamp + send_interval();
|
||||
}
|
||||
|
||||
uint64_t next_wakeup = next_ack_time;
|
||||
|
||||
if ( next_send_time < next_wakeup ) {
|
||||
next_wakeup = next_send_time;
|
||||
}
|
||||
|
||||
if ( !connection->get_attached() ) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
if ( next_wakeup > timestamp() ) {
|
||||
return next_wakeup - timestamp();
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
/* Send data or an empty ack if necessary */
|
||||
template <class MyState>
|
||||
void TransportSender<MyState>::tick( void )
|
||||
{
|
||||
wait_time();
|
||||
|
||||
if ( !connection->get_attached() ) {
|
||||
return;
|
||||
}
|
||||
|
||||
if ( (timestamp() < next_ack_time)
|
||||
&& (timestamp() < next_send_time) ) {
|
||||
return;
|
||||
}
|
||||
|
||||
/* Determine if a new diff or empty ack needs to be sent */
|
||||
/* Update assumed receiver state */
|
||||
update_assumed_receiver_state();
|
||||
|
||||
/* Cut out common prefix of all states */
|
||||
rationalize_states();
|
||||
|
||||
string diff = current_state.diff_from( assumed_receiver_state->state );
|
||||
|
||||
if ( diff.empty() && (timestamp() >= next_ack_time) ) {
|
||||
send_empty_ack();
|
||||
return;
|
||||
}
|
||||
|
||||
if ( !diff.empty() && ( (timestamp() >= next_send_time)
|
||||
|| (timestamp() >= next_ack_time) ) ) {
|
||||
/* Send diffs or ack */
|
||||
send_to_receiver( diff );
|
||||
}
|
||||
}
|
||||
|
||||
template <class MyState>
|
||||
void TransportSender<MyState>::send_empty_ack( void )
|
||||
{
|
||||
assert ( timestamp() >= next_ack_time );
|
||||
|
||||
uint64_t new_num = sent_states.back().num + 1;
|
||||
|
||||
/* special case for shutdown sequence */
|
||||
if ( shutdown_in_progress ) {
|
||||
new_num = uint64_t( -1 );
|
||||
}
|
||||
|
||||
// sent_states.push_back( TimestampedState<MyState>( sent_states.back().timestamp, new_num, current_state ) );
|
||||
add_sent_state( sent_states.back().timestamp, new_num, current_state );
|
||||
send_in_fragments( "", new_num );
|
||||
|
||||
next_ack_time = timestamp() + ACK_INTERVAL;
|
||||
}
|
||||
|
||||
template <class MyState>
|
||||
void TransportSender<MyState>::add_sent_state( uint64_t the_timestamp, uint64_t num, MyState &state )
|
||||
{
|
||||
sent_states.push_back( TimestampedState<MyState>( the_timestamp, num, state ) );
|
||||
if ( sent_states.size() > 32 ) { /* limit on state queue */
|
||||
auto last = sent_states.end();
|
||||
for ( int i = 0; i < 16; i++ ) { last--; }
|
||||
sent_states.erase( last ); /* erase state from middle of queue */
|
||||
}
|
||||
}
|
||||
|
||||
template <class MyState>
|
||||
void TransportSender<MyState>::send_to_receiver( string diff )
|
||||
{
|
||||
uint64_t new_num;
|
||||
if ( current_state == sent_states.back().state ) { /* previously sent */
|
||||
new_num = sent_states.back().num;
|
||||
} else { /* new state */
|
||||
new_num = sent_states.back().num + 1;
|
||||
}
|
||||
|
||||
/* special case for shutdown sequence */
|
||||
if ( shutdown_in_progress ) {
|
||||
new_num = uint64_t( -1 );
|
||||
}
|
||||
|
||||
if ( new_num == sent_states.back().num ) {
|
||||
sent_states.back().timestamp = timestamp();
|
||||
} else {
|
||||
add_sent_state( timestamp(), new_num, current_state );
|
||||
}
|
||||
|
||||
send_in_fragments( diff, new_num ); // Can throw NetworkException
|
||||
|
||||
/* successfully sent, probably */
|
||||
/* ("probably" because the FIRST size-exceeded datagram doesn't get an error) */
|
||||
assumed_receiver_state = sent_states.end();
|
||||
assumed_receiver_state--;
|
||||
next_ack_time = timestamp() + ACK_INTERVAL;
|
||||
next_send_time = uint64_t(-1);
|
||||
}
|
||||
|
||||
template <class MyState>
|
||||
void TransportSender<MyState>::update_assumed_receiver_state( void )
|
||||
{
|
||||
uint64_t now = timestamp();
|
||||
|
||||
/* start from what is known and give benefit of the doubt to unacknowledged states
|
||||
transmitted recently enough ago */
|
||||
assumed_receiver_state = sent_states.begin();
|
||||
|
||||
typename list< TimestampedState<MyState> >::iterator i = sent_states.begin();
|
||||
i++;
|
||||
|
||||
while ( i != sent_states.end() ) {
|
||||
assert( now >= i->timestamp );
|
||||
|
||||
if ( uint64_t(now - i->timestamp) < connection->timeout() + ACK_DELAY ) {
|
||||
assumed_receiver_state = i;
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
|
||||
i++;
|
||||
}
|
||||
}
|
||||
|
||||
template <class MyState>
|
||||
void TransportSender<MyState>::rationalize_states( void )
|
||||
{
|
||||
const MyState * known_receiver_state = &sent_states.front().state;
|
||||
|
||||
current_state.subtract( known_receiver_state );
|
||||
|
||||
for ( typename list< TimestampedState<MyState> >::reverse_iterator i = sent_states.rbegin();
|
||||
i != sent_states.rend();
|
||||
i++ ) {
|
||||
i->state.subtract( known_receiver_state );
|
||||
}
|
||||
}
|
||||
|
||||
template <class MyState>
|
||||
void TransportSender<MyState>::send_in_fragments( string diff, uint64_t new_num )
|
||||
{
|
||||
Instruction inst;
|
||||
|
||||
uint64_t now = timestamp();
|
||||
|
||||
inst.set_protocol_version( MOSH_PROTOCOL_VERSION );
|
||||
inst.set_old_num( assumed_receiver_state->num );
|
||||
inst.set_new_num( new_num );
|
||||
inst.set_ack_num( ack_num );
|
||||
inst.set_throwaway_num( sent_states.front().num );
|
||||
inst.set_late_ack_num( get_late_ack( now ) );
|
||||
inst.set_diff( diff );
|
||||
|
||||
if ( new_num == uint64_t(-1) ) {
|
||||
shutdown_tries++;
|
||||
}
|
||||
|
||||
vector<Fragment> fragments = fragmenter.make_fragments( inst, connection->get_MTU() );
|
||||
|
||||
for ( auto i = fragments.begin(); i != fragments.end(); i++ ) {
|
||||
connection->send( i->tostring() );
|
||||
|
||||
if ( verbose ) {
|
||||
fprintf( stderr, "[%u] Sent [%d=>%d] id %d, frag %d ack=%d, late_ack=%d, throwaway=%d, len=%d, frame rate=%.2f, timeout=%d, srtt=%.1f age=%llu\n",
|
||||
(unsigned int)(timestamp() % 100000), (int)inst.old_num(), (int)inst.new_num(), (int)i->id, (int)i->fragment_num,
|
||||
(int)inst.ack_num(), (int)inst.late_ack_num(), (int)inst.throwaway_num(), (int)i->contents.size(),
|
||||
1000.0 / (double)send_interval(),
|
||||
(int)connection->timeout(), connection->get_SRTT(),
|
||||
(long long)(now - ack_timestamp) );
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
pending_data_ack = false;
|
||||
}
|
||||
|
||||
template <class MyState>
|
||||
void TransportSender<MyState>::process_acknowledgment_through( uint64_t ack_num )
|
||||
{
|
||||
/* Ignore ack if we have culled the state it's acknowledging */
|
||||
|
||||
if ( sent_states.end() != find_if( sent_states.begin(), sent_states.end(),
|
||||
[&]( const TimestampedState<MyState> &x ) { return x.num == ack_num; } ) ) {
|
||||
sent_states.remove_if( [&]( const TimestampedState<MyState> &x ) { return x.num < ack_num; } );
|
||||
}
|
||||
|
||||
assert( !sent_states.empty() );
|
||||
}
|
||||
|
||||
/* give up on getting acknowledgement for shutdown */
|
||||
template <class MyState>
|
||||
bool TransportSender<MyState>::shutdown_ack_timed_out( void ) const
|
||||
{
|
||||
return shutdown_tries >= SHUTDOWN_RETRIES;
|
||||
}
|
||||
|
||||
/* Executed upon entry to new receiver state */
|
||||
template <class MyState>
|
||||
void TransportSender<MyState>::set_ack_num( uint64_t s_ack_num )
|
||||
{
|
||||
ack_num = s_ack_num;
|
||||
ack_timestamp = timestamp();
|
||||
ack_history.push_back( make_pair( ack_num, ack_timestamp ) );
|
||||
}
|
||||
|
||||
/* The "late" ack is for the input state that has had enough time on the host to have been echoed */
|
||||
template <class MyState>
|
||||
uint64_t TransportSender<MyState>::get_late_ack( uint64_t now )
|
||||
{
|
||||
uint64_t newest_echo_ack = 0;
|
||||
|
||||
for ( auto i = ack_history.begin(); i != ack_history.end(); i++ ) {
|
||||
if ( i->second < now - ECHO_TIMEOUT ) {
|
||||
newest_echo_ack = i->first;
|
||||
}
|
||||
}
|
||||
|
||||
ack_history.remove_if( [&]( const pair<uint64_t, uint64_t> &x ) { return x.first < newest_echo_ack; } );
|
||||
|
||||
return newest_echo_ack;
|
||||
}
|
||||
@@ -0,0 +1,116 @@
|
||||
|
||||
#ifndef TRANSPORT_SENDER_HPP
|
||||
#define TRANSPORT_SENDER_HPP
|
||||
|
||||
#include <string>
|
||||
#include <list>
|
||||
|
||||
#include "network.h"
|
||||
#include "transportinstruction.pb.h"
|
||||
#include "transportstate.h"
|
||||
#include "transportfragment.h"
|
||||
|
||||
using namespace std;
|
||||
using namespace TransportBuffers;
|
||||
|
||||
namespace Network {
|
||||
template <class MyState>
|
||||
class TransportSender
|
||||
{
|
||||
private:
|
||||
/* timing parameters */
|
||||
static const int SEND_INTERVAL_MIN = 20; /* ms between frames */
|
||||
static const int SEND_INTERVAL_MAX = 250; /* ms between frames */
|
||||
static const int ACK_INTERVAL = 3000; /* ms between empty acks */
|
||||
static const int ACK_DELAY = 100; /* ms before delayed ack */
|
||||
static const int SHUTDOWN_RETRIES = 3; /* number of shutdown packets to send before giving up */
|
||||
static const int ECHO_TIMEOUT = 50; /* for late ack */
|
||||
|
||||
/* helper methods for tick() */
|
||||
void update_assumed_receiver_state( void );
|
||||
void rationalize_states( void );
|
||||
void send_to_receiver( string diff );
|
||||
void send_empty_ack( void );
|
||||
void send_in_fragments( string diff, uint64_t new_num );
|
||||
void add_sent_state( uint64_t the_timestamp, uint64_t num, MyState &state );
|
||||
|
||||
/* state of sender */
|
||||
Connection *connection;
|
||||
|
||||
MyState current_state;
|
||||
|
||||
list< TimestampedState<MyState> > sent_states;
|
||||
/* first element: known, acknowledged receiver state */
|
||||
/* last element: last sent state */
|
||||
|
||||
/* somewhere in the middle: the assumed state of the receiver */
|
||||
typename list< TimestampedState<MyState> >::iterator assumed_receiver_state;
|
||||
|
||||
/* for fragment creation */
|
||||
Fragmenter fragmenter;
|
||||
|
||||
/* timing state */
|
||||
uint64_t next_ack_time;
|
||||
uint64_t next_send_time;
|
||||
|
||||
bool verbose;
|
||||
bool shutdown_in_progress;
|
||||
int shutdown_tries;
|
||||
|
||||
/* information about receiver state */
|
||||
uint64_t ack_num;
|
||||
bool pending_data_ack;
|
||||
uint64_t ack_timestamp;
|
||||
|
||||
list< pair<uint64_t, uint64_t> > ack_history;
|
||||
uint64_t get_late_ack( uint64_t now ); /* calculate delayed "echo" acknowledgment */
|
||||
|
||||
unsigned int SEND_MINDELAY; /* ms to collect all input */
|
||||
|
||||
public:
|
||||
/* constructor */
|
||||
TransportSender( Connection *s_connection, MyState &initial_state );
|
||||
|
||||
/* Send data or an ack if necessary */
|
||||
void tick( void );
|
||||
|
||||
/* Returns the number of ms to wait until next possible event. */
|
||||
int wait_time( void );
|
||||
|
||||
/* Executed upon receipt of ack */
|
||||
void process_acknowledgment_through( uint64_t ack_num );
|
||||
|
||||
/* Executed upon entry to new receiver state */
|
||||
void set_ack_num( uint64_t s_ack_num );
|
||||
|
||||
/* Accelerate reply ack */
|
||||
void set_data_ack( void ) { pending_data_ack = true; }
|
||||
|
||||
/* Starts shutdown sequence */
|
||||
void start_shutdown( void ) { shutdown_in_progress = true; }
|
||||
|
||||
/* Misc. getters and setters */
|
||||
/* Cannot modify current_state while shutdown in progress */
|
||||
MyState &get_current_state( void ) { assert( !shutdown_in_progress ); return current_state; }
|
||||
void set_current_state( const MyState &x ) { assert( !shutdown_in_progress ); current_state = x; }
|
||||
void set_verbose( void ) { verbose = true; }
|
||||
|
||||
bool get_shutdown_in_progress( void ) const { return shutdown_in_progress; }
|
||||
bool get_shutdown_acknowledged( void ) const { return sent_states.front().num == uint64_t(-1); }
|
||||
bool get_counterparty_shutdown_acknowledged( void ) const { return fragmenter.last_ack_sent() == uint64_t(-1); }
|
||||
uint64_t get_sent_state_acked( void ) const { return sent_states.front().num; }
|
||||
uint64_t get_sent_state_last( void ) const { return sent_states.back().num; }
|
||||
|
||||
bool shutdown_ack_timed_out( void ) const;
|
||||
|
||||
void set_send_delay( int new_delay ) { SEND_MINDELAY = new_delay; }
|
||||
|
||||
unsigned int send_interval( void ) const;
|
||||
|
||||
/* nonexistent methods to satisfy -Weffc++ */
|
||||
TransportSender( const TransportSender &x );
|
||||
TransportSender & operator=( const TransportSender &x );
|
||||
};
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,19 @@
|
||||
#ifndef TRANSPORT_STATE_HPP
|
||||
#define TRANSPORT_STATE_HPP
|
||||
|
||||
namespace Network {
|
||||
template <class State>
|
||||
class TimestampedState
|
||||
{
|
||||
public:
|
||||
uint64_t timestamp;
|
||||
uint64_t num;
|
||||
State state;
|
||||
|
||||
TimestampedState( uint64_t s_timestamp, uint64_t s_num, State &s_state )
|
||||
: timestamp( s_timestamp ), num( s_num ), state( s_state )
|
||||
{}
|
||||
};
|
||||
}
|
||||
|
||||
#endif
|
||||
Reference in New Issue
Block a user