From 1b3443befd7b50da1b7ad39b29cf755af1b3f0f7 Mon Sep 17 00:00:00 2001 From: Keith Winstein Date: Fri, 5 Aug 2011 19:44:34 -0400 Subject: [PATCH] Simplify network.cpp to transmit only strings. --- crypto.cpp | 20 +++++++- crypto.hpp | 105 ++++++++++++++++++++++-------------------- decrypt.cpp | 2 + dos_assert.hpp | 2 +- encrypt.cpp | 15 +----- network.cpp | 106 ++++++++++++++++++------------------------- network.hpp | 73 ++++++++++++----------------- networktransport.hpp | 27 +++++++++++ ntester.cpp | 53 ++++++++++++++-------- templates.cpp | 5 +- 10 files changed, 211 insertions(+), 197 deletions(-) create mode 100644 networktransport.hpp diff --git a/crypto.cpp b/crypto.cpp index 6431328..724d199 100644 --- a/crypto.cpp +++ b/crypto.cpp @@ -5,9 +5,25 @@ #include "base64.h" using namespace std; +using namespace Crypto; const char rdev[] = "/dev/urandom"; +long int myatoi( char *str ) +{ + char *end; + + errno = 0; + long int ret = strtol( str, &end, 10 ); + + if ( ( errno != 0 ) + || ( end != str + strlen( str ) ) ) { + throw CryptoException( "Bad integer." ); + } + + return ret; +} + static void * sse_alloc( int len ) { void *ptr = NULL; @@ -178,7 +194,7 @@ Message Session::decrypt( string ciphertext ) int body_len = ciphertext.size() - 8; int pt_len = body_len - 16; - if ( pt_len <= 0 ) { /* super-assertion that does not equal AE_INVALID */ + if ( pt_len < 0 ) { /* super-assertion that pt_len does not equal AE_INVALID */ fprintf( stderr, "BUG.\n" ); exit( 1 ); } @@ -200,7 +216,7 @@ Message Session::decrypt( string ciphertext ) AE_FINALIZE ) ) { /* final */ free( plaintext ); free( body ); - throw CryptoException( "ae_decrypt() returned error." ); + throw CryptoException( "Packet failed integrity check." ); } Message ret( nonce, string( plaintext, pt_len ) ); diff --git a/crypto.hpp b/crypto.hpp index e3be49f..5f2be4a 100644 --- a/crypto.hpp +++ b/crypto.hpp @@ -3,63 +3,68 @@ #include "ae.hpp" #include +#include using namespace std; -class CryptoException { -public: - string text; - CryptoException( string s_text ) : text( s_text ) {}; -}; +long int myatoi( char *str ); -class Base64Key { -private: - unsigned char key[ 16 ]; +namespace Crypto { + class CryptoException { + public: + string text; + CryptoException( string s_text ) : text( s_text ) {}; + }; -public: - Base64Key(); /* random key */ - Base64Key( string printable_key ); - string printable_key( void ); - unsigned char *data( void ) { return key; } -}; + class Base64Key { + private: + unsigned char key[ 16 ]; -class Nonce { -private: - char bytes[ 12 ]; + public: + Base64Key(); /* random key */ + Base64Key( string printable_key ); + string printable_key( void ); + unsigned char *data( void ) { return key; } + }; -public: - Nonce( uint64_t val ); - Nonce( char *s_bytes, size_t len ); + class Nonce { + private: + char bytes[ 12 ]; - string cpp_str( void ) { return string( (char *)( bytes + 4 ), 8 ); } - char *data( void ) { return bytes; } - uint64_t val( void ); -}; - -class Message { -public: - Nonce nonce; - string text; - - Message( char *nonce_bytes, size_t nonce_len, - char *text_bytes, size_t text_len ); - Message( Nonce s_nonce, string s_text ); -}; - -class Session { -private: - Base64Key key; - ae_ctx *ctx; - -public: - Session( Base64Key s_key ); - ~Session(); - - string encrypt( Message plaintext ); - Message decrypt( string ciphertext ); - - Session( const Session & ); - Session & operator=( const Session & ); -}; + public: + Nonce( uint64_t val ); + Nonce( char *s_bytes, size_t len ); + + string cpp_str( void ) { return string( (char *)( bytes + 4 ), 8 ); } + char *data( void ) { return bytes; } + uint64_t val( void ); + }; + + class Message { + public: + Nonce nonce; + string text; + + Message( char *nonce_bytes, size_t nonce_len, + char *text_bytes, size_t text_len ); + Message( Nonce s_nonce, string s_text ); + }; + + class Session { + private: + Base64Key key; + ae_ctx *ctx; + + public: + Session( Base64Key s_key ); + ~Session(); + + string encrypt( Message plaintext ); + Message decrypt( string ciphertext ); + + Session( const Session & ); + Session & operator=( const Session & ); + }; +} #endif diff --git a/decrypt.cpp b/decrypt.cpp index 5f2d13b..9f69ab2 100644 --- a/decrypt.cpp +++ b/decrypt.cpp @@ -8,6 +8,8 @@ #include "crypto.hpp" +using namespace Crypto; + int main( int argc, char *argv[] ) { if ( argc != 2 ) { diff --git a/dos_assert.hpp b/dos_assert.hpp index 24939cc..1afae24 100644 --- a/dos_assert.hpp +++ b/dos_assert.hpp @@ -11,7 +11,7 @@ static void dos_detected( const char *expression, const char *file, int line, co char buffer[ 2048 ]; snprintf( buffer, 2048, "Illegal counterparty input (possible denial of service) in function %s at %s:%d, failed test: %s\n", function, file, line, expression ); - throw CryptoException( buffer ); + throw Crypto::CryptoException( buffer ); } #define dos_assert(expr) \ diff --git a/encrypt.cpp b/encrypt.cpp index b6d6d95..c6b70e8 100644 --- a/encrypt.cpp +++ b/encrypt.cpp @@ -8,20 +8,7 @@ #include "crypto.hpp" -long int myatoi( char *str ) -{ - char *end; - - errno = 0; - long int ret = strtol( str, &end, 10 ); - - if ( ( errno != 0 ) - || ( end != str + strlen( str ) ) ) { - throw CryptoException( "Bad integer." ); - } - - return ret; -} +using namespace Crypto; int main( int argc, char *argv[] ) { diff --git a/network.cpp b/network.cpp index 8e5cdf9..f22dbb8 100644 --- a/network.cpp +++ b/network.cpp @@ -10,50 +10,40 @@ using namespace std; using namespace Network; +using namespace Crypto; -template -Flow::Packet::DecodingCache::DecodingCache( string coded_packet, Session *session ) - : direction( TO_CLIENT ), seq( -1 ), payload_string() +/* Read in packet from coded string */ +Packet::Packet( string coded_packet, Session *session ) + : seq( -1 ), + direction( TO_SERVER ), + payload() { Message message = session->decrypt( coded_packet ); direction = (message.nonce.val() & 8000000000000000) ? TO_CLIENT : TO_SERVER; seq = message.nonce.val() & 0x7FFFFFFFFFFFFFFF; - payload_string = message.text; + payload = message.text; } -template -Flow::Packet::Packet( string coded_packet, Session *session ) - : decoding_cache( coded_packet, session ), - seq( decoding_cache.seq ), - direction( decoding_cache.direction ), - payload( decoding_cache.payload_string ) -{ - decoding_cache = DecodingCache(); -} - -template -string Flow::Packet::tostring( Session *session ) +/* Output coded string from packet */ +string Packet::tostring( Session *session ) { uint64_t direction_seq = (uint64_t( direction == TO_CLIENT ) << 63) | (seq & 0x7FFFFFFFFFFFFFFF); - return session->encrypt( Message( direction_seq, payload.tostring() ) ); + return session->encrypt( Message( direction_seq, payload ) ); } -template -typename Flow::Packet Flow::new_packet( Payload &s_payload ) +Packet Connection::new_packet( string &s_payload ) { return Packet( next_seq++, direction, s_payload ); } -template -void Connection::setup( void ) +void Connection::setup( void ) { /* create socket */ sock = socket( AF_INET, SOCK_DGRAM, 0 ); if ( sock < 0 ) { - perror( "socket" ); - exit( 1 ); + throw NetworkException( "socket", errno ); } /* Bind to free local port. @@ -63,21 +53,18 @@ void Connection::setup( void ) local_addr.sin_port = htons( 0 ); local_addr.sin_addr.s_addr = INADDR_ANY; if ( bind( sock, (sockaddr *)&local_addr, sizeof( local_addr ) ) < 0 ) { - perror( "bind" ); - exit( 1 ); + throw NetworkException( "bind", errno ); } /* Enable path MTU discovery */ char flag = IP_PMTUDISC_DO; socklen_t optlen = sizeof( flag ); if ( setsockopt( sock, IPPROTO_IP, IP_MTU_DISCOVER, &flag, optlen ) < 0 ) { - perror( "setsockopt" ); - exit( 1 ); + throw NetworkException( "setsockopt", errno ); } } -template -Connection::Connection() /* server */ +Connection::Connection() /* server */ : sock( -1 ), remote_addr(), server( true ), @@ -85,13 +72,13 @@ Connection::Connection() /* server */ MTU( RECEIVE_MTU ), key(), session( key ), - flow( TO_CLIENT, &session ) + direction( TO_CLIENT ), + next_seq( 0 ) { setup(); } -template -Connection::Connection( const char *key_str, const char *ip, int port ) /* client */ +Connection::Connection( const char *key_str, const char *ip, int port ) /* client */ : sock( -1 ), remote_addr(), server( false ), @@ -99,7 +86,8 @@ Connection::Connection( const char *key_str, const char *ip, MTU( RECEIVE_MTU ), key( key_str ), session( key ), - flow( TO_SERVER, &session ) + direction( TO_SERVER ), + next_seq( 0 ) { setup(); @@ -107,59 +95,53 @@ Connection::Connection( const char *key_str, const char *ip, remote_addr.sin_family = AF_INET; remote_addr.sin_port = htons( port ); if ( !inet_aton( ip, &remote_addr.sin_addr ) ) { - fprintf( stderr, "Bad IP address %s\n", ip ); - exit( 1 ); + int saved_errno = errno; + char buffer[ 2048 ]; + snprintf( buffer, 2048, "Bad IP address (%s)", ip ); + throw NetworkException( buffer, saved_errno ); } if ( connect( sock, (sockaddr *)&remote_addr, sizeof( remote_addr ) ) < 0 ) { - perror( "connect" ); - exit( 1 ); + throw NetworkException( "connect", errno ); } attached = true; } -template -void Connection::update_MTU( void ) +void Connection::update_MTU( void ) { socklen_t optlen = sizeof( MTU ); if ( getsockopt( sock, IPPROTO_IP, IP_MTU, &MTU, &optlen ) < 0 ) { - perror( "getsockopt" ); - exit( 1 ); + throw NetworkException( "getsockopt", errno ); } if ( optlen != sizeof( MTU ) ) { - fprintf( stderr, "Error getting path MTU.\n" ); - exit( 1 ); + throw NetworkException( "Error getting path MTU", errno ); } fprintf( stderr, "Path MTU: %d\n", MTU ); } -template -bool Connection::send( Outgoing &s ) +void Connection::send( string &s ) { assert( attached ); - string p = flow.new_packet( s ).tostring( &session ); + string p = new_packet( s ).tostring( &session ); ssize_t bytes_sent = sendto( sock, p.data(), p.size(), 0, (sockaddr *)&remote_addr, sizeof( remote_addr ) ); if ( (bytes_sent < 0) && (errno == EMSGSIZE) ) { update_MTU(); - return false; + throw MTUException( MTU ); } else if ( bytes_sent == static_cast( p.size() ) ) { - return true; + return; } else { - perror( "sendto" ); - exit( 1 ); - return false; + throw NetworkException( "sendto", errno ); } } -template -Incoming Connection::recv( void ) +string Connection::recv( void ) { struct sockaddr_in packet_remote_addr; @@ -170,17 +152,17 @@ Incoming Connection::recv( void ) ssize_t received_len = recvfrom( sock, buf, RECEIVE_MTU, 0, (sockaddr *)&packet_remote_addr, &addrlen ); if ( received_len < 0 ) { - perror( "recvfrom" ); - exit( 1 ); + throw NetworkException( "recvfrom", errno ); } if ( received_len > RECEIVE_MTU ) { - fprintf( stderr, "Received oversize datagram (size %d) and limit is %d.\n", - static_cast( received_len ), RECEIVE_MTU ); - exit( 1 ); + char buffer[ 2048 ]; + snprintf( buffer, 2048, "Received oversize datagram (size %d) and limit is %d\n", + static_cast( received_len ), RECEIVE_MTU ); + throw NetworkException( buffer, errno ); } - typename Flow::Packet p( string( buf, received_len ), &session ); + Packet p( string( buf, received_len ), &session ); dos_assert( p.direction == (server ? TO_SERVER : TO_CLIENT) ); /* prevent malicious playback to sender */ /* server auto-adjusts to client */ @@ -199,15 +181,13 @@ Incoming Connection::recv( void ) return p.payload; } -template -int Connection::port( void ) +int Connection::port( void ) { struct sockaddr_in local_addr; socklen_t addrlen = sizeof( local_addr ); if ( getsockname( sock, (sockaddr *)&local_addr, &addrlen ) < 0 ) { - perror( "getsockname" ); - exit( 1 ); + throw NetworkException( "getsockname", errno ); } return ntohs( local_addr.sin_port ); diff --git a/network.hpp b/network.hpp index dc1af85..a430615 100644 --- a/network.hpp +++ b/network.hpp @@ -10,57 +10,42 @@ #include "crypto.hpp" using namespace std; +using namespace Crypto; namespace Network { + class MTUException { + public: + int MTU; + MTUException( int s_MTU ) : MTU( s_MTU ) {}; + }; + + class NetworkException { + public: + string function; + int the_errno; + NetworkException( string s_function, int s_errno ) : function( s_function ), the_errno( s_errno ) {} + }; + enum Direction { TO_SERVER = 0, TO_CLIENT = 1 }; - template - class Flow { + class Packet { public: - class Packet { - private: - class DecodingCache - { - public: - Direction direction; - uint64_t seq; - string payload_string; - - DecodingCache( string coded_packet, Session *session ); - DecodingCache() : direction( TO_CLIENT ), seq( -1 ), payload_string() {} - }; - - DecodingCache decoding_cache; - - public: - uint64_t seq; - Direction direction; - Payload payload; - - Packet( uint64_t s_seq, Direction s_direction, Payload s_payload ) - : decoding_cache(), seq( s_seq ), direction( s_direction ), payload( s_payload ) - {} - - Packet( string coded_packet, Session *session ); - - string tostring( Session *session ); - }; - - uint64_t next_seq; + uint64_t seq; Direction direction; - Session *session; - - Flow( Direction s_direction, Session *s_session ) - : next_seq( 0 ), direction( s_direction ), session( s_session ) + string payload; + + Packet( uint64_t s_seq, Direction s_direction, string s_payload ) + : seq( s_seq ), direction( s_direction ), payload( s_payload ) {} - - Packet new_packet( Payload &s_payload ); + + Packet( string coded_packet, Session *session ); + + string tostring( Session *session ); }; - template class Connection { private: static const int RECEIVE_MTU = 2048; @@ -76,18 +61,20 @@ namespace Network { Base64Key key; Session session; - Flow flow; - void update_MTU( void ); void setup( void ); + Direction direction; + uint64_t next_seq; + Packet new_packet( string &s_payload ); + public: Connection(); Connection( const char *key_str, const char *ip, int port ); - bool send( Outgoing &s ); - Incoming recv( void ); + void send( string &s ); + string recv( void ); int fd( void ) { return sock; } int port( void ); int get_MTU( void ) { return MTU; } diff --git a/networktransport.hpp b/networktransport.hpp new file mode 100644 index 0000000..f840dd5 --- /dev/null +++ b/networktransport.hpp @@ -0,0 +1,27 @@ +#ifndef NETWORK_TRANSPORT_HPP +#define NETWORK_TRANSPORT_HPP + +#include + +using google::dense_hash_map; + +namespace Network { + template + class Transport + { + private: + Connection connection; + + uint64_t last_acknowledged_state; + uint64_t assumed_receiver_state; + uint64_t last_sent_state; + + public: + Transport(); + Transport( const char *key_str, const char *ip, int port ); + + + }; +}; + +#endif diff --git a/ntester.cpp b/ntester.cpp index 31c0aba..28f23b1 100644 --- a/ntester.cpp +++ b/ntester.cpp @@ -11,7 +11,7 @@ int main( int argc, char *argv[] ) char *ip; int port; - Network::Connection *n; + Network::Connection *n; try { if ( argc > 1 ) { @@ -22,9 +22,9 @@ int main( int argc, char *argv[] ) ip = argv[ 2 ]; port = atoi( argv[ 3 ] ); - n = new Network::Connection( key, ip, port ); + n = new Network::Connection( key, ip, port ); } else { - n = new Network::Connection(); + n = new Network::Connection(); } } catch ( CryptoException e ) { fprintf( stderr, "Fatal error: %s\n", e.text.c_str() ); @@ -36,34 +36,47 @@ int main( int argc, char *argv[] ) if ( server ) { while ( true ) { try { - KeyStroke s = n->recv(); - printf( "%c", s.letter ); + string s = n->recv(); + printf( "%s", s.c_str() ); fflush( NULL ); } catch ( CryptoException e ) { - fprintf( stderr, "Error: %s\n", e.text.c_str() ); + fprintf( stderr, "Cryptographic error: %s\n", e.text.c_str() ); } } } else { - struct termios the_termios; + struct termios saved_termios; + struct termios the_termios; - if ( tcgetattr( STDIN_FILENO, &the_termios ) < 0 ) { - perror( "tcgetattr" ); - exit( 1 ); - } + if ( tcgetattr( STDIN_FILENO, &the_termios ) < 0 ) { + perror( "tcgetattr" ); + exit( 1 ); + } - cfmakeraw( &the_termios ); + saved_termios = the_termios; - if ( tcsetattr( STDIN_FILENO, TCSANOW, &the_termios ) < 0 ) { - perror( "tcsetattr" ); - exit( 1 ); - } + cfmakeraw( &the_termios ); + + if ( tcsetattr( STDIN_FILENO, TCSANOW, &the_termios ) < 0 ) { + perror( "tcsetattr" ); + exit( 1 ); + } while( true ) { char x = getchar(); - - KeyStroke t( string( &x, 1 ) ); - - n->send( t ); + + string prefix = "Key(" + string( &x, 1 ) + ")"; + + try { + n->send( prefix ); + } catch ( Network::NetworkException e ) { + fprintf( stderr, "%s: %s\r\n", e.function.c_str(), strerror( e.the_errno ) ); + break; + } } + + if ( tcsetattr( STDIN_FILENO, TCSANOW, &saved_termios ) < 0 ) { + perror( "tcsetattr" ); + exit( 1 ); + } } } diff --git a/templates.cpp b/templates.cpp index 441ec7c..ff66ccd 100644 --- a/templates.cpp +++ b/templates.cpp @@ -6,9 +6,6 @@ #include "terminal.hpp" -#include "network.cpp" -#include "keystroke.hpp" - namespace Parser { class Action; } @@ -24,4 +21,4 @@ template class vector; template class vector; template class map; template class vector; -template class Network::Connection; +