Simplify network.cpp to transmit only strings.

This commit is contained in:
Keith Winstein
2011-08-05 19:44:34 -04:00
parent bffc099754
commit 1b3443befd
10 changed files with 211 additions and 197 deletions
+18 -2
View File
@@ -5,9 +5,25 @@
#include "base64.h"
using namespace std;
using namespace Crypto;
const char rdev[] = "/dev/urandom";
long int myatoi( char *str )
{
char *end;
errno = 0;
long int ret = strtol( str, &end, 10 );
if ( ( errno != 0 )
|| ( end != str + strlen( str ) ) ) {
throw CryptoException( "Bad integer." );
}
return ret;
}
static void * sse_alloc( int len )
{
void *ptr = NULL;
@@ -178,7 +194,7 @@ Message Session::decrypt( string ciphertext )
int body_len = ciphertext.size() - 8;
int pt_len = body_len - 16;
if ( pt_len <= 0 ) { /* super-assertion that does not equal AE_INVALID */
if ( pt_len < 0 ) { /* super-assertion that pt_len does not equal AE_INVALID */
fprintf( stderr, "BUG.\n" );
exit( 1 );
}
@@ -200,7 +216,7 @@ Message Session::decrypt( string ciphertext )
AE_FINALIZE ) ) { /* final */
free( plaintext );
free( body );
throw CryptoException( "ae_decrypt() returned error." );
throw CryptoException( "Packet failed integrity check." );
}
Message ret( nonce, string( plaintext, pt_len ) );
+49 -44
View File
@@ -3,63 +3,68 @@
#include "ae.hpp"
#include <string>
#include <string.h>
using namespace std;
class CryptoException {
public:
string text;
CryptoException( string s_text ) : text( s_text ) {};
};
long int myatoi( char *str );
class Base64Key {
private:
unsigned char key[ 16 ];
namespace Crypto {
class CryptoException {
public:
string text;
CryptoException( string s_text ) : text( s_text ) {};
};
public:
Base64Key(); /* random key */
Base64Key( string printable_key );
string printable_key( void );
unsigned char *data( void ) { return key; }
};
class Base64Key {
private:
unsigned char key[ 16 ];
class Nonce {
private:
char bytes[ 12 ];
public:
Base64Key(); /* random key */
Base64Key( string printable_key );
string printable_key( void );
unsigned char *data( void ) { return key; }
};
public:
Nonce( uint64_t val );
Nonce( char *s_bytes, size_t len );
class Nonce {
private:
char bytes[ 12 ];
string cpp_str( void ) { return string( (char *)( bytes + 4 ), 8 ); }
char *data( void ) { return bytes; }
uint64_t val( void );
};
public:
Nonce( uint64_t val );
Nonce( char *s_bytes, size_t len );
class Message {
public:
Nonce nonce;
string text;
string cpp_str( void ) { return string( (char *)( bytes + 4 ), 8 ); }
char *data( void ) { return bytes; }
uint64_t val( void );
};
Message( char *nonce_bytes, size_t nonce_len,
char *text_bytes, size_t text_len );
Message( Nonce s_nonce, string s_text );
};
class Message {
public:
Nonce nonce;
string text;
class Session {
private:
Base64Key key;
ae_ctx *ctx;
Message( char *nonce_bytes, size_t nonce_len,
char *text_bytes, size_t text_len );
Message( Nonce s_nonce, string s_text );
};
public:
Session( Base64Key s_key );
~Session();
class Session {
private:
Base64Key key;
ae_ctx *ctx;
string encrypt( Message plaintext );
Message decrypt( string ciphertext );
public:
Session( Base64Key s_key );
~Session();
Session( const Session & );
Session & operator=( const Session & );
};
string encrypt( Message plaintext );
Message decrypt( string ciphertext );
Session( const Session & );
Session & operator=( const Session & );
};
}
#endif
+2
View File
@@ -8,6 +8,8 @@
#include "crypto.hpp"
using namespace Crypto;
int main( int argc, char *argv[] )
{
if ( argc != 2 ) {
+1 -1
View File
@@ -11,7 +11,7 @@ static void dos_detected( const char *expression, const char *file, int line, co
char buffer[ 2048 ];
snprintf( buffer, 2048, "Illegal counterparty input (possible denial of service) in function %s at %s:%d, failed test: %s\n",
function, file, line, expression );
throw CryptoException( buffer );
throw Crypto::CryptoException( buffer );
}
#define dos_assert(expr) \
+1 -14
View File
@@ -8,20 +8,7 @@
#include "crypto.hpp"
long int myatoi( char *str )
{
char *end;
errno = 0;
long int ret = strtol( str, &end, 10 );
if ( ( errno != 0 )
|| ( end != str + strlen( str ) ) ) {
throw CryptoException( "Bad integer." );
}
return ret;
}
using namespace Crypto;
int main( int argc, char *argv[] )
{
+43 -63
View File
@@ -10,50 +10,40 @@
using namespace std;
using namespace Network;
using namespace Crypto;
template <class Payload>
Flow<Payload>::Packet::DecodingCache::DecodingCache( string coded_packet, Session *session )
: direction( TO_CLIENT ), seq( -1 ), payload_string()
/* Read in packet from coded string */
Packet::Packet( string coded_packet, Session *session )
: seq( -1 ),
direction( TO_SERVER ),
payload()
{
Message message = session->decrypt( coded_packet );
direction = (message.nonce.val() & 8000000000000000) ? TO_CLIENT : TO_SERVER;
seq = message.nonce.val() & 0x7FFFFFFFFFFFFFFF;
payload_string = message.text;
payload = message.text;
}
template <class Payload>
Flow<Payload>::Packet::Packet( string coded_packet, Session *session )
: decoding_cache( coded_packet, session ),
seq( decoding_cache.seq ),
direction( decoding_cache.direction ),
payload( decoding_cache.payload_string )
{
decoding_cache = DecodingCache();
}
template <class Payload>
string Flow<Payload>::Packet::tostring( Session *session )
/* Output coded string from packet */
string Packet::tostring( Session *session )
{
uint64_t direction_seq = (uint64_t( direction == TO_CLIENT ) << 63) | (seq & 0x7FFFFFFFFFFFFFFF);
return session->encrypt( Message( direction_seq, payload.tostring() ) );
return session->encrypt( Message( direction_seq, payload ) );
}
template <class Payload>
typename Flow<Payload>::Packet Flow<Payload>::new_packet( Payload &s_payload )
Packet Connection::new_packet( string &s_payload )
{
return Packet( next_seq++, direction, s_payload );
}
template <class Outgoing, class Incoming>
void Connection<Outgoing, Incoming>::setup( void )
void Connection::setup( void )
{
/* create socket */
sock = socket( AF_INET, SOCK_DGRAM, 0 );
if ( sock < 0 ) {
perror( "socket" );
exit( 1 );
throw NetworkException( "socket", errno );
}
/* Bind to free local port.
@@ -63,21 +53,18 @@ void Connection<Outgoing, Incoming>::setup( void )
local_addr.sin_port = htons( 0 );
local_addr.sin_addr.s_addr = INADDR_ANY;
if ( bind( sock, (sockaddr *)&local_addr, sizeof( local_addr ) ) < 0 ) {
perror( "bind" );
exit( 1 );
throw NetworkException( "bind", errno );
}
/* Enable path MTU discovery */
char flag = IP_PMTUDISC_DO;
socklen_t optlen = sizeof( flag );
if ( setsockopt( sock, IPPROTO_IP, IP_MTU_DISCOVER, &flag, optlen ) < 0 ) {
perror( "setsockopt" );
exit( 1 );
throw NetworkException( "setsockopt", errno );
}
}
template <class Outgoing, class Incoming>
Connection<Outgoing, Incoming>::Connection() /* server */
Connection::Connection() /* server */
: sock( -1 ),
remote_addr(),
server( true ),
@@ -85,13 +72,13 @@ Connection<Outgoing, Incoming>::Connection() /* server */
MTU( RECEIVE_MTU ),
key(),
session( key ),
flow( TO_CLIENT, &session )
direction( TO_CLIENT ),
next_seq( 0 )
{
setup();
}
template <class Outgoing, class Incoming>
Connection<Outgoing, Incoming>::Connection( const char *key_str, const char *ip, int port ) /* client */
Connection::Connection( const char *key_str, const char *ip, int port ) /* client */
: sock( -1 ),
remote_addr(),
server( false ),
@@ -99,7 +86,8 @@ Connection<Outgoing, Incoming>::Connection( const char *key_str, const char *ip,
MTU( RECEIVE_MTU ),
key( key_str ),
session( key ),
flow( TO_SERVER, &session )
direction( TO_SERVER ),
next_seq( 0 )
{
setup();
@@ -107,59 +95,53 @@ Connection<Outgoing, Incoming>::Connection( const char *key_str, const char *ip,
remote_addr.sin_family = AF_INET;
remote_addr.sin_port = htons( port );
if ( !inet_aton( ip, &remote_addr.sin_addr ) ) {
fprintf( stderr, "Bad IP address %s\n", ip );
exit( 1 );
int saved_errno = errno;
char buffer[ 2048 ];
snprintf( buffer, 2048, "Bad IP address (%s)", ip );
throw NetworkException( buffer, saved_errno );
}
if ( connect( sock, (sockaddr *)&remote_addr, sizeof( remote_addr ) ) < 0 ) {
perror( "connect" );
exit( 1 );
throw NetworkException( "connect", errno );
}
attached = true;
}
template <class Outgoing, class Incoming>
void Connection<Outgoing, Incoming>::update_MTU( void )
void Connection::update_MTU( void )
{
socklen_t optlen = sizeof( MTU );
if ( getsockopt( sock, IPPROTO_IP, IP_MTU, &MTU, &optlen ) < 0 ) {
perror( "getsockopt" );
exit( 1 );
throw NetworkException( "getsockopt", errno );
}
if ( optlen != sizeof( MTU ) ) {
fprintf( stderr, "Error getting path MTU.\n" );
exit( 1 );
throw NetworkException( "Error getting path MTU", errno );
}
fprintf( stderr, "Path MTU: %d\n", MTU );
}
template <class Outgoing, class Incoming>
bool Connection<Outgoing, Incoming>::send( Outgoing &s )
void Connection::send( string &s )
{
assert( attached );
string p = flow.new_packet( s ).tostring( &session );
string p = new_packet( s ).tostring( &session );
ssize_t bytes_sent = sendto( sock, p.data(), p.size(), 0,
(sockaddr *)&remote_addr, sizeof( remote_addr ) );
if ( (bytes_sent < 0) && (errno == EMSGSIZE) ) {
update_MTU();
return false;
throw MTUException( MTU );
} else if ( bytes_sent == static_cast<int>( p.size() ) ) {
return true;
return;
} else {
perror( "sendto" );
exit( 1 );
return false;
throw NetworkException( "sendto", errno );
}
}
template <class Outgoing, class Incoming>
Incoming Connection<Outgoing, Incoming>::recv( void )
string Connection::recv( void )
{
struct sockaddr_in packet_remote_addr;
@@ -170,17 +152,17 @@ Incoming Connection<Outgoing, Incoming>::recv( void )
ssize_t received_len = recvfrom( sock, buf, RECEIVE_MTU, 0, (sockaddr *)&packet_remote_addr, &addrlen );
if ( received_len < 0 ) {
perror( "recvfrom" );
exit( 1 );
throw NetworkException( "recvfrom", errno );
}
if ( received_len > RECEIVE_MTU ) {
fprintf( stderr, "Received oversize datagram (size %d) and limit is %d.\n",
static_cast<int>( received_len ), RECEIVE_MTU );
exit( 1 );
char buffer[ 2048 ];
snprintf( buffer, 2048, "Received oversize datagram (size %d) and limit is %d\n",
static_cast<int>( received_len ), RECEIVE_MTU );
throw NetworkException( buffer, errno );
}
typename Flow<Incoming>::Packet p( string( buf, received_len ), &session );
Packet p( string( buf, received_len ), &session );
dos_assert( p.direction == (server ? TO_SERVER : TO_CLIENT) ); /* prevent malicious playback to sender */
/* server auto-adjusts to client */
@@ -199,15 +181,13 @@ Incoming Connection<Outgoing, Incoming>::recv( void )
return p.payload;
}
template <class Outgoing, class Incoming>
int Connection<Outgoing, Incoming>::port( void )
int Connection::port( void )
{
struct sockaddr_in local_addr;
socklen_t addrlen = sizeof( local_addr );
if ( getsockname( sock, (sockaddr *)&local_addr, &addrlen ) < 0 ) {
perror( "getsockname" );
exit( 1 );
throw NetworkException( "getsockname", errno );
}
return ntohs( local_addr.sin_port );
+28 -41
View File
@@ -10,57 +10,42 @@
#include "crypto.hpp"
using namespace std;
using namespace Crypto;
namespace Network {
class MTUException {
public:
int MTU;
MTUException( int s_MTU ) : MTU( s_MTU ) {};
};
class NetworkException {
public:
string function;
int the_errno;
NetworkException( string s_function, int s_errno ) : function( s_function ), the_errno( s_errno ) {}
};
enum Direction {
TO_SERVER = 0,
TO_CLIENT = 1
};
template <class Payload>
class Flow {
class Packet {
public:
class Packet {
private:
class DecodingCache
{
public:
Direction direction;
uint64_t seq;
string payload_string;
DecodingCache( string coded_packet, Session *session );
DecodingCache() : direction( TO_CLIENT ), seq( -1 ), payload_string() {}
};
DecodingCache decoding_cache;
public:
uint64_t seq;
Direction direction;
Payload payload;
Packet( uint64_t s_seq, Direction s_direction, Payload s_payload )
: decoding_cache(), seq( s_seq ), direction( s_direction ), payload( s_payload )
{}
Packet( string coded_packet, Session *session );
string tostring( Session *session );
};
uint64_t next_seq;
uint64_t seq;
Direction direction;
Session *session;
string payload;
Flow( Direction s_direction, Session *s_session )
: next_seq( 0 ), direction( s_direction ), session( s_session )
Packet( uint64_t s_seq, Direction s_direction, string s_payload )
: seq( s_seq ), direction( s_direction ), payload( s_payload )
{}
Packet new_packet( Payload &s_payload );
Packet( string coded_packet, Session *session );
string tostring( Session *session );
};
template <class Outgoing, class Incoming>
class Connection {
private:
static const int RECEIVE_MTU = 2048;
@@ -76,18 +61,20 @@ namespace Network {
Base64Key key;
Session session;
Flow<Outgoing> flow;
void update_MTU( void );
void setup( void );
Direction direction;
uint64_t next_seq;
Packet new_packet( string &s_payload );
public:
Connection();
Connection( const char *key_str, const char *ip, int port );
bool send( Outgoing &s );
Incoming recv( void );
void send( string &s );
string recv( void );
int fd( void ) { return sock; }
int port( void );
int get_MTU( void ) { return MTU; }
+27
View File
@@ -0,0 +1,27 @@
#ifndef NETWORK_TRANSPORT_HPP
#define NETWORK_TRANSPORT_HPP
#include <google/dense_hash_map>
using google::dense_hash_map;
namespace Network {
template <class MyState, class RemoteState>
class Transport
{
private:
Connection<typename MyState::Conveyance, typename RemoteState::Conveyance> connection;
uint64_t last_acknowledged_state;
uint64_t assumed_receiver_state;
uint64_t last_sent_state;
public:
Transport();
Transport( const char *key_str, const char *ip, int port );
};
};
#endif
+31 -18
View File
@@ -11,7 +11,7 @@ int main( int argc, char *argv[] )
char *ip;
int port;
Network::Connection<KeyStroke, KeyStroke> *n;
Network::Connection *n;
try {
if ( argc > 1 ) {
@@ -22,9 +22,9 @@ int main( int argc, char *argv[] )
ip = argv[ 2 ];
port = atoi( argv[ 3 ] );
n = new Network::Connection<KeyStroke, KeyStroke>( key, ip, port );
n = new Network::Connection( key, ip, port );
} else {
n = new Network::Connection<KeyStroke, KeyStroke>();
n = new Network::Connection();
}
} catch ( CryptoException e ) {
fprintf( stderr, "Fatal error: %s\n", e.text.c_str() );
@@ -36,34 +36,47 @@ int main( int argc, char *argv[] )
if ( server ) {
while ( true ) {
try {
KeyStroke s = n->recv();
printf( "%c", s.letter );
string s = n->recv();
printf( "%s", s.c_str() );
fflush( NULL );
} catch ( CryptoException e ) {
fprintf( stderr, "Error: %s\n", e.text.c_str() );
fprintf( stderr, "Cryptographic error: %s\n", e.text.c_str() );
}
}
} else {
struct termios the_termios;
struct termios saved_termios;
struct termios the_termios;
if ( tcgetattr( STDIN_FILENO, &the_termios ) < 0 ) {
perror( "tcgetattr" );
exit( 1 );
}
if ( tcgetattr( STDIN_FILENO, &the_termios ) < 0 ) {
perror( "tcgetattr" );
exit( 1 );
}
cfmakeraw( &the_termios );
saved_termios = the_termios;
if ( tcsetattr( STDIN_FILENO, TCSANOW, &the_termios ) < 0 ) {
perror( "tcsetattr" );
exit( 1 );
}
cfmakeraw( &the_termios );
if ( tcsetattr( STDIN_FILENO, TCSANOW, &the_termios ) < 0 ) {
perror( "tcsetattr" );
exit( 1 );
}
while( true ) {
char x = getchar();
KeyStroke t( string( &x, 1 ) );
string prefix = "Key(" + string( &x, 1 ) + ")";
n->send( t );
try {
n->send( prefix );
} catch ( Network::NetworkException e ) {
fprintf( stderr, "%s: %s\r\n", e.function.c_str(), strerror( e.the_errno ) );
break;
}
}
if ( tcsetattr( STDIN_FILENO, TCSANOW, &saved_termios ) < 0 ) {
perror( "tcsetattr" );
exit( 1 );
}
}
}
+1 -4
View File
@@ -6,9 +6,6 @@
#include "terminal.hpp"
#include "network.cpp"
#include "keystroke.hpp"
namespace Parser {
class Action;
}
@@ -24,4 +21,4 @@ template class vector<wchar_t>;
template class vector<int>;
template class map<string, Function>;
template class vector<bool>;
template class Network::Connection<KeyStroke, KeyStroke>;