From 4a513ff496ec118700d3cb38ecc5b53c99e46c3b Mon Sep 17 00:00:00 2001 From: Keith Winstein Date: Fri, 26 Aug 2011 05:08:30 -0400 Subject: [PATCH] Graceful shutdown on signal kill --- networktransport.hpp | 1 + stm-server.cpp | 36 ++++++++++++++++++++++++++++++++++-- stm.cpp | 32 ++++++++++++++++++++++++++++++-- 3 files changed, 65 insertions(+), 4 deletions(-) diff --git a/networktransport.hpp b/networktransport.hpp index d8a76f9..4805b04 100644 --- a/networktransport.hpp +++ b/networktransport.hpp @@ -55,6 +55,7 @@ namespace Network { void start_shutdown( void ) { sender.start_shutdown(); } bool shutdown_in_progress( void ) { return sender.get_shutdown_in_progress(); } bool shutdown_acknowledged( void ) { return sender.get_shutdown_acknowledged(); } + bool attached( void ) { return connection.get_attached(); } /* Other side has requested shutdown and we have sent one ACK */ bool counterparty_shutdown_ack_sent( void ) { return sender.get_counterparty_shutdown_acknowledged(); } diff --git a/stm-server.cpp b/stm-server.cpp index a7c0bd3..9a0998f 100644 --- a/stm-server.cpp +++ b/stm-server.cpp @@ -11,6 +11,8 @@ #include #include #include +#include +#include #include "networktransport.hpp" #include "completeterminal.hpp" @@ -99,6 +101,24 @@ int main( void ) void serve( int host_fd ) { + /* establish fd for shutdown signals */ + sigset_t signal_mask; + + assert( sigemptyset( &signal_mask ) == 0 ); + assert( sigaddset( &signal_mask, SIGTERM ) == 0 ); + assert( sigaddset( &signal_mask, SIGINT ) == 0 ); + assert( sigaddset( &signal_mask, SIGHUP ) == 0 ); + assert( sigaddset( &signal_mask, SIGPIPE ) == 0 ); + + /* don't let signals kill us */ + assert( sigprocmask( SIG_BLOCK, &signal_mask, NULL ) == 0 ); + + int shutdown_signal_fd = signalfd( -1, &signal_mask, 0 ); + if ( shutdown_signal_fd < 0 ) { + perror( "signalfd" ); + return; + } + /* get initial window size */ struct winsize window_size; if ( ioctl( STDIN_FILENO, TIOCGWINSZ, &window_size ) < 0 ) { @@ -124,7 +144,7 @@ void serve( int host_fd ) printf( "key= %s port= %d\n", network.get_key().c_str(), network.port() ); /* prepare to poll for events */ - struct pollfd pollfds[ 2 ]; + struct pollfd pollfds[ 3 ]; pollfds[ 0 ].fd = network.fd(); pollfds[ 0 ].events = POLLIN; @@ -132,11 +152,14 @@ void serve( int host_fd ) pollfds[ 1 ].fd = host_fd; pollfds[ 1 ].events = POLLIN; + pollfds[ 2 ].fd = shutdown_signal_fd; + pollfds[ 2 ].events = POLLIN; + uint64_t last_remote_num = network.get_remote_state_num(); while ( 1 ) { try { - int active_fds = poll( pollfds, 2, network.wait_time() ); + int active_fds = poll( pollfds, 3, network.wait_time() ); if ( active_fds < 0 ) { perror( "poll" ); break; @@ -205,6 +228,15 @@ void serve( int host_fd ) break; } } + + if ( pollfds[ 2 ].revents & POLLIN ) { + /* shutdown signal */ + if ( network.attached() ) { + network.start_shutdown(); + } else { + break; + } + } if ( (pollfds[ 0 ].revents) & (POLLERR | POLLHUP | POLLNVAL) ) { diff --git a/stm.cpp b/stm.cpp index ce3296b..4bde25f 100644 --- a/stm.cpp +++ b/stm.cpp @@ -105,6 +105,22 @@ void client( const char *ip, int port, const char *key ) return; } + /* establish fd for shutdown signals */ + assert( sigemptyset( &signal_mask ) == 0 ); + assert( sigaddset( &signal_mask, SIGTERM ) == 0 ); + assert( sigaddset( &signal_mask, SIGINT ) == 0 ); + assert( sigaddset( &signal_mask, SIGHUP ) == 0 ); + assert( sigaddset( &signal_mask, SIGPIPE ) == 0 ); + + /* don't let signals kill us */ + assert( sigprocmask( SIG_BLOCK, &signal_mask, NULL ) == 0 ); + + int shutdown_signal_fd = signalfd( -1, &signal_mask, 0 ); + if ( shutdown_signal_fd < 0 ) { + perror( "signalfd" ); + return; + } + /* get initial window size */ struct winsize window_size; if ( ioctl( STDIN_FILENO, TIOCGWINSZ, &window_size ) < 0 ) { @@ -128,7 +144,7 @@ void client( const char *ip, int port, const char *key ) network.get_current_state().push_back( Parser::Resize( window_size.ws_col, window_size.ws_row ) ); /* prepare to poll for events */ - struct pollfd pollfds[ 3 ]; + struct pollfd pollfds[ 4 ]; pollfds[ 0 ].fd = network.fd(); pollfds[ 0 ].events = POLLIN; @@ -139,11 +155,14 @@ void client( const char *ip, int port, const char *key ) pollfds[ 2 ].fd = winch_fd; pollfds[ 2 ].events = POLLIN; + pollfds[ 3 ].fd = shutdown_signal_fd; + pollfds[ 3 ].events = POLLIN; + uint64_t last_remote_num = network.get_remote_state_num(); while ( 1 ) { try { - int active_fds = poll( pollfds, 3, network.wait_time() ); + int active_fds = poll( pollfds, 4, network.wait_time() ); if ( active_fds < 0 ) { perror( "poll" ); break; @@ -208,6 +227,15 @@ void client( const char *ip, int port, const char *key ) } } + if ( pollfds[ 3 ].revents & POLLIN ) { + /* shutdown signal */ + if ( network.attached() ) { + network.start_shutdown(); + } else { + break; + } + } + if ( (pollfds[ 0 ].revents) & (POLLERR | POLLHUP | POLLNVAL) ) { /* network problem */