diff --git a/src/crypto.cpp b/src/crypto.cpp index b16ab08a..7f4417da 100644 --- a/src/crypto.cpp +++ b/src/crypto.cpp @@ -11,11 +11,11 @@ namespace crypto { cert_chain_t::cert_chain_t(): _certs {}, _cert_ctx { X509_STORE_CTX_new() } {} void - cert_chain_t::add(x509_t &&cert) { + cert_chain_t::add(p_named_cert_t& named_cert_p) { x509_store_t x509_store { X509_STORE_new() }; - X509_STORE_add_cert(x509_store.get(), cert.get()); - _certs.emplace_back(std::make_pair(std::move(cert), std::move(x509_store))); + X509_STORE_add_cert(x509_store.get(), x509(named_cert_p->cert).get()); + _certs.emplace_back(std::make_pair(named_cert_p, std::move(x509_store))); } void cert_chain_t::clear() { @@ -52,9 +52,9 @@ namespace crypto { * @return nullptr if the certificate is valid, otherwise an error string. */ const char * - cert_chain_t::verify(x509_t::element_type *cert) { + cert_chain_t::verify(x509_t::element_type *cert, p_named_cert_t& named_cert_out) { int err_code = 0; - for (auto &[_, x509_store] : _certs) { + for (auto &[named_cert_p, x509_store] : _certs) { auto fg = util::fail_guard([this]() { X509_STORE_CTX_cleanup(_cert_ctx.get()); }); @@ -70,6 +70,7 @@ namespace crypto { auto err = X509_verify_cert(_cert_ctx.get()); if (err == 1) { + named_cert_out = named_cert_p; return nullptr; } diff --git a/src/crypto.h b/src/crypto.h index 859c6675..45c876f6 100644 --- a/src/crypto.h +++ b/src/crypto.h @@ -34,6 +34,14 @@ namespace crypto { using pkey_ctx_t = util::safe_ptr; using bignum_t = util::safe_ptr; + struct named_cert_t { + std::string name; + std::string uuid; + std::string cert; + }; + + using p_named_cert_t = std::shared_ptr; + /** * @brief Hashes the given plaintext using SHA-256. * @param plaintext @@ -76,16 +84,16 @@ namespace crypto { KITTY_DECL_CONSTR(cert_chain_t) void - add(x509_t &&cert); + add(p_named_cert_t& named_cert_p); void clear(); const char * - verify(x509_t::element_type *cert); + verify(x509_t::element_type *cert, p_named_cert_t& named_cert_out); private: - std::vector> _certs; + std::vector> _certs; x509_store_ctx_t _cert_ctx; }; diff --git a/src/nvhttp.cpp b/src/nvhttp.cpp index d27d66af..3e5d29e4 100644 --- a/src/nvhttp.cpp +++ b/src/nvhttp.cpp @@ -46,6 +46,14 @@ namespace nvhttp { namespace fs = std::filesystem; namespace pt = boost::property_tree; + using p_named_cert_t = crypto::p_named_cert_t; + + struct client_t { + std::vector named_devices; + }; + + struct pair_session_t; + crypto::cert_chain_t cert_chain; static std::string one_time_pin; static std::string otp_passphrase; @@ -66,6 +74,15 @@ namespace nvhttp { class SunshineHTTPSServer: public SimpleWeb::ServerBase { public: + class ApolloSession: public Session { + public: + bool verified = false; + crypto::named_cert_t* named_cert = nullptr; + void* userp = nullptr; + + template + ApolloSession(Args&&... args): Session(std::forward(args)...) {} + }; SunshineHTTPSServer(const std::string &certification_file, const std::string &private_key_file): ServerBase::ServerBase(443), context(boost::asio::ssl::context::tls_server) { @@ -76,7 +93,7 @@ namespace nvhttp { context.use_private_key_file(private_key_file, boost::asio::ssl::context::pem); } - std::function verify; + std::function verify; std::function, std::shared_ptr)> on_verify_failed; protected: @@ -106,7 +123,7 @@ namespace nvhttp { if (ec != SimpleWeb::error::operation_aborted) this->accept(); - auto session = std::make_shared(config.max_request_streambuf_size, connection); + auto session = std::make_shared(config.max_request_streambuf_size, connection); if (!ec) { boost::asio::ip::tcp::no_delay option(true); @@ -120,7 +137,7 @@ namespace nvhttp { if (!lock) return; if (!ec) { - if (verify && !verify(session->connection->socket->native_handle())) + if (verify && !verify(session.get(), session->connection->socket->native_handle())) this->write(session, on_verify_failed); else this->read(session); @@ -137,22 +154,13 @@ namespace nvhttp { using https_server_t = SunshineHTTPSServer; using http_server_t = SimpleWeb::Server; + using https_session_t = SunshineHTTPSServer::ApolloSession; struct conf_intern_t { std::string servercert; std::string pkey; } conf_intern; - struct named_cert_t { - std::string name; - std::string uuid; - std::string cert; - }; - - struct client_t { - std::vector named_devices; - }; - struct pair_session_t { struct { std::string uniqueID; @@ -225,12 +233,27 @@ namespace nvhttp { pt::ptree node; pt::ptree named_cert_nodes; - for (auto &named_cert : client.named_devices) { - pt::ptree named_cert_node; - named_cert_node.put("name"s, named_cert.name); - named_cert_node.put("cert"s, named_cert.cert); - named_cert_node.put("uuid"s, named_cert.uuid); - named_cert_nodes.push_back(std::make_pair(""s, named_cert_node)); + std::unordered_set unique_certs; + std::unordered_map name_counts; + for (auto &named_cert_p : client.named_devices) { + if (unique_certs.insert(named_cert_p->cert).second) { + pt::ptree named_cert_node; + std::string base_name = named_cert_p->name; + // Remove existing pending id if present + size_t pos = base_name.find(" ("); + if (pos != std::string::npos) { + base_name = base_name.substr(0, pos); + } + int count = name_counts[base_name]++; + std::string final_name = base_name; + if (count > 0) { + final_name += " (" + std::to_string(count + 1) + ")"; + } + named_cert_node.put("name"s, final_name); + named_cert_node.put("cert"s, named_cert_p->cert); + named_cert_node.put("uuid"s, named_cert_p->uuid); + named_cert_nodes.push_back(std::make_pair(""s, named_cert_node)); + } } root.add_child("root.named_devices"s, named_cert_nodes); @@ -282,11 +305,11 @@ namespace nvhttp { if (device_node.count("certs")) { for (auto &[_, el] : device_node.get_child("certs")) { - named_cert_t named_cert; - named_cert.name = ""s; - named_cert.cert = el.get_value(); - named_cert.uuid = uuid_util::uuid_t::generate().string(); - client.named_devices.emplace_back(named_cert); + auto named_cert_p = std::make_shared(); + named_cert_p->name = ""s; + named_cert_p->cert = el.get_value(); + named_cert_p->uuid = uuid_util::uuid_t::generate().string(); + client.named_devices.emplace_back(named_cert_p); } } } @@ -294,34 +317,31 @@ namespace nvhttp { if (root.count("named_devices")) { for (auto &[_, el] : root.get_child("named_devices")) { - named_cert_t named_cert; - named_cert.name = el.get_child("name").get_value(); - named_cert.cert = el.get_child("cert").get_value(); - named_cert.uuid = el.get_child("uuid").get_value(); - client.named_devices.emplace_back(named_cert); + auto named_cert_p = std::make_shared(); + named_cert_p->name = el.get_child("name").get_value(); + named_cert_p->cert = el.get_child("cert").get_value(); + named_cert_p->uuid = el.get_child("uuid").get_value(); + client.named_devices.emplace_back(named_cert_p); } } // Empty certificate chain and import certs from file cert_chain.clear(); for (auto &named_cert : client.named_devices) { - cert_chain.add(crypto::x509(named_cert.cert)); + cert_chain.add(named_cert); } client_root = client; } void - add_authorized_client(const std::string &name, std::string &&cert) { + add_authorized_client(const p_named_cert_t& named_cert_p) { client_t &client = client_root; - named_cert_t named_cert; - named_cert.name = name; - named_cert.cert = std::move(cert); - named_cert.uuid = uuid_util::uuid_t::generate().string(); - client.named_devices.emplace_back(named_cert); + client.named_devices.emplace_back(named_cert_p); if (!config::sunshine.flags[config::flag::FRESH_STATE]) { save_state(); + load_state(); } } @@ -458,7 +478,7 @@ namespace nvhttp { } void - clientpairingsecret(std::shared_ptr> &add_cert, pair_session_t &sess, pt::ptree &tree, const args_t &args) { + clientpairingsecret(pair_session_t &sess, pt::ptree &tree, const args_t &args) { auto &client = sess.client; auto pairingsecret = util::from_hex_vec(get_arg(args, "clientpairingsecret"), true); @@ -487,11 +507,20 @@ namespace nvhttp { // if hash not correct, probably MITM if (!std::memcmp(hash.data(), sess.clienthash.data(), hash.size()) && crypto::verify256(crypto::x509(client.cert), secret, sign)) { tree.put("root.paired", 1); - add_cert->raise(crypto::x509(client.cert)); + + auto named_cert_p = std::make_shared(); + named_cert_p->name = client.name; + for (char& c : named_cert_p->name) { + if (c == '(') c = '['; + else if (c == ')') c = ']'; + } + named_cert_p->cert = std::move(client.cert); + named_cert_p->uuid = uuid_util::uuid_t::generate().string(); auto it = map_id_sess.find(client.uniqueID); - add_authorized_client(client.name, std::move(client.cert)); map_id_sess.erase(it); + + add_authorized_client(named_cert_p); } else { map_id_sess.erase(client.uniqueID); @@ -557,7 +586,7 @@ namespace nvhttp { template void - pair(std::shared_ptr> &add_cert, std::shared_ptr::Response> response, std::shared_ptr::Request> request) { + pair(std::shared_ptr::Response> response, std::shared_ptr::Request> request) { print_req(request); pt::ptree tree; @@ -673,7 +702,7 @@ namespace nvhttp { serverchallengeresp(sess_it->second, tree, args); } else if (it = args.find("clientpairingsecret"); it != std::end(args)) { - clientpairingsecret(add_cert, sess_it->second, tree, args); + clientpairingsecret(sess_it->second, tree, args); } else { tree.put("root..status_code", 404); @@ -849,10 +878,10 @@ namespace nvhttp { get_all_clients() { pt::ptree named_cert_nodes; client_t &client = client_root; - for (auto &named_cert : client.named_devices) { + for (auto &named_cert_p : client.named_devices) { pt::ptree named_cert_node; - named_cert_node.put("name"s, named_cert.name); - named_cert_node.put("uuid"s, named_cert.uuid); + named_cert_node.put("name"s, named_cert_p->name); + named_cert_node.put("uuid"s, named_cert_p->uuid); named_cert_nodes.push_back(std::make_pair(""s, named_cert_node)); } @@ -1133,8 +1162,6 @@ namespace nvhttp { conf_intern.pkey = file_handler::read_file(config::nvhttp.pkey.c_str()); conf_intern.servercert = file_handler::read_file(config::nvhttp.cert.c_str()); - auto add_cert = std::make_shared>(30); - // resume doesn't always get the parameter "localAudioPlayMode" // launch will store it in host_audio bool host_audio {}; @@ -1143,43 +1170,34 @@ namespace nvhttp { http_server_t http_server; // Verify certificates after establishing connection - https_server.verify = [add_cert](SSL *ssl) { + https_server.verify = [](https_session_t* session, SSL *ssl) { crypto::x509_t x509 { SSL_get_peer_certificate(ssl) }; if (!x509) { BOOST_LOG(info) << "unknown -- denied"sv; - return 0; + return false; } - int verified = 0; - auto fg = util::fail_guard([&]() { char subject_name[256]; X509_NAME_oneline(X509_get_subject_name(x509.get()), subject_name, sizeof(subject_name)); - BOOST_LOG(debug) << subject_name << " -- "sv << (verified ? "verified"sv : "denied"sv); + BOOST_LOG(debug) << subject_name << " -- "sv << (session->verified ? "verified"sv : "denied"sv); }); - while (add_cert->peek()) { - char subject_name[256]; - - auto cert = add_cert->pop(); - X509_NAME_oneline(X509_get_subject_name(cert.get()), subject_name, sizeof(subject_name)); - - BOOST_LOG(debug) << "Added cert ["sv << subject_name << ']'; - cert_chain.add(std::move(cert)); - } - - auto err_str = cert_chain.verify(x509.get()); + p_named_cert_t named_cert_p; + auto err_str = cert_chain.verify(x509.get(), named_cert_p); if (err_str) { BOOST_LOG(warning) << "SSL Verification error :: "sv << err_str; - return verified; + return session->verified; } - verified = 1; + session->verified = true; - return verified; + BOOST_LOG(info) << "Device " << named_cert_p->name << " verified!"; + + return session->verified; }; https_server.on_verify_failed = [](resp_https_t resp, req_https_t req) { @@ -1199,7 +1217,7 @@ namespace nvhttp { https_server.default_resource["GET"] = not_found; https_server.resource["^/serverinfo$"]["GET"] = serverinfo; - https_server.resource["^/pair$"]["GET"] = [&add_cert](auto resp, auto req) { pair(add_cert, resp, req); }; + https_server.resource["^/pair$"]["GET"] = pair; https_server.resource["^/applist$"]["GET"] = applist; https_server.resource["^/appasset$"]["GET"] = appasset; https_server.resource["^/launch$"]["GET"] = [&host_audio](auto resp, auto req) { launch(host_audio, resp, req); }; @@ -1212,7 +1230,7 @@ namespace nvhttp { http_server.default_resource["GET"] = not_found; http_server.resource["^/serverinfo$"]["GET"] = serverinfo; - http_server.resource["^/pair$"]["GET"] = [&add_cert](auto resp, auto req) { pair(add_cert, resp, req); }; + http_server.resource["^/pair$"]["GET"] = pair; http_server.config.reuse_address = true; http_server.config.address = net::af_to_any_address_string(address_family); @@ -1265,6 +1283,7 @@ namespace nvhttp { client_root = client; cert_chain.clear(); save_state(); + load_state(); } int @@ -1272,7 +1291,7 @@ namespace nvhttp { int removed = 0; client_t &client = client_root; for (auto it = client.named_devices.begin(); it != client.named_devices.end();) { - if ((*it).uuid == uuid) { + if ((*it)->uuid == uuid) { it = client.named_devices.erase(it); removed++; }