Add crypto to existing network class

This commit is contained in:
Keith Winstein
2011-08-04 04:52:47 -04:00
parent 215c75c6ea
commit 7824318c54
4 changed files with 96 additions and 72 deletions
+5 -2
View File
@@ -4,11 +4,14 @@
#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
#include "crypto.hpp"
static void dos_detected( const char *expression, const char *file, int line, const char *function ) static void dos_detected( const char *expression, const char *file, int line, const char *function )
{ {
fprintf( stderr, "Illegal counterparty input (possible denial of service) in function %s at %s:%d, failed test: %s\n", char buffer[ 2048 ];
snprintf( buffer, 2048, "Illegal counterparty input (possible denial of service) in function %s at %s:%d, failed test: %s\n",
function, file, line, expression ); function, file, line, expression );
exit( 1 ); throw CryptoException( buffer );
} }
#define dos_assert(expr) \ #define dos_assert(expr) \
+50 -43
View File
@@ -6,30 +6,25 @@
#include "dos_assert.hpp" #include "dos_assert.hpp"
#include "network.hpp" #include "network.hpp"
#include "crypto.hpp"
using namespace std; using namespace std;
using namespace Network; using namespace Network;
template <class Payload> template <class Payload>
Flow<Payload>::Packet::DecodingCache::DecodingCache( string coded_packet ) Flow<Payload>::Packet::DecodingCache::DecodingCache( string coded_packet, Session *session )
: direction( TO_CLIENT ), seq( -1 ), payload_string() : direction( TO_CLIENT ), seq( -1 ), payload_string()
{ {
dos_assert( coded_packet.size() >= 8 ); Message message = session->decrypt( coded_packet );
/* Read in sequence number and direction */ direction = (message.nonce.val() & 8000000000000000) ? TO_CLIENT : TO_SERVER;
string seq_string( coded_packet.begin(), coded_packet.begin() + 8 ); seq = message.nonce.val() & 0x7FFFFFFFFFFFFFFF;
uint64_t *network_order_seq = (uint64_t *)seq_string.data(); payload_string = message.text;
uint64_t direction_seq = be64toh( *network_order_seq );
direction = (direction_seq & 8000000000000000) ? TO_CLIENT : TO_SERVER;
seq = direction_seq & 0x7FFFFFFFFFFFFFFF;
/* Read in payload */
payload_string = string( coded_packet.begin() + 8, coded_packet.end() );
} }
template <class Payload> template <class Payload>
Flow<Payload>::Packet::Packet( string coded_packet ) Flow<Payload>::Packet::Packet( string coded_packet, Session *session )
: decoding_cache( coded_packet ), : decoding_cache( coded_packet, session ),
seq( decoding_cache.seq ), seq( decoding_cache.seq ),
direction( decoding_cache.direction ), direction( decoding_cache.direction ),
payload( decoding_cache.payload_string ) payload( decoding_cache.payload_string )
@@ -38,14 +33,11 @@ Flow<Payload>::Packet::Packet( string coded_packet )
} }
template <class Payload> template <class Payload>
string Flow<Payload>::Packet::tostring( void ) string Flow<Payload>::Packet::tostring( Session *session )
{ {
uint64_t direction_seq = (uint64_t( direction == TO_CLIENT ) << 63) | (seq & 0x7FFFFFFFFFFFFFFF); uint64_t direction_seq = (uint64_t( direction == TO_CLIENT ) << 63) | (seq & 0x7FFFFFFFFFFFFFFF);
uint64_t network_order_seq = htobe64( direction_seq );
const char *seq_str = (const char *)&network_order_seq;
string seq_string( seq_str, 8 ); /* necessary in case there is a zero byte */
return seq_string + payload.tostring(); return session->encrypt( Message( direction_seq, payload.tostring() ) );
} }
template <class Payload> template <class Payload>
@@ -55,15 +47,8 @@ typename Flow<Payload>::Packet Flow<Payload>::new_packet( Payload &s_payload )
} }
template <class Outgoing, class Incoming> template <class Outgoing, class Incoming>
Connection<Outgoing, Incoming>::Connection( bool s_server ) void Connection<Outgoing, Incoming>::setup( void )
: flow( s_server ? TO_CLIENT : TO_SERVER ),
sock( -1 ),
remote_addr(),
server( s_server ),
attached( false ),
MTU( RECEIVE_MTU )
{ {
/* create socket */ /* create socket */
sock = socket( AF_INET, SOCK_DGRAM, 0 ); sock = socket( AF_INET, SOCK_DGRAM, 0 );
if ( sock < 0 ) { if ( sock < 0 ) {
@@ -92,26 +77,31 @@ Connection<Outgoing, Incoming>::Connection( bool s_server )
} }
template <class Outgoing, class Incoming> template <class Outgoing, class Incoming>
void Connection<Outgoing, Incoming>::update_MTU( void ) Connection<Outgoing, Incoming>::Connection() /* server */
: sock( -1 ),
remote_addr(),
server( true ),
attached( false ),
MTU( RECEIVE_MTU ),
key(),
session( key ),
flow( TO_CLIENT, &session )
{ {
socklen_t optlen = sizeof( MTU ); setup();
if ( getsockopt( sock, IPPROTO_IP, IP_MTU, &MTU, &optlen ) < 0 ) {
perror( "getsockopt" );
exit( 1 );
}
if ( optlen != sizeof( MTU ) ) {
fprintf( stderr, "Error getting path MTU.\n" );
exit( 1 );
}
fprintf( stderr, "Path MTU: %d\n", MTU );
} }
template <class Outgoing, class Incoming> template <class Outgoing, class Incoming>
void Connection<Outgoing, Incoming>::client_connect( const char *ip, int port ) Connection<Outgoing, Incoming>::Connection( const char *key_str, const char *ip, int port ) /* client */
: sock( -1 ),
remote_addr(),
server( false ),
attached( false ),
MTU( RECEIVE_MTU ),
key( key_str ),
session( key ),
flow( TO_SERVER, &session )
{ {
assert( !server ); setup();
/* associate socket with remote host and port */ /* associate socket with remote host and port */
remote_addr.sin_family = AF_INET; remote_addr.sin_family = AF_INET;
@@ -129,12 +119,29 @@ void Connection<Outgoing, Incoming>::client_connect( const char *ip, int port )
attached = true; attached = true;
} }
template <class Outgoing, class Incoming>
void Connection<Outgoing, Incoming>::update_MTU( void )
{
socklen_t optlen = sizeof( MTU );
if ( getsockopt( sock, IPPROTO_IP, IP_MTU, &MTU, &optlen ) < 0 ) {
perror( "getsockopt" );
exit( 1 );
}
if ( optlen != sizeof( MTU ) ) {
fprintf( stderr, "Error getting path MTU.\n" );
exit( 1 );
}
fprintf( stderr, "Path MTU: %d\n", MTU );
}
template <class Outgoing, class Incoming> template <class Outgoing, class Incoming>
bool Connection<Outgoing, Incoming>::send( Outgoing &s ) bool Connection<Outgoing, Incoming>::send( Outgoing &s )
{ {
assert( attached ); assert( attached );
string p = flow.new_packet( s ).tostring(); string p = flow.new_packet( s ).tostring( &session );
ssize_t bytes_sent = sendto( sock, p.data(), p.size(), 0, ssize_t bytes_sent = sendto( sock, p.data(), p.size(), 0,
(sockaddr *)&remote_addr, sizeof( remote_addr ) ); (sockaddr *)&remote_addr, sizeof( remote_addr ) );
@@ -173,7 +180,7 @@ Incoming Connection<Outgoing, Incoming>::recv( void )
exit( 1 ); exit( 1 );
} }
typename Flow<Incoming>::Packet p( string( buf, received_len ) ); typename Flow<Incoming>::Packet p( string( buf, received_len ), &session );
dos_assert( p.direction == (server ? TO_SERVER : TO_CLIENT) ); /* prevent malicious playback to sender */ dos_assert( p.direction == (server ? TO_SERVER : TO_CLIENT) ); /* prevent malicious playback to sender */
/* server auto-adjusts to client */ /* server auto-adjusts to client */
+21 -11
View File
@@ -7,6 +7,8 @@
#include <netinet/in.h> #include <netinet/in.h>
#include <string> #include <string>
#include "crypto.hpp"
using namespace std; using namespace std;
namespace Network { namespace Network {
@@ -27,7 +29,7 @@ namespace Network {
uint64_t seq; uint64_t seq;
string payload_string; string payload_string;
DecodingCache( string coded_packet ); DecodingCache( string coded_packet, Session *session );
DecodingCache() : direction( TO_CLIENT ), seq( -1 ), payload_string() {} DecodingCache() : direction( TO_CLIENT ), seq( -1 ), payload_string() {}
}; };
@@ -42,16 +44,17 @@ namespace Network {
: decoding_cache(), seq( s_seq ), direction( s_direction ), payload( s_payload ) : decoding_cache(), seq( s_seq ), direction( s_direction ), payload( s_payload )
{} {}
Packet( string coded_packet ); Packet( string coded_packet, Session *session );
string tostring( void ); string tostring( Session *session );
}; };
uint64_t next_seq; uint64_t next_seq;
const Direction direction; Direction direction;
Session *session;
Flow( Direction s_direction ) Flow( Direction s_direction, Session *s_session )
: next_seq( 0 ), direction( s_direction ) : next_seq( 0 ), direction( s_direction ), session( s_session )
{} {}
Packet new_packet( Payload &s_payload ); Packet new_packet( Payload &s_payload );
@@ -62,8 +65,6 @@ namespace Network {
private: private:
static const int RECEIVE_MTU = 2048; static const int RECEIVE_MTU = 2048;
Flow<Outgoing> flow;
int sock; int sock;
struct sockaddr_in remote_addr; struct sockaddr_in remote_addr;
@@ -72,16 +73,25 @@ namespace Network {
int MTU; int MTU;
Base64Key key;
Session session;
Flow<Outgoing> flow;
void update_MTU( void ); void update_MTU( void );
public: void setup( void );
Connection( bool s_server );
public:
Connection();
Connection( const char *key_str, const char *ip, int port );
void client_connect( const char *ip, int port );
bool send( Outgoing &s ); bool send( Outgoing &s );
Incoming recv( void ); Incoming recv( void );
int fd( void ) { return sock; } int fd( void ) { return sock; }
int port( void ); int port( void );
int get_MTU( void ) { return MTU; }
string get_key( void ) { return key.printable_key(); }
}; };
} }
+14 -10
View File
@@ -4,26 +4,30 @@
int main( int argc, char *argv[] ) int main( int argc, char *argv[] )
{ {
bool server = true; bool server = true;
char *key;
char *ip; char *ip;
int port; int port;
Network::Connection<KeyStroke, KeyStroke> *n;
if ( argc > 1 ) { if ( argc > 1 ) {
server = false; server = false;
/* client */
ip = argv[ 1 ]; key = argv[ 1 ];
port = atoi( argv[ 2 ] ); ip = argv[ 2 ];
port = atoi( argv[ 3 ] );
n = new Network::Connection<KeyStroke, KeyStroke>( key, ip, port );
} else {
n = new Network::Connection<KeyStroke, KeyStroke>();
} }
Network::Connection<KeyStroke, KeyStroke> n( server ); fprintf( stderr, "Port bound is %d, key is %s\n", n->port(), n->get_key().c_str() );
fprintf( stderr, "Port bound is %d\n", n.port() );
if ( !server ) {
n.client_connect( ip, port );
}
if ( server ) { if ( server ) {
while ( true ) { while ( true ) {
KeyStroke s = n.recv(); KeyStroke s = n->recv();
fprintf( stderr, "Got KeyStroke: %c\n", s.letter ); fprintf( stderr, "Got KeyStroke: %c\n", s.letter );
} }
@@ -33,7 +37,7 @@ int main( int argc, char *argv[] )
KeyStroke t( string( "x", 1 ) ); KeyStroke t( string( "x", 1 ) );
n.send( t ); n->send( t );
} }
} }
} }