diff --git a/configure.ac b/configure.ac index b07291a..83d6ceb 100644 --- a/configure.ac +++ b/configure.ac @@ -170,7 +170,28 @@ AC_SEARCH_LIBS([socket], [socket]) AC_SEARCH_LIBS([inet_addr], [nsl]) # Checks for header files. -AC_CHECK_HEADERS([arpa/inet.h fcntl.h langinfo.h limits.h locale.h netinet/in.h stddef.h stdint.h inttypes.h stdlib.h string.h sys/ioctl.h sys/resource.h sys/socket.h sys/stat.h sys/time.h termios.h unistd.h wchar.h wctype.h], [], [AC_MSG_ERROR([Missing required header file.])]) +AC_CHECK_HEADERS([m4_normalize([ + fcntl.h + langinfo.h + limits.h + locale.h + netdb.h + netinet/in.h + stddef.h + stdint.h + inttypes.h + stdlib.h + string.h + sys/ioctl.h + sys/resource.h + sys/socket.h + sys/stat.h + sys/time.h + termios.h + unistd.h + wchar.h + wctype.h + ])], [], [AC_MSG_ERROR([Missing required header file.])]) AC_CHECK_HEADERS([pty.h util.h libutil.h paths.h]) AC_CHECK_HEADERS([endian.h sys/endian.h]) @@ -195,7 +216,30 @@ AC_TYPE_UINTPTR_T # Checks for library functions. AC_FUNC_FORK AC_FUNC_MBRTOWC -AC_CHECK_FUNCS([gettimeofday setrlimit inet_ntoa iswprint memchr memset nl_langinfo posix_memalign setenv setlocale sigaction socket strchr strdup strncasecmp strtok strerror strtol wcwidth cfmakeraw pselect]) +AC_CHECK_FUNCS([m4_normalize([ + gettimeofday + setrlimit + iswprint + memchr + memset + nl_langinfo + posix_memalign + setenv + setlocale + sigaction + socket + strchr + strdup + strncasecmp + strtok + strerror + strtol + wcwidth + cfmakeraw + pselect + getaddrinfo + getnameinfo + ])]) AC_SEARCH_LIBS([clock_gettime], [rt], [AC_DEFINE([HAVE_CLOCK_GETTIME], [1], [Define if clock_gettime is available.])]) diff --git a/src/frontend/mosh-client.cc b/src/frontend/mosh-client.cc index e4f4b48..d3f9744 100644 --- a/src/frontend/mosh-client.cc +++ b/src/frontend/mosh-client.cc @@ -127,13 +127,6 @@ int main( int argc, char *argv[] ) desired_port = argv[ optind + 1 ]; /* Sanity-check arguments */ - if ( ip - && ( strspn( ip, "0123456789." ) != strlen( ip ) ) ) { - fprintf( stderr, "%s: Bad IP address (%s)\n\n", argv[ 0 ], ip ); - usage( argv[ 0 ] ); - exit( 1 ); - } - if ( desired_port && ( strspn( desired_port, "0123456789" ) != strlen( desired_port ) ) ) { fprintf( stderr, "%s: Bad UDP port (%s)\n\n", argv[ 0 ], desired_port ); diff --git a/src/frontend/mosh-server.cc b/src/frontend/mosh-server.cc index ee27fea..d91cbc7 100644 --- a/src/frontend/mosh-server.cc +++ b/src/frontend/mosh-server.cc @@ -50,8 +50,7 @@ #include #endif #include -#include -#include +#include #include #include @@ -227,13 +226,6 @@ int main( int argc, char *argv[] ) } /* Sanity-check arguments */ - if ( desired_ip - && ( strspn( desired_ip, "0123456789." ) != strlen( desired_ip ) ) ) { - fprintf( stderr, "%s: Bad IP address (%s)\n", argv[ 0 ], desired_ip ); - print_usage( argv[ 0 ] ); - exit( 1 ); - } - int dpl, dph; if ( desired_port && ! Connection::parse_portrange( desired_port, dpl, dph ) ) { fprintf( stderr, "%s: Bad UDP port range (%s)\n", argv[ 0 ], desired_port ); @@ -532,8 +524,8 @@ void serve( int host_fd, Terminal::Complete &terminal, ServerConnection &network #ifdef HAVE_UTEMPTER bool connected_utmp = false; - struct in_addr saved_addr; - saved_addr.s_addr = 0; + Addr saved_addr; + socklen_t saved_addr_len = 0; #endif while ( 1 ) { @@ -616,13 +608,24 @@ void serve( int host_fd, Terminal::Complete &terminal, ServerConnection &network #ifdef HAVE_UTEMPTER /* update utmp entry if we have become "connected" */ if ( (!connected_utmp) - || ( saved_addr.s_addr != network.get_remote_ip().s_addr ) ) { + || saved_addr_len != network.get_remote_addr_len() + || memcmp( &saved_addr, &network.get_remote_addr(), + saved_addr_len ) != 0 ) { utempter_remove_record( host_fd ); - saved_addr = network.get_remote_ip(); + saved_addr = network.get_remote_addr(); + saved_addr_len = network.get_remote_addr_len(); + + char host[ NI_MAXHOST ]; + int errcode = getnameinfo( &saved_addr.sa, saved_addr_len, + host, sizeof( host ), NULL, 0, + NI_NUMERICHOST ); + if ( errcode != 0 ) { + throw NetworkException( std::string( "serve: getnameinfo: " ) + gai_strerror( errcode ), 0 ); + } char tmp[ 64 ]; - snprintf( tmp, 64, "%s via mosh [%d]", inet_ntoa( saved_addr ), getpid() ); + snprintf( tmp, 64, "%s via mosh [%d]", host, getpid() ); utempter_add_record( host_fd, tmp ); connected_utmp = true; diff --git a/src/network/network.cc b/src/network/network.cc index a214b1a..08b263a 100644 --- a/src/network/network.cc +++ b/src/network/network.cc @@ -37,13 +37,14 @@ #ifdef HAVE_SYS_UIO_H #include #endif +#include #include -#include #include #include #include #include "dos_assert.h" +#include "fatal_assert.h" #include "byteorder.h" #include "network.h" #include "crypto.h" @@ -119,7 +120,8 @@ void Connection::hop_port( void ) assert( !server ); setup(); - socks.push_back( Socket() ); + assert( remote_addr_len != 0 ); + socks.push_back( Socket( remote_addr.sa.sa_family ) ); prune_sockets(); } @@ -147,8 +149,8 @@ void Connection::prune_sockets( void ) } } -Connection::Socket::Socket() - : _fd( socket( AF_INET, SOCK_DGRAM, 0 ) ) +Connection::Socket::Socket( int family ) + : _fd( socket( family, SOCK_DGRAM, 0 ) ) { if ( _fd < 0 ) { throw NetworkException( "socket", errno ); @@ -197,10 +199,28 @@ const std::vector< int > Connection::fds( void ) const return ret; } +class AddrInfo { +public: + struct addrinfo *res; + AddrInfo( const char *node, const char *service, + const struct addrinfo *hints ) : + res( NULL ) { + int errcode = getaddrinfo( node, service, hints, &res ); + if ( errcode != 0 ) { + throw NetworkException( std::string( "Bad IP address (" ) + node + "): " + gai_strerror( errcode ), 0 ); + } + } + ~AddrInfo() { freeaddrinfo(res); } +private: + AddrInfo(const AddrInfo &); + AddrInfo &operator=(const AddrInfo &); +}; + Connection::Connection( const char *desired_ip, const char *desired_port ) /* server */ : socks(), has_remote_addr( false ), remote_addr(), + remote_addr_len( 0 ), server( true ), MTU( DEFAULT_SEND_MTU ), key(), @@ -235,33 +255,20 @@ Connection::Connection( const char *desired_ip, const char *desired_port ) /* se throw NetworkException("Invalid port range", 0); } - /* convert desired IP */ - uint32_t desired_ip_addr = INADDR_ANY; - - if ( desired_ip ) { - struct in_addr sin_addr; - if ( inet_aton( desired_ip, &sin_addr ) == 0 ) { - throw NetworkException( "Invalid IP address", errno ); - } - desired_ip_addr = sin_addr.s_addr; - } - /* try to bind to desired IP first */ - if ( desired_ip_addr != INADDR_ANY ) { + if ( desired_ip ) { try { - if ( try_bind( desired_ip_addr, desired_port_low, desired_port_high ) ) { return; } + if ( try_bind( desired_ip, desired_port_low, desired_port_high ) ) { return; } } catch ( const NetworkException& e ) { - struct in_addr sin_addr; - sin_addr.s_addr = desired_ip_addr; fprintf( stderr, "Error binding to IP %s: %s: %s\n", - inet_ntoa( sin_addr ), + desired_ip, e.function.c_str(), strerror( e.the_errno ) ); } } /* now try any local interface */ try { - if ( try_bind( INADDR_ANY, desired_port_low, desired_port_high ) ) { return; } + if ( try_bind( NULL, desired_port_low, desired_port_high ) ) { return; } } catch ( const NetworkException& e ) { fprintf( stderr, "Error binding to any interface: %s: %s\n", e.function.c_str(), strerror( e.the_errno ) ); @@ -272,11 +279,18 @@ Connection::Connection( const char *desired_ip, const char *desired_port ) /* se throw NetworkException( "Could not bind", errno ); } -bool Connection::try_bind( uint32_t addr, int port_low, int port_high ) +bool Connection::try_bind( const char *addr, int port_low, int port_high ) { - struct sockaddr_in local_addr; - local_addr.sin_family = AF_INET; - local_addr.sin_addr.s_addr = addr; + struct addrinfo hints; + memset( &hints, 0, sizeof( hints ) ); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_DGRAM; + hints.ai_flags = AI_PASSIVE | AI_NUMERICHOST | AI_NUMERICSERV; + AddrInfo ai( addr, 0, &hints ); + + Addr local_addr; + socklen_t local_addr_len = ai.res->ai_addrlen; + memcpy( &local_addr.sa, ai.res->ai_addr, local_addr_len ); int search_low = PORT_RANGE_LOW, search_high = PORT_RANGE_HIGH; @@ -287,18 +301,34 @@ bool Connection::try_bind( uint32_t addr, int port_low, int port_high ) search_high = port_high; } - socks.push_back( Socket() ); + socks.push_back( Socket( local_addr.sa.sa_family ) ); for ( int i = search_low; i <= search_high; i++ ) { - local_addr.sin_port = htons( i ); + switch (local_addr.sa.sa_family) { + case AF_INET: + local_addr.sin.sin_port = htons( i ); + break; + case AF_INET6: + local_addr.sin6.sin6_port = htons( i ); + break; + default: + throw NetworkException( "Unknown address family", 0 ); + } - if ( bind( sock(), (sockaddr *)&local_addr, sizeof( local_addr ) ) == 0 ) { + if ( bind( sock(), &local_addr.sa, local_addr_len ) == 0 ) { return true; } else if ( i == search_high ) { /* last port to search */ - fprintf( stderr, "Failed binding to %s:%d\n", - inet_ntoa( local_addr.sin_addr ), - ntohs( local_addr.sin_port ) ); + int saved_errno = errno; socks.pop_back(); - throw NetworkException( "bind", errno ); + char host[ NI_MAXHOST ], serv[ NI_MAXSERV ]; + int errcode = getnameinfo( &local_addr.sa, local_addr_len, + host, sizeof( host ), serv, sizeof( serv ), + NI_DGRAM | NI_NUMERICHOST | NI_NUMERICSERV ); + if ( errcode != 0 ) { + throw NetworkException( std::string( "bind: getnameinfo: " ) + gai_strerror( errcode ), 0 ); + } + fprintf( stderr, "Failed binding to %s:%s\n", + host, serv ); + throw NetworkException( "bind", saved_errno ); } } @@ -310,6 +340,7 @@ Connection::Connection( const char *key_str, const char *ip, const char *port ) : socks(), has_remote_addr( false ), remote_addr(), + remote_addr_len( 0 ), server( false ), MTU( DEFAULT_SEND_MTU ), key( key_str ), @@ -331,18 +362,19 @@ Connection::Connection( const char *key_str, const char *ip, const char *port ) setup(); /* associate socket with remote host and port */ - remote_addr.sin_family = AF_INET; - remote_addr.sin_port = htons( myatoi( port ) ); - if ( !inet_aton( ip, &remote_addr.sin_addr ) ) { - int saved_errno = errno; - char buffer[ 2048 ]; - snprintf( buffer, 2048, "Bad IP address (%s)", ip ); - throw NetworkException( buffer, saved_errno ); - } + struct addrinfo hints; + memset( &hints, 0, sizeof( hints ) ); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_DGRAM; + hints.ai_flags = AI_NUMERICHOST | AI_NUMERICSERV; + AddrInfo ai( ip, port, &hints ); + fatal_assert( ai.res->ai_addrlen <= sizeof( remote_addr ) ); + remote_addr_len = ai.res->ai_addrlen; + memcpy( &remote_addr.sa, ai.res->ai_addr, remote_addr_len ); has_remote_addr = true; - socks.push_back( Socket() ); + socks.push_back( Socket( remote_addr.sa.sa_family ) ); } void Connection::send( string s ) @@ -356,7 +388,7 @@ void Connection::send( string s ) string p = px.tostring( &session ); ssize_t bytes_sent = sendto( sock(), p.data(), p.size(), MSG_DONTWAIT, - (sockaddr *)&remote_addr, sizeof( remote_addr ) ); + &remote_addr.sa, remote_addr_len ); if ( bytes_sent == static_cast( p.size() ) ) { have_send_exception = false; @@ -417,7 +449,7 @@ string Connection::recv( void ) string Connection::recv_one( int sock_to_recv, bool nonblocking ) { /* receive source address, ECN, and payload in msghdr structure */ - struct sockaddr_in packet_remote_addr; + Addr packet_remote_addr; struct msghdr header; struct iovec msg_iovec; @@ -425,7 +457,7 @@ string Connection::recv_one( int sock_to_recv, bool nonblocking ) char msg_control[ Session::RECEIVE_MTU ]; /* receive source address */ - header.msg_name = &packet_remote_addr; + header.msg_name = &packet_remote_addr.sa; header.msg_namelen = sizeof( packet_remote_addr ); /* receive payload */ @@ -513,12 +545,19 @@ string Connection::recv_one( int sock_to_recv, bool nonblocking ) last_heard = timestamp(); if ( server ) { /* only client can roam */ - if ( (remote_addr.sin_addr.s_addr != packet_remote_addr.sin_addr.s_addr) - || (remote_addr.sin_port != packet_remote_addr.sin_port) ) { + if ( remote_addr_len != header.msg_namelen || + memcmp( &remote_addr, &packet_remote_addr, remote_addr_len ) != 0 ) { remote_addr = packet_remote_addr; - fprintf( stderr, "Server now attached to client at %s:%d\n", - inet_ntoa( remote_addr.sin_addr ), - ntohs( remote_addr.sin_port ) ); + remote_addr_len = header.msg_namelen; + char host[ NI_MAXHOST ], serv[ NI_MAXSERV ]; + int errcode = getnameinfo( &remote_addr.sa, remote_addr_len, + host, sizeof( host ), serv, sizeof( serv ), + NI_DGRAM | NI_NUMERICHOST | NI_NUMERICSERV ); + if ( errcode != 0 ) { + throw NetworkException( std::string( "recv_one: getnameinfo: " ) + gai_strerror( errcode ), 0 ); + } + fprintf( stderr, "Server now attached to client at %s:%s\n", + host, serv ); } } } @@ -528,16 +567,22 @@ string Connection::recv_one( int sock_to_recv, bool nonblocking ) std::string Connection::port( void ) const { - struct sockaddr_in local_addr; + Addr local_addr; socklen_t addrlen = sizeof( local_addr ); - if ( getsockname( sock(), (sockaddr *)&local_addr, &addrlen ) < 0 ) { + if ( getsockname( sock(), &local_addr.sa, &addrlen ) < 0 ) { throw NetworkException( "getsockname", errno ); } - char buf[ 32 ]; - snprintf( buf, sizeof( buf ), "%d", ntohs( local_addr.sin_port ) ); - return std::string( buf ); + char serv[ NI_MAXSERV ]; + int errcode = getnameinfo( &local_addr.sa, addrlen, + NULL, 0, serv, sizeof( serv ), + NI_DGRAM | NI_NUMERICSERV ); + if ( errcode != 0 ) { + throw NetworkException( std::string( "port: getnameinfo: " ) + gai_strerror( errcode ), 0 ); + } + + return std::string( serv ); } uint64_t Network::timestamp( void ) diff --git a/src/network/network.h b/src/network/network.h index 90749d5..c248ff7 100644 --- a/src/network/network.h +++ b/src/network/network.h @@ -84,6 +84,13 @@ namespace Network { string tostring( Session *session ); }; + union Addr { + struct sockaddr sa; + struct sockaddr_in sin; + struct sockaddr_in6 sin6; + struct sockaddr_storage ss; + }; + class Connection { private: static const int DEFAULT_SEND_MTU = 1300; @@ -101,7 +108,7 @@ namespace Network { static const int CONGESTION_TIMESTAMP_PENALTY = 500; /* ms */ - bool try_bind( uint32_t addr, int port_low, int port_high ); + bool try_bind( const char *addr, int port_low, int port_high ); class Socket { @@ -110,7 +117,7 @@ namespace Network { public: int fd( void ) const { return _fd; } - Socket(); + Socket( int family ); ~Socket(); Socket( const Socket & other ); @@ -119,7 +126,8 @@ namespace Network { std::deque< Socket > socks; bool has_remote_addr; - struct sockaddr_in remote_addr; + Addr remote_addr; + socklen_t remote_addr_len; bool server; @@ -175,7 +183,8 @@ namespace Network { uint64_t timeout( void ) const; double get_SRTT( void ) const { return SRTT; } - const struct in_addr & get_remote_ip( void ) const { return remote_addr.sin_addr; } + const Addr &get_remote_addr( void ) const { return remote_addr; } + socklen_t get_remote_addr_len( void ) const { return remote_addr_len; } const NetworkException *get_send_exception( void ) const { diff --git a/src/network/networktransport.h b/src/network/networktransport.h index 5cea434..6f3efd5 100644 --- a/src/network/networktransport.h +++ b/src/network/networktransport.h @@ -116,7 +116,8 @@ namespace Network { unsigned int send_interval( void ) const { return sender.send_interval(); } - const struct in_addr & get_remote_ip( void ) const { return connection.get_remote_ip(); } + const Addr &get_remote_addr( void ) const { return connection.get_remote_addr(); } + socklen_t get_remote_addr_len( void ) const { return connection.get_remote_addr_len(); } const NetworkException *get_send_exception( void ) const { return connection.get_send_exception(); } };