diff --git a/Makefile b/Makefile index 08e52e8..b962c2e 100644 --- a/Makefile +++ b/Makefile @@ -4,7 +4,7 @@ repos = templates.rpo executables = parse termemu ntester CXX = g++ -CXXFLAGS = -g -O2 --std=c++0x -pedantic -Werror -Wall -Wextra -Weffc++ -fno-implicit-templates -fno-default-inline -pipe -D_FILE_OFFSET_BITS=64 -D_XOPEN_SOURCE=500 -D_GNU_SOURCE +CXXFLAGS = -g --std=c++0x -pedantic -Werror -Wall -Wextra -Weffc++ -fno-implicit-templates -fno-default-inline -pipe -D_FILE_OFFSET_BITS=64 -D_XOPEN_SOURCE=500 -D_GNU_SOURCE LIBS = -lutil all: $(executables) diff --git a/dos_assert.hpp b/dos_assert.hpp new file mode 100644 index 0000000..3b1acf2 --- /dev/null +++ b/dos_assert.hpp @@ -0,0 +1,19 @@ +#ifndef DOS_ASSERT_HPP +#define DOS_ASSERT_HPP + +#include +#include + +static void dos_detected( const char *expression, const char *file, int line, const char *function ) +{ + fprintf( stderr, "Illegal counterparty input (possible DOS) in function %s at %s:%d, failed test: %s\n", + function, file, line, expression ); + exit( 1 ); +} + +#define dos_assert(expr) \ + ((expr) \ + ? (void)0 \ + : dos_detected (__STRING(expr), __FILE__, __LINE__, __PRETTY_FUNCTION__ )) + +#endif diff --git a/keystroke.hpp b/keystroke.hpp new file mode 100644 index 0000000..924b177 --- /dev/null +++ b/keystroke.hpp @@ -0,0 +1,30 @@ +#ifndef KEYSTROKE_HPP +#define KEYSTROKE_HPP + +#include +#include + +class KeyStroke +{ +public: + char letter; + + string tostring( void ) + { + return string( &letter, 1 ); + }; + + KeyStroke( const string x ) + : letter() + { + assert( x.size() == 1 ); + + letter = x[ 0 ]; + } + + KeyStroke() + : letter( 0 ) + {} +}; + +#endif diff --git a/network.cpp b/network.cpp index b6a8c18..78871a2 100644 --- a/network.cpp +++ b/network.cpp @@ -1,32 +1,171 @@ +#include +#include +#include +#include #include +#include "dos_assert.hpp" #include "network.hpp" using namespace std; using namespace Network; template -Connection::Packet::Packet( int64_t s_seq, int64_t s_ack, Packet *s_previous, Payload s_state ) - : seq( s_seq ), ack( s_ack ), previous( s_previous ), state( s_state ) +Flow::Packet::DecodingCache::DecodingCache( string coded_packet ) + : direction( TO_CLIENT ), seq( -1 ), payload_string() { + dos_assert( coded_packet.size() >= 8 ); + /* Read in sequence number and direction */ + string seq_string( coded_packet.begin(), coded_packet.begin() + 8 ); + uint64_t *network_order_seq = (uint64_t *)seq_string.data(); + uint64_t direction_seq = be64toh( *network_order_seq ); + direction = (direction_seq & 8000000000000000) ? TO_CLIENT : TO_SERVER; + seq = direction_seq & 0x7FFFFFFFFFFFFFFF; + + /* Read in payload */ + payload_string = string( coded_packet.begin() + 8, coded_packet.end() ); } template -Connection::Packet::Packet( string wire ) +Flow::Packet::Packet( string coded_packet ) + : decoding_cache( coded_packet ), + seq( decoding_cache.seq ), + direction( decoding_cache.direction ), + payload( decoding_cache.payload_string ) +{} + +template +string Flow::Packet::tostring( void ) { - assert( wire.length() >= 32 ); + uint64_t direction_seq = ((uint64_t( direction == TO_CLIENT ) & 0x1) << 63) | (seq & 0x7FFFFFFFFFFFFFFF); + uint64_t network_order_seq = htobe64( direction_seq ); + const char *seq_str = (const char *)&network_order_seq; + string seq_string( seq_str, 8 ); /* necessary in case there is a zero byte */ - seq = be64toh( (uint64_t) *wire_c ); - reference_seq = be64toh( (uint64_t) *( wire_c + 8 ) ); - - tag = string( wire.begin() + 16, wire.begin() + 32 ); - - ack = be64toh( (uint64_t) *( wire_c + 16 ) ); + return seq_string + payload.tostring(); } template -Connection::Connection( const char *ip, const char *port, bool server ) +typename Flow::Packet Flow::new_packet( Payload &s_payload ) +{ + return Packet( next_seq++, direction, s_payload ); +} + +template +Connection::Connection( bool s_server ) + : flow( s_server ? TO_CLIENT : TO_SERVER ), + sock( -1 ), + remote_addr(), + server( s_server ), + attached( false ) { + /* create socket */ + sock = socket( AF_INET, SOCK_DGRAM, 0 ); + if ( sock < 0 ) { + perror( "socket" ); + exit( 1 ); + } + + /* Bind to free local port. + This usage does not seem to be endorsed by POSIX. */ + struct sockaddr_in local_addr; + local_addr.sin_family = AF_INET; + local_addr.sin_port = htons( 0 ); + local_addr.sin_addr.s_addr = INADDR_ANY; + if ( bind( sock, (sockaddr *)&local_addr, sizeof( local_addr ) ) < 0 ) { + perror( "bind" ); + exit( 1 ); + } +} + +template +void Connection::client_connect( const char *ip, int port ) +{ + assert( !server ); + + /* associate socket with remote host and port */ + remote_addr.sin_family = AF_INET; + remote_addr.sin_port = htons( port ); + if ( !inet_aton( ip, &remote_addr.sin_addr ) ) { + fprintf( stderr, "Bad IP address %s\n", ip ); + exit( 1 ); + } + + if ( connect( sock, (sockaddr *)&remote_addr, sizeof( remote_addr ) ) < 0 ) { + perror( "connect" ); + exit( 1 ); + } + + attached = true; +} + +template +void Connection::send( Outgoing &s ) +{ + assert( attached ); + + string p = flow.new_packet( s ).tostring(); + + if ( sendto( sock, p.data(), p.size(), 0, + (sockaddr *)&remote_addr, sizeof( remote_addr ) ) < 0 ) { + perror( "sendto" ); + exit( 1 ); + } +} + +template +Incoming Connection::recv( void ) +{ + struct sockaddr_in packet_remote_addr; + + char buf[ RECEIVE_MTU ]; + + socklen_t addrlen = sizeof( packet_remote_addr ); + + ssize_t received_len = recvfrom( sock, buf, RECEIVE_MTU, 0, (sockaddr *)&packet_remote_addr, &addrlen ); + + if ( received_len < 0 ) { + perror( "recvfrom" ); + exit( 1 ); + } + + if ( received_len > RECEIVE_MTU ) { + fprintf( stderr, "Received oversize datagram (size %d) and limit is %d.\n", + static_cast( received_len ), RECEIVE_MTU ); + exit( 1 ); + } + + typename Flow::Packet p( string( buf, received_len ) ); + dos_assert( p.direction == (server ? TO_SERVER : TO_CLIENT) ); /* prevent malicious playback to sender */ + + /* server auto-adjusts to client */ + if ( server ) { + attached = true; + + if ( (remote_addr.sin_addr.s_addr != packet_remote_addr.sin_addr.s_addr) + || (remote_addr.sin_port != packet_remote_addr.sin_port) ) { + remote_addr = packet_remote_addr; + fprintf( stderr, "Server now attached to client at %s:%d\n", + inet_ntoa( remote_addr.sin_addr ), + ntohs( remote_addr.sin_port ) ); + } + } + + return p.payload; +} + +template +int Connection::port( void ) +{ + struct sockaddr_in local_addr; + socklen_t addrlen = sizeof( local_addr ); + + if ( getsockname( sock, (sockaddr *)&local_addr, &addrlen ) < 0 ) { + perror( "getsockname" ); + exit( 1 ); + } + + return ntohs( local_addr.sin_port ); } diff --git a/network.hpp b/network.hpp index 867295f..f74403d 100644 --- a/network.hpp +++ b/network.hpp @@ -7,37 +7,78 @@ #include #include +using namespace std; + namespace Network { + enum Direction { + TO_SERVER = 0, + TO_CLIENT = 1 + }; + template - class Connection { - private: + class Flow { + public: class Packet { + private: + class DecodingCache + { + public: + Direction direction; + uint64_t seq; + string payload_string; + + DecodingCache( string coded_packet ); + DecodingCache() : direction( TO_CLIENT ), seq( -1 ), payload_string() {} + }; + + DecodingCache decoding_cache; + public: - int64_t seq; - int64_t reference_seq; - - std::string tag; - - int64_t ack; - - Payload state; - - Packet( int64_t s_seq, int64_t s_ack, Packet *s_previous, Payload s_state ); - Packet( std::string wire ); + uint64_t seq; + Direction direction; + Payload payload; + + Packet( uint64_t s_seq, Direction s_direction, Payload s_payload ) + : decoding_cache(), seq( s_seq ), direction( s_direction ), payload( s_payload ) + {} + + Packet( string coded_packet ); + + string tostring( void ); }; - int64_t next_seq; - int64_t next_ack; - int sequence_increment; + uint64_t next_seq; + const Direction direction; + int MTU; + + Flow( Direction s_direction ) + : next_seq( 0 ), direction( s_direction ), MTU( 2048 ) + {} + + Packet new_packet( Payload &s_payload ); + }; + + template + class Connection { + private: + static const int RECEIVE_MTU = 2048; + + Flow flow; int sock; - struct sockaddr_in addr; + struct sockaddr_in remote_addr; - std::deque send_queue; - std::deque recv_queue; + bool server; + bool attached; public: - Connection( const char *ip, const char *port, bool server ); + Connection( bool s_server ); + + void client_connect( const char *ip, int port ); + void send( Outgoing &s ); + Incoming recv( void ); + int fd( void ) { return sock; } + int port( void ); }; } diff --git a/ntester.cpp b/ntester.cpp index fbca0c7..9ea7fd9 100644 --- a/ntester.cpp +++ b/ntester.cpp @@ -1,12 +1,39 @@ #include "network.hpp" +#include "keystroke.hpp" -class KeyStroke +int main( int argc, char *argv[] ) { -public: - char letter; -}; + bool server = true; + char *ip; + int port; -int main( void ) -{ - Network::Connection n(); + if ( argc > 1 ) { + server = false; + + ip = argv[ 1 ]; + port = atoi( argv[ 2 ] ); + } + + Network::Connection n( server ); + fprintf( stderr, "Port bound is %d\n", n.port() ); + + if ( !server ) { + n.client_connect( ip, port ); + } + + if ( server ) { + while ( true ) { + KeyStroke s = n.recv(); + + fprintf( stderr, "Got KeyStroke: %c\n", s.letter ); + } + } else { + while( true ) { + sleep( 1 ); + + KeyStroke t( string( "x", 1 ) ); + + n.send( t ); + } + } } diff --git a/templates.cpp b/templates.cpp index d8b2a0b..441ec7c 100644 --- a/templates.cpp +++ b/templates.cpp @@ -6,6 +6,9 @@ #include "terminal.hpp" +#include "network.cpp" +#include "keystroke.hpp" + namespace Parser { class Action; } @@ -21,3 +24,4 @@ template class vector; template class vector; template class map; template class vector; +template class Network::Connection;