Simplify network.cpp to transmit only strings.

This commit is contained in:
Keith Winstein
2011-08-05 19:44:34 -04:00
parent bffc099754
commit 1b3443befd
10 changed files with 211 additions and 197 deletions
+18 -2
View File
@@ -5,9 +5,25 @@
#include "base64.h" #include "base64.h"
using namespace std; using namespace std;
using namespace Crypto;
const char rdev[] = "/dev/urandom"; const char rdev[] = "/dev/urandom";
long int myatoi( char *str )
{
char *end;
errno = 0;
long int ret = strtol( str, &end, 10 );
if ( ( errno != 0 )
|| ( end != str + strlen( str ) ) ) {
throw CryptoException( "Bad integer." );
}
return ret;
}
static void * sse_alloc( int len ) static void * sse_alloc( int len )
{ {
void *ptr = NULL; void *ptr = NULL;
@@ -178,7 +194,7 @@ Message Session::decrypt( string ciphertext )
int body_len = ciphertext.size() - 8; int body_len = ciphertext.size() - 8;
int pt_len = body_len - 16; int pt_len = body_len - 16;
if ( pt_len <= 0 ) { /* super-assertion that does not equal AE_INVALID */ if ( pt_len < 0 ) { /* super-assertion that pt_len does not equal AE_INVALID */
fprintf( stderr, "BUG.\n" ); fprintf( stderr, "BUG.\n" );
exit( 1 ); exit( 1 );
} }
@@ -200,7 +216,7 @@ Message Session::decrypt( string ciphertext )
AE_FINALIZE ) ) { /* final */ AE_FINALIZE ) ) { /* final */
free( plaintext ); free( plaintext );
free( body ); free( body );
throw CryptoException( "ae_decrypt() returned error." ); throw CryptoException( "Packet failed integrity check." );
} }
Message ret( nonce, string( plaintext, pt_len ) ); Message ret( nonce, string( plaintext, pt_len ) );
+5
View File
@@ -3,9 +3,13 @@
#include "ae.hpp" #include "ae.hpp"
#include <string> #include <string>
#include <string.h>
using namespace std; using namespace std;
long int myatoi( char *str );
namespace Crypto {
class CryptoException { class CryptoException {
public: public:
string text; string text;
@@ -61,5 +65,6 @@ public:
Session( const Session & ); Session( const Session & );
Session & operator=( const Session & ); Session & operator=( const Session & );
}; };
}
#endif #endif
+2
View File
@@ -8,6 +8,8 @@
#include "crypto.hpp" #include "crypto.hpp"
using namespace Crypto;
int main( int argc, char *argv[] ) int main( int argc, char *argv[] )
{ {
if ( argc != 2 ) { if ( argc != 2 ) {
+1 -1
View File
@@ -11,7 +11,7 @@ static void dos_detected( const char *expression, const char *file, int line, co
char buffer[ 2048 ]; char buffer[ 2048 ];
snprintf( buffer, 2048, "Illegal counterparty input (possible denial of service) in function %s at %s:%d, failed test: %s\n", snprintf( buffer, 2048, "Illegal counterparty input (possible denial of service) in function %s at %s:%d, failed test: %s\n",
function, file, line, expression ); function, file, line, expression );
throw CryptoException( buffer ); throw Crypto::CryptoException( buffer );
} }
#define dos_assert(expr) \ #define dos_assert(expr) \
+1 -14
View File
@@ -8,20 +8,7 @@
#include "crypto.hpp" #include "crypto.hpp"
long int myatoi( char *str ) using namespace Crypto;
{
char *end;
errno = 0;
long int ret = strtol( str, &end, 10 );
if ( ( errno != 0 )
|| ( end != str + strlen( str ) ) ) {
throw CryptoException( "Bad integer." );
}
return ret;
}
int main( int argc, char *argv[] ) int main( int argc, char *argv[] )
{ {
+42 -62
View File
@@ -10,50 +10,40 @@
using namespace std; using namespace std;
using namespace Network; using namespace Network;
using namespace Crypto;
template <class Payload> /* Read in packet from coded string */
Flow<Payload>::Packet::DecodingCache::DecodingCache( string coded_packet, Session *session ) Packet::Packet( string coded_packet, Session *session )
: direction( TO_CLIENT ), seq( -1 ), payload_string() : seq( -1 ),
direction( TO_SERVER ),
payload()
{ {
Message message = session->decrypt( coded_packet ); Message message = session->decrypt( coded_packet );
direction = (message.nonce.val() & 8000000000000000) ? TO_CLIENT : TO_SERVER; direction = (message.nonce.val() & 8000000000000000) ? TO_CLIENT : TO_SERVER;
seq = message.nonce.val() & 0x7FFFFFFFFFFFFFFF; seq = message.nonce.val() & 0x7FFFFFFFFFFFFFFF;
payload_string = message.text; payload = message.text;
} }
template <class Payload> /* Output coded string from packet */
Flow<Payload>::Packet::Packet( string coded_packet, Session *session ) string Packet::tostring( Session *session )
: decoding_cache( coded_packet, session ),
seq( decoding_cache.seq ),
direction( decoding_cache.direction ),
payload( decoding_cache.payload_string )
{
decoding_cache = DecodingCache();
}
template <class Payload>
string Flow<Payload>::Packet::tostring( Session *session )
{ {
uint64_t direction_seq = (uint64_t( direction == TO_CLIENT ) << 63) | (seq & 0x7FFFFFFFFFFFFFFF); uint64_t direction_seq = (uint64_t( direction == TO_CLIENT ) << 63) | (seq & 0x7FFFFFFFFFFFFFFF);
return session->encrypt( Message( direction_seq, payload.tostring() ) ); return session->encrypt( Message( direction_seq, payload ) );
} }
template <class Payload> Packet Connection::new_packet( string &s_payload )
typename Flow<Payload>::Packet Flow<Payload>::new_packet( Payload &s_payload )
{ {
return Packet( next_seq++, direction, s_payload ); return Packet( next_seq++, direction, s_payload );
} }
template <class Outgoing, class Incoming> void Connection::setup( void )
void Connection<Outgoing, Incoming>::setup( void )
{ {
/* create socket */ /* create socket */
sock = socket( AF_INET, SOCK_DGRAM, 0 ); sock = socket( AF_INET, SOCK_DGRAM, 0 );
if ( sock < 0 ) { if ( sock < 0 ) {
perror( "socket" ); throw NetworkException( "socket", errno );
exit( 1 );
} }
/* Bind to free local port. /* Bind to free local port.
@@ -63,21 +53,18 @@ void Connection<Outgoing, Incoming>::setup( void )
local_addr.sin_port = htons( 0 ); local_addr.sin_port = htons( 0 );
local_addr.sin_addr.s_addr = INADDR_ANY; local_addr.sin_addr.s_addr = INADDR_ANY;
if ( bind( sock, (sockaddr *)&local_addr, sizeof( local_addr ) ) < 0 ) { if ( bind( sock, (sockaddr *)&local_addr, sizeof( local_addr ) ) < 0 ) {
perror( "bind" ); throw NetworkException( "bind", errno );
exit( 1 );
} }
/* Enable path MTU discovery */ /* Enable path MTU discovery */
char flag = IP_PMTUDISC_DO; char flag = IP_PMTUDISC_DO;
socklen_t optlen = sizeof( flag ); socklen_t optlen = sizeof( flag );
if ( setsockopt( sock, IPPROTO_IP, IP_MTU_DISCOVER, &flag, optlen ) < 0 ) { if ( setsockopt( sock, IPPROTO_IP, IP_MTU_DISCOVER, &flag, optlen ) < 0 ) {
perror( "setsockopt" ); throw NetworkException( "setsockopt", errno );
exit( 1 );
} }
} }
template <class Outgoing, class Incoming> Connection::Connection() /* server */
Connection<Outgoing, Incoming>::Connection() /* server */
: sock( -1 ), : sock( -1 ),
remote_addr(), remote_addr(),
server( true ), server( true ),
@@ -85,13 +72,13 @@ Connection<Outgoing, Incoming>::Connection() /* server */
MTU( RECEIVE_MTU ), MTU( RECEIVE_MTU ),
key(), key(),
session( key ), session( key ),
flow( TO_CLIENT, &session ) direction( TO_CLIENT ),
next_seq( 0 )
{ {
setup(); setup();
} }
template <class Outgoing, class Incoming> Connection::Connection( const char *key_str, const char *ip, int port ) /* client */
Connection<Outgoing, Incoming>::Connection( const char *key_str, const char *ip, int port ) /* client */
: sock( -1 ), : sock( -1 ),
remote_addr(), remote_addr(),
server( false ), server( false ),
@@ -99,7 +86,8 @@ Connection<Outgoing, Incoming>::Connection( const char *key_str, const char *ip,
MTU( RECEIVE_MTU ), MTU( RECEIVE_MTU ),
key( key_str ), key( key_str ),
session( key ), session( key ),
flow( TO_SERVER, &session ) direction( TO_SERVER ),
next_seq( 0 )
{ {
setup(); setup();
@@ -107,59 +95,53 @@ Connection<Outgoing, Incoming>::Connection( const char *key_str, const char *ip,
remote_addr.sin_family = AF_INET; remote_addr.sin_family = AF_INET;
remote_addr.sin_port = htons( port ); remote_addr.sin_port = htons( port );
if ( !inet_aton( ip, &remote_addr.sin_addr ) ) { if ( !inet_aton( ip, &remote_addr.sin_addr ) ) {
fprintf( stderr, "Bad IP address %s\n", ip ); int saved_errno = errno;
exit( 1 ); char buffer[ 2048 ];
snprintf( buffer, 2048, "Bad IP address (%s)", ip );
throw NetworkException( buffer, saved_errno );
} }
if ( connect( sock, (sockaddr *)&remote_addr, sizeof( remote_addr ) ) < 0 ) { if ( connect( sock, (sockaddr *)&remote_addr, sizeof( remote_addr ) ) < 0 ) {
perror( "connect" ); throw NetworkException( "connect", errno );
exit( 1 );
} }
attached = true; attached = true;
} }
template <class Outgoing, class Incoming> void Connection::update_MTU( void )
void Connection<Outgoing, Incoming>::update_MTU( void )
{ {
socklen_t optlen = sizeof( MTU ); socklen_t optlen = sizeof( MTU );
if ( getsockopt( sock, IPPROTO_IP, IP_MTU, &MTU, &optlen ) < 0 ) { if ( getsockopt( sock, IPPROTO_IP, IP_MTU, &MTU, &optlen ) < 0 ) {
perror( "getsockopt" ); throw NetworkException( "getsockopt", errno );
exit( 1 );
} }
if ( optlen != sizeof( MTU ) ) { if ( optlen != sizeof( MTU ) ) {
fprintf( stderr, "Error getting path MTU.\n" ); throw NetworkException( "Error getting path MTU", errno );
exit( 1 );
} }
fprintf( stderr, "Path MTU: %d\n", MTU ); fprintf( stderr, "Path MTU: %d\n", MTU );
} }
template <class Outgoing, class Incoming> void Connection::send( string &s )
bool Connection<Outgoing, Incoming>::send( Outgoing &s )
{ {
assert( attached ); assert( attached );
string p = flow.new_packet( s ).tostring( &session ); string p = new_packet( s ).tostring( &session );
ssize_t bytes_sent = sendto( sock, p.data(), p.size(), 0, ssize_t bytes_sent = sendto( sock, p.data(), p.size(), 0,
(sockaddr *)&remote_addr, sizeof( remote_addr ) ); (sockaddr *)&remote_addr, sizeof( remote_addr ) );
if ( (bytes_sent < 0) && (errno == EMSGSIZE) ) { if ( (bytes_sent < 0) && (errno == EMSGSIZE) ) {
update_MTU(); update_MTU();
return false; throw MTUException( MTU );
} else if ( bytes_sent == static_cast<int>( p.size() ) ) { } else if ( bytes_sent == static_cast<int>( p.size() ) ) {
return true; return;
} else { } else {
perror( "sendto" ); throw NetworkException( "sendto", errno );
exit( 1 );
return false;
} }
} }
template <class Outgoing, class Incoming> string Connection::recv( void )
Incoming Connection<Outgoing, Incoming>::recv( void )
{ {
struct sockaddr_in packet_remote_addr; struct sockaddr_in packet_remote_addr;
@@ -170,17 +152,17 @@ Incoming Connection<Outgoing, Incoming>::recv( void )
ssize_t received_len = recvfrom( sock, buf, RECEIVE_MTU, 0, (sockaddr *)&packet_remote_addr, &addrlen ); ssize_t received_len = recvfrom( sock, buf, RECEIVE_MTU, 0, (sockaddr *)&packet_remote_addr, &addrlen );
if ( received_len < 0 ) { if ( received_len < 0 ) {
perror( "recvfrom" ); throw NetworkException( "recvfrom", errno );
exit( 1 );
} }
if ( received_len > RECEIVE_MTU ) { if ( received_len > RECEIVE_MTU ) {
fprintf( stderr, "Received oversize datagram (size %d) and limit is %d.\n", char buffer[ 2048 ];
snprintf( buffer, 2048, "Received oversize datagram (size %d) and limit is %d\n",
static_cast<int>( received_len ), RECEIVE_MTU ); static_cast<int>( received_len ), RECEIVE_MTU );
exit( 1 ); throw NetworkException( buffer, errno );
} }
typename Flow<Incoming>::Packet p( string( buf, received_len ), &session ); Packet p( string( buf, received_len ), &session );
dos_assert( p.direction == (server ? TO_SERVER : TO_CLIENT) ); /* prevent malicious playback to sender */ dos_assert( p.direction == (server ? TO_SERVER : TO_CLIENT) ); /* prevent malicious playback to sender */
/* server auto-adjusts to client */ /* server auto-adjusts to client */
@@ -199,15 +181,13 @@ Incoming Connection<Outgoing, Incoming>::recv( void )
return p.payload; return p.payload;
} }
template <class Outgoing, class Incoming> int Connection::port( void )
int Connection<Outgoing, Incoming>::port( void )
{ {
struct sockaddr_in local_addr; struct sockaddr_in local_addr;
socklen_t addrlen = sizeof( local_addr ); socklen_t addrlen = sizeof( local_addr );
if ( getsockname( sock, (sockaddr *)&local_addr, &addrlen ) < 0 ) { if ( getsockname( sock, (sockaddr *)&local_addr, &addrlen ) < 0 ) {
perror( "getsockname" ); throw NetworkException( "getsockname", errno );
exit( 1 );
} }
return ntohs( local_addr.sin_port ); return ntohs( local_addr.sin_port );
+23 -36
View File
@@ -10,38 +10,35 @@
#include "crypto.hpp" #include "crypto.hpp"
using namespace std; using namespace std;
using namespace Crypto;
namespace Network { namespace Network {
class MTUException {
public:
int MTU;
MTUException( int s_MTU ) : MTU( s_MTU ) {};
};
class NetworkException {
public:
string function;
int the_errno;
NetworkException( string s_function, int s_errno ) : function( s_function ), the_errno( s_errno ) {}
};
enum Direction { enum Direction {
TO_SERVER = 0, TO_SERVER = 0,
TO_CLIENT = 1 TO_CLIENT = 1
}; };
template <class Payload>
class Flow {
public:
class Packet { class Packet {
private:
class DecodingCache
{
public:
Direction direction;
uint64_t seq;
string payload_string;
DecodingCache( string coded_packet, Session *session );
DecodingCache() : direction( TO_CLIENT ), seq( -1 ), payload_string() {}
};
DecodingCache decoding_cache;
public: public:
uint64_t seq; uint64_t seq;
Direction direction; Direction direction;
Payload payload; string payload;
Packet( uint64_t s_seq, Direction s_direction, Payload s_payload ) Packet( uint64_t s_seq, Direction s_direction, string s_payload )
: decoding_cache(), seq( s_seq ), direction( s_direction ), payload( s_payload ) : seq( s_seq ), direction( s_direction ), payload( s_payload )
{} {}
Packet( string coded_packet, Session *session ); Packet( string coded_packet, Session *session );
@@ -49,18 +46,6 @@ namespace Network {
string tostring( Session *session ); string tostring( Session *session );
}; };
uint64_t next_seq;
Direction direction;
Session *session;
Flow( Direction s_direction, Session *s_session )
: next_seq( 0 ), direction( s_direction ), session( s_session )
{}
Packet new_packet( Payload &s_payload );
};
template <class Outgoing, class Incoming>
class Connection { class Connection {
private: private:
static const int RECEIVE_MTU = 2048; static const int RECEIVE_MTU = 2048;
@@ -76,18 +61,20 @@ namespace Network {
Base64Key key; Base64Key key;
Session session; Session session;
Flow<Outgoing> flow;
void update_MTU( void ); void update_MTU( void );
void setup( void ); void setup( void );
Direction direction;
uint64_t next_seq;
Packet new_packet( string &s_payload );
public: public:
Connection(); Connection();
Connection( const char *key_str, const char *ip, int port ); Connection( const char *key_str, const char *ip, int port );
bool send( Outgoing &s ); void send( string &s );
Incoming recv( void ); string recv( void );
int fd( void ) { return sock; } int fd( void ) { return sock; }
int port( void ); int port( void );
int get_MTU( void ) { return MTU; } int get_MTU( void ) { return MTU; }
+27
View File
@@ -0,0 +1,27 @@
#ifndef NETWORK_TRANSPORT_HPP
#define NETWORK_TRANSPORT_HPP
#include <google/dense_hash_map>
using google::dense_hash_map;
namespace Network {
template <class MyState, class RemoteState>
class Transport
{
private:
Connection<typename MyState::Conveyance, typename RemoteState::Conveyance> connection;
uint64_t last_acknowledged_state;
uint64_t assumed_receiver_state;
uint64_t last_sent_state;
public:
Transport();
Transport( const char *key_str, const char *ip, int port );
};
};
#endif
+21 -8
View File
@@ -11,7 +11,7 @@ int main( int argc, char *argv[] )
char *ip; char *ip;
int port; int port;
Network::Connection<KeyStroke, KeyStroke> *n; Network::Connection *n;
try { try {
if ( argc > 1 ) { if ( argc > 1 ) {
@@ -22,9 +22,9 @@ int main( int argc, char *argv[] )
ip = argv[ 2 ]; ip = argv[ 2 ];
port = atoi( argv[ 3 ] ); port = atoi( argv[ 3 ] );
n = new Network::Connection<KeyStroke, KeyStroke>( key, ip, port ); n = new Network::Connection( key, ip, port );
} else { } else {
n = new Network::Connection<KeyStroke, KeyStroke>(); n = new Network::Connection();
} }
} catch ( CryptoException e ) { } catch ( CryptoException e ) {
fprintf( stderr, "Fatal error: %s\n", e.text.c_str() ); fprintf( stderr, "Fatal error: %s\n", e.text.c_str() );
@@ -36,14 +36,15 @@ int main( int argc, char *argv[] )
if ( server ) { if ( server ) {
while ( true ) { while ( true ) {
try { try {
KeyStroke s = n->recv(); string s = n->recv();
printf( "%c", s.letter ); printf( "%s", s.c_str() );
fflush( NULL ); fflush( NULL );
} catch ( CryptoException e ) { } catch ( CryptoException e ) {
fprintf( stderr, "Error: %s\n", e.text.c_str() ); fprintf( stderr, "Cryptographic error: %s\n", e.text.c_str() );
} }
} }
} else { } else {
struct termios saved_termios;
struct termios the_termios; struct termios the_termios;
if ( tcgetattr( STDIN_FILENO, &the_termios ) < 0 ) { if ( tcgetattr( STDIN_FILENO, &the_termios ) < 0 ) {
@@ -51,6 +52,8 @@ int main( int argc, char *argv[] )
exit( 1 ); exit( 1 );
} }
saved_termios = the_termios;
cfmakeraw( &the_termios ); cfmakeraw( &the_termios );
if ( tcsetattr( STDIN_FILENO, TCSANOW, &the_termios ) < 0 ) { if ( tcsetattr( STDIN_FILENO, TCSANOW, &the_termios ) < 0 ) {
@@ -61,9 +64,19 @@ int main( int argc, char *argv[] )
while( true ) { while( true ) {
char x = getchar(); char x = getchar();
KeyStroke t( string( &x, 1 ) ); string prefix = "Key(" + string( &x, 1 ) + ")";
n->send( t ); try {
n->send( prefix );
} catch ( Network::NetworkException e ) {
fprintf( stderr, "%s: %s\r\n", e.function.c_str(), strerror( e.the_errno ) );
break;
}
}
if ( tcsetattr( STDIN_FILENO, TCSANOW, &saved_termios ) < 0 ) {
perror( "tcsetattr" );
exit( 1 );
} }
} }
} }
+1 -4
View File
@@ -6,9 +6,6 @@
#include "terminal.hpp" #include "terminal.hpp"
#include "network.cpp"
#include "keystroke.hpp"
namespace Parser { namespace Parser {
class Action; class Action;
} }
@@ -24,4 +21,4 @@ template class vector<wchar_t>;
template class vector<int>; template class vector<int>;
template class map<string, Function>; template class map<string, Function>;
template class vector<bool>; template class vector<bool>;
template class Network::Connection<KeyStroke, KeyStroke>;