Separate modules by subdirectory
This commit is contained in:
@@ -0,0 +1,359 @@
|
||||
#include <sys/types.h>
|
||||
#include <sys/socket.h>
|
||||
#include <netinet/in.h>
|
||||
#include <arpa/inet.h>
|
||||
#include <assert.h>
|
||||
#include <endian.h>
|
||||
|
||||
#include "dos_assert.h"
|
||||
#include "network.h"
|
||||
#include "crypto.h"
|
||||
|
||||
using namespace std;
|
||||
using namespace Network;
|
||||
using namespace Crypto;
|
||||
|
||||
const uint64_t DIRECTION_MASK = uint64_t(1) << 63;
|
||||
const uint64_t SEQUENCE_MASK = uint64_t(-1) ^ DIRECTION_MASK;
|
||||
|
||||
/* Read in packet from coded string */
|
||||
Packet::Packet( string coded_packet, Session *session )
|
||||
: seq( -1 ),
|
||||
direction( TO_SERVER ),
|
||||
timestamp( -1 ),
|
||||
timestamp_reply( -1 ),
|
||||
payload()
|
||||
{
|
||||
Message message = session->decrypt( coded_packet );
|
||||
|
||||
direction = (message.nonce.val() & DIRECTION_MASK) ? TO_CLIENT : TO_SERVER;
|
||||
seq = message.nonce.val() & SEQUENCE_MASK;
|
||||
|
||||
dos_assert( message.text.size() >= 2 * sizeof( uint16_t ) );
|
||||
|
||||
uint16_t *data = (uint16_t *)message.text.data();
|
||||
timestamp = be16toh( data[ 0 ] );
|
||||
timestamp_reply = be16toh( data[ 1 ] );
|
||||
|
||||
payload = string( message.text.begin() + 2 * sizeof( uint16_t ), message.text.end() );
|
||||
}
|
||||
|
||||
/* Output coded string from packet */
|
||||
string Packet::tostring( Session *session )
|
||||
{
|
||||
uint64_t direction_seq = (uint64_t( direction == TO_CLIENT ) << 63) | (seq & SEQUENCE_MASK);
|
||||
|
||||
uint16_t ts_net[ 2 ] = { htobe16( timestamp ), htobe16( timestamp_reply ) };
|
||||
|
||||
string timestamps = string( (char *)ts_net, 2 * sizeof( uint16_t ) );
|
||||
|
||||
return session->encrypt( Message( Nonce( direction_seq ), timestamps + payload ) );
|
||||
}
|
||||
|
||||
Packet Connection::new_packet( string &s_payload )
|
||||
{
|
||||
uint16_t outgoing_timestamp_reply = -1;
|
||||
|
||||
uint64_t now = timestamp();
|
||||
|
||||
if ( now - saved_timestamp_received_at < 1000 ) { /* we have a recent received timestamp */
|
||||
/* send "corrected" timestamp advanced by how long we held it */
|
||||
outgoing_timestamp_reply = saved_timestamp + (now - saved_timestamp_received_at);
|
||||
saved_timestamp = -1;
|
||||
saved_timestamp_received_at = 0;
|
||||
}
|
||||
|
||||
Packet p( next_seq++, direction, timestamp16(), outgoing_timestamp_reply, s_payload );
|
||||
|
||||
return p;
|
||||
}
|
||||
|
||||
void Connection::setup( void )
|
||||
{
|
||||
/* create socket */
|
||||
sock = socket( AF_INET, SOCK_DGRAM, 0 );
|
||||
if ( sock < 0 ) {
|
||||
throw NetworkException( "socket", errno );
|
||||
}
|
||||
|
||||
/* Enable path MTU discovery */
|
||||
char flag = IP_PMTUDISC_DO;
|
||||
socklen_t optlen = sizeof( flag );
|
||||
if ( setsockopt( sock, IPPROTO_IP, IP_MTU_DISCOVER, &flag, optlen ) < 0 ) {
|
||||
throw NetworkException( "setsockopt", errno );
|
||||
}
|
||||
}
|
||||
|
||||
Connection::Connection( const char *desired_ip ) /* server */
|
||||
: sock( -1 ),
|
||||
remote_addr(),
|
||||
server( true ),
|
||||
attached( false ),
|
||||
MTU( SEND_MTU ),
|
||||
key(),
|
||||
session( key ),
|
||||
direction( TO_CLIENT ),
|
||||
next_seq( 0 ),
|
||||
saved_timestamp( -1 ),
|
||||
saved_timestamp_received_at( 0 ),
|
||||
expected_receiver_seq( 0 ),
|
||||
RTT_hit( false ),
|
||||
SRTT( 1000 ),
|
||||
RTTVAR( 500 )
|
||||
{
|
||||
setup();
|
||||
|
||||
/* Attempt to bind free local port, with
|
||||
address client used to connect to us.
|
||||
|
||||
This usage does not seem to be endorsed by POSIX. */
|
||||
|
||||
struct sockaddr_in local_addr;
|
||||
local_addr.sin_family = AF_INET;
|
||||
local_addr.sin_port = htons( 0 );
|
||||
if ( desired_ip
|
||||
&& inet_aton( desired_ip, &local_addr.sin_addr )
|
||||
&& (bind( sock, (sockaddr *)&local_addr, sizeof( local_addr ) ) == 0) ) {
|
||||
return;
|
||||
}
|
||||
|
||||
if ( desired_ip ) {
|
||||
fprintf( stderr, "Could not bind to desired local address %s.\n", desired_ip );
|
||||
}
|
||||
|
||||
/* Could not bind to that IP (maybe we are behind NAT).
|
||||
Try again with any IP. */
|
||||
local_addr.sin_addr.s_addr = INADDR_ANY;
|
||||
if ( bind( sock, (sockaddr *)&local_addr, sizeof( local_addr ) ) < 0 ) {
|
||||
throw NetworkException( "bind", errno );
|
||||
}
|
||||
}
|
||||
|
||||
Connection::Connection( const char *key_str, const char *ip, int port ) /* client */
|
||||
: sock( -1 ),
|
||||
remote_addr(),
|
||||
server( false ),
|
||||
attached( false ),
|
||||
MTU( SEND_MTU ),
|
||||
key( key_str ),
|
||||
session( key ),
|
||||
direction( TO_SERVER ),
|
||||
next_seq( 0 ),
|
||||
saved_timestamp( -1 ),
|
||||
saved_timestamp_received_at( 0 ),
|
||||
expected_receiver_seq( 0 ),
|
||||
RTT_hit( false ),
|
||||
SRTT( 1000 ),
|
||||
RTTVAR( 500 )
|
||||
{
|
||||
setup();
|
||||
|
||||
/* associate socket with remote host and port */
|
||||
remote_addr.sin_family = AF_INET;
|
||||
remote_addr.sin_port = htons( port );
|
||||
if ( !inet_aton( ip, &remote_addr.sin_addr ) ) {
|
||||
int saved_errno = errno;
|
||||
char buffer[ 2048 ];
|
||||
snprintf( buffer, 2048, "Bad IP address (%s)", ip );
|
||||
throw NetworkException( buffer, saved_errno );
|
||||
}
|
||||
|
||||
attached = true;
|
||||
}
|
||||
|
||||
void Connection::send( string s )
|
||||
{
|
||||
assert( attached );
|
||||
|
||||
Packet px = new_packet( s );
|
||||
|
||||
string p = px.tostring( &session );
|
||||
|
||||
ssize_t bytes_sent = sendto( sock, p.data(), p.size(), 0,
|
||||
(sockaddr *)&remote_addr, sizeof( remote_addr ) );
|
||||
|
||||
if ( (bytes_sent < 0) && (errno == EMSGSIZE) ) {
|
||||
update_MTU();
|
||||
throw NetworkException( "Path MTU Discovery", EMSGSIZE );
|
||||
} else if ( bytes_sent == static_cast<int>( p.size() ) ) {
|
||||
return;
|
||||
} else {
|
||||
throw NetworkException( "sendto", errno );
|
||||
}
|
||||
}
|
||||
|
||||
string Connection::recv( void )
|
||||
{
|
||||
struct sockaddr_in packet_remote_addr;
|
||||
|
||||
char buf[ RECEIVE_MTU ];
|
||||
|
||||
socklen_t addrlen = sizeof( packet_remote_addr );
|
||||
|
||||
ssize_t received_len = recvfrom( sock, buf, RECEIVE_MTU, 0, (sockaddr *)&packet_remote_addr, &addrlen );
|
||||
|
||||
if ( received_len < 0 ) {
|
||||
throw NetworkException( "recvfrom", errno );
|
||||
}
|
||||
|
||||
if ( received_len > RECEIVE_MTU ) {
|
||||
char buffer[ 2048 ];
|
||||
snprintf( buffer, 2048, "Received oversize datagram (size %d) and limit is %d\n",
|
||||
static_cast<int>( received_len ), RECEIVE_MTU );
|
||||
throw NetworkException( buffer, errno );
|
||||
}
|
||||
|
||||
Packet p( string( buf, received_len ), &session );
|
||||
|
||||
dos_assert( p.direction == (server ? TO_SERVER : TO_CLIENT) ); /* prevent malicious playback to sender */
|
||||
|
||||
if ( p.seq >= expected_receiver_seq ) { /* don't use out-of-order packets for timestamp or targeting */
|
||||
expected_receiver_seq = p.seq + 1; /* this is security-sensitive because a replay attack could otherwise
|
||||
screw up the timestamp and targeting */
|
||||
|
||||
if ( p.timestamp != uint16_t(-1) ) {
|
||||
saved_timestamp = p.timestamp;
|
||||
saved_timestamp_received_at = timestamp();
|
||||
}
|
||||
|
||||
if ( p.timestamp_reply != uint16_t(-1) ) {
|
||||
uint16_t now = timestamp16();
|
||||
double R = timestamp_diff( now, p.timestamp_reply );
|
||||
|
||||
if ( R < 5000 ) { /* ignore large values, e.g. server was Ctrl-Zed */
|
||||
if ( !RTT_hit ) { /* first measurement */
|
||||
SRTT = R;
|
||||
RTTVAR = R / 2;
|
||||
RTT_hit = true;
|
||||
} else {
|
||||
const double alpha = 1.0 / 8.0;
|
||||
const double beta = 1.0 / 4.0;
|
||||
|
||||
RTTVAR = (1 - beta) * RTTVAR + ( beta * fabs( SRTT - R ) );
|
||||
SRTT = (1 - alpha) * SRTT + ( alpha * R );
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* auto-adjust to remote host */
|
||||
attached = true;
|
||||
|
||||
if ( (remote_addr.sin_addr.s_addr != packet_remote_addr.sin_addr.s_addr)
|
||||
|| (remote_addr.sin_port != packet_remote_addr.sin_port) ) {
|
||||
remote_addr = packet_remote_addr;
|
||||
if ( server ) {
|
||||
fprintf( stderr, "Server now attached to client at %s:%d\n",
|
||||
inet_ntoa( remote_addr.sin_addr ),
|
||||
ntohs( remote_addr.sin_port ) );
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return p.payload; /* we do return out-of-order or duplicated packets to caller */
|
||||
}
|
||||
|
||||
int Connection::port( void ) const
|
||||
{
|
||||
struct sockaddr_in local_addr;
|
||||
socklen_t addrlen = sizeof( local_addr );
|
||||
|
||||
if ( getsockname( sock, (sockaddr *)&local_addr, &addrlen ) < 0 ) {
|
||||
throw NetworkException( "getsockname", errno );
|
||||
}
|
||||
|
||||
return ntohs( local_addr.sin_port );
|
||||
}
|
||||
|
||||
uint64_t Network::timestamp( void )
|
||||
{
|
||||
struct timespec tp;
|
||||
|
||||
if ( clock_gettime( CLOCK_MONOTONIC, &tp ) < 0 ) {
|
||||
throw NetworkException( "clock_gettime", errno );
|
||||
}
|
||||
|
||||
uint64_t millis = tp.tv_nsec / 1000000;
|
||||
millis += uint64_t( tp.tv_sec ) * 1000;
|
||||
|
||||
return millis;
|
||||
}
|
||||
|
||||
uint16_t Network::timestamp16( void )
|
||||
{
|
||||
uint16_t ts = timestamp() % 65536;
|
||||
if ( ts == uint16_t(-1) ) {
|
||||
ts++;
|
||||
}
|
||||
return ts;
|
||||
}
|
||||
|
||||
uint16_t Network::timestamp_diff( uint16_t tsnew, uint16_t tsold )
|
||||
{
|
||||
int diff = tsnew - tsold;
|
||||
if ( diff < 0 ) {
|
||||
diff += 65536;
|
||||
}
|
||||
|
||||
assert( diff >= 0 );
|
||||
assert( diff <= 65535 );
|
||||
|
||||
return diff;
|
||||
}
|
||||
|
||||
uint64_t Connection::timeout( void ) const
|
||||
{
|
||||
uint64_t RTO = lrint( ceil( SRTT + 4 * RTTVAR ) );
|
||||
if ( RTO < MIN_RTO ) {
|
||||
RTO = MIN_RTO;
|
||||
} else if ( RTO > MAX_RTO ) {
|
||||
RTO = MAX_RTO;
|
||||
}
|
||||
return RTO;
|
||||
}
|
||||
|
||||
class Socket {
|
||||
public:
|
||||
int fd;
|
||||
|
||||
Socket( int domain, int type, int protocol )
|
||||
: fd( socket( domain, type, protocol ) )
|
||||
{
|
||||
if ( fd < 0 ) {
|
||||
throw NetworkException( "socket", errno );
|
||||
}
|
||||
}
|
||||
|
||||
~Socket()
|
||||
{
|
||||
if ( close( fd ) < 0 ) {
|
||||
throw NetworkException( "close", errno );
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
void Connection::update_MTU( void )
|
||||
{
|
||||
if ( !attached ) {
|
||||
return;
|
||||
}
|
||||
|
||||
/* We don't want to use our main socket because we don't want to have to connect it */
|
||||
Socket path_MTU_socket( AF_INET, SOCK_DGRAM, 0 );
|
||||
|
||||
/* Connect socket so we can retrieve path MTU */
|
||||
if ( connect( path_MTU_socket.fd, (sockaddr *)&remote_addr, sizeof( remote_addr ) ) < 0 ) {
|
||||
throw NetworkException( "connect", errno );
|
||||
}
|
||||
|
||||
int PMTU;
|
||||
socklen_t optlen = sizeof( PMTU );
|
||||
if ( getsockopt( path_MTU_socket.fd, IPPROTO_IP, IP_MTU, &PMTU, &optlen ) < 0 ) {
|
||||
throw NetworkException( "getsockopt", errno );
|
||||
}
|
||||
|
||||
if ( optlen != sizeof( PMTU ) ) {
|
||||
throw NetworkException( "Error getting path MTU", errno );
|
||||
}
|
||||
|
||||
MTU = min( PMTU, int(SEND_MTU) ); /* need cast to compile without optimization! XXX */
|
||||
}
|
||||
Reference in New Issue
Block a user