diff --git a/go.mod b/go.mod index c5908df..b9a01d8 100644 --- a/go.mod +++ b/go.mod @@ -3,9 +3,12 @@ module tcp-proxy go 1.19 require ( - github.com/integrii/flaggy v1.5.2 // indirect + github.com/integrii/flaggy v1.5.2 + github.com/rs/zerolog v1.29.0 +) + +require ( github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.17 // indirect - github.com/rs/zerolog v1.29.0 // indirect golang.org/x/sys v0.4.0 // indirect ) diff --git a/go.sum b/go.sum index 7c90c60..54f80ef 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,6 @@ github.com/coreos/go-systemd/v22 v22.3.3-0.20220203105225-a9a7ef127534/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ= github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/integrii/flaggy v1.5.2 h1:bWV20MQEngo4hWhno3i5Z9ISPxLPKj9NOGNwTWb/8IQ= github.com/integrii/flaggy v1.5.2/go.mod h1:dO13u7SYuhk910nayCJ+s1DeAAGC1THCMj1uSFmwtQ8= diff --git a/main.go b/main.go index e98089a..74d1744 100644 --- a/main.go +++ b/main.go @@ -17,35 +17,113 @@ func main() { var ( listenAddr = ":8000" forwardAddr string + mode = "tcp" ) flaggy.String(&listenAddr, "l", "listen", "Address to listen on") flaggy.String(&forwardAddr, "f", "forward", "Address to forward to") + flaggy.String(&mode, "m", "mode", "Mode to use (tcp, udp)") flaggy.Parse() if forwardAddr == "" { log.Fatal().Msg("Forward address is required") } - ln, err := net.Listen("tcp", listenAddr) - if err != nil { - log.Panic().Err(err).Msg("Failed to listen") + if mode == "tcp" { + tcpListen(listenAddr, forwardAddr) + } else if mode == "udp" { + udpListen(listenAddr, forwardAddr) + } else { + log.Fatal().Msg("Invalid mode") } - defer ln.Close() - log.Info().Str("Listen Address", listenAddr).Msg("Listening") +} + +func tcpListen(listenAddr, forwardAddr string) { + var ( + ln net.Listener + err error + ) + + if ln, err = net.Listen("tcp", listenAddr); err != nil { + log.Panic().Err(err).Str("Mode", "TCP").Msg("Failed to listen") + } + + log.Info().Str("Mode", "TCP").Str("Listen Address", listenAddr).Msg("Listening") + + defer ln.Close() for { conn, err := ln.Accept() if err != nil { - log.Err(err).Msg("Failed to accept connection") + log.Err(err).Str("Mode", "TCP").Msg("Failed to accept connection") continue } - log.Info().Str("Incoming Address", conn.RemoteAddr().String()).Msg("Connection accepted") + log.Info().Str("Mode", "TCP").Str("Incoming Address", conn.RemoteAddr().String()).Msg("Connection accepted") go handleConnection(conn, forwardAddr) } +} +func udpListen(la, forwardAddr string) { + // Listen for incoming UDP packets on port 3000 + listenAddr, err := net.ResolveUDPAddr("udp", la) + if err != nil { + log.Err(err).Str("Mode", "UDP").Msg("Failed to resolve UDP listen address") + return + } + + listener, err := net.ListenUDP("udp", listenAddr) + if err != nil { + log.Err(err).Str("Mode", "UDP").Msg("Failed to listen on UDP port") + return + } + defer listener.Close() + + log.Info().Str("Mode", "UDP").Str("Listen Address", listenAddr.String()).Msg("Listening") + + // Forward incoming UDP packets to another host and port + remoteAddr, err := net.ResolveUDPAddr("udp", forwardAddr) + if err != nil { + log.Err(err).Str("Mode", "UDP").Msg("Failed to resolve remote address") + return + } + + var ( + callerAddr *net.UDPAddr + sendAddr *net.UDPAddr + ) + + buf := make([]byte, 2048) + + for { + + n, addr, err := listener.ReadFromUDP(buf) + if err != nil { + log.Err(err).Str("Mode", "UDP").Msg("Failed to read from UDP connection") + continue + } + + if callerAddr == nil { + callerAddr = addr + } + + if addr.String() == remoteAddr.String() { + sendAddr = callerAddr + } else { + sendAddr = remoteAddr + callerAddr = addr + } + + //log.Info().Str("Mode", "UDP").Str("Incoming Address", addr.String()).Str("Forward Address", sendAddr.String()).Msg("Packet received") + + _, err = listener.WriteToUDP(buf[:n], sendAddr) + if err != nil { + log.Err(err).Str("Mode", "udp").Msg("Failed to write to UDP connection") + continue + } + + } } func handleConnection(conn net.Conn, forwardAddr string) { @@ -53,7 +131,7 @@ func handleConnection(conn net.Conn, forwardAddr string) { fwrd, err := net.Dial("tcp", forwardAddr) if err != nil { - log.Err(err).Msg("Failed to connect to forward address") + log.Err(err).Str("Mode", "TCP").Msg("Failed to connect to forward address") return } @@ -67,7 +145,7 @@ func handleConnection(conn net.Conn, forwardAddr string) { _, err := io.Copy(fwrd, conn) if err != nil { - log.Err(err).Msg("Failed to copy from connection to forward address") + log.Err(err).Str("Mode", "TCP").Msg("Failed to copy from connection to forward address") } }() @@ -77,10 +155,10 @@ func handleConnection(conn net.Conn, forwardAddr string) { _, err := io.Copy(conn, fwrd) if err != nil { - log.Err(err).Msg("Failed to copy from forward address to connection") + log.Err(err).Str("Mode", "TCP").Msg("Failed to copy from forward address to connection") } }() wg.Wait() - log.Info().Str("Incoming Address", conn.RemoteAddr().String()).Msg("Connection closed") + log.Info().Str("Mode", "TCP").Str("Incoming Address", conn.RemoteAddr().String()).Msg("Connection closed") }