From c0092a6e7ee481fc897661ff107e3fa940547574 Mon Sep 17 00:00:00 2001 From: Keith Winstein Date: Fri, 5 Oct 2012 02:51:25 -0400 Subject: [PATCH] Hop new ports, but keep the old [for a minute, and up to 10 at a time] (One is silver and the other gold...) --- src/examples/ntester.cc | 37 +++++++-- src/frontend/mosh-server.cc | 9 ++- src/frontend/stmclient.cc | 32 ++++++-- src/network/network.cc | 134 +++++++++++++++++++++++++------ src/network/network.h | 34 +++++++- src/network/networktransport.h | 2 +- src/network/transportfragment.cc | 3 +- 7 files changed, 205 insertions(+), 46 deletions(-) diff --git a/src/examples/ntester.cc b/src/examples/ntester.cc index 0b8f673..25bef36 100644 --- a/src/examples/ntester.cc +++ b/src/examples/ntester.cc @@ -74,10 +74,14 @@ int main( int argc, char *argv[] ) if ( server ) { Select &sel = Select::get_instance(); - sel.add_fd( n->fd() ); uint64_t last_num = n->get_remote_state_num(); while ( true ) { try { + sel.clear_fds(); + std::vector< int > fd_list( n->fds() ); + assert( fd_list.size() == 1 ); /* servers don't hop */ + int network_fd = fd_list.back(); + sel.add_fd( network_fd ); if ( sel.select( n->wait_time() ) < 0 ) { perror( "select" ); exit( 1 ); @@ -85,7 +89,7 @@ int main( int argc, char *argv[] ) n->tick(); - if ( sel.read( n->fd() ) ) { + if ( sel.read( network_fd ) ) { n->recv(); if ( n->get_remote_state_num() != last_num ) { @@ -116,10 +120,18 @@ int main( int argc, char *argv[] ) } Select &sel = Select::get_instance(); - sel.add_fd( STDIN_FILENO ); - sel.add_fd( n->fd() ); while( true ) { + sel.clear_fds(); + sel.add_fd( STDIN_FILENO ); + + std::vector< int > fd_list( n->fds() ); + for ( std::vector< int >::const_iterator it = fd_list.begin(); + it != fd_list.end(); + it++ ) { + sel.add_fd( *it ); + } + try { if ( sel.select( n->wait_time() ) < 0 ) { perror( "select" ); @@ -133,7 +145,22 @@ int main( int argc, char *argv[] ) n->get_current_state().push_back( Parser::UserByte( x ) ); } - if ( sel.read( n->fd() ) ) { + bool network_ready_to_read = false; + for ( std::vector< int >::const_iterator it = fd_list.begin(); + it != fd_list.end(); + it++ ) { + if ( sel.read( *it ) ) { + /* packet received from the network */ + /* we only read one socket each run */ + network_ready_to_read = true; + } + + if ( sel.error( *it ) ) { + break; + } + } + + if ( network_ready_to_read ) { n->recv(); } } catch ( NetworkException e ) { diff --git a/src/frontend/mosh-server.cc b/src/frontend/mosh-server.cc index b118daa..d8be58e 100644 --- a/src/frontend/mosh-server.cc +++ b/src/frontend/mosh-server.cc @@ -534,7 +534,10 @@ void serve( int host_fd, Terminal::Complete &terminal, ServerConnection &network /* poll for events */ sel.clear_fds(); - sel.add_fd( network.fd() ); + std::vector< int > fd_list( network.fds() ); + assert( fd_list.size() == 1 ); /* servers don't hop */ + int network_fd = fd_list.back(); + sel.add_fd( network_fd ); sel.add_fd( host_fd ); int active_fds = sel.select( timeout ); @@ -546,7 +549,7 @@ void serve( int host_fd, Terminal::Complete &terminal, ServerConnection &network now = Network::timestamp(); uint64_t time_since_remote_state = now - network.get_latest_remote_state().timestamp; - if ( sel.read( network.fd() ) ) { + if ( sel.read( network_fd ) ) { /* packet received from the network */ network.recv(); @@ -652,7 +655,7 @@ void serve( int host_fd, Terminal::Complete &terminal, ServerConnection &network } } - if ( sel.error( network.fd() ) ) { + if ( sel.error( network_fd ) ) { /* network problem */ break; } diff --git a/src/frontend/stmclient.cc b/src/frontend/stmclient.cc index 18c9dbe..07238bc 100644 --- a/src/frontend/stmclient.cc +++ b/src/frontend/stmclient.cc @@ -324,7 +324,12 @@ void STMClient::main( void ) /* poll for events */ /* network->fd() can in theory change over time */ sel.clear_fds(); - sel.add_fd( network->fd() ); + std::vector< int > fd_list( network->fds() ); + for ( std::vector< int >::const_iterator it = fd_list.begin(); + it != fd_list.end(); + it++ ) { + sel.add_fd( *it ); + } sel.add_fd( STDIN_FILENO ); int active_fds = sel.select( wait_time ); @@ -333,8 +338,24 @@ void STMClient::main( void ) break; } - if ( sel.read( network->fd() ) ) { - /* packet received from the network */ + bool network_ready_to_read = false; + + for ( std::vector< int >::const_iterator it = fd_list.begin(); + it != fd_list.end(); + it++ ) { + if ( sel.read( *it ) ) { + /* packet received from the network */ + /* we only read one socket each run */ + network_ready_to_read = true; + } + + if ( sel.error( *it ) ) { + /* network problem */ + break; + } + } + + if ( network_ready_to_read ) { if ( !process_network_input() ) { return; } } @@ -370,11 +391,6 @@ void STMClient::main( void ) } } - if ( sel.error( network->fd() ) ) { - /* network problem */ - break; - } - if ( sel.error( STDIN_FILENO ) ) { /* user problem */ if ( !network->has_remote_addr() ) { diff --git a/src/network/network.cc b/src/network/network.cc index 60df783..782f78e 100644 --- a/src/network/network.cc +++ b/src/network/network.cc @@ -111,35 +111,54 @@ void Connection::hop_port( void ) { assert( !server ); - if ( close( sock ) < 0 ) { - throw NetworkException( "close", errno ); - } - setup(); + + prune_sockets(); } -void Connection::setup( void ) +void Connection::prune_sockets( void ) { - /* create socket */ - sock = socket( AF_INET, SOCK_DGRAM, 0 ); - if ( sock < 0 ) { - throw NetworkException( "socket", errno ); + /* don't keep old sockets if the new socket has been working for long enough */ + if ( socks.size() > 1 ) { + if ( timestamp() - last_port_choice > MAX_OLD_SOCKET_AGE ) { + int num_to_kill = socks.size() - 1; + for ( int i = 0; i < num_to_kill; i++ ) { + socks.pop_front(); + } + } + } else { + return; } - last_port_choice = timestamp(); + /* make sure we don't have too many receive sockets open */ + if ( socks.size() > MAX_PORTS_OPEN ) { + int num_to_kill = socks.size() - MAX_PORTS_OPEN; + for ( int i = 0; i < num_to_kill; i++ ) { + socks.pop_front(); + } + } +} + +Connection::Socket::Socket() + : _fd( socket( AF_INET, SOCK_DGRAM, 0 ) ), + _moved( false ) +{ + if ( _fd < 0 ) { + throw NetworkException( "socket", errno ); + } /* Disable path MTU discovery */ #ifdef HAVE_IP_MTU_DISCOVER char flag = IP_PMTUDISC_DONT; socklen_t optlen = sizeof( flag ); - if ( setsockopt( sock, IPPROTO_IP, IP_MTU_DISCOVER, &flag, optlen ) < 0 ) { + if ( setsockopt( _fd, IPPROTO_IP, IP_MTU_DISCOVER, &flag, optlen ) < 0 ) { throw NetworkException( "setsockopt", errno ); } #endif /* set diffserv values to AF42 + ECT */ uint8_t dscp = 0x92; - if ( setsockopt( sock, IPPROTO_IP, IP_TOS, &dscp, 1) < 0 ) { + if ( setsockopt( _fd, IPPROTO_IP, IP_TOS, &dscp, 1) < 0 ) { // perror( "setsockopt( IP_TOS )" ); } @@ -147,14 +166,35 @@ void Connection::setup( void ) #ifdef HAVE_IP_RECVTOS char tosflag = true; socklen_t tosoptlen = sizeof( tosflag ); - if ( setsockopt( sock, IPPROTO_IP, IP_RECVTOS, &tosflag, tosoptlen ) < 0 ) { + if ( setsockopt( _fd, IPPROTO_IP, IP_RECVTOS, &tosflag, tosoptlen ) < 0 ) { perror( "setsockopt( IP_RECVTOS )" ); } #endif } +void Connection::setup( void ) +{ + /* create socket */ + socks.push_back( Socket() ); + + last_port_choice = timestamp(); +} + +const std::vector< int > Connection::fds( void ) const +{ + std::vector< int > ret; + + for ( std::deque< Socket >::const_iterator it = socks.begin(); + it != socks.end(); + it++ ) { + ret.push_back( it->fd() ); + } + + return ret; +} + Connection::Connection( const char *desired_ip, const char *desired_port ) /* server */ - : sock( -1 ), + : socks(), has_remote_addr( false ), remote_addr(), server( true ), @@ -213,7 +253,7 @@ Connection::Connection( const char *desired_ip, const char *desired_port ) /* se /* try to bind to desired IP first */ if ( desired_ip_addr != INADDR_ANY ) { try { - if ( try_bind( sock, desired_ip_addr, desired_port_no ) ) { return; } + if ( try_bind( sock(), desired_ip_addr, desired_port_no ) ) { return; } } catch ( const NetworkException& e ) { struct in_addr sin_addr; sin_addr.s_addr = desired_ip_addr; @@ -225,7 +265,7 @@ Connection::Connection( const char *desired_ip, const char *desired_port ) /* se /* now try any local interface */ try { - if ( try_bind( sock, INADDR_ANY, desired_port_no ) ) { return; } + if ( try_bind( sock(), INADDR_ANY, desired_port_no ) ) { return; } } catch ( const NetworkException& e ) { fprintf( stderr, "Error binding to any interface: %s: %s\n", e.function.c_str(), strerror( e.the_errno ) ); @@ -266,7 +306,7 @@ bool Connection::try_bind( int socket, uint32_t addr, int port ) } Connection::Connection( const char *key_str, const char *ip, int port ) /* client */ - : sock( -1 ), + : socks(), has_remote_addr( false ), remote_addr(), server( false ), @@ -312,7 +352,7 @@ void Connection::send( string s ) string p = px.tostring( &session ); - ssize_t bytes_sent = sendto( sock, p.data(), p.size(), 0, + ssize_t bytes_sent = sendto( sock(), p.data(), p.size(), 0, (sockaddr *)&remote_addr, sizeof( remote_addr ) ); if ( bytes_sent == static_cast( p.size() ) ) { @@ -340,6 +380,34 @@ void Connection::send( string s ) } string Connection::recv( void ) +{ + assert( !socks.empty() ); + for ( std::deque< Socket >::const_iterator it = socks.begin(); + it != socks.end(); + it++ ) { + bool islast = (it + 1) == socks.end(); + string payload; + try { + payload = recv_one( it->fd(), !islast ); + } catch ( NetworkException & e ) { + if ( (e.the_errno == EAGAIN) + || (e.the_errno == EWOULDBLOCK) ) { + assert( !islast ); + continue; + } else { + throw e; + } + } + + /* succeeded */ + prune_sockets(); + return payload; + } + assert( false ); + return ""; +} + +string Connection::recv_one( int sock_to_recv, bool nonblocking ) { /* receive source address, ECN, and payload in msghdr structure */ struct sockaddr_in packet_remote_addr; @@ -366,10 +434,10 @@ string Connection::recv( void ) /* receive flags */ header.msg_flags = 0; - ssize_t received_len = recvmsg( sock, &header, 0 ); + ssize_t received_len = recvmsg( sock_to_recv, &header, nonblocking ? MSG_DONTWAIT : 0 ); if ( received_len < 0 ) { - throw NetworkException( "recvfrom", errno ); + throw NetworkException( "recvmsg", errno ); } if ( header.msg_flags & MSG_TRUNC ) { @@ -456,7 +524,7 @@ int Connection::port( void ) const struct sockaddr_in local_addr; socklen_t addrlen = sizeof( local_addr ); - if ( getsockname( sock, (sockaddr *)&local_addr, &addrlen ) < 0 ) { + if ( getsockname( sock(), (sockaddr *)&local_addr, &addrlen ) < 0 ) { throw NetworkException( "getsockname", errno ); } @@ -501,9 +569,27 @@ uint64_t Connection::timeout( void ) const return RTO; } -Connection::~Connection() +Connection::Socket::~Socket() { - if ( close( sock ) < 0 ) { - throw NetworkException( "close", errno ); + if ( !_moved ) { + if ( close( _fd ) < 0 ) { + throw NetworkException( "close", errno ); + } } } + +Connection::Socket::Socket( const Socket & other ) + : _fd( other._fd ), + _moved( false ) +{ + other.move(); +} + +const Connection::Socket & Connection::Socket::operator=( const Socket & other ) +{ + _fd = other._fd; + + other.move(); + + return *this; +} diff --git a/src/network/network.h b/src/network/network.h index 3d27039..f64b93b 100644 --- a/src/network/network.h +++ b/src/network/network.h @@ -39,6 +39,8 @@ #include #include #include +#include +#include #include "crypto.h" @@ -92,13 +94,32 @@ namespace Network { static const int PORT_RANGE_HIGH = 60999; static const unsigned int SERVER_ASSOCIATION_TIMEOUT = 40000; - static const unsigned int PORT_HOP_INTERVAL = 30000; + static const unsigned int PORT_HOP_INTERVAL = 10000; + + static const unsigned int MAX_PORTS_OPEN = 10; + static const unsigned int MAX_OLD_SOCKET_AGE = 60000; static const int CONGESTION_TIMESTAMP_PENALTY = 500; /* ms */ static bool try_bind( int socket, uint32_t addr, int port ); - int sock; + class Socket + { + private: + int _fd; + mutable bool _moved; + + public: + int fd( void ) const { assert( !_moved ); return _fd; } + void move( void ) const { assert( !_moved ); _moved = true; } + Socket(); + ~Socket(); + + Socket( const Socket & other ); + const Socket & operator=( const Socket & other ); + }; + + std::deque< Socket > socks; bool has_remote_addr; struct sockaddr_in remote_addr; @@ -134,14 +155,19 @@ namespace Network { void hop_port( void ); + int sock( void ) const { assert( !socks.empty() ); return socks.back().fd(); } + + void prune_sockets( void ); + + string recv_one( int sock_to_recv, bool nonblocking ); + public: Connection( const char *desired_ip, const char *desired_port ); /* server */ Connection( const char *key_str, const char *ip, int port ); /* client */ - ~Connection(); void send( string s ); string recv( void ); - int fd( void ) const { return sock; } + const std::vector< int > fds( void ) const; int get_MTU( void ) const { return MTU; } int port( void ) const; diff --git a/src/network/networktransport.h b/src/network/networktransport.h index cadaf82..99ece2c 100644 --- a/src/network/networktransport.h +++ b/src/network/networktransport.h @@ -103,7 +103,7 @@ namespace Network { const TimestampedState & get_latest_remote_state( void ) const { return received_states.back(); } - int fd( void ) const { return connection.fd(); } + const std::vector< int > fds( void ) const { return connection.fds(); } void set_verbose( void ) { sender.set_verbose(); verbose = true; } diff --git a/src/network/transportfragment.cc b/src/network/transportfragment.cc index ed199d6..bbd41eb 100644 --- a/src/network/transportfragment.cc +++ b/src/network/transportfragment.cc @@ -74,9 +74,10 @@ string Fragment::tostring( void ) Fragment::Fragment( string &x ) : id( -1 ), fragment_num( -1 ), final( false ), initialized( true ), - contents( x.begin() + frag_header_len, x.end() ) + contents() { assert( x.size() >= frag_header_len ); + contents = string( x.begin() + frag_header_len, x.end() ); uint64_t data64; uint16_t *data16 = (uint16_t *)x.data();