Add crypto to existing network class
This commit is contained in:
+5
-2
@@ -4,11 +4,14 @@
|
|||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
#include <stdlib.h>
|
#include <stdlib.h>
|
||||||
|
|
||||||
|
#include "crypto.hpp"
|
||||||
|
|
||||||
static void dos_detected( const char *expression, const char *file, int line, const char *function )
|
static void dos_detected( const char *expression, const char *file, int line, const char *function )
|
||||||
{
|
{
|
||||||
fprintf( stderr, "Illegal counterparty input (possible denial of service) in function %s at %s:%d, failed test: %s\n",
|
char buffer[ 2048 ];
|
||||||
|
snprintf( buffer, 2048, "Illegal counterparty input (possible denial of service) in function %s at %s:%d, failed test: %s\n",
|
||||||
function, file, line, expression );
|
function, file, line, expression );
|
||||||
exit( 1 );
|
throw CryptoException( buffer );
|
||||||
}
|
}
|
||||||
|
|
||||||
#define dos_assert(expr) \
|
#define dos_assert(expr) \
|
||||||
|
|||||||
+50
-43
@@ -6,30 +6,25 @@
|
|||||||
|
|
||||||
#include "dos_assert.hpp"
|
#include "dos_assert.hpp"
|
||||||
#include "network.hpp"
|
#include "network.hpp"
|
||||||
|
#include "crypto.hpp"
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace Network;
|
using namespace Network;
|
||||||
|
|
||||||
template <class Payload>
|
template <class Payload>
|
||||||
Flow<Payload>::Packet::DecodingCache::DecodingCache( string coded_packet )
|
Flow<Payload>::Packet::DecodingCache::DecodingCache( string coded_packet, Session *session )
|
||||||
: direction( TO_CLIENT ), seq( -1 ), payload_string()
|
: direction( TO_CLIENT ), seq( -1 ), payload_string()
|
||||||
{
|
{
|
||||||
dos_assert( coded_packet.size() >= 8 );
|
Message message = session->decrypt( coded_packet );
|
||||||
|
|
||||||
/* Read in sequence number and direction */
|
direction = (message.nonce.val() & 8000000000000000) ? TO_CLIENT : TO_SERVER;
|
||||||
string seq_string( coded_packet.begin(), coded_packet.begin() + 8 );
|
seq = message.nonce.val() & 0x7FFFFFFFFFFFFFFF;
|
||||||
uint64_t *network_order_seq = (uint64_t *)seq_string.data();
|
payload_string = message.text;
|
||||||
uint64_t direction_seq = be64toh( *network_order_seq );
|
|
||||||
direction = (direction_seq & 8000000000000000) ? TO_CLIENT : TO_SERVER;
|
|
||||||
seq = direction_seq & 0x7FFFFFFFFFFFFFFF;
|
|
||||||
|
|
||||||
/* Read in payload */
|
|
||||||
payload_string = string( coded_packet.begin() + 8, coded_packet.end() );
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class Payload>
|
template <class Payload>
|
||||||
Flow<Payload>::Packet::Packet( string coded_packet )
|
Flow<Payload>::Packet::Packet( string coded_packet, Session *session )
|
||||||
: decoding_cache( coded_packet ),
|
: decoding_cache( coded_packet, session ),
|
||||||
seq( decoding_cache.seq ),
|
seq( decoding_cache.seq ),
|
||||||
direction( decoding_cache.direction ),
|
direction( decoding_cache.direction ),
|
||||||
payload( decoding_cache.payload_string )
|
payload( decoding_cache.payload_string )
|
||||||
@@ -38,14 +33,11 @@ Flow<Payload>::Packet::Packet( string coded_packet )
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <class Payload>
|
template <class Payload>
|
||||||
string Flow<Payload>::Packet::tostring( void )
|
string Flow<Payload>::Packet::tostring( Session *session )
|
||||||
{
|
{
|
||||||
uint64_t direction_seq = (uint64_t( direction == TO_CLIENT ) << 63) | (seq & 0x7FFFFFFFFFFFFFFF);
|
uint64_t direction_seq = (uint64_t( direction == TO_CLIENT ) << 63) | (seq & 0x7FFFFFFFFFFFFFFF);
|
||||||
uint64_t network_order_seq = htobe64( direction_seq );
|
|
||||||
const char *seq_str = (const char *)&network_order_seq;
|
|
||||||
string seq_string( seq_str, 8 ); /* necessary in case there is a zero byte */
|
|
||||||
|
|
||||||
return seq_string + payload.tostring();
|
return session->encrypt( Message( direction_seq, payload.tostring() ) );
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class Payload>
|
template <class Payload>
|
||||||
@@ -55,15 +47,8 @@ typename Flow<Payload>::Packet Flow<Payload>::new_packet( Payload &s_payload )
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <class Outgoing, class Incoming>
|
template <class Outgoing, class Incoming>
|
||||||
Connection<Outgoing, Incoming>::Connection( bool s_server )
|
void Connection<Outgoing, Incoming>::setup( void )
|
||||||
: flow( s_server ? TO_CLIENT : TO_SERVER ),
|
|
||||||
sock( -1 ),
|
|
||||||
remote_addr(),
|
|
||||||
server( s_server ),
|
|
||||||
attached( false ),
|
|
||||||
MTU( RECEIVE_MTU )
|
|
||||||
{
|
{
|
||||||
|
|
||||||
/* create socket */
|
/* create socket */
|
||||||
sock = socket( AF_INET, SOCK_DGRAM, 0 );
|
sock = socket( AF_INET, SOCK_DGRAM, 0 );
|
||||||
if ( sock < 0 ) {
|
if ( sock < 0 ) {
|
||||||
@@ -92,26 +77,31 @@ Connection<Outgoing, Incoming>::Connection( bool s_server )
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <class Outgoing, class Incoming>
|
template <class Outgoing, class Incoming>
|
||||||
void Connection<Outgoing, Incoming>::update_MTU( void )
|
Connection<Outgoing, Incoming>::Connection() /* server */
|
||||||
|
: sock( -1 ),
|
||||||
|
remote_addr(),
|
||||||
|
server( true ),
|
||||||
|
attached( false ),
|
||||||
|
MTU( RECEIVE_MTU ),
|
||||||
|
key(),
|
||||||
|
session( key ),
|
||||||
|
flow( TO_CLIENT, &session )
|
||||||
{
|
{
|
||||||
socklen_t optlen = sizeof( MTU );
|
setup();
|
||||||
if ( getsockopt( sock, IPPROTO_IP, IP_MTU, &MTU, &optlen ) < 0 ) {
|
|
||||||
perror( "getsockopt" );
|
|
||||||
exit( 1 );
|
|
||||||
}
|
|
||||||
|
|
||||||
if ( optlen != sizeof( MTU ) ) {
|
|
||||||
fprintf( stderr, "Error getting path MTU.\n" );
|
|
||||||
exit( 1 );
|
|
||||||
}
|
|
||||||
|
|
||||||
fprintf( stderr, "Path MTU: %d\n", MTU );
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class Outgoing, class Incoming>
|
template <class Outgoing, class Incoming>
|
||||||
void Connection<Outgoing, Incoming>::client_connect( const char *ip, int port )
|
Connection<Outgoing, Incoming>::Connection( const char *key_str, const char *ip, int port ) /* client */
|
||||||
|
: sock( -1 ),
|
||||||
|
remote_addr(),
|
||||||
|
server( false ),
|
||||||
|
attached( false ),
|
||||||
|
MTU( RECEIVE_MTU ),
|
||||||
|
key( key_str ),
|
||||||
|
session( key ),
|
||||||
|
flow( TO_SERVER, &session )
|
||||||
{
|
{
|
||||||
assert( !server );
|
setup();
|
||||||
|
|
||||||
/* associate socket with remote host and port */
|
/* associate socket with remote host and port */
|
||||||
remote_addr.sin_family = AF_INET;
|
remote_addr.sin_family = AF_INET;
|
||||||
@@ -129,12 +119,29 @@ void Connection<Outgoing, Incoming>::client_connect( const char *ip, int port )
|
|||||||
attached = true;
|
attached = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <class Outgoing, class Incoming>
|
||||||
|
void Connection<Outgoing, Incoming>::update_MTU( void )
|
||||||
|
{
|
||||||
|
socklen_t optlen = sizeof( MTU );
|
||||||
|
if ( getsockopt( sock, IPPROTO_IP, IP_MTU, &MTU, &optlen ) < 0 ) {
|
||||||
|
perror( "getsockopt" );
|
||||||
|
exit( 1 );
|
||||||
|
}
|
||||||
|
|
||||||
|
if ( optlen != sizeof( MTU ) ) {
|
||||||
|
fprintf( stderr, "Error getting path MTU.\n" );
|
||||||
|
exit( 1 );
|
||||||
|
}
|
||||||
|
|
||||||
|
fprintf( stderr, "Path MTU: %d\n", MTU );
|
||||||
|
}
|
||||||
|
|
||||||
template <class Outgoing, class Incoming>
|
template <class Outgoing, class Incoming>
|
||||||
bool Connection<Outgoing, Incoming>::send( Outgoing &s )
|
bool Connection<Outgoing, Incoming>::send( Outgoing &s )
|
||||||
{
|
{
|
||||||
assert( attached );
|
assert( attached );
|
||||||
|
|
||||||
string p = flow.new_packet( s ).tostring();
|
string p = flow.new_packet( s ).tostring( &session );
|
||||||
|
|
||||||
ssize_t bytes_sent = sendto( sock, p.data(), p.size(), 0,
|
ssize_t bytes_sent = sendto( sock, p.data(), p.size(), 0,
|
||||||
(sockaddr *)&remote_addr, sizeof( remote_addr ) );
|
(sockaddr *)&remote_addr, sizeof( remote_addr ) );
|
||||||
@@ -173,7 +180,7 @@ Incoming Connection<Outgoing, Incoming>::recv( void )
|
|||||||
exit( 1 );
|
exit( 1 );
|
||||||
}
|
}
|
||||||
|
|
||||||
typename Flow<Incoming>::Packet p( string( buf, received_len ) );
|
typename Flow<Incoming>::Packet p( string( buf, received_len ), &session );
|
||||||
dos_assert( p.direction == (server ? TO_SERVER : TO_CLIENT) ); /* prevent malicious playback to sender */
|
dos_assert( p.direction == (server ? TO_SERVER : TO_CLIENT) ); /* prevent malicious playback to sender */
|
||||||
|
|
||||||
/* server auto-adjusts to client */
|
/* server auto-adjusts to client */
|
||||||
|
|||||||
+21
-11
@@ -7,6 +7,8 @@
|
|||||||
#include <netinet/in.h>
|
#include <netinet/in.h>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
|
#include "crypto.hpp"
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
namespace Network {
|
namespace Network {
|
||||||
@@ -27,7 +29,7 @@ namespace Network {
|
|||||||
uint64_t seq;
|
uint64_t seq;
|
||||||
string payload_string;
|
string payload_string;
|
||||||
|
|
||||||
DecodingCache( string coded_packet );
|
DecodingCache( string coded_packet, Session *session );
|
||||||
DecodingCache() : direction( TO_CLIENT ), seq( -1 ), payload_string() {}
|
DecodingCache() : direction( TO_CLIENT ), seq( -1 ), payload_string() {}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -42,16 +44,17 @@ namespace Network {
|
|||||||
: decoding_cache(), seq( s_seq ), direction( s_direction ), payload( s_payload )
|
: decoding_cache(), seq( s_seq ), direction( s_direction ), payload( s_payload )
|
||||||
{}
|
{}
|
||||||
|
|
||||||
Packet( string coded_packet );
|
Packet( string coded_packet, Session *session );
|
||||||
|
|
||||||
string tostring( void );
|
string tostring( Session *session );
|
||||||
};
|
};
|
||||||
|
|
||||||
uint64_t next_seq;
|
uint64_t next_seq;
|
||||||
const Direction direction;
|
Direction direction;
|
||||||
|
Session *session;
|
||||||
|
|
||||||
Flow( Direction s_direction )
|
Flow( Direction s_direction, Session *s_session )
|
||||||
: next_seq( 0 ), direction( s_direction )
|
: next_seq( 0 ), direction( s_direction ), session( s_session )
|
||||||
{}
|
{}
|
||||||
|
|
||||||
Packet new_packet( Payload &s_payload );
|
Packet new_packet( Payload &s_payload );
|
||||||
@@ -62,8 +65,6 @@ namespace Network {
|
|||||||
private:
|
private:
|
||||||
static const int RECEIVE_MTU = 2048;
|
static const int RECEIVE_MTU = 2048;
|
||||||
|
|
||||||
Flow<Outgoing> flow;
|
|
||||||
|
|
||||||
int sock;
|
int sock;
|
||||||
struct sockaddr_in remote_addr;
|
struct sockaddr_in remote_addr;
|
||||||
|
|
||||||
@@ -72,16 +73,25 @@ namespace Network {
|
|||||||
|
|
||||||
int MTU;
|
int MTU;
|
||||||
|
|
||||||
|
Base64Key key;
|
||||||
|
Session session;
|
||||||
|
|
||||||
|
Flow<Outgoing> flow;
|
||||||
|
|
||||||
void update_MTU( void );
|
void update_MTU( void );
|
||||||
|
|
||||||
public:
|
void setup( void );
|
||||||
Connection( bool s_server );
|
|
||||||
|
public:
|
||||||
|
Connection();
|
||||||
|
Connection( const char *key_str, const char *ip, int port );
|
||||||
|
|
||||||
void client_connect( const char *ip, int port );
|
|
||||||
bool send( Outgoing &s );
|
bool send( Outgoing &s );
|
||||||
Incoming recv( void );
|
Incoming recv( void );
|
||||||
int fd( void ) { return sock; }
|
int fd( void ) { return sock; }
|
||||||
int port( void );
|
int port( void );
|
||||||
|
int get_MTU( void ) { return MTU; }
|
||||||
|
string get_key( void ) { return key.printable_key(); }
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
+14
-10
@@ -4,26 +4,30 @@
|
|||||||
int main( int argc, char *argv[] )
|
int main( int argc, char *argv[] )
|
||||||
{
|
{
|
||||||
bool server = true;
|
bool server = true;
|
||||||
|
char *key;
|
||||||
char *ip;
|
char *ip;
|
||||||
int port;
|
int port;
|
||||||
|
|
||||||
|
Network::Connection<KeyStroke, KeyStroke> *n;
|
||||||
|
|
||||||
if ( argc > 1 ) {
|
if ( argc > 1 ) {
|
||||||
server = false;
|
server = false;
|
||||||
|
/* client */
|
||||||
|
|
||||||
ip = argv[ 1 ];
|
key = argv[ 1 ];
|
||||||
port = atoi( argv[ 2 ] );
|
ip = argv[ 2 ];
|
||||||
|
port = atoi( argv[ 3 ] );
|
||||||
|
|
||||||
|
n = new Network::Connection<KeyStroke, KeyStroke>( key, ip, port );
|
||||||
|
} else {
|
||||||
|
n = new Network::Connection<KeyStroke, KeyStroke>();
|
||||||
}
|
}
|
||||||
|
|
||||||
Network::Connection<KeyStroke, KeyStroke> n( server );
|
fprintf( stderr, "Port bound is %d, key is %s\n", n->port(), n->get_key().c_str() );
|
||||||
fprintf( stderr, "Port bound is %d\n", n.port() );
|
|
||||||
|
|
||||||
if ( !server ) {
|
|
||||||
n.client_connect( ip, port );
|
|
||||||
}
|
|
||||||
|
|
||||||
if ( server ) {
|
if ( server ) {
|
||||||
while ( true ) {
|
while ( true ) {
|
||||||
KeyStroke s = n.recv();
|
KeyStroke s = n->recv();
|
||||||
|
|
||||||
fprintf( stderr, "Got KeyStroke: %c\n", s.letter );
|
fprintf( stderr, "Got KeyStroke: %c\n", s.letter );
|
||||||
}
|
}
|
||||||
@@ -33,7 +37,7 @@ int main( int argc, char *argv[] )
|
|||||||
|
|
||||||
KeyStroke t( string( "x", 1 ) );
|
KeyStroke t( string( "x", 1 ) );
|
||||||
|
|
||||||
n.send( t );
|
n->send( t );
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user