diff --git a/network.cpp b/network.cpp index 8ffb896..442a5ed 100644 --- a/network.cpp +++ b/network.cpp @@ -82,6 +82,7 @@ Connection::Connection() /* server */ remote_addr(), server( true ), attached( false ), + MTU( SEND_MTU ), key(), session( key ), direction( TO_CLIENT ), @@ -111,6 +112,7 @@ Connection::Connection( const char *key_str, const char *ip, int port ) /* clien remote_addr(), server( false ), attached( false ), + MTU( SEND_MTU ), key( key_str ), session( key ), direction( TO_SERVER ), @@ -134,12 +136,6 @@ Connection::Connection( const char *key_str, const char *ip, int port ) /* clien throw NetworkException( buffer, saved_errno ); } - /* - if ( connect( sock, (sockaddr *)&remote_addr, sizeof( remote_addr ) ) < 0 ) { - throw NetworkException( "connect", errno ); - } - */ - attached = true; } @@ -154,7 +150,10 @@ void Connection::send( string s ) ssize_t bytes_sent = sendto( sock, p.data(), p.size(), 0, (sockaddr *)&remote_addr, sizeof( remote_addr ) ); - if ( bytes_sent == static_cast( p.size() ) ) { + if ( (bytes_sent < 0) && (errno == EMSGSIZE) ) { + update_MTU(); + throw NetworkException( "Path MTU Discovery", EMSGSIZE ); + } else if ( bytes_sent == static_cast( p.size() ) ) { return; } else { throw NetworkException( "sendto", errno ); @@ -289,3 +288,50 @@ uint64_t Connection::timeout( void ) const } return RTO; } + +class Socket { +public: + int fd; + + Socket( int domain, int type, int protocol ) + : fd( socket( domain, type, protocol ) ) + { + if ( fd < 0 ) { + throw NetworkException( "socket", errno ); + } + } + + ~Socket() + { + if ( close( fd ) < 0 ) { + throw NetworkException( "close", errno ); + } + } +}; + +void Connection::update_MTU( void ) +{ + if ( !attached ) { + return; + } + + /* We don't want to use our main socket because we don't want to have to connect it */ + Socket path_MTU_socket( AF_INET, SOCK_DGRAM, 0 ); + + /* Connect socket so we can retrieve path MTU */ + if ( connect( path_MTU_socket.fd, (sockaddr *)&remote_addr, sizeof( remote_addr ) ) < 0 ) { + throw NetworkException( "connect", errno ); + } + + int PMTU; + socklen_t optlen = sizeof( PMTU ); + if ( getsockopt( path_MTU_socket.fd, IPPROTO_IP, IP_MTU, &PMTU, &optlen ) < 0 ) { + throw NetworkException( "getsockopt", errno ); + } + + if ( optlen != sizeof( PMTU ) ) { + throw NetworkException( "Error getting path MTU", errno ); + } + + MTU = max( PMTU, SEND_MTU ); +} diff --git a/network.hpp b/network.hpp index 9d8f7d8..0bac534 100644 --- a/network.hpp +++ b/network.hpp @@ -63,6 +63,8 @@ namespace Network { bool server; bool attached; + int MTU; + Base64Key key; Session session; @@ -80,6 +82,8 @@ namespace Network { Packet new_packet( string &s_payload ); + void update_MTU( void ); + public: Connection(); Connection( const char *key_str, const char *ip, int port ); @@ -87,7 +91,7 @@ namespace Network { void send( string s ); string recv( void ); int fd( void ) const { return sock; } - int get_MTU( void ) const { return SEND_MTU; } + int get_MTU( void ) const { return MTU; } int port( void ) const; string get_key( void ) const { return key.printable_key(); }