diff --git a/common/xproto/xproto.go b/common/xproto/xproto.go new file mode 100644 index 0000000..2d3873b --- /dev/null +++ b/common/xproto/xproto.go @@ -0,0 +1,114 @@ +package xproto + +import ( + "bytes" + "crypto/md5" + "errors" + "fmt" +) + +const ProtocolVersion = 1 + +type ClientSendPacketHeader struct { + ProtocolVersion uint8 //1 byte + Key *AuthKey //16 byte + Length int //2 byte, convert to [2]byte +} + +func (p *ClientSendPacketHeader) Bytes() []byte { + data := make([]byte, ClientSendPacketHeaderLength) + data[0] = p.ProtocolVersion + copy(data[1:17], p.Key[:]) + data[17] = byte(p.Length >> 8 & 0xff) + data[18] = byte(p.Length & 0xff) + return data +} + +func ParseClientSendPacketHeader(data []byte) *ClientSendPacketHeader { + var obj = &ClientSendPacketHeader{} + var authKey AuthKey + if len(data) != ClientSendPacketHeaderLength { + return nil + } + obj.ProtocolVersion = data[0] + copy(authKey[:], data[1:17]) + obj.Key = &authKey + obj.Length = ((obj.Length & 0x00) | int(data[17])) << 8 + obj.Length = obj.Length | int(data[18]) + return obj +} + +type ServerSendPacketHeader struct { + ProtocolVersion uint8 //1 byte + Length int //2 byte, convert to [2]byte +} + +func (p *ServerSendPacketHeader) Bytes() []byte { + data := make([]byte, ServerSendPacketHeaderLength) + data[0] = p.ProtocolVersion + data[1] = byte(p.Length >> 8 & 0xff) + data[2] = byte(p.Length & 0xff) + return data +} + +func ParseServerSendPacketHeader(data []byte) *ServerSendPacketHeader { + var obj = &ServerSendPacketHeader{} + if len(data) != ServerSendPacketHeaderLength { + return nil + } + obj.ProtocolVersion = data[0] + obj.Length = ((obj.Length & 0x00) | int(data[1])) << 8 + obj.Length = obj.Length | int(data[2]) + return obj +} + +const ClientSendPacketHeaderLength = 19 +const ServerSendPacketHeaderLength = 3 + +// ConvertLength []byte length to int length +func ConvertLength(header []byte) int { + length := 0 + if len(header) >= 2 { + length = ((length & 0x00) | int(header[0])) << 8 + length = length | int(header[1]) + } + return length +} + +type AuthKey [16]byte + +// Bytes returns the bytes representation of this AuthKey. +func (u *AuthKey) Bytes() []byte { + return u[:] +} + +// Equals returns true if this AuthKey equals another AuthKey by value. +func (u *AuthKey) Equals(another *AuthKey) bool { + if u == nil && another == nil { + return true + } + if u == nil || another == nil { + return false + } + return bytes.Equal(u.Bytes(), another.Bytes()) +} + +// ParseBytes converts a AuthKey in byte form to object. +func ParseBytes(b []byte) (AuthKey, error) { + var authKey AuthKey + if len(b) != 16 { + return authKey, errors.New(fmt.Sprintf("invalid AuthKey: %v", b)) + } + copy(authKey[:], b) + return authKey, nil +} + +// ParseAuthKeyFromString converts a AuthKey in string form to object. +func ParseAuthKeyFromString(str string) *AuthKey { + var authKey AuthKey + m := md5.New() + m.Write([]byte(str)) + r := m.Sum(nil) + copy(authKey[:], r[:16]) + return &authKey +} diff --git a/tls/tlsclient.go b/tls/tlsclient.go index 64e1cda..fc75154 100644 --- a/tls/tlsclient.go +++ b/tls/tlsclient.go @@ -2,6 +2,9 @@ package tls import ( "crypto/tls" + "errors" + "fmt" + "github.com/net-byte/vtun/common/xproto" "log" "net" "time" @@ -16,9 +19,9 @@ import ( ) // StartClient starts the tls client -func StartClient(iface *water.Interface, config config.Config) { +func StartClient(iFace *water.Interface, config config.Config) { log.Println("vtun tls client started") - go tunToTLS(config, iface) + go tunToTLS(config, iFace) tlsConfig := &tls.Config{ InsecureSkipVerify: config.TLSInsecureSkipVerify, MinVersion: tls.VersionTLS13, @@ -43,31 +46,42 @@ func StartClient(iface *water.Interface, config config.Config) { netutil.PrintErr(err, config.Verbose) continue } - cache.GetCache().Set("tlsconn", conn, 24*time.Hour) - tlsToTun(config, conn, iface) - cache.GetCache().Delete("tlsconn") + cache.GetCache().Set("tlsConn", conn, 24*time.Hour) + tlsToTun(config, conn, iFace) + cache.GetCache().Delete("tlsConn") } } // tunToTLS sends packets from tun to tls -func tunToTLS(config config.Config, iface *water.Interface) { - packet := make([]byte, config.BufferSize) +func tunToTLS(config config.Config, iFace *water.Interface) { + authKey := xproto.ParseAuthKeyFromString(config.Key) + buffer := make([]byte, config.BufferSize) for { - n, err := iface.Read(packet) + n, err := iFace.Read(buffer) if err != nil { netutil.PrintErr(err, config.Verbose) - break + continue } - if v, ok := cache.GetCache().Get("tlsconn"); ok { - b := packet[:n] + b := buffer[:n] + if v, ok := cache.GetCache().Get("tlsConn"); ok { if config.Obfs { b = cipher.XOR(b) } if config.Compress { b = snappy.Encode(nil, b) } - tlsconn := v.(net.Conn) - _, err = tlsconn.Write(b) + tlsConn := v.(net.Conn) + ph := &xproto.ClientSendPacketHeader{ + ProtocolVersion: xproto.ProtocolVersion, + Key: authKey, + Length: len(b), + } + _, err := tlsConn.Write(ph.Bytes()) + if err != nil { + netutil.PrintErr(err, config.Verbose) + continue + } + n, err := tlsConn.Write(b[:]) if err != nil { netutil.PrintErr(err, config.Verbose) continue @@ -78,13 +92,37 @@ func tunToTLS(config config.Config, iface *water.Interface) { } // tlsToTun sends packets from tls to tun -func tlsToTun(config config.Config, tlsconn net.Conn, iface *water.Interface) { - defer tlsconn.Close() - packet := make([]byte, config.BufferSize) - for { - n, err := tlsconn.Read(packet) +func tlsToTun(config config.Config, tlsConn net.Conn, iFace *water.Interface) { + defer func(tlsConn net.Conn) { + err := tlsConn.Close() if err != nil { netutil.PrintErr(err, config.Verbose) + } + }(tlsConn) + header := make([]byte, xproto.ServerSendPacketHeaderLength) + packet := make([]byte, config.BufferSize) + for { + n, err := tlsConn.Read(header) + if err != nil { + netutil.PrintErr(err, config.Verbose) + break + } + if n != xproto.ServerSendPacketHeaderLength { + netutil.PrintErr(errors.New(fmt.Sprintf("received length <%d> not equals <%d>!", n, xproto.ServerSendPacketHeaderLength)), config.Verbose) + break + } + ph := xproto.ParseServerSendPacketHeader(header[:n]) + if ph == nil { + netutil.PrintErr(errors.New("ph == nil"), config.Verbose) + break + } + n, err = tlsConn.Read(packet[:ph.Length]) + if err != nil { + netutil.PrintErr(err, config.Verbose) + break + } + if n != ph.Length { + netutil.PrintErr(errors.New(fmt.Sprintf("received length <%d> not equals <%d>!", n, ph.Length)), config.Verbose) break } b := packet[:n] @@ -98,7 +136,7 @@ func tlsToTun(config config.Config, tlsconn net.Conn, iface *water.Interface) { if config.Obfs { b = cipher.XOR(b) } - _, err = iface.Write(b) + n, err = iFace.Write(b) if err != nil { netutil.PrintErr(err, config.Verbose) break diff --git a/tls/tlsserver.go b/tls/tlsserver.go index 6d2ae2b..6c3730a 100644 --- a/tls/tlsserver.go +++ b/tls/tlsserver.go @@ -2,6 +2,9 @@ package tls import ( "crypto/tls" + "errors" + "fmt" + "github.com/net-byte/vtun/common/xproto" "log" "net" "time" @@ -16,7 +19,7 @@ import ( ) // StartServer starts the tls server -func StartServer(iface *water.Interface, config config.Config) { +func StartServer(iFace *water.Interface, config config.Config) { log.Printf("vtun tls server started on %v", config.LocalAddr) cert, err := tls.LoadX509KeyPair(config.TLSCertificateFilePath, config.TLSCertificateKeyFilePath) if err != nil { @@ -41,7 +44,7 @@ func StartServer(iface *water.Interface, config config.Config) { log.Panic(err) } // server -> client - go toClient(config, iface) + go toClient(config, iFace) // client -> server for { conn, err := ln.Accept() @@ -59,20 +62,20 @@ func StartServer(iface *water.Interface, config config.Config) { continue } } - go toServer(config, sniffConn, iface) + go toServer(config, sniffConn, iFace) } } -// toClient sends packets from iface to tlsconn -func toClient(config config.Config, iface *water.Interface) { - packet := make([]byte, config.BufferSize) +// toClient sends packets from iFace to tlsConn +func toClient(config config.Config, iFace *water.Interface) { + buffer := make([]byte, config.BufferSize) for { - n, err := iface.Read(packet) + n, err := iFace.Read(buffer) if err != nil { netutil.PrintErr(err, config.Verbose) continue } - b := packet[:n] + b := buffer[:n] if key := netutil.GetDstKey(b); key != "" { if v, ok := cache.GetCache().Get(key); ok { if config.Obfs { @@ -81,7 +84,16 @@ func toClient(config config.Config, iface *water.Interface) { if config.Compress { b = snappy.Encode(nil, b) } - _, err := v.(net.Conn).Write(b) + ph := &xproto.ServerSendPacketHeader{ + ProtocolVersion: xproto.ProtocolVersion, + Length: len(b), + } + _, err := v.(net.Conn).Write(ph.Bytes()) + if err != nil { + cache.GetCache().Delete(key) + continue + } + n, err := v.(net.Conn).Write(b[:]) if err != nil { cache.GetCache().Delete(key) continue @@ -92,14 +104,43 @@ func toClient(config config.Config, iface *water.Interface) { } } -// toServer sends packets from tlsconn to iface -func toServer(config config.Config, tlsconn net.Conn, iface *water.Interface) { - defer tlsconn.Close() - packet := make([]byte, config.BufferSize) - for { - n, err := tlsconn.Read(packet) +// toServer sends packets from tlsConn to iFace +func toServer(config config.Config, tlsConn net.Conn, iFace *water.Interface) { + defer func(tlsConn net.Conn) { + err := tlsConn.Close() if err != nil { netutil.PrintErr(err, config.Verbose) + } + }(tlsConn) + header := make([]byte, xproto.ClientSendPacketHeaderLength) + packet := make([]byte, config.BufferSize) + authKey := xproto.ParseAuthKeyFromString(config.Key) + for { + n, err := tlsConn.Read(header) + if err != nil { + netutil.PrintErr(err, config.Verbose) + break + } + if n != xproto.ClientSendPacketHeaderLength { + netutil.PrintErr(errors.New(fmt.Sprintf("received length <%d> not equals <%d>!", n, xproto.ClientSendPacketHeaderLength)), config.Verbose) + break + } + ph := xproto.ParseClientSendPacketHeader(header[:n]) + if ph == nil { + netutil.PrintErr(errors.New("ph == nil"), config.Verbose) + break + } + if !ph.Key.Equals(authKey) { + netutil.PrintErr(errors.New("authentication failed"), config.Verbose) + break + } + n, err = tlsConn.Read(packet[:ph.Length]) + if err != nil { + netutil.PrintErr(err, config.Verbose) + break + } + if n != ph.Length { + netutil.PrintErr(errors.New(fmt.Sprintf("received length <%d> not equals <%d>!", n, ph.Length)), config.Verbose) break } b := packet[:n] @@ -114,8 +155,12 @@ func toServer(config config.Config, tlsconn net.Conn, iface *water.Interface) { b = cipher.XOR(b) } if key := netutil.GetSrcKey(b); key != "" { - cache.GetCache().Set(key, tlsconn, 24*time.Hour) - iface.Write(b) + cache.GetCache().Set(key, tlsConn, 24*time.Hour) + n, err := iFace.Write(b) + if err != nil { + netutil.PrintErr(err, config.Verbose) + break + } counter.IncrReadBytes(n) } }