diff --git a/src/textual_webterm/local_server.py b/src/textual_webterm/local_server.py index 4c2048b..110c86d 100644 --- a/src/textual_webterm/local_server.py +++ b/src/textual_webterm/local_server.py @@ -579,39 +579,38 @@ class LocalServer: def _get_ws_url_from_request(self, request: web.Request, route_key: str) -> str: """Build WebSocket URL honoring reverse proxies and port mapping.""" + # Extract forwarded headers (take first value if comma-separated) + def first_header(name: str) -> str: + return request.headers.get(name, "").split(",")[0].strip().lower() - forwarded_proto = request.headers.get("X-Forwarded-Proto", "").split(",")[0].strip().lower() - forwarded_host = request.headers.get("X-Forwarded-Host", "").split(",")[0].strip() - forwarded_port = request.headers.get("X-Forwarded-Port", "").split(",")[0].strip() + forwarded_proto = first_header("X-Forwarded-Proto") + forwarded_host = first_header("X-Forwarded-Host") + forwarded_port = first_header("X-Forwarded-Port") - def _pick_proto() -> str: - if forwarded_proto in ("https", "wss"): - return "wss" - if forwarded_proto in ("http", "ws"): - return "ws" - return "wss" if request.secure else "ws" + # Determine WebSocket protocol + if forwarded_proto in ("https", "wss"): + ws_proto = "wss" + elif forwarded_proto in ("http", "ws"): + ws_proto = "ws" + else: + ws_proto = "wss" if request.secure else "ws" - def _split_host_port(host: str) -> tuple[str, str]: - if not host: - return "", "" - if ":" in host: - return host.rsplit(":", 1) - return host, "" - - ws_proto = _pick_proto() - ws_host, ws_port = _split_host_port(forwarded_host) - - if not ws_host: - host_header = request.headers.get("Host", "") - ws_host, ws_port = _split_host_port(host_header) + # Determine host and port (priority: forwarded > Host header > server config) + ws_host, ws_port = "", "" + for candidate in (forwarded_host, request.headers.get("Host", "")): + if candidate: + ws_host, _, ws_port = candidate.rpartition(":") + if not ws_host: # No colon found, entire string is host + ws_host, ws_port = candidate, "" + break if not ws_host: ws_host = "localhost" if self.host == "0.0.0.0" else self.host ws_port = str(self.port) - if not ws_port and forwarded_port: - ws_port = forwarded_port + ws_port = ws_port or forwarded_port + # Include port in URL only for non-standard ports if ws_port and ws_port not in ("80", "443"): return f"{ws_proto}://{ws_host}:{ws_port}/ws/{route_key}" if not ws_port and self.port not in (80, 443):