From 6c7564062ffe9cd73dd74335b900a6b55a066568 Mon Sep 17 00:00:00 2001 From: Nils Date: Tue, 7 Feb 2023 19:17:30 +0100 Subject: [PATCH] make things a bit less ugly --- README.md | 1 + main.go | 113 ++++++++++++++++++++++++++---------------------------- 2 files changed, 56 insertions(+), 58 deletions(-) diff --git a/README.md b/README.md index a217257..bb2dd28 100644 --- a/README.md +++ b/README.md @@ -77,6 +77,7 @@ https://[address]/[command]/[host](_[port]),[host].../[options] - ping - mtr - traceroute + - nping - [host] = can be one or more hosts query, seperated by a comma - [port] = port to be queried, optional - [options] = options to run the command with, seperated by a comma diff --git a/main.go b/main.go index e7e307a..0f32766 100644 --- a/main.go +++ b/main.go @@ -13,78 +13,70 @@ import ( flag "github.com/spf13/pflag" ) -var logstdout = log.New() -var logfile = log.New() +var logStdout = log.New() +var logFile = log.New() -var listenport int -var disablexforwardedfor bool -var allowprivate bool +var listenPort = 8080 // port to listen on +var disableXForwardedFor bool // whether to disable parsing the X-Forwarded-For header or not +var allowPrivate bool // whether to allow private IP ranges or not func init() { - logstdout.SetFormatter(&log.TextFormatter{ + logStdout.SetFormatter(&log.TextFormatter{ FullTimestamp: true}) - logstdout.SetOutput(os.Stdout) - logstdout.SetLevel(log.InfoLevel) - var logfilepath string + logStdout.SetOutput(os.Stdout) + logStdout.SetLevel(log.InfoLevel) - if _, exists := os.LookupEnv("PROBEHOST_LOGPATH"); exists { - logfilepath, _ = os.LookupEnv("PROBEHOST_LOGPATH") - } else { - logfilepath = "probehost2.log" + logFilePath := "probehost2.log" + if val, exists := os.LookupEnv("PROBEHOST_LOGPATH"); exists { + logFilePath = val } - if exists, _ := os.LookupEnv("PROBEHOST_ALLOW_PRIVATE"); exists == "true" { - allowprivate = true - } else { - allowprivate = false - } - if envvalue, exists := os.LookupEnv("PROBEHOST_LISTEN_PORT"); exists { + + _, allowPrivate = os.LookupEnv("PROBEHOST_ALLOW_PRIVATE") + _, disableXForwardedFor = os.LookupEnv("PROBEHOST_DISABLE_X_FORWARDED_FOR") + + if val, exists := os.LookupEnv("PROBEHOST_LISTEN_PORT"); exists { var err error - listenport, err = strconv.Atoi(envvalue) + listenPort, err = strconv.Atoi(val) if err != nil { - logstdout.Fatal("Failed to read PROBEHOST_LISTEN_PORT: ", err.Error()) + logStdout.Fatal("Failed to read PROBEHOST_LISTEN_PORT: ", err.Error()) } - } else { - listenport = 8000 } - if exists, _ := os.LookupEnv("PROBEHOST_DISABLE_X_FORWARDED_FOR"); exists == "true" { - disablexforwardedfor = true - } else { - disablexforwardedfor = false - } - flag.StringVarP(&logfilepath, "logfilepath", "o", logfilepath, "sets the output file for the log") - flag.IntVarP(&listenport, "port", "p", listenport, "sets the port to listen on") - flag.BoolVarP(&disablexforwardedfor, "disable-x-forwarded-for", "x", disablexforwardedfor, "whether to show x-forwarded-for or the requesting IP") - flag.BoolVarP(&allowprivate, "allow-private", "l", allowprivate, "whether to show lookups of private IP ranges") + + flag.StringVarP(&logFilePath, "logFilePath", "o", logFilePath, "sets the output file for the log") + flag.IntVarP(&listenPort, "port", "p", listenPort, "sets the port to listen on") + flag.BoolVarP(&disableXForwardedFor, "disable-x-forwarded-for", "x", disableXForwardedFor, "whether to show x-forwarded-for or the requesting IP") + flag.BoolVarP(&allowPrivate, "allow-private", "l", allowPrivate, "whether to show lookups of private IP ranges") flag.Parse() - logpath, err := os.OpenFile(logfilepath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0660) + logpath, err := os.OpenFile(logFilePath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0660) if err != nil { - logstdout.Fatal("Failed to initialize the logfile: ", err.Error()) + logStdout.Fatal("Failed to initialize the logFile: ", err.Error()) } - logfile.SetLevel(log.InfoLevel) - logfile.SetOutput(logpath) - logfile.Info("probehost2 initialized") + logFile.SetLevel(log.InfoLevel) + logFile.SetOutput(logpath) + logFile.Info("probehost2 initialized") } +// runner runs the given command with the given args and returns stdout as string. Also logs all executed commands and their exit state. func runner(remoteip string, command string, args ...string) string { - logfile.WithFields(log.Fields{ + logFile.WithFields(log.Fields{ "remote_ip": remoteip, "command": fmt.Sprint(command, args), }).Info("request initiated:") cmd, err := exec.Command(command, args...).Output() if err != nil { - logstdout.WithFields(log.Fields{ + logStdout.WithFields(log.Fields{ "remote_ip": remoteip, "command": fmt.Sprint(command, args), "error": err.Error(), }).Warn("request failed:") - logfile.WithFields(log.Fields{ + logFile.WithFields(log.Fields{ "remote_ip": remoteip, "command": fmt.Sprint(command, args), "error": err.Error(), }).Warn("request failed:") } else { - logfile.WithFields(log.Fields{ + logFile.WithFields(log.Fields{ "remote_ip": remoteip, "command": fmt.Sprint(command, args), }).Info("request succeeded:") @@ -92,20 +84,21 @@ func runner(remoteip string, command string, args ...string) string { return string(cmd) } +// validatehosts checks the given host+port combinations for validity and returns valid hosts + valid ports seperately. func validatehosts(hosts []string) ([]string, []string) { - var validhosts []string - var validports []string + var validHosts []string + var validPorts []string for _, host := range hosts { split := strings.Split(host, "_") host = split[0] if hostparse := net.ParseIP(host); hostparse != nil { - if (net.IP.IsPrivate(hostparse) || net.IP.IsLoopback(hostparse)) && allowprivate { - validhosts = append(validhosts, host) + if (net.IP.IsPrivate(hostparse) || net.IP.IsLoopback(hostparse)) && allowPrivate { + validHosts = append(validHosts, host) } else if !(net.IP.IsPrivate(hostparse) || net.IP.IsLoopback(hostparse)) { - validhosts = append(validhosts, host) + validHosts = append(validHosts, host) } } else if _, err := net.LookupIP(host); err == nil { - validhosts = append(validhosts, host) + validHosts = append(validHosts, host) } else { continue } @@ -115,17 +108,18 @@ func validatehosts(hosts []string) ([]string, []string) { port = split[1] _, err := strconv.Atoi(port) // validate if port is just an int if err == nil { - validports = append(validports, port) + validPorts = append(validPorts, port) } else { - validports = append(validports, "0") + validPorts = append(validPorts, "0") } } else { - validports = append(validports, "0") + validPorts = append(validPorts, "0") } } - return validhosts, validports + return validHosts, validPorts } +// parseopts matches the given user options to the valid optionmap. func parseopts(options []string, cmdopts map[string]string) []string { var opts []string for _, opt := range options { @@ -134,6 +128,7 @@ func parseopts(options []string, cmdopts map[string]string) []string { return opts } +// prerunner processes the incoming request to send it to runner. func prerunner(req *http.Request, cmd string, cmdopts map[string]string, defaultopts []string) string { geturl := strings.Split(req.URL.String(), "/") targets := strings.Split(geturl[2], ",") @@ -146,11 +141,9 @@ func prerunner(req *http.Request, cmd string, cmdopts map[string]string, default } var res string var args []string - var remoteaddr string - if req.Header.Get("X-Forwarded-For") != "" && !disablexforwardedfor { + remoteaddr := req.RemoteAddr + if req.Header.Get("X-Forwarded-For") != "" && !disableXForwardedFor { remoteaddr = req.Header.Get("X-Forwarded-For") - } else { - remoteaddr = req.RemoteAddr } for i, host := range hosts { runargs := append(args, opts...) @@ -163,6 +156,7 @@ func prerunner(req *http.Request, cmd string, cmdopts map[string]string, default return res } +// ping is the response handler for the ping command. It defines the allowed options. func ping(w http.ResponseWriter, req *http.Request) { cmd := "ping" cmdopts := map[string]string{ @@ -179,6 +173,7 @@ func ping(w http.ResponseWriter, req *http.Request) { } } +// mtr is the response handler for the mtr command. It defines the allowed options. func mtr(w http.ResponseWriter, req *http.Request) { cmd := "mtr" cmdopts := map[string]string{ @@ -195,6 +190,7 @@ func mtr(w http.ResponseWriter, req *http.Request) { } } +// traceroute is the response handler for the traceroute command. It defines the allowed options. func traceroute(w http.ResponseWriter, req *http.Request) { cmd := "traceroute" cmdopts := map[string]string{ @@ -211,6 +207,7 @@ func traceroute(w http.ResponseWriter, req *http.Request) { } } +// nping is the response handler for the nping command. It defines the allowed options. func nping(w http.ResponseWriter, req *http.Request) { cmd := "nping" cmdopts := map[string]string{ @@ -233,7 +230,7 @@ func main() { http.HandleFunc("/tracert/", traceroute) http.HandleFunc("/traceroute/", traceroute) http.HandleFunc("/nping/", nping) - logstdout.Info("Serving on :", listenport) - logfile.Info("Serving on :", listenport) - _ = http.ListenAndServe(fmt.Sprint(":", listenport), nil) + logStdout.Info("Serving on :", listenPort) + logFile.Info("Serving on :", listenPort) + _ = http.ListenAndServe(fmt.Sprint(":", listenPort), nil) }