From 7824318c5483fcec7acefbdc225d312020731aff Mon Sep 17 00:00:00 2001 From: Keith Winstein Date: Thu, 4 Aug 2011 04:52:47 -0400 Subject: [PATCH] Add crypto to existing network class --- dos_assert.hpp | 9 +++-- network.cpp | 105 ++++++++++++++++++++++++++----------------------- network.hpp | 30 +++++++++----- ntester.cpp | 24 ++++++----- 4 files changed, 96 insertions(+), 72 deletions(-) diff --git a/dos_assert.hpp b/dos_assert.hpp index f28ecf9..24939cc 100644 --- a/dos_assert.hpp +++ b/dos_assert.hpp @@ -4,11 +4,14 @@ #include #include +#include "crypto.hpp" + static void dos_detected( const char *expression, const char *file, int line, const char *function ) { - fprintf( stderr, "Illegal counterparty input (possible denial of service) in function %s at %s:%d, failed test: %s\n", - function, file, line, expression ); - exit( 1 ); + char buffer[ 2048 ]; + snprintf( buffer, 2048, "Illegal counterparty input (possible denial of service) in function %s at %s:%d, failed test: %s\n", + function, file, line, expression ); + throw CryptoException( buffer ); } #define dos_assert(expr) \ diff --git a/network.cpp b/network.cpp index 60c928b..8e5cdf9 100644 --- a/network.cpp +++ b/network.cpp @@ -6,30 +6,25 @@ #include "dos_assert.hpp" #include "network.hpp" +#include "crypto.hpp" using namespace std; using namespace Network; template -Flow::Packet::DecodingCache::DecodingCache( string coded_packet ) +Flow::Packet::DecodingCache::DecodingCache( string coded_packet, Session *session ) : direction( TO_CLIENT ), seq( -1 ), payload_string() { - dos_assert( coded_packet.size() >= 8 ); + Message message = session->decrypt( coded_packet ); - /* Read in sequence number and direction */ - string seq_string( coded_packet.begin(), coded_packet.begin() + 8 ); - uint64_t *network_order_seq = (uint64_t *)seq_string.data(); - uint64_t direction_seq = be64toh( *network_order_seq ); - direction = (direction_seq & 8000000000000000) ? TO_CLIENT : TO_SERVER; - seq = direction_seq & 0x7FFFFFFFFFFFFFFF; - - /* Read in payload */ - payload_string = string( coded_packet.begin() + 8, coded_packet.end() ); + direction = (message.nonce.val() & 8000000000000000) ? TO_CLIENT : TO_SERVER; + seq = message.nonce.val() & 0x7FFFFFFFFFFFFFFF; + payload_string = message.text; } template -Flow::Packet::Packet( string coded_packet ) - : decoding_cache( coded_packet ), +Flow::Packet::Packet( string coded_packet, Session *session ) + : decoding_cache( coded_packet, session ), seq( decoding_cache.seq ), direction( decoding_cache.direction ), payload( decoding_cache.payload_string ) @@ -38,14 +33,11 @@ Flow::Packet::Packet( string coded_packet ) } template -string Flow::Packet::tostring( void ) +string Flow::Packet::tostring( Session *session ) { uint64_t direction_seq = (uint64_t( direction == TO_CLIENT ) << 63) | (seq & 0x7FFFFFFFFFFFFFFF); - uint64_t network_order_seq = htobe64( direction_seq ); - const char *seq_str = (const char *)&network_order_seq; - string seq_string( seq_str, 8 ); /* necessary in case there is a zero byte */ - return seq_string + payload.tostring(); + return session->encrypt( Message( direction_seq, payload.tostring() ) ); } template @@ -55,15 +47,8 @@ typename Flow::Packet Flow::new_packet( Payload &s_payload ) } template -Connection::Connection( bool s_server ) - : flow( s_server ? TO_CLIENT : TO_SERVER ), - sock( -1 ), - remote_addr(), - server( s_server ), - attached( false ), - MTU( RECEIVE_MTU ) +void Connection::setup( void ) { - /* create socket */ sock = socket( AF_INET, SOCK_DGRAM, 0 ); if ( sock < 0 ) { @@ -88,7 +73,50 @@ Connection::Connection( bool s_server ) if ( setsockopt( sock, IPPROTO_IP, IP_MTU_DISCOVER, &flag, optlen ) < 0 ) { perror( "setsockopt" ); exit( 1 ); + } +} + +template +Connection::Connection() /* server */ + : sock( -1 ), + remote_addr(), + server( true ), + attached( false ), + MTU( RECEIVE_MTU ), + key(), + session( key ), + flow( TO_CLIENT, &session ) +{ + setup(); +} + +template +Connection::Connection( const char *key_str, const char *ip, int port ) /* client */ + : sock( -1 ), + remote_addr(), + server( false ), + attached( false ), + MTU( RECEIVE_MTU ), + key( key_str ), + session( key ), + flow( TO_SERVER, &session ) +{ + setup(); + + /* associate socket with remote host and port */ + remote_addr.sin_family = AF_INET; + remote_addr.sin_port = htons( port ); + if ( !inet_aton( ip, &remote_addr.sin_addr ) ) { + fprintf( stderr, "Bad IP address %s\n", ip ); + exit( 1 ); } + + if ( connect( sock, (sockaddr *)&remote_addr, sizeof( remote_addr ) ) < 0 ) { + perror( "connect" ); + exit( 1 ); + } + + attached = true; } template @@ -108,33 +136,12 @@ void Connection::update_MTU( void ) fprintf( stderr, "Path MTU: %d\n", MTU ); } -template -void Connection::client_connect( const char *ip, int port ) -{ - assert( !server ); - - /* associate socket with remote host and port */ - remote_addr.sin_family = AF_INET; - remote_addr.sin_port = htons( port ); - if ( !inet_aton( ip, &remote_addr.sin_addr ) ) { - fprintf( stderr, "Bad IP address %s\n", ip ); - exit( 1 ); - } - - if ( connect( sock, (sockaddr *)&remote_addr, sizeof( remote_addr ) ) < 0 ) { - perror( "connect" ); - exit( 1 ); - } - - attached = true; -} - template bool Connection::send( Outgoing &s ) { assert( attached ); - string p = flow.new_packet( s ).tostring(); + string p = flow.new_packet( s ).tostring( &session ); ssize_t bytes_sent = sendto( sock, p.data(), p.size(), 0, (sockaddr *)&remote_addr, sizeof( remote_addr ) ); @@ -173,7 +180,7 @@ Incoming Connection::recv( void ) exit( 1 ); } - typename Flow::Packet p( string( buf, received_len ) ); + typename Flow::Packet p( string( buf, received_len ), &session ); dos_assert( p.direction == (server ? TO_SERVER : TO_CLIENT) ); /* prevent malicious playback to sender */ /* server auto-adjusts to client */ diff --git a/network.hpp b/network.hpp index 3d1e2b7..dc1af85 100644 --- a/network.hpp +++ b/network.hpp @@ -7,6 +7,8 @@ #include #include +#include "crypto.hpp" + using namespace std; namespace Network { @@ -27,7 +29,7 @@ namespace Network { uint64_t seq; string payload_string; - DecodingCache( string coded_packet ); + DecodingCache( string coded_packet, Session *session ); DecodingCache() : direction( TO_CLIENT ), seq( -1 ), payload_string() {} }; @@ -42,16 +44,17 @@ namespace Network { : decoding_cache(), seq( s_seq ), direction( s_direction ), payload( s_payload ) {} - Packet( string coded_packet ); + Packet( string coded_packet, Session *session ); - string tostring( void ); + string tostring( Session *session ); }; uint64_t next_seq; - const Direction direction; + Direction direction; + Session *session; - Flow( Direction s_direction ) - : next_seq( 0 ), direction( s_direction ) + Flow( Direction s_direction, Session *s_session ) + : next_seq( 0 ), direction( s_direction ), session( s_session ) {} Packet new_packet( Payload &s_payload ); @@ -62,8 +65,6 @@ namespace Network { private: static const int RECEIVE_MTU = 2048; - Flow flow; - int sock; struct sockaddr_in remote_addr; @@ -72,16 +73,25 @@ namespace Network { int MTU; + Base64Key key; + Session session; + + Flow flow; + void update_MTU( void ); + void setup( void ); + public: - Connection( bool s_server ); + Connection(); + Connection( const char *key_str, const char *ip, int port ); - void client_connect( const char *ip, int port ); bool send( Outgoing &s ); Incoming recv( void ); int fd( void ) { return sock; } int port( void ); + int get_MTU( void ) { return MTU; } + string get_key( void ) { return key.printable_key(); } }; } diff --git a/ntester.cpp b/ntester.cpp index 9ea7fd9..e14551f 100644 --- a/ntester.cpp +++ b/ntester.cpp @@ -4,26 +4,30 @@ int main( int argc, char *argv[] ) { bool server = true; + char *key; char *ip; int port; + Network::Connection *n; + if ( argc > 1 ) { server = false; + /* client */ - ip = argv[ 1 ]; - port = atoi( argv[ 2 ] ); + key = argv[ 1 ]; + ip = argv[ 2 ]; + port = atoi( argv[ 3 ] ); + + n = new Network::Connection( key, ip, port ); + } else { + n = new Network::Connection(); } - Network::Connection n( server ); - fprintf( stderr, "Port bound is %d\n", n.port() ); - - if ( !server ) { - n.client_connect( ip, port ); - } + fprintf( stderr, "Port bound is %d, key is %s\n", n->port(), n->get_key().c_str() ); if ( server ) { while ( true ) { - KeyStroke s = n.recv(); + KeyStroke s = n->recv(); fprintf( stderr, "Got KeyStroke: %c\n", s.letter ); } @@ -33,7 +37,7 @@ int main( int argc, char *argv[] ) KeyStroke t( string( "x", 1 ) ); - n.send( t ); + n->send( t ); } } }