Simplify network.cpp to transmit only strings.

This commit is contained in:
Keith Winstein
2011-08-05 19:44:34 -04:00
parent bffc099754
commit 1b3443befd
10 changed files with 211 additions and 197 deletions
+18 -2
View File
@@ -5,9 +5,25 @@
#include "base64.h" #include "base64.h"
using namespace std; using namespace std;
using namespace Crypto;
const char rdev[] = "/dev/urandom"; const char rdev[] = "/dev/urandom";
long int myatoi( char *str )
{
char *end;
errno = 0;
long int ret = strtol( str, &end, 10 );
if ( ( errno != 0 )
|| ( end != str + strlen( str ) ) ) {
throw CryptoException( "Bad integer." );
}
return ret;
}
static void * sse_alloc( int len ) static void * sse_alloc( int len )
{ {
void *ptr = NULL; void *ptr = NULL;
@@ -178,7 +194,7 @@ Message Session::decrypt( string ciphertext )
int body_len = ciphertext.size() - 8; int body_len = ciphertext.size() - 8;
int pt_len = body_len - 16; int pt_len = body_len - 16;
if ( pt_len <= 0 ) { /* super-assertion that does not equal AE_INVALID */ if ( pt_len < 0 ) { /* super-assertion that pt_len does not equal AE_INVALID */
fprintf( stderr, "BUG.\n" ); fprintf( stderr, "BUG.\n" );
exit( 1 ); exit( 1 );
} }
@@ -200,7 +216,7 @@ Message Session::decrypt( string ciphertext )
AE_FINALIZE ) ) { /* final */ AE_FINALIZE ) ) { /* final */
free( plaintext ); free( plaintext );
free( body ); free( body );
throw CryptoException( "ae_decrypt() returned error." ); throw CryptoException( "Packet failed integrity check." );
} }
Message ret( nonce, string( plaintext, pt_len ) ); Message ret( nonce, string( plaintext, pt_len ) );
+55 -50
View File
@@ -3,63 +3,68 @@
#include "ae.hpp" #include "ae.hpp"
#include <string> #include <string>
#include <string.h>
using namespace std; using namespace std;
class CryptoException { long int myatoi( char *str );
public:
string text;
CryptoException( string s_text ) : text( s_text ) {};
};
class Base64Key { namespace Crypto {
private: class CryptoException {
unsigned char key[ 16 ]; public:
string text;
CryptoException( string s_text ) : text( s_text ) {};
};
public: class Base64Key {
Base64Key(); /* random key */ private:
Base64Key( string printable_key ); unsigned char key[ 16 ];
string printable_key( void );
unsigned char *data( void ) { return key; }
};
class Nonce { public:
private: Base64Key(); /* random key */
char bytes[ 12 ]; Base64Key( string printable_key );
string printable_key( void );
unsigned char *data( void ) { return key; }
};
public: class Nonce {
Nonce( uint64_t val ); private:
Nonce( char *s_bytes, size_t len ); char bytes[ 12 ];
string cpp_str( void ) { return string( (char *)( bytes + 4 ), 8 ); } public:
char *data( void ) { return bytes; } Nonce( uint64_t val );
uint64_t val( void ); Nonce( char *s_bytes, size_t len );
};
string cpp_str( void ) { return string( (char *)( bytes + 4 ), 8 ); }
class Message { char *data( void ) { return bytes; }
public: uint64_t val( void );
Nonce nonce; };
string text;
class Message {
Message( char *nonce_bytes, size_t nonce_len, public:
char *text_bytes, size_t text_len ); Nonce nonce;
Message( Nonce s_nonce, string s_text ); string text;
};
Message( char *nonce_bytes, size_t nonce_len,
class Session { char *text_bytes, size_t text_len );
private: Message( Nonce s_nonce, string s_text );
Base64Key key; };
ae_ctx *ctx;
class Session {
public: private:
Session( Base64Key s_key ); Base64Key key;
~Session(); ae_ctx *ctx;
string encrypt( Message plaintext ); public:
Message decrypt( string ciphertext ); Session( Base64Key s_key );
~Session();
Session( const Session & );
Session & operator=( const Session & ); string encrypt( Message plaintext );
}; Message decrypt( string ciphertext );
Session( const Session & );
Session & operator=( const Session & );
};
}
#endif #endif
+2
View File
@@ -8,6 +8,8 @@
#include "crypto.hpp" #include "crypto.hpp"
using namespace Crypto;
int main( int argc, char *argv[] ) int main( int argc, char *argv[] )
{ {
if ( argc != 2 ) { if ( argc != 2 ) {
+1 -1
View File
@@ -11,7 +11,7 @@ static void dos_detected( const char *expression, const char *file, int line, co
char buffer[ 2048 ]; char buffer[ 2048 ];
snprintf( buffer, 2048, "Illegal counterparty input (possible denial of service) in function %s at %s:%d, failed test: %s\n", snprintf( buffer, 2048, "Illegal counterparty input (possible denial of service) in function %s at %s:%d, failed test: %s\n",
function, file, line, expression ); function, file, line, expression );
throw CryptoException( buffer ); throw Crypto::CryptoException( buffer );
} }
#define dos_assert(expr) \ #define dos_assert(expr) \
+1 -14
View File
@@ -8,20 +8,7 @@
#include "crypto.hpp" #include "crypto.hpp"
long int myatoi( char *str ) using namespace Crypto;
{
char *end;
errno = 0;
long int ret = strtol( str, &end, 10 );
if ( ( errno != 0 )
|| ( end != str + strlen( str ) ) ) {
throw CryptoException( "Bad integer." );
}
return ret;
}
int main( int argc, char *argv[] ) int main( int argc, char *argv[] )
{ {
+43 -63
View File
@@ -10,50 +10,40 @@
using namespace std; using namespace std;
using namespace Network; using namespace Network;
using namespace Crypto;
template <class Payload> /* Read in packet from coded string */
Flow<Payload>::Packet::DecodingCache::DecodingCache( string coded_packet, Session *session ) Packet::Packet( string coded_packet, Session *session )
: direction( TO_CLIENT ), seq( -1 ), payload_string() : seq( -1 ),
direction( TO_SERVER ),
payload()
{ {
Message message = session->decrypt( coded_packet ); Message message = session->decrypt( coded_packet );
direction = (message.nonce.val() & 8000000000000000) ? TO_CLIENT : TO_SERVER; direction = (message.nonce.val() & 8000000000000000) ? TO_CLIENT : TO_SERVER;
seq = message.nonce.val() & 0x7FFFFFFFFFFFFFFF; seq = message.nonce.val() & 0x7FFFFFFFFFFFFFFF;
payload_string = message.text; payload = message.text;
} }
template <class Payload> /* Output coded string from packet */
Flow<Payload>::Packet::Packet( string coded_packet, Session *session ) string Packet::tostring( Session *session )
: decoding_cache( coded_packet, session ),
seq( decoding_cache.seq ),
direction( decoding_cache.direction ),
payload( decoding_cache.payload_string )
{
decoding_cache = DecodingCache();
}
template <class Payload>
string Flow<Payload>::Packet::tostring( Session *session )
{ {
uint64_t direction_seq = (uint64_t( direction == TO_CLIENT ) << 63) | (seq & 0x7FFFFFFFFFFFFFFF); uint64_t direction_seq = (uint64_t( direction == TO_CLIENT ) << 63) | (seq & 0x7FFFFFFFFFFFFFFF);
return session->encrypt( Message( direction_seq, payload.tostring() ) ); return session->encrypt( Message( direction_seq, payload ) );
} }
template <class Payload> Packet Connection::new_packet( string &s_payload )
typename Flow<Payload>::Packet Flow<Payload>::new_packet( Payload &s_payload )
{ {
return Packet( next_seq++, direction, s_payload ); return Packet( next_seq++, direction, s_payload );
} }
template <class Outgoing, class Incoming> void Connection::setup( void )
void Connection<Outgoing, Incoming>::setup( void )
{ {
/* create socket */ /* create socket */
sock = socket( AF_INET, SOCK_DGRAM, 0 ); sock = socket( AF_INET, SOCK_DGRAM, 0 );
if ( sock < 0 ) { if ( sock < 0 ) {
perror( "socket" ); throw NetworkException( "socket", errno );
exit( 1 );
} }
/* Bind to free local port. /* Bind to free local port.
@@ -63,21 +53,18 @@ void Connection<Outgoing, Incoming>::setup( void )
local_addr.sin_port = htons( 0 ); local_addr.sin_port = htons( 0 );
local_addr.sin_addr.s_addr = INADDR_ANY; local_addr.sin_addr.s_addr = INADDR_ANY;
if ( bind( sock, (sockaddr *)&local_addr, sizeof( local_addr ) ) < 0 ) { if ( bind( sock, (sockaddr *)&local_addr, sizeof( local_addr ) ) < 0 ) {
perror( "bind" ); throw NetworkException( "bind", errno );
exit( 1 );
} }
/* Enable path MTU discovery */ /* Enable path MTU discovery */
char flag = IP_PMTUDISC_DO; char flag = IP_PMTUDISC_DO;
socklen_t optlen = sizeof( flag ); socklen_t optlen = sizeof( flag );
if ( setsockopt( sock, IPPROTO_IP, IP_MTU_DISCOVER, &flag, optlen ) < 0 ) { if ( setsockopt( sock, IPPROTO_IP, IP_MTU_DISCOVER, &flag, optlen ) < 0 ) {
perror( "setsockopt" ); throw NetworkException( "setsockopt", errno );
exit( 1 );
} }
} }
template <class Outgoing, class Incoming> Connection::Connection() /* server */
Connection<Outgoing, Incoming>::Connection() /* server */
: sock( -1 ), : sock( -1 ),
remote_addr(), remote_addr(),
server( true ), server( true ),
@@ -85,13 +72,13 @@ Connection<Outgoing, Incoming>::Connection() /* server */
MTU( RECEIVE_MTU ), MTU( RECEIVE_MTU ),
key(), key(),
session( key ), session( key ),
flow( TO_CLIENT, &session ) direction( TO_CLIENT ),
next_seq( 0 )
{ {
setup(); setup();
} }
template <class Outgoing, class Incoming> Connection::Connection( const char *key_str, const char *ip, int port ) /* client */
Connection<Outgoing, Incoming>::Connection( const char *key_str, const char *ip, int port ) /* client */
: sock( -1 ), : sock( -1 ),
remote_addr(), remote_addr(),
server( false ), server( false ),
@@ -99,7 +86,8 @@ Connection<Outgoing, Incoming>::Connection( const char *key_str, const char *ip,
MTU( RECEIVE_MTU ), MTU( RECEIVE_MTU ),
key( key_str ), key( key_str ),
session( key ), session( key ),
flow( TO_SERVER, &session ) direction( TO_SERVER ),
next_seq( 0 )
{ {
setup(); setup();
@@ -107,59 +95,53 @@ Connection<Outgoing, Incoming>::Connection( const char *key_str, const char *ip,
remote_addr.sin_family = AF_INET; remote_addr.sin_family = AF_INET;
remote_addr.sin_port = htons( port ); remote_addr.sin_port = htons( port );
if ( !inet_aton( ip, &remote_addr.sin_addr ) ) { if ( !inet_aton( ip, &remote_addr.sin_addr ) ) {
fprintf( stderr, "Bad IP address %s\n", ip ); int saved_errno = errno;
exit( 1 ); char buffer[ 2048 ];
snprintf( buffer, 2048, "Bad IP address (%s)", ip );
throw NetworkException( buffer, saved_errno );
} }
if ( connect( sock, (sockaddr *)&remote_addr, sizeof( remote_addr ) ) < 0 ) { if ( connect( sock, (sockaddr *)&remote_addr, sizeof( remote_addr ) ) < 0 ) {
perror( "connect" ); throw NetworkException( "connect", errno );
exit( 1 );
} }
attached = true; attached = true;
} }
template <class Outgoing, class Incoming> void Connection::update_MTU( void )
void Connection<Outgoing, Incoming>::update_MTU( void )
{ {
socklen_t optlen = sizeof( MTU ); socklen_t optlen = sizeof( MTU );
if ( getsockopt( sock, IPPROTO_IP, IP_MTU, &MTU, &optlen ) < 0 ) { if ( getsockopt( sock, IPPROTO_IP, IP_MTU, &MTU, &optlen ) < 0 ) {
perror( "getsockopt" ); throw NetworkException( "getsockopt", errno );
exit( 1 );
} }
if ( optlen != sizeof( MTU ) ) { if ( optlen != sizeof( MTU ) ) {
fprintf( stderr, "Error getting path MTU.\n" ); throw NetworkException( "Error getting path MTU", errno );
exit( 1 );
} }
fprintf( stderr, "Path MTU: %d\n", MTU ); fprintf( stderr, "Path MTU: %d\n", MTU );
} }
template <class Outgoing, class Incoming> void Connection::send( string &s )
bool Connection<Outgoing, Incoming>::send( Outgoing &s )
{ {
assert( attached ); assert( attached );
string p = flow.new_packet( s ).tostring( &session ); string p = new_packet( s ).tostring( &session );
ssize_t bytes_sent = sendto( sock, p.data(), p.size(), 0, ssize_t bytes_sent = sendto( sock, p.data(), p.size(), 0,
(sockaddr *)&remote_addr, sizeof( remote_addr ) ); (sockaddr *)&remote_addr, sizeof( remote_addr ) );
if ( (bytes_sent < 0) && (errno == EMSGSIZE) ) { if ( (bytes_sent < 0) && (errno == EMSGSIZE) ) {
update_MTU(); update_MTU();
return false; throw MTUException( MTU );
} else if ( bytes_sent == static_cast<int>( p.size() ) ) { } else if ( bytes_sent == static_cast<int>( p.size() ) ) {
return true; return;
} else { } else {
perror( "sendto" ); throw NetworkException( "sendto", errno );
exit( 1 );
return false;
} }
} }
template <class Outgoing, class Incoming> string Connection::recv( void )
Incoming Connection<Outgoing, Incoming>::recv( void )
{ {
struct sockaddr_in packet_remote_addr; struct sockaddr_in packet_remote_addr;
@@ -170,17 +152,17 @@ Incoming Connection<Outgoing, Incoming>::recv( void )
ssize_t received_len = recvfrom( sock, buf, RECEIVE_MTU, 0, (sockaddr *)&packet_remote_addr, &addrlen ); ssize_t received_len = recvfrom( sock, buf, RECEIVE_MTU, 0, (sockaddr *)&packet_remote_addr, &addrlen );
if ( received_len < 0 ) { if ( received_len < 0 ) {
perror( "recvfrom" ); throw NetworkException( "recvfrom", errno );
exit( 1 );
} }
if ( received_len > RECEIVE_MTU ) { if ( received_len > RECEIVE_MTU ) {
fprintf( stderr, "Received oversize datagram (size %d) and limit is %d.\n", char buffer[ 2048 ];
static_cast<int>( received_len ), RECEIVE_MTU ); snprintf( buffer, 2048, "Received oversize datagram (size %d) and limit is %d\n",
exit( 1 ); static_cast<int>( received_len ), RECEIVE_MTU );
throw NetworkException( buffer, errno );
} }
typename Flow<Incoming>::Packet p( string( buf, received_len ), &session ); Packet p( string( buf, received_len ), &session );
dos_assert( p.direction == (server ? TO_SERVER : TO_CLIENT) ); /* prevent malicious playback to sender */ dos_assert( p.direction == (server ? TO_SERVER : TO_CLIENT) ); /* prevent malicious playback to sender */
/* server auto-adjusts to client */ /* server auto-adjusts to client */
@@ -199,15 +181,13 @@ Incoming Connection<Outgoing, Incoming>::recv( void )
return p.payload; return p.payload;
} }
template <class Outgoing, class Incoming> int Connection::port( void )
int Connection<Outgoing, Incoming>::port( void )
{ {
struct sockaddr_in local_addr; struct sockaddr_in local_addr;
socklen_t addrlen = sizeof( local_addr ); socklen_t addrlen = sizeof( local_addr );
if ( getsockname( sock, (sockaddr *)&local_addr, &addrlen ) < 0 ) { if ( getsockname( sock, (sockaddr *)&local_addr, &addrlen ) < 0 ) {
perror( "getsockname" ); throw NetworkException( "getsockname", errno );
exit( 1 );
} }
return ntohs( local_addr.sin_port ); return ntohs( local_addr.sin_port );
+30 -43
View File
@@ -10,57 +10,42 @@
#include "crypto.hpp" #include "crypto.hpp"
using namespace std; using namespace std;
using namespace Crypto;
namespace Network { namespace Network {
class MTUException {
public:
int MTU;
MTUException( int s_MTU ) : MTU( s_MTU ) {};
};
class NetworkException {
public:
string function;
int the_errno;
NetworkException( string s_function, int s_errno ) : function( s_function ), the_errno( s_errno ) {}
};
enum Direction { enum Direction {
TO_SERVER = 0, TO_SERVER = 0,
TO_CLIENT = 1 TO_CLIENT = 1
}; };
template <class Payload> class Packet {
class Flow {
public: public:
class Packet { uint64_t seq;
private:
class DecodingCache
{
public:
Direction direction;
uint64_t seq;
string payload_string;
DecodingCache( string coded_packet, Session *session );
DecodingCache() : direction( TO_CLIENT ), seq( -1 ), payload_string() {}
};
DecodingCache decoding_cache;
public:
uint64_t seq;
Direction direction;
Payload payload;
Packet( uint64_t s_seq, Direction s_direction, Payload s_payload )
: decoding_cache(), seq( s_seq ), direction( s_direction ), payload( s_payload )
{}
Packet( string coded_packet, Session *session );
string tostring( Session *session );
};
uint64_t next_seq;
Direction direction; Direction direction;
Session *session; string payload;
Flow( Direction s_direction, Session *s_session ) Packet( uint64_t s_seq, Direction s_direction, string s_payload )
: next_seq( 0 ), direction( s_direction ), session( s_session ) : seq( s_seq ), direction( s_direction ), payload( s_payload )
{} {}
Packet new_packet( Payload &s_payload ); Packet( string coded_packet, Session *session );
string tostring( Session *session );
}; };
template <class Outgoing, class Incoming>
class Connection { class Connection {
private: private:
static const int RECEIVE_MTU = 2048; static const int RECEIVE_MTU = 2048;
@@ -76,18 +61,20 @@ namespace Network {
Base64Key key; Base64Key key;
Session session; Session session;
Flow<Outgoing> flow;
void update_MTU( void ); void update_MTU( void );
void setup( void ); void setup( void );
Direction direction;
uint64_t next_seq;
Packet new_packet( string &s_payload );
public: public:
Connection(); Connection();
Connection( const char *key_str, const char *ip, int port ); Connection( const char *key_str, const char *ip, int port );
bool send( Outgoing &s ); void send( string &s );
Incoming recv( void ); string recv( void );
int fd( void ) { return sock; } int fd( void ) { return sock; }
int port( void ); int port( void );
int get_MTU( void ) { return MTU; } int get_MTU( void ) { return MTU; }
+27
View File
@@ -0,0 +1,27 @@
#ifndef NETWORK_TRANSPORT_HPP
#define NETWORK_TRANSPORT_HPP
#include <google/dense_hash_map>
using google::dense_hash_map;
namespace Network {
template <class MyState, class RemoteState>
class Transport
{
private:
Connection<typename MyState::Conveyance, typename RemoteState::Conveyance> connection;
uint64_t last_acknowledged_state;
uint64_t assumed_receiver_state;
uint64_t last_sent_state;
public:
Transport();
Transport( const char *key_str, const char *ip, int port );
};
};
#endif
+33 -20
View File
@@ -11,7 +11,7 @@ int main( int argc, char *argv[] )
char *ip; char *ip;
int port; int port;
Network::Connection<KeyStroke, KeyStroke> *n; Network::Connection *n;
try { try {
if ( argc > 1 ) { if ( argc > 1 ) {
@@ -22,9 +22,9 @@ int main( int argc, char *argv[] )
ip = argv[ 2 ]; ip = argv[ 2 ];
port = atoi( argv[ 3 ] ); port = atoi( argv[ 3 ] );
n = new Network::Connection<KeyStroke, KeyStroke>( key, ip, port ); n = new Network::Connection( key, ip, port );
} else { } else {
n = new Network::Connection<KeyStroke, KeyStroke>(); n = new Network::Connection();
} }
} catch ( CryptoException e ) { } catch ( CryptoException e ) {
fprintf( stderr, "Fatal error: %s\n", e.text.c_str() ); fprintf( stderr, "Fatal error: %s\n", e.text.c_str() );
@@ -36,34 +36,47 @@ int main( int argc, char *argv[] )
if ( server ) { if ( server ) {
while ( true ) { while ( true ) {
try { try {
KeyStroke s = n->recv(); string s = n->recv();
printf( "%c", s.letter ); printf( "%s", s.c_str() );
fflush( NULL ); fflush( NULL );
} catch ( CryptoException e ) { } catch ( CryptoException e ) {
fprintf( stderr, "Error: %s\n", e.text.c_str() ); fprintf( stderr, "Cryptographic error: %s\n", e.text.c_str() );
} }
} }
} else { } else {
struct termios the_termios; struct termios saved_termios;
struct termios the_termios;
if ( tcgetattr( STDIN_FILENO, &the_termios ) < 0 ) { if ( tcgetattr( STDIN_FILENO, &the_termios ) < 0 ) {
perror( "tcgetattr" ); perror( "tcgetattr" );
exit( 1 ); exit( 1 );
} }
cfmakeraw( &the_termios ); saved_termios = the_termios;
if ( tcsetattr( STDIN_FILENO, TCSANOW, &the_termios ) < 0 ) { cfmakeraw( &the_termios );
perror( "tcsetattr" );
exit( 1 ); if ( tcsetattr( STDIN_FILENO, TCSANOW, &the_termios ) < 0 ) {
} perror( "tcsetattr" );
exit( 1 );
}
while( true ) { while( true ) {
char x = getchar(); char x = getchar();
KeyStroke t( string( &x, 1 ) ); string prefix = "Key(" + string( &x, 1 ) + ")";
n->send( t ); try {
n->send( prefix );
} catch ( Network::NetworkException e ) {
fprintf( stderr, "%s: %s\r\n", e.function.c_str(), strerror( e.the_errno ) );
break;
}
} }
if ( tcsetattr( STDIN_FILENO, TCSANOW, &saved_termios ) < 0 ) {
perror( "tcsetattr" );
exit( 1 );
}
} }
} }
+1 -4
View File
@@ -6,9 +6,6 @@
#include "terminal.hpp" #include "terminal.hpp"
#include "network.cpp"
#include "keystroke.hpp"
namespace Parser { namespace Parser {
class Action; class Action;
} }
@@ -24,4 +21,4 @@ template class vector<wchar_t>;
template class vector<int>; template class vector<int>;
template class map<string, Function>; template class map<string, Function>;
template class vector<bool>; template class vector<bool>;
template class Network::Connection<KeyStroke, KeyStroke>;