diff --git a/common/netutil/netutil.go b/common/netutil/netutil.go index 9cc274f..eb0a087 100644 --- a/common/netutil/netutil.go +++ b/common/netutil/netutil.go @@ -20,20 +20,30 @@ import ( // ConnectServer connects to the server with the given address. func ConnectServer(config config.Config) net.Conn { scheme := "ws" + host := config.ServerAddr if config.Protocol == "wss" { scheme = "wss" + host = config.TLSSni } - u := url.URL{Scheme: scheme, Host: config.ServerAddr, Path: config.WebSocketPath} + u := url.URL{Scheme: scheme, Host: host, Path: config.WebSocketPath} header := make(http.Header) header.Set("user-agent", "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/88.0.4324.182 Safari/537.36") - header.Set("key", config.Key) - tlsconfig := &tls.Config{ + if config.Key != "" { + header.Set("key", config.Key) + } + tlsConfig := &tls.Config{ InsecureSkipVerify: config.TLSInsecureSkipVerify, } + if config.TLSSni != "" { + tlsConfig.ServerName = config.TLSSni + } dialer := ws.Dialer{ Header: ws.HandshakeHeaderHTTP(header), Timeout: time.Duration(config.Timeout) * time.Second, - TLSConfig: tlsconfig, + TLSConfig: tlsConfig, + NetDial: func(ctx context.Context, network, addr string) (net.Conn, error) { + return net.Dial(network, config.ServerAddr) + }, } c, _, _, err := dialer.Dial(context.Background(), u.String()) if err != nil { diff --git a/ws/wsserver.go b/ws/wsserver.go index 77d5a9f..5ddcfea 100644 --- a/ws/wsserver.go +++ b/ws/wsserver.go @@ -133,6 +133,9 @@ func StartServer(iface *water.Interface, config config.Config) { // checkPermission checks the permission of the request func checkPermission(w http.ResponseWriter, req *http.Request, config config.Config) bool { + if config.Key == "" { + return true + } key := req.Header.Get("key") if key != config.Key { w.WriteHeader(http.StatusForbidden)