From 1dd5789a72a654ab39f12525d3c3501e8ce706a1 Mon Sep 17 00:00:00 2001 From: database64128 Date: Wed, 4 Sep 2024 17:33:13 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=B9=20all:=20use=20github.com/database?= =?UTF-8?q?64128/netx-go=20to=20reduce=20linkname=20usage?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- go.mod | 2 + go.sum | 2 + netpoll_windows_checklinkname0.go | 15 ++--- netpoll_windows_go121.go | 14 ++-- netpoll_windows_go123_checklinkname0.go | 14 ++-- tfo.go | 11 ---- tfo_bsd+linux.go | 61 ++++++------------ tfo_connect_generic.go | 86 ++++++------------------- tfo_darwin.go | 4 +- tfo_linux.go | 7 +- tfo_listen_generic.go | 3 +- tfo_windows_checklinkname0.go | 83 +++++++++++++----------- 12 files changed, 116 insertions(+), 186 deletions(-) diff --git a/go.mod b/go.mod index dd79044..38fe3f1 100644 --- a/go.mod +++ b/go.mod @@ -3,3 +3,5 @@ module github.com/database64128/tfo-go/v2 go 1.21.0 require golang.org/x/sys v0.24.1-0.20240828075529-ed67b1566aaf + +require github.com/database64128/netx-go v0.0.0-20240904075656-1efc34d35e1a diff --git a/go.sum b/go.sum index 09b5eab..1d99d00 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,4 @@ +github.com/database64128/netx-go v0.0.0-20240904075656-1efc34d35e1a h1:EFUgNuSxsGOE6zP49HlhhXx+w0SZs6ADuqVEMaKSFFU= +github.com/database64128/netx-go v0.0.0-20240904075656-1efc34d35e1a/go.mod h1:uMBPfZT3hyBlp6X8qIToro7wX+zymQTMe1bxfqUsbIs= golang.org/x/sys v0.24.1-0.20240828075529-ed67b1566aaf h1:q2Cx0keWwW5HecyZeIyA3DCuupo8A/zjDqsOQK0+Z80= golang.org/x/sys v0.24.1-0.20240828075529-ed67b1566aaf/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= diff --git a/netpoll_windows_checklinkname0.go b/netpoll_windows_checklinkname0.go index be5f6ae..165eec5 100644 --- a/netpoll_windows_checklinkname0.go +++ b/netpoll_windows_checklinkname0.go @@ -5,13 +5,11 @@ package tfo import ( "net" "sync" - "syscall" "time" _ "unsafe" -) -//go:linkname sockaddrToTCP net.sockaddrToTCP -func sockaddrToTCP(sa syscall.Sockaddr) net.Addr + "golang.org/x/sys/windows" +) //go:linkname execIO internal/poll.execIO func execIO(o *operation, submit func(o *operation) error) (int, error) @@ -26,7 +24,7 @@ type pFD struct { fdmuW uint32 // System file descriptor. Immutable until Close. - Sysfd syscall.Handle + Sysfd windows.Handle // Read operation. rop operation @@ -65,10 +63,9 @@ type pFD struct { kind byte } -func (fd *pFD) ConnectEx(ra syscall.Sockaddr, b []byte) (n int, err error) { - fd.wop.sa = ra +func (fd *pFD) ConnectEx(ra windows.Sockaddr, b []byte) (n int, err error) { n, err = execIO(&fd.wop, func(o *operation) error { - return syscall.ConnectEx(o.fd.Sysfd, o.sa, &b[0], uint32(len(b)), &o.qty, &o.o) + return windows.ConnectEx(o.fd.Sysfd, ra, &b[0], uint32(len(b)), &o.qty, &o.o) }) return } @@ -89,7 +86,7 @@ type netFD struct { } //go:linkname newFD net.newFD -func newFD(sysfd syscall.Handle, family, sotype int, net string) (*netFD, error) +func newFD(sysfd windows.Handle, family, sotype int, net string) (*netFD, error) //go:linkname netFDInit net.(*netFD).init func netFDInit(fd *netFD) error diff --git a/netpoll_windows_go121.go b/netpoll_windows_go121.go index c7d392c..0ca50a4 100644 --- a/netpoll_windows_go121.go +++ b/netpoll_windows_go121.go @@ -3,8 +3,6 @@ package tfo import ( - "syscall" - "golang.org/x/sys/windows" ) @@ -14,7 +12,7 @@ import ( type operation struct { // Used by IOCP interface, it must be first field // of the struct, as our code rely on it. - o syscall.Overlapped + o windows.Overlapped // fields used by runtime.netpoll runtimeCtx uintptr @@ -24,12 +22,12 @@ type operation struct { // fields used only by net package fd *pFD - buf syscall.WSABuf + buf windows.WSABuf msg windows.WSAMsg - sa syscall.Sockaddr - rsa *syscall.RawSockaddrAny + sa windows.Sockaddr + rsa *windows.RawSockaddrAny rsan int32 - handle syscall.Handle + handle windows.Handle flags uint32 - bufs []syscall.WSABuf + bufs []windows.WSABuf } diff --git a/netpoll_windows_go123_checklinkname0.go b/netpoll_windows_go123_checklinkname0.go index 20d8fbe..b74f381 100644 --- a/netpoll_windows_go123_checklinkname0.go +++ b/netpoll_windows_go123_checklinkname0.go @@ -3,8 +3,6 @@ package tfo import ( - "syscall" - "golang.org/x/sys/windows" ) @@ -14,7 +12,7 @@ import ( type operation struct { // Used by IOCP interface, it must be first field // of the struct, as our code rely on it. - o syscall.Overlapped + o windows.Overlapped // fields used by runtime.netpoll runtimeCtx uintptr @@ -22,13 +20,13 @@ type operation struct { // fields used only by net package fd *pFD - buf syscall.WSABuf + buf windows.WSABuf msg windows.WSAMsg - sa syscall.Sockaddr - rsa *syscall.RawSockaddrAny + sa windows.Sockaddr + rsa *windows.RawSockaddrAny rsan int32 - handle syscall.Handle + handle windows.Handle flags uint32 qty uint32 - bufs []syscall.WSABuf + bufs []windows.WSABuf } diff --git a/tfo.go b/tfo.go index f420c18..793c6a0 100644 --- a/tfo.go +++ b/tfo.go @@ -13,9 +13,7 @@ import ( "context" "errors" "net" - "os" "sync/atomic" - "syscall" "time" ) @@ -228,15 +226,6 @@ func opAddr(a *net.TCPAddr) net.Addr { return a } -// wrapSyscallError takes an error and a syscall name. If the error is -// a syscall.Errno, it wraps it in a os.SyscallError using the syscall name. -func wrapSyscallError(name string, err error) error { - if _, ok := err.(syscall.Errno); ok { - err = os.NewSyscallError(name, err) - } - return err -} - // aLongTimeAgo is a non-zero time, far in the past, used for immediate deadlines. var aLongTimeAgo = time.Unix(0, 0) diff --git a/tfo_bsd+linux.go b/tfo_bsd+linux.go index 11a2241..6038129 100644 --- a/tfo_bsd+linux.go +++ b/tfo_bsd+linux.go @@ -9,6 +9,7 @@ import ( "os" "syscall" + "github.com/database64128/netx-go" "golang.org/x/sys/unix" ) @@ -34,19 +35,7 @@ func ctrlNetwork(network string, family int) string { } func (d *Dialer) dialSingle(ctx context.Context, network string, laddr, raddr *net.TCPAddr, b []byte, ctrlCtxFn func(context.Context, string, string, syscall.RawConn) error) (*net.TCPConn, error) { - ltsa := (*tcpSockaddr)(laddr) - rtsa := (*tcpSockaddr)(raddr) - family, ipv6only := favoriteAddrFamily(network, ltsa, rtsa, "dial") - - lsa, err := ltsa.sockaddr(family) - if err != nil { - return nil, err - } - - rsa, err := rtsa.sockaddr(family) - if err != nil { - return nil, err - } + family, ipv6only := favoriteDialAddrFamily(network, laddr, raddr) fd, err := d.socket(family) if err != nil { @@ -55,18 +44,18 @@ func (d *Dialer) dialSingle(ctx context.Context, network string, laddr, raddr *n if err = d.setIPv6Only(fd, family, ipv6only); err != nil { unix.Close(fd) - return nil, wrapSyscallError("setsockopt(IPV6_V6ONLY)", err) + return nil, os.NewSyscallError("setsockopt(IPV6_V6ONLY)", err) } if err = setNoDelay(fd, 1); err != nil { unix.Close(fd) - return nil, wrapSyscallError("setsockopt(TCP_NODELAY)", err) + return nil, os.NewSyscallError("setsockopt(TCP_NODELAY)", err) } if err = setTFODialerFromSocket(uintptr(fd)); err != nil { if !d.Fallback || !errors.Is(err, errors.ErrUnsupported) { unix.Close(fd) - return nil, wrapSyscallError("setsockopt("+setTFODialerFromSocketSockoptName+")", err) + return nil, os.NewSyscallError("setsockopt("+setTFODialerFromSocketSockoptName+")", err) } runtimeDialTFOSupport.storeNone() } @@ -87,7 +76,7 @@ func (d *Dialer) dialSingle(ctx context.Context, network string, laddr, raddr *n if laddr != nil { if cErr := rawConn.Control(func(fd uintptr) { - err = syscall.Bind(int(fd), lsa) + err = unix.Bind(int(fd), unixSockaddrFromTCPAddr(laddr)) }); cErr != nil { return nil, cErr } @@ -96,18 +85,13 @@ func (d *Dialer) dialSingle(ctx context.Context, network string, laddr, raddr *n } } - rusa, err := unixSockaddrFromSyscallSockaddr(rsa) - if err != nil { - return nil, err - } - var ( n int canFallback bool ) if err = connWriteFunc(ctx, f, func(f *os.File) (err error) { - n, canFallback, err = connect(rawConn, rusa, b) + n, canFallback, err = connect(rawConn, unixSockaddrFromTCPAddr(raddr), b) return err }); err != nil { if d.Fallback && canFallback { @@ -132,24 +116,21 @@ func (d *Dialer) dialSingle(ctx context.Context, network string, laddr, raddr *n return c.(*net.TCPConn), err } -func unixSockaddrFromSyscallSockaddr(sa syscall.Sockaddr) (unix.Sockaddr, error) { - if sa == nil { - return nil, nil +func unixSockaddrFromTCPAddr(a *net.TCPAddr) unix.Sockaddr { + if a == nil { + return nil } - switch sa := sa.(type) { - case *syscall.SockaddrInet4: + if ip4 := a.IP.To4(); ip4 != nil { return &unix.SockaddrInet4{ - Port: sa.Port, - Addr: sa.Addr, - }, nil - case *syscall.SockaddrInet6: - return &unix.SockaddrInet6{ - Port: sa.Port, - ZoneId: sa.ZoneId, - Addr: sa.Addr, - }, nil - } - return nil, errors.New("unsupported sockaddr type") + Port: a.Port, + Addr: [4]byte(ip4), + } + } + return &unix.SockaddrInet6{ + Port: a.Port, + ZoneId: uint32(netx.ZoneCache.Index(a.Zone)), + Addr: [16]byte(a.IP), + } } func connect(rawConn syscall.RawConn, rsa unix.Sockaddr, b []byte) (n int, canFallback bool, err error) { @@ -187,7 +168,7 @@ func connect(rawConn syscall.RawConn, rsa unix.Sockaddr, b []byte) (n int, canFa func getSocketError(fd int, call string) error { nerr, err := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_ERROR) if err != nil { - return wrapSyscallError("getsockopt", err) + return os.NewSyscallError("getsockopt", err) } if nerr != 0 { return os.NewSyscallError(call, syscall.Errno(nerr)) diff --git a/tfo_connect_generic.go b/tfo_connect_generic.go index 16ddece..2fb18e5 100644 --- a/tfo_connect_generic.go +++ b/tfo_connect_generic.go @@ -26,84 +26,34 @@ func boolint(b bool) int { return 0 } -// A sockaddr represents a TCP, UDP, IP or Unix network endpoint -// address that can be converted into a syscall.Sockaddr. -// -// Copied from src/net/sockaddr_posix.go -type sockaddr interface { - net.Addr - - // family returns the platform-dependent address family - // identifier. - family() int - - // isWildcard reports whether the address is a wildcard - // address. - isWildcard() bool - - // sockaddr returns the address converted into a syscall - // sockaddr type that implements syscall.Sockaddr - // interface. It returns a nil interface when the address is - // nil. - sockaddr(family int) (syscall.Sockaddr, error) - - // toLocal maps the zero address to a local system address (127.0.0.1 or ::1) - toLocal(net string) sockaddr -} - -type tcpSockaddr net.TCPAddr - -func (a *tcpSockaddr) Network() string { - return "tcp" -} - -func (a *tcpSockaddr) String() string { - return (*net.TCPAddr)(a).String() -} - -// Copied from src/net/tcpsock_posix.go -func (a *tcpSockaddr) family() int { - if a == nil || len(a.IP) <= net.IPv4len { - return syscall.AF_INET - } - if a.IP.To4() != nil { - return syscall.AF_INET +// wrapSyscallError takes an error and a syscall name. If the error is +// a syscall.Errno, it wraps it in a os.SyscallError using the syscall name. +func wrapSyscallError(name string, err error) error { + if _, ok := err.(syscall.Errno); ok { + err = os.NewSyscallError(name, err) } - return syscall.AF_INET6 + return err } -// Copied from src/net/tcpsock_posix.go -func (a *tcpSockaddr) isWildcard() bool { - if a == nil || a.IP == nil { - return true +// Modified from favoriteAddrFamily in src/net/ipsock_posix.go +func favoriteDialAddrFamily(network string, laddr, raddr *net.TCPAddr) (family int, ipv6only bool) { + switch network { + case "tcp4": + return syscall.AF_INET, false + case "tcp6": + return syscall.AF_INET6, true } - return a.IP.IsUnspecified() -} -//go:linkname ipToSockaddr net.ipToSockaddr -func ipToSockaddr(family int, ip net.IP, port int, zone string) (syscall.Sockaddr, error) - -// Copied from src/net/tcpsock_posix.go -func (a *tcpSockaddr) sockaddr(family int) (syscall.Sockaddr, error) { - if a == nil { - return nil, nil + if tcpAddrIs4(laddr) || tcpAddrIs4(raddr) { + return syscall.AF_INET, false } - return ipToSockaddr(family, a.IP, a.Port, a.Zone) + return syscall.AF_INET6, false } -//go:linkname loopbackIP net.loopbackIP -func loopbackIP(net string) net.IP - -// Modified from src/net/tcpsock_posix.go -func (a *tcpSockaddr) toLocal(net string) sockaddr { - la := *a - la.IP = loopbackIP(net) - return &la +func tcpAddrIs4(a *net.TCPAddr) bool { + return a != nil && a.IP.To4() != nil } -//go:linkname favoriteAddrFamily net.favoriteAddrFamily -func favoriteAddrFamily(network string, laddr, raddr sockaddr, mode string) (family int, ipv6only bool) - func (d *Dialer) dialTFOFromSocket(ctx context.Context, network, address string, b []byte) (*net.TCPConn, error) { if ctx == nil { panic("nil context") diff --git a/tfo_darwin.go b/tfo_darwin.go index 7584d13..718ee1f 100644 --- a/tfo_darwin.go +++ b/tfo_darwin.go @@ -34,7 +34,7 @@ func (lc *ListenConfig) listenTFO(ctx context.Context, network, address string) if err != nil { if !lc.Fallback || !errors.Is(err, errors.ErrUnsupported) { - return wrapSyscallError("setsockopt(TCP_FASTOPEN_FORCE_ENABLE)", err) + return os.NewSyscallError("setsockopt(TCP_FASTOPEN_FORCE_ENABLE)", err) } runtimeListenNoTFO.Store(true) } @@ -62,7 +62,7 @@ func (lc *ListenConfig) listenTFO(ctx context.Context, network, address string) if err != nil { ln.Close() if !lc.Fallback || !errors.Is(err, errors.ErrUnsupported) { - return nil, wrapSyscallError("setsockopt(TCP_FASTOPEN)", err) + return nil, os.NewSyscallError("setsockopt(TCP_FASTOPEN)", err) } runtimeListenNoTFO.Store(true) } diff --git a/tfo_linux.go b/tfo_linux.go index fe4fff0..2d60eb6 100644 --- a/tfo_linux.go +++ b/tfo_linux.go @@ -4,6 +4,7 @@ import ( "context" "errors" "net" + "os" "syscall" "golang.org/x/sys/unix" @@ -11,7 +12,7 @@ import ( const setTFODialerFromSocketSockoptName = "unreachable" -func setTFODialerFromSocket(fd uintptr) error { +func setTFODialerFromSocket(_ uintptr) error { return nil } @@ -23,7 +24,7 @@ func doConnectCanFallback(err error) bool { // returns -EPIPE. This indicates that the MSG_FASTOPEN flag is not recognized by the kernel. // // -EOPNOTSUPP is returned if the kernel recognizes the flag, but TFO is disabled via sysctl. - return err == syscall.EPIPE || err == syscall.EOPNOTSUPP + return err == unix.EPIPE || err == unix.EOPNOTSUPP } func (a *atomicDialTFOSupport) casLinuxSendto() bool { @@ -66,7 +67,7 @@ func (d *Dialer) dialTFO(ctx context.Context, network, address string, b []byte) if d.Fallback && errors.Is(err, errors.ErrUnsupported) { canFallback = true } - return wrapSyscallError("setsockopt(TCP_FASTOPEN_CONNECT)", err) + return os.NewSyscallError("setsockopt(TCP_FASTOPEN_CONNECT)", err) } return nil } diff --git a/tfo_listen_generic.go b/tfo_listen_generic.go index 8277e5a..f695d3e 100644 --- a/tfo_listen_generic.go +++ b/tfo_listen_generic.go @@ -6,6 +6,7 @@ import ( "context" "errors" "net" + "os" "syscall" ) @@ -28,7 +29,7 @@ func (lc *ListenConfig) listenTFO(ctx context.Context, network, address string) if err != nil { if !lc.Fallback || !errors.Is(err, errors.ErrUnsupported) { - return wrapSyscallError("setsockopt(TCP_FASTOPEN)", err) + return os.NewSyscallError("setsockopt(TCP_FASTOPEN)", err) } runtimeListenNoTFO.Store(true) } diff --git a/tfo_windows_checklinkname0.go b/tfo_windows_checklinkname0.go index 1427aee..8857cae 100644 --- a/tfo_windows_checklinkname0.go +++ b/tfo_windows_checklinkname0.go @@ -11,6 +11,7 @@ import ( "syscall" "unsafe" + "github.com/database64128/netx-go" "golang.org/x/sys/windows" ) @@ -33,38 +34,14 @@ func setUpdateConnectContext(fd windows.Handle) error { } func (d *Dialer) dialSingle(ctx context.Context, network string, laddr, raddr *net.TCPAddr, b []byte, ctrlCtxFn func(context.Context, string, string, syscall.RawConn) error) (*net.TCPConn, error) { - ltsa := (*tcpSockaddr)(laddr) - rtsa := (*tcpSockaddr)(raddr) - family, ipv6only := favoriteAddrFamily(network, ltsa, rtsa, "dial") - - var ( - ip net.IP - port int - zone string - ) - - if laddr != nil { - ip = laddr.IP - port = laddr.Port - zone = laddr.Zone - } - - lsa, err := ipToSockaddr(family, ip, port, zone) - if err != nil { - return nil, err - } - - rsa, err := rtsa.sockaddr(family) - if err != nil { - return nil, err - } + family, ipv6only := favoriteDialAddrFamily(network, laddr, raddr) handle, err := windows.WSASocket(int32(family), windows.SOCK_STREAM, windows.IPPROTO_TCP, nil, 0, windows.WSA_FLAG_OVERLAPPED|windows.WSA_FLAG_NO_HANDLE_INHERIT) if err != nil { return nil, os.NewSyscallError("WSASocket", err) } - fd, err := newFD(syscall.Handle(handle), family, windows.SOCK_STREAM, network) + fd, err := newFD(handle, family, windows.SOCK_STREAM, network) if err != nil { windows.Closesocket(handle) return nil, err @@ -72,18 +49,18 @@ func (d *Dialer) dialSingle(ctx context.Context, network string, laddr, raddr *n if err = setIPv6Only(handle, family, ipv6only); err != nil { fd.Close() - return nil, wrapSyscallError("setsockopt(IPV6_V6ONLY)", err) + return nil, os.NewSyscallError("setsockopt(IPV6_V6ONLY)", err) } if err = setNoDelay(handle, 1); err != nil { fd.Close() - return nil, wrapSyscallError("setsockopt(TCP_NODELAY)", err) + return nil, os.NewSyscallError("setsockopt(TCP_NODELAY)", err) } if err = setTFODialer(uintptr(handle)); err != nil { if !d.Fallback || !errors.Is(err, errors.ErrUnsupported) { fd.Close() - return nil, wrapSyscallError("setsockopt(TCP_FASTOPEN)", err) + return nil, os.NewSyscallError("setsockopt(TCP_FASTOPEN)", err) } runtimeDialTFOSupport.storeNone() } @@ -95,7 +72,12 @@ func (d *Dialer) dialSingle(ctx context.Context, network string, laddr, raddr *n } } - if err = syscall.Bind(syscall.Handle(handle), lsa); err != nil { + lsa := windowsSockaddrFromTCPAddr(laddr) + if lsa == nil { + lsa = &windows.SockaddrInet6{} + } + + if err = windows.Bind(handle, lsa); err != nil { fd.Close() return nil, wrapSyscallError("bind", err) } @@ -106,26 +88,28 @@ func (d *Dialer) dialSingle(ctx context.Context, network string, laddr, raddr *n } if err = connWriteFunc(ctx, fd, func(fd *netFD) error { + rsa := windowsSockaddrFromTCPAddr(raddr) + n, err := fd.pfd.ConnectEx(rsa, b) if err != nil { - return os.NewSyscallError("connectex", err) + return wrapSyscallError("connectex", err) } if err = setUpdateConnectContext(handle); err != nil { - return wrapSyscallError("setsockopt(SO_UPDATE_CONNECT_CONTEXT)", err) + return os.NewSyscallError("setsockopt(SO_UPDATE_CONNECT_CONTEXT)", err) } - lsa, err = syscall.Getsockname(syscall.Handle(handle)) + lsa, err = windows.Getsockname(handle) if err != nil { return wrapSyscallError("getsockname", err) } - fd.laddr = sockaddrToTCP(lsa) + fd.laddr = tcpAddrFromWindowsSockaddr(lsa) - rsa, err = syscall.Getpeername(syscall.Handle(handle)) + rsa, err = windows.Getpeername(handle) if err != nil { return wrapSyscallError("getpeername", err) } - fd.raddr = sockaddrToTCP(rsa) + fd.raddr = tcpAddrFromWindowsSockaddr(rsa) if n < len(b) { if _, err = fd.Write(b[n:]); err != nil { @@ -142,3 +126,30 @@ func (d *Dialer) dialSingle(ctx context.Context, network string, laddr, raddr *n runtime.SetFinalizer(fd, netFDClose) return (*net.TCPConn)(unsafe.Pointer(&fd)), nil } + +func windowsSockaddrFromTCPAddr(a *net.TCPAddr) windows.Sockaddr { + if a == nil { + return nil + } + if ip4 := a.IP.To4(); ip4 != nil { + return &windows.SockaddrInet4{ + Port: a.Port, + Addr: [4]byte(ip4), + } + } + return &windows.SockaddrInet6{ + Port: a.Port, + ZoneId: uint32(netx.ZoneCache.Index(a.Zone)), + Addr: [16]byte(a.IP), + } +} + +func tcpAddrFromWindowsSockaddr(sa windows.Sockaddr) *net.TCPAddr { + switch sa := sa.(type) { + case *windows.SockaddrInet4: + return &net.TCPAddr{IP: sa.Addr[0:], Port: sa.Port} + case *windows.SockaddrInet6: + return &net.TCPAddr{IP: sa.Addr[0:], Port: sa.Port, Zone: netx.ZoneCache.Name(int(sa.ZoneId))} + } + return nil +}