Fixed deadlock on unexpected connection loss

This commit is contained in:
loki
2020-03-14 14:38:09 +01:00
parent 1362abc70d
commit b4f1ef1127
3 changed files with 82 additions and 100 deletions
+55 -45
View File
@@ -110,46 +110,10 @@ public:
_map_addr_session->emplace(addr, std::make_pair(0u, &session)); _map_addr_session->emplace(addr, std::make_pair(0u, &session));
} }
void erase_session(session_t &session) {
auto lg = _map_addr_session.lock();
auto pos = std::find_if(std::begin(_map_addr_session.raw), std::end(_map_addr_session.raw), [session_p=&session](auto &current_port_and_session) {
return session_p == current_port_and_session.second.second;
});
_map_addr_session->erase(pos);
}
// Get session associated with address. // Get session associated with address.
// If none are found, try to find a session not yet claimed. (It will be marked by a port of value 0 // If none are found, try to find a session not yet claimed. (It will be marked by a port of value 0
// If none of those are found, return nullptr // If none of those are found, return nullptr
session_t *get_session(const ENetAddress &address) { session_t *get_session(const net::peer_t peer);
TUPLE_2D(port, addr_string, platf::from_sockaddr_ex((sockaddr*)&address.address));
auto lg = _map_addr_session.lock();
TUPLE_2D(begin, end, _map_addr_session->equal_range(addr_string));
auto it = std::end(_map_addr_session.raw);
for(auto pos = begin; pos != end; ++pos) {
TUPLE_2D_REF(session_port, session_p, pos->second);
if(port == session_port) {
return session_p;
}
else if(session_port == 0) {
it = pos;
}
}
if(it != std::end(_map_addr_session.raw)) {
TUPLE_2D_REF(session_port, session_p, it->second);
session_port = port;
return session_p;
}
return nullptr;
}
// Circular dependency: // Circular dependency:
// iterate refers to session // iterate refers to session
@@ -158,11 +122,6 @@ public:
// Therefore, iterate is implemented further down the source file // Therefore, iterate is implemented further down the source file
void iterate(std::chrono::milliseconds timeout); void iterate(std::chrono::milliseconds timeout);
template<class T, class X>
void iterate(std::chrono::duration<T, X> timeout) {
iterate(std::chrono::floor<std::chrono::milliseconds>(timeout));
}
void map(uint16_t type, std::function<void(session_t *, const std::string_view&)> cb) { void map(uint16_t type, std::function<void(session_t *, const std::string_view&)> cb) {
_map_type_cb.emplace(type, std::move(cb)); _map_type_cb.emplace(type, std::move(cb));
} }
@@ -227,10 +186,16 @@ struct session_t {
udp::endpoint peer; udp::endpoint peer;
} audio; } audio;
struct {
net::peer_t peer;
} control;
crypto::aes_t gcm_key; crypto::aes_t gcm_key;
crypto::aes_t iv; crypto::aes_t iv;
safe::signal_t shutdown_event; safe::signal_t shutdown_event;
safe::signal_t controlEnd;
std::atomic<session::state_e> state; std::atomic<session::state_e> state;
}; };
@@ -242,12 +207,42 @@ std::shared_ptr<input::input_t> input;
static auto broadcast = safe::make_shared<broadcast_ctx_t>(start_broadcast, end_broadcast); static auto broadcast = safe::make_shared<broadcast_ctx_t>(start_broadcast, end_broadcast);
safe::signal_t broadcast_shutdown_event; safe::signal_t broadcast_shutdown_event;
session_t *control_server_t::get_session(const net::peer_t peer) {
TUPLE_2D(port, addr_string, platf::from_sockaddr_ex((sockaddr*)&peer->address.address));
auto lg = _map_addr_session.lock();
TUPLE_2D(begin, end, _map_addr_session->equal_range(addr_string));
auto it = std::end(_map_addr_session.raw);
for(auto pos = begin; pos != end; ++pos) {
TUPLE_2D_REF(session_port, session_p, pos->second);
if(port == session_port) {
return session_p;
}
else if(session_port == 0) {
it = pos;
}
}
if(it != std::end(_map_addr_session.raw)) {
TUPLE_2D_REF(session_port, session_p, it->second);
session_p->control.peer = peer;
session_port = port;
return session_p;
}
return nullptr;
}
void control_server_t::iterate(std::chrono::milliseconds timeout) { void control_server_t::iterate(std::chrono::milliseconds timeout) {
ENetEvent event; ENetEvent event;
auto res = enet_host_service(_host.get(), &event, timeout.count()); auto res = enet_host_service(_host.get(), &event, timeout.count());
if(res > 0) { if(res > 0) {
auto session = get_session(event.peer->address); auto session = get_session(event.peer);
if(!session) { if(!session) {
BOOST_LOG(warning) << "Rejected connection from ["sv << platf::from_sockaddr((sockaddr*)&event.peer->address.address) << "]: it's not properly set up"sv; BOOST_LOG(warning) << "Rejected connection from ["sv << platf::from_sockaddr((sockaddr*)&event.peer->address.address) << "]: it's not properly set up"sv;
enet_peer_disconnect_now(event.peer, 0); enet_peer_disconnect_now(event.peer, 0);
@@ -466,13 +461,26 @@ void controlBroadcastThread(safe::signal_t *shutdown_event, control_server_t *se
auto lg = server->_map_addr_session.lock(); auto lg = server->_map_addr_session.lock();
auto now = std::chrono::steady_clock::now(); auto now = std::chrono::steady_clock::now();
for(auto &[addr,port_session] : server->_map_addr_session.raw) {
KITTY_WHILE_LOOP(auto pos = std::begin(*server->_map_addr_session), pos != std::end(*server->_map_addr_session), {
TUPLE_2D_REF(addr, port_session, *pos);
auto session = port_session.second; auto session = port_session.second;
if(now > session->pingTimeout) { if(now > session->pingTimeout) {
BOOST_LOG(info) << addr << ": Ping Timeout"sv; BOOST_LOG(info) << addr << ": Ping Timeout"sv;
session::stop(*session); session::stop(*session);
} }
if(session->state.load(std::memory_order_acquire) == session::state_e::STOPPING) {
pos = server->_map_addr_session->erase(pos);
enet_peer_disconnect_now(session->control.peer, 0);
session->controlEnd.raise(true);
continue;
} }
++pos;
})
} }
if(proc::proc.running() == -1) { if(proc::proc.running() == -1) {
@@ -833,7 +841,6 @@ void stop(session_t &session) {
return; return;
} }
session.broadcast_ref->control_server.erase_session(session);
session.shutdown_event.raise(true); session.shutdown_event.raise(true);
} }
@@ -842,6 +849,8 @@ void join(session_t &session) {
session.videoThread.join(); session.videoThread.join();
BOOST_LOG(debug) << "Waiting for audio to end..."sv; BOOST_LOG(debug) << "Waiting for audio to end..."sv;
session.audioThread.join(); session.audioThread.join();
BOOST_LOG(debug) << "Waiting for control to end..."sv;
session.controlEnd.view();
BOOST_LOG(debug) << "Session ended"sv; BOOST_LOG(debug) << "Session ended"sv;
} }
@@ -869,6 +878,7 @@ std::shared_ptr<session_t> alloc(config_t &config, crypto::aes_t &gcm_key, crypt
session->audio.frame = 1; session->audio.frame = 1;
session->control.peer = nullptr;
session->state.store(state_e::STOPPED, std::memory_order_relaxed); session->state.store(state_e::STOPPED, std::memory_order_relaxed);
return session; return session;
+26 -53
View File
@@ -11,99 +11,72 @@
namespace util { namespace util {
template<class T, std::size_t N = 1> template<class T, class M = std::mutex>
class sync_t { class sync_t {
public: public:
static_assert(N > 0, "sync_t should have more than zero mutexes"); using value_t = T;
using value_type = T; using mutex_t = M;
template<std::size_t I = 0> std::lock_guard<mutex_t> lock() {
std::lock_guard<std::mutex> lock() { return std::lock_guard { _lock };
return std::lock_guard { std::get<I>(_lock) };
} }
template<class ...Args> template<class ...Args>
sync_t(Args&&... args) : raw {std::forward<Args>(args)... } {} sync_t(Args&&... args) : raw {std::forward<Args>(args)... } {}
sync_t &operator=(sync_t &&other) noexcept { sync_t &operator=(sync_t &&other) noexcept {
for(auto &l : _lock) { std::lock(_lock, other._lock);
l.lock();
}
for(auto &l : other._lock) {
l.lock();
}
raw = std::move(other.raw); raw = std::move(other.raw);
for(auto &l : _lock) { _lock.unlock();
l.unlock(); other._lock.unlock();
}
for(auto &l : other._lock) {
l.unlock();
}
return *this; return *this;
} }
sync_t &operator=(sync_t &other) noexcept { sync_t &operator=(sync_t &other) noexcept {
for(auto &l : _lock) { std::lock(_lock, other._lock);
l.lock();
}
for(auto &l : other._lock) {
l.lock();
}
raw = other.raw; raw = other.raw;
for(auto &l : _lock) { _lock.unlock();
l.unlock(); other._lock.unlock();
}
for(auto &l : other._lock) {
l.unlock();
}
return *this; return *this;
} }
sync_t &operator=(const value_type &val) noexcept { sync_t &operator=(const value_t &val) noexcept {
for(auto &l : _lock) { auto lg = lock();
l.lock();
}
raw = val; raw = val;
for(auto &l : _lock) {
l.unlock();
}
return *this; return *this;
} }
sync_t &operator=(value_type &&val) noexcept { sync_t &operator=(value_t &&val) noexcept {
for(auto &l : _lock) { auto lg = lock();
l.lock();
}
raw = std::move(val); raw = std::move(val);
for(auto &l : _lock) {
l.unlock();
}
return *this; return *this;
} }
value_type *operator->() { value_t *operator->() {
return &raw; return &raw;
} }
value_type raw; value_t &operator*() {
return raw;
}
const value_t &operator*() const {
return raw;
}
value_t raw;
private: private:
std::array<std::mutex, N> _lock; mutex_t _lock;
}; };
} }
-1
View File
@@ -101,7 +101,6 @@ void captureThread(std::shared_ptr<safe::queue_t<capture_ctx_t>> capture_ctx_que
for(auto &capture_ctx : capture_ctx_queue->unsafe()) { for(auto &capture_ctx : capture_ctx_queue->unsafe()) {
capture_ctx.images->stop(); capture_ctx.images->stop();
} }
}); });
auto disp = platf::display(); auto disp = platf::display();