commit 845b66c535ac0d4d8c7cb83aa57109ff415b2b95 Author: kmainhq Date: Fri Apr 17 12:29:00 2026 +0000 Upload files to "/" diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..6d9d031 --- /dev/null +++ b/go.mod @@ -0,0 +1,13 @@ +module tunnel + +go 1.19 + +require ( + github.com/huin/goupnp v1.3.0 + golang.org/x/crypto v0.43.0 +) + +require ( + golang.org/x/sync v0.0.0-20210220032951-036812b2e83c // indirect + golang.org/x/sys v0.37.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..49dce53 --- /dev/null +++ b/go.sum @@ -0,0 +1,9 @@ +github.com/huin/goupnp v1.3.0 h1:UvLUlWDNpoUdYzb2TCn+MuTWtcjXKSza2n6CBdQ0xXc= +github.com/huin/goupnp v1.3.0/go.mod h1:gnGPsThkYa7bFi/KWmEysQRf48l2dvR5bxr2OFckNX8= +golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04= +golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ= +golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/term v0.36.0 h1:zMPR+aF8gfksFprF/Nc/rd1wRS1EI6nDBGyWAvDzx2Q= diff --git a/main.go b/main.go new file mode 100644 index 0000000..89bf822 --- /dev/null +++ b/main.go @@ -0,0 +1,1206 @@ +// main.go (Cloudflare + UPnP + per-user keepalive + passwords + graceful shutdown) +package main + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/binary" + "encoding/json" + "encoding/pem" + "flag" + "fmt" + "io" + "log" + mrand "math/rand" + "net" + "net/http" + "os" + "os/exec" + "os/signal" + "path/filepath" + "runtime" + "strings" + "sync" + "syscall" + "time" + "errors" + "strconv" + + "github.com/huin/goupnp/dcps/internetgateway1" + "golang.org/x/crypto/bcrypt" + "golang.org/x/crypto/ssh" +) + +var ( + serverRunning = true + upnpEnabled bool + upnpMu sync.Mutex + + portFilePath string + portFileMu sync.Mutex + httpClient = &http.Client{Timeout: 15 * time.Second} + + // session tracking for graceful shutdown + sessionsMu sync.Mutex + sessionCancels = map[string]context.CancelFunc{} // key: sessionID + sessionConns = map[string]*ssh.ServerConn{} // so we can force-close connections + + cfTunnelsMu sync.Mutex + cfTunnels = map[string]cfTunnelInfo{} +) + + +type cfTunnelInfo struct { + Name, Hostname, CfgPath, CredsPath string + Cmd *exec.Cmd +} + +// ---------------- USER KEEPALIVE + PASSWORD ---------------- +type UserKeepalive struct { + Keepalive int `json:"keepalive"` + Password string `json:"password,omitempty"` + CustomTunnel bool `json:"custom_tunnel,omitempty"` +} + +func loadUserKeepalives(path string) (map[string]UserKeepalive, error) { + data, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + empty := map[string]UserKeepalive{} + data, _ := json.MarshalIndent(empty, "", " ") + _ = os.WriteFile(path, data, 0644) + return empty, nil + } + return nil, err + } + var users map[string]UserKeepalive + if err := json.Unmarshal(data, &users); err != nil { + return nil, err + } + // ensure default password hash for any user with empty password + changed := false + for name, u := range users { + if strings.TrimSpace(u.Password) == "" { + hash, _ := bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost) + u.Password = string(hash) + users[name] = u + changed = true + } + } + if changed { + _ = saveUserKeepalives(path, users) + } + return users, nil +} + +func saveUserKeepalives(path string, users map[string]UserKeepalive) error { + data, _ := json.MarshalIndent(users, "", " ") + return os.WriteFile(path, data, 0644) +} + +// ---------------- PORT JSON MANAGEMENT ---------------- +func loadPorts() ([]int, error) { + portFileMu.Lock() + defer portFileMu.Unlock() + + data, err := os.ReadFile(portFilePath) + if err != nil { + if os.IsNotExist(err) { + return []int{}, nil + } + return nil, err + } + + var ports []int + if err := json.Unmarshal(data, &ports); err != nil { + return []int{}, nil + } + return ports, nil +} + +func savePorts(ports []int) error { + portFileMu.Lock() + defer portFileMu.Unlock() + + data, _ := json.MarshalIndent(ports, "", " ") + return os.WriteFile(portFilePath, data, 0644) +} + +func addPort(p int) { + ports, _ := loadPorts() + for _, existing := range ports { + if existing == p { + return + } + } + ports = append(ports, p) + _ = savePorts(ports) +} + +func removePort(p int) { + ports, _ := loadPorts() + newPorts := []int{} + for _, existing := range ports { + if existing != p { + newPorts = append(newPorts, existing) + } + } + _ = savePorts(newPorts) +} + +func clearAllPorts() { + _ = savePorts([]int{}) +} + +func cleanupOldPorts() { + ports, _ := loadPorts() + if len(ports) == 0 { + return + } + + log.Printf("[CLEANUP] Removing %d old open ports...", len(ports)) + for _, port := range ports { + addr := fmt.Sprintf("0.0.0.0:%d", port) + conn, err := net.DialTimeout("tcp", addr, 200*time.Millisecond) + if err == nil { + conn.Close() + log.Printf("[CLEANUP] Closed lingering listener on %d", port) + } + } + clearAllPorts() +} + +// ---------------- NETWORK HELPERS ---------------- +func GetPublicIP() (string, error) { + conn, err := net.Dial("udp", "8.8.8.8:80") + if err != nil { + return "", err + } + defer conn.Close() + localAddr := conn.LocalAddr().(*net.UDPAddr) + return localAddr.IP.String(), nil +} + +func getLocalIP() (string, error) { + if runtime.GOOS == "windows" { + addrs, err := net.InterfaceAddrs() + if err != nil { + return "", err + } + for _, a := range addrs { + if ipnet, ok := a.(*net.IPNet); ok && !ipnet.IP.IsLoopback() && ipnet.IP.To4() != nil { + return ipnet.IP.String(), nil + } + } + return "127.0.0.1", nil + } + out, err := exec.Command("sh", "-c", `hostname -I | awk '{print $1}'`).Output() + if err != nil { + addrs, err2 := net.InterfaceAddrs() + if err2 != nil { + return "127.0.0.1", nil + } + for _, a := range addrs { + if ipnet, ok := a.(*net.IPNet); ok && !ipnet.IP.IsLoopback() && ipnet.IP.To4() != nil { + return ipnet.IP.String(), nil + } + } + return "127.0.0.1", nil + } + return strings.TrimSpace(string(out)), nil +} + +// ---------------- UPNP HELPERS ---------------- + +func tryInitUPnP() bool { + clients, _, err := internetgateway1.NewWANIPConnection1Clients() + if err != nil || len(clients) == 0 { + _, _, err2 := internetgateway1.NewWANPPPConnection1Clients() + if err2 != nil { + return false + } + return true + } + return true +} + +func addUPnPPortForward(localIP string, port int, lease time.Duration) (func(), error) { + upnpMu.Lock() + enabled := upnpEnabled + upnpMu.Unlock() + if !enabled { + return func() {}, nil + } + wanIPClients, _, errIP := internetgateway1.NewWANIPConnection1Clients() + wanPPPClients, _, errPPP := internetgateway1.NewWANPPPConnection1Clients() + if (errIP != nil || len(wanIPClients) == 0) && (errPPP != nil || len(wanPPPClients) == 0) { + upnpMu.Lock() + upnpEnabled = false + upnpMu.Unlock() + return func() {}, fmt.Errorf("no IGD clients discovered") + } + leaseSeconds := uint32(0) + if lease > 0 { + leaseSeconds = uint32(lease.Seconds()) + } + tryClient := func(c interface{}) (string, error) { + switch cl := c.(type) { + case *internetgateway1.WANIPConnection1: + ip, _ := cl.GetExternalIPAddress() + err := cl.AddPortMapping("", uint16(port), "TCP", uint16(port), localIP, true, "ssh-tunnel", leaseSeconds) + return ip, err + case *internetgateway1.WANPPPConnection1: + ip, _ := cl.GetExternalIPAddress() + err := cl.AddPortMapping("", uint16(port), "TCP", uint16(port), localIP, true, "ssh-tunnel", leaseSeconds) + return ip, err + default: + return "", fmt.Errorf("unsupported client type") + } + } + for _, c := range wanIPClients { + ip, err := tryClient(c) + if err == nil { + cleanup := func() { _ = c.DeletePortMapping("", uint16(port), "TCP") } + log.Printf("[UPnP] Mapped external port %d -> %s:%d (external IP %s)", port, localIP, port, ip) + return cleanup, nil + } + } + for _, c := range wanPPPClients { + ip, err := tryClient(c) + if err == nil { + cleanup := func() { _ = c.DeletePortMapping("", uint16(port), "TCP") } + log.Printf("[UPnP] Mapped external port %d -> %s:%d (external IP %s)", port, localIP, port, ip) + return cleanup, nil + } + } + upnpMu.Lock() + upnpEnabled = false + upnpMu.Unlock() + return func() {}, fmt.Errorf("UPnP mapping failed on all discovered IGD clients") +} + +// ---------------- LOG HELPERS ---------------- +func sendLogToClient(channel ssh.Channel, msg string) error { + _, err := channel.Write([]byte(fmt.Sprintf("[LOG] %s\n", msg))) + return err +} +func sendWarnToClient(channel ssh.Channel, msg string) error { + _, err := channel.Write([]byte(fmt.Sprintf("[WARN] %s\n", msg))) + return err +} +func sendErrorToClient(channel ssh.Channel, msg string) error { + _, err := channel.Write([]byte(fmt.Sprintf("[ERR] %s\n", msg))) + return err +} + +// ---------------- SSH ENV ARGS PARSING ---------------- +func parseSSHString(b []byte) (string, []byte, error) { + if len(b) < 4 { + return "", nil, fmt.Errorf("payload too short") + } + l := int(binary.BigEndian.Uint32(b[:4])) + if len(b) < 4+l { + return "", nil, fmt.Errorf("payload shorter than length") + } + return string(b[4 : 4+l]), b[4+l:], nil +} + +func parseArgsEnv(payload []byte) map[string]string { + name, rest, err := parseSSHString(payload) + if err != nil || !strings.EqualFold(name, "ARGS") { + return nil + } + + val, _, err := parseSSHString(rest) + if err != nil { + return nil + } + + val = strings.Trim(val, `"'`) + argsMap := make(map[string]string) + pairs := strings.Split(val, ";") + for _, pair := range pairs { + pair = strings.TrimSpace(pair) + if pair == "" { + continue + } + parts := strings.SplitN(pair, "=", 2) + if len(parts) != 2 { + continue + } + key := strings.ToLower(strings.TrimSpace(parts[0])) + value := strings.TrimSpace(parts[1]) + if key != "" && value != "" { + argsMap[key] = value + } + } + return argsMap +} + +// ---------------- KEY HELPERS ---------------- +func generateKey() ([]byte, error) { + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, err + } + privateKeyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(privateKey), + }) + return privateKeyPEM, nil +} +func loadOrCreateKey(pemPath string) ([]byte, error) { + if pemPath != "" { + return os.ReadFile(pemPath) + } + return generateKey() +} + +// ---------------- Cloudflared helpers ---------------- + +// randomAlphaNum12 returns 12-char lowercase a-z0-9 string +func randomAlphaNum12() (string, error) { + b := make([]byte, 9) + if _, err := rand.Read(b); err != nil { + return "", err + } + chars := "abcdefghijklmnopqrstuvwxyz0123456789" + out := make([]byte, 12) + for i := range out { + out[i] = chars[int(b[i%len(b)])%len(chars)] + } + return string(out), nil +} + +func fileExists(filePath string) bool { + _, err := os.Stat(filePath) + return !errors.Is(err, os.ErrNotExist) +} + +// startCloudflaredForPort: now takes baseDir (directory where credentials+cfg+ports.json are stored). +// Creates credentials file under baseDir, temp config under baseDir, starts tunnel process and optionally creates dns record. +func startCloudflaredForPort(ctx context.Context, baseDir, cfCertPath, cfHost string, localPort int, protocol string, id string) (tunnelName, hostname string, cmdRun *exec.Cmd, cfgPath string, credsPath string, err error) { + if id == "" { + newid, err := randomAlphaNum12() + if err != nil { + return "", "", nil, "", "", err + } + id = newid + } + + tunnelName = "ssh-" + id + hostname = fmt.Sprintf("%s.%s", id, cfHost) + + // ensure baseDir exists + if err := os.MkdirAll(baseDir, 0700); err != nil { + return "", "", nil, "", "", fmt.Errorf("failed to create base dir: %v", err) + } + + // credentials path: baseDir/.json + credsPath = filepath.Join(baseDir, fmt.Sprintf("%s.json", tunnelName)) + + //if fileExists(credsPath) { + // return "", "", nil, "", "", fmt.Errorf("Tunnel exists already!") + //} + + // 1) create tunnel with specified credentials file + cmdCreate := exec.CommandContext(ctx, "cloudflared", "tunnel", "--credentials-file", credsPath, "create", tunnelName) + out, err := cmdCreate.CombinedOutput() + if err != nil { + // include output for debugging + return "", "", nil, "", "", fmt.Errorf("cloudflared tunnel create failed: %v: %s", err, string(out)) + } + + // 2) write config YAML into baseDir + cfgPath = filepath.Join(baseDir, fmt.Sprintf("%s.yml", tunnelName)) + service := fmt.Sprintf("%s://localhost:%d", protocol, localPort) + yaml := fmt.Sprintf(`tunnel: %s +credentials-file: %s + +ingress: + - hostname: %s + service: %s + - service: http_status:404 +`, tunnelName, credsPath, hostname, service) + + if err := os.WriteFile(cfgPath, []byte(yaml), 0600); err != nil { + return "", "", nil, "", "", fmt.Errorf("failed writing cfg: %v", err) + } + + // 3) run tunnel + cmdRun = exec.CommandContext(ctx, "cloudflared", "--config", cfgPath, "tunnel", "run", tunnelName) + cmdRun.Stdout = nil + cmdRun.Stderr = nil + if err := cmdRun.Start(); err != nil { + return "", "", nil, "", "", fmt.Errorf("cloudflared run failed: %v", err) + } + + // 4) create DNS route (best-effort) + cmdDNS := exec.Command("cloudflared", "tunnel", "route", "dns", tunnelName, hostname) + if out2, err2 := cmdDNS.CombinedOutput(); err2 != nil { + log.Printf("[CF] route dns returned error (continuing): %v — output: %s", err2, string(out2)) + } + + return tunnelName, hostname, cmdRun, cfgPath, credsPath, nil +} + +// ---------------- Cloudflare API helpers ---------------- + +// getZoneID uses Global API Key auth (X-Auth-Email + X-Auth-Key) to fetch the zone id for zoneName. +func getZoneID(ctx context.Context, authEmail, authKey, zoneName string) (string, error) { + req, err := http.NewRequestWithContext(ctx, "GET", "https://api.cloudflare.com/client/v4/zones?name="+zoneName, nil) + if err != nil { + return "", err + } + req.Header.Set("X-Auth-Email", strings.TrimSpace(authEmail)) + req.Header.Set("X-Auth-Key", strings.TrimSpace(authKey)) + req.Header.Set("Content-Type", "application/json") + + resp, err := httpClient.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + body, _ := io.ReadAll(resp.Body) + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return "", fmt.Errorf("getZoneID: non-2xx status %d: %s", resp.StatusCode, string(body)) + } + + var parsed struct { + Success bool `json:"success"` + Errors json.RawMessage `json:"errors"` + Result []struct { + ID string `json:"id"` + } `json:"result"` + } + if err := json.Unmarshal(body, &parsed); err != nil { + return "", fmt.Errorf("getZoneID: decode error: %w (body: %s)", err, string(body)) + } + if !parsed.Success || len(parsed.Result) == 0 { + return "", fmt.Errorf("getZoneID: no zone found or api failure: %s", string(parsed.Errors)) + } + return parsed.Result[0].ID, nil +} + +// getDNSRecordID finds a DNS record ID by name in a zone +func getDNSRecordID(ctx context.Context, authEmail, authKey, zoneID, recordName string) (string, error) { + url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records?type=CNAME&name=%s", zoneID, recordName) + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return "", err + } + req.Header.Set("X-Auth-Email", strings.TrimSpace(authEmail)) + req.Header.Set("X-Auth-Key", strings.TrimSpace(authKey)) + req.Header.Set("Content-Type", "application/json") + + resp, err := httpClient.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + body, _ := io.ReadAll(resp.Body) + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return "", fmt.Errorf("getDNSRecordID: non-2xx status %d: %s", resp.StatusCode, string(body)) + } + + var parsed struct { + Success bool `json:"success"` + Errors json.RawMessage `json:"errors"` + Result []struct { + ID string `json:"id"` + Name string `json:"name"` + } `json:"result"` + } + if err := json.Unmarshal(body, &parsed); err != nil { + return "", fmt.Errorf("getDNSRecordID: decode error: %w (body: %s)", err, string(body)) + } + if !parsed.Success { + return "", fmt.Errorf("getDNSRecordID: api error: %s", string(parsed.Errors)) + } + if len(parsed.Result) == 0 { + return "", fmt.Errorf("getDNSRecordID: record %s not found in zone %s", recordName, zoneID) + } + return parsed.Result[0].ID, nil +} + +// deleteDNSRecord deletes a DNS record by ID +func deleteDNSRecord(ctx context.Context, authEmail, authKey, zoneID, recordID string) error { + url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records/%s", zoneID, recordID) + req, err := http.NewRequestWithContext(ctx, "DELETE", url, nil) + if err != nil { + return err + } + req.Header.Set("X-Auth-Email", strings.TrimSpace(authEmail)) + req.Header.Set("X-Auth-Key", strings.TrimSpace(authKey)) + req.Header.Set("Content-Type", "application/json") + resp, err := httpClient.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + body, _ := io.ReadAll(resp.Body) + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return fmt.Errorf("deleteDNSRecord: non-2xx status %d: %s", resp.StatusCode, string(body)) + } + var parsed struct { + Success bool `json:"success"` + Errors json.RawMessage `json:"errors"` + } + if err := json.Unmarshal(body, &parsed); err != nil { + return fmt.Errorf("deleteDNSRecord: decode error: %w (body: %s)", err, string(body)) + } + if !parsed.Success { + return fmt.Errorf("deleteDNSRecord: api error: %s", string(parsed.Errors)) + } + return nil +} + +// ---------------- stopCloudflared (clean-up) ---------------- + +// stopCloudflared: stops process, deletes DNS via API (if creds provided), deletes tunnel object, and removes files +func stopCloudflared(tunnelName, hostname, cfgPath, credsPath string, cmdRun *exec.Cmd, cfEmail, cfGlobalKey, cfHost string) { + log.Printf("[CLEANUP] Stopping Cloudflare tunnel %s (%s)...", tunnelName, hostname) + + // --- 1) Kill process if still running --- + if cmdRun != nil && cmdRun.Process != nil { + _ = cmdRun.Process.Kill() + done := make(chan error, 1) + go func() { done <- cmdRun.Wait() }() + select { + case <-done: + log.Printf("[CLEANUP] Process for %s exited", tunnelName) + case <-time.After(5 * time.Second): + log.Printf("[CLEANUP] Timeout waiting for process %s to exit", tunnelName) + } + } else { + log.Printf("[CLEANUP] No running process for %s", tunnelName) + } + + // --- 2) Always remove config/creds files first --- + removeFile := func(path string, label string) { + if path == "" { + return + } + if err := os.Remove(path); err != nil { + if !os.IsNotExist(err) { + log.Printf("[CLEANUP] Failed to remove %s (%s): %v", label, path, err) + } + } else { + log.Printf("[CLEANUP] Deleted %s file: %s", label, path) + } + } + + removeFile(cfgPath, "config") + removeFile(credsPath, "credentials") + + // --- 3) Derive root zone (strip subdomains) --- + parts := strings.Split(cfHost, ".") + zoneName := cfHost + if len(parts) > 2 { + zoneName = strings.Join(parts[len(parts)-2:], ".") // e.g. tunnels.linuxdummy.win → linuxdummy.win + } + + // --- 4) Delete DNS record via Cloudflare API (best effort) --- + if cfEmail != "" && cfGlobalKey != "" && zoneName != "" { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + zoneID, err := getZoneID(ctx, cfEmail, cfGlobalKey, zoneName) + if err != nil { + log.Printf("[CF-API] getZoneID failed: %v", err) + } else { + recID, err2 := getDNSRecordID(ctx, cfEmail, cfGlobalKey, zoneID, hostname) + if err2 != nil { + log.Printf("[CF-API] getDNSRecordID failed (likely already removed): %v", err2) + } else if err3 := deleteDNSRecord(ctx, cfEmail, cfGlobalKey, zoneID, recID); err3 != nil { + log.Printf("[CF-API] deleteDNSRecord failed: %v", err3) + } else { + log.Printf("[CF-API] Deleted DNS record %s (id=%s)", hostname, recID) + } + } + } else { + log.Printf("[CF-API] CF credentials or zone not provided; skipping DNS cleanup for %s", hostname) + } + + // --- 5) Delete tunnel object via cloudflared CLI (optional) --- + if tunnelName != "" { + if err := exec.Command("cloudflared", "tunnel", "delete", "-f", tunnelName).Run(); err != nil { + log.Printf("[CLEANUP] cloudflared tunnel delete failed for %s: %v", tunnelName, err) + } + } + + log.Printf("[CLEANUP] Tunnel %s fully cleaned up.", tunnelName) +} + +// ---------------- SSH ENV ARGS PARSING & HANDLERS ---------------- + +// handleForwardRequest now accepts ctx and baseDir so tunnels/files are created inside baseDir and can be canceled by ctx. +func handleForwardRequest(ctx context.Context, req *ssh.Request, conn *ssh.ServerConn, tunnelTimeout time.Duration, portStart, portEnd int, sessionChannel ssh.Channel, host *string, cfEnabled bool, baseDir, cfCertPath, cfHostname, cfEmail, cfGlobalKey string, extraArgs map[string]string, customTunnelN bool) { + type ForwardPayload struct { + Addr string + Port uint32 + } + var payload ForwardPayload + if err := ssh.Unmarshal(req.Payload, &payload); err != nil { + req.Reply(false, nil) + return + } + + var pubListener net.Listener + var publicPort int + + if customTunnelN { + if p, ok := extraArgs["port"]; ok && p != "" { + publicPort, _ = strconv.Atoi(p) + } else { + for i := 0; i < 10; i++ { + if p, ok := extraArgs["port"]; ok && p != "" { + publicPort, _ = strconv.Atoi(p) + break + } + time.Sleep(50 * time.Millisecond) + } + } + } else { + publicPort = mrand.Intn(portEnd-portStart+1) + portStart + } + + for { + addr := fmt.Sprintf("0.0.0.0:%d", publicPort) + l, err := net.Listen("tcp", addr) + if err != nil { + publicPort = mrand.Intn(portEnd-portStart+1) + portStart + continue + } + pubListener = l + break + } + + addPort(publicPort) + + resp := struct{ Port uint32 }{Port: uint32(publicPort)} + if err := req.Reply(true, ssh.Marshal(resp)); err != nil { + pubListener.Close() + removePort(publicPort) + return + } + + var cleanupMu sync.Mutex + cleanup := func() {} + done := make(chan struct{}) + + // ensure cleanup when ctx is done, or when done channel closes + go func() { + localIP, _ := getLocalIP() + c, err := addUPnPPortForward(localIP, publicPort, tunnelTimeout) + if err == nil { + cleanupMu.Lock() + cleanup = c + cleanupMu.Unlock() + } + select { + case <-done: + cleanupMu.Lock() + cleanup() + cleanupMu.Unlock() + case <-ctx.Done(): + // session ended by server or SSH connection closed -> close listener to force teardown + _ = pubListener.Close() + cleanupMu.Lock() + cleanup() + cleanupMu.Unlock() + } + }() + + // Cloudflared per-forward + var cfName, cfHostName, cfgPath, credsPath string + var cfCmd *exec.Cmd + if cfEnabled { + protocol := "tcp" + name := "" + if extraArgs != nil { + if p, ok := extraArgs["protocol"]; ok && p != "" { + protocol = p + } else { + for i := 0; i < 10; i++ { + if p, ok := extraArgs["protocol"]; ok && p != "" { + protocol = p + break + } + time.Sleep(50 * time.Millisecond) + } + } + if customTunnelN { + if n, ok := extraArgs["tunnelname"]; ok && n != "" { + name = n + } else { + for i := 0; i < 10; i++ { + if n, ok := extraArgs["tunnelname"]; ok && n != "" { + name = n + break + } + time.Sleep(50 * time.Millisecond) + } + } + } + } + tn, hn, cmdRun, cfg, creds, err := startCloudflaredForPort(ctx, baseDir, cfCertPath, cfHostname, publicPort, protocol, name) + if err != nil { + msg := fmt.Sprintf("Tunnel error on port %d: %v, exiting...", publicPort, err) + log.Println("[CF]", msg) + _ = sendErrorToClient(sessionChannel, msg) + req.Reply(false, nil) + _ = pubListener.Close() + removePort(publicPort) + close(done) + sessionChannel.Close() + return + } + cfName = tn + cfHostName = hn + cfgPath = cfg + credsPath = creds + cfCmd = cmdRun + _ = sendLogToClient(sessionChannel, fmt.Sprintf("Tunnel ready: %s", cfHostName)) + } + + // Register the new Cloudflare tunnel for cleanup tracking + cfTunnelsMu.Lock() + cfTunnels[cfName] = cfTunnelInfo{ + Name: cfName, Hostname: cfHostName, + CfgPath: cfgPath, CredsPath: credsPath, Cmd: cfCmd, + } + cfTunnelsMu.Unlock() + + log.Printf("Created public listener %s:%d", *host, publicPort) + + if tunnelTimeout > 0 { + go func() { + select { + case <-time.After(tunnelTimeout): + pubListener.Close() + case <-done: + case <-ctx.Done(): + } + }() + } + + // This goroutine watches the session channel to detect session-level EOF and close the listener. + go func() { + buf := make([]byte, 512) + for { + if _, err := sessionChannel.Read(buf); err != nil { + _ = pubListener.Close() + return + } + } + }() + + // Main accept loop. When pubListener closes, Accept returns error and we reach cleanup. + go func() { + defer func() { + select { + case <-done: + default: + close(done) + } + cleanupMu.Lock() + cleanup() + cleanupMu.Unlock() + if cfName != "" { + // attempt to stop the cloudflared tunnel and remove DNS + files + stopCloudflared(cfName, cfHostName, cfgPath, credsPath, cfCmd, cfEmail, cfGlobalKey, cfHostname) + } + removePort(publicPort) + _ = sendWarnToClient(sessionChannel, fmt.Sprintf("Tunnel closed (port %d)", publicPort)) + }() + + for { + publicConn, err := pubListener.Accept() + if err != nil { + return + } + go func(pc net.Conn) { + ch, requests, err := conn.OpenChannel("forwarded-tcpip", ssh.Marshal(struct { + Addr string + Port uint32 + OriginAddr string + OriginPort uint32 + }{ + Addr: payload.Addr, + Port: payload.Port, + OriginAddr: pc.RemoteAddr().(*net.TCPAddr).IP.String(), + OriginPort: uint32(pc.RemoteAddr().(*net.TCPAddr).Port), + })) + if err != nil { + pc.Close() + return + } + go ssh.DiscardRequests(requests) + go io.Copy(ch, pc) + go io.Copy(pc, ch) + }(publicConn) + } + }() +} + +// handleSSH accepts an ssh connection and processes channels/requests +// now takes baseDir so that per-session tunnels use that directory and can be cleaned up on session end. +func handleSSH(nConn net.Conn, config *ssh.ServerConfig, tunnelTimeout time.Duration, portStart, portEnd int, host *string, cfEnabled bool, baseDir, cfCertPath, cfHostname, cfEmail, cfGlobalKey string, usersFile string) { + sshConn, chans, reqs, err := ssh.NewServerConn(nConn, config) + if err != nil { + log.Printf("SSH handshake failed: %v", err) + return + } + // register session for global shutdown handling + sessionID := fmt.Sprintf("%s-%d", sshConn.RemoteAddr().String(), time.Now().UnixNano()) + sessionsMu.Lock() + sessionConns[sessionID] = sshConn + ctx, cancel := context.WithCancel(context.Background()) + sessionCancels[sessionID] = cancel + sessionsMu.Unlock() + + defer func() { + sessionsMu.Lock() + delete(sessionConns, sessionID) + delete(sessionCancels, sessionID) + sessionsMu.Unlock() + sshConn.Close() + }() + + log.Printf("New connection from %s (user: %s)", sshConn.RemoteAddr(), sshConn.User()) + + userKeepalives, _ := loadUserKeepalives(usersFile) + username := sshConn.User() + customTunnelN := false + if username != "" { + if uk, ok := userKeepalives[username]; ok { + tunnelTimeout = time.Duration(uk.Keepalive) * time.Second + customTunnelN = uk.CustomTunnel + } + } + + var sessionChan ssh.Channel + var sessionReqs <-chan *ssh.Request + extraArgs := make(map[string]string) + + for newCh := range chans { + if newCh.ChannelType() != "session" { + newCh.Reject(ssh.UnknownChannelType, "only session supported") + continue + } + ch, reqs2, err := newCh.Accept() + if err != nil { + return + } + sessionChan = ch + sessionReqs = reqs2 + break + } + if sessionChan == nil { + cancel() + return + } + + // --- PASSWORD PROMPT --- + if username != "" { + if uk, ok := userKeepalives[username]; ok { + _, _ = sessionChan.Write([]byte("Password: ")) + + var passBuf []byte + tmp := make([]byte, 1) + for { + n, err := sessionChan.Read(tmp) + if err != nil || n == 0 { + _ = sessionChan.Close() + cancel() + return + } + if tmp[0] == '\n' || tmp[0] == '\r' { + break + } + passBuf = append(passBuf, tmp[0]) + } + passInput := strings.TrimSpace(string(passBuf)) + if bcrypt.CompareHashAndPassword([]byte(uk.Password), []byte(passInput)) != nil { + _, _ = sessionChan.Write([]byte("Invalid password. Disconnecting.\n")) + _ = sessionChan.Close() + cancel() + return + } + sendLogToClient(sessionChan, "Authenticated.") + } + } + + go func() { + for req := range sessionReqs { + switch req.Type { + case "shell", "exec": + req.Reply(true, nil) + case "env": + if args := parseArgsEnv(req.Payload); args != nil { + for k, v := range args { + extraArgs[k] = v + } + } + req.Reply(true, nil) + default: + req.Reply(false, nil) + } + } + }() + + // --- SESSION INPUT LOOP (password change / exit) --- + go func() { + buf := make([]byte, 1) + var lineBuf []byte + for { + n, err := sessionChan.Read(buf) + if err != nil || n == 0 { + cancel() + return + } + b := buf[0] + if b == '\r' || b == '\n' { + line := strings.TrimSpace(string(lineBuf)) + lineBuf = nil + + // .setpass command + if strings.HasPrefix(line, ".setpass ") && username != "" { + newPass := strings.TrimSpace(strings.TrimPrefix(line, ".setpass ")) + if newPass == "" { + _ = sendLogToClient(sessionChan, "Usage: .setpass ") + continue + } + users, err := loadUserKeepalives(usersFile) + if err != nil { + _ = sendWarnToClient(sessionChan, "failed to load users file") + continue + } + if _, ok := users[username]; !ok { + _ = sendWarnToClient(sessionChan, "username not found in users.json; cannot set password") + continue + } + hash, _ := bcrypt.GenerateFromPassword([]byte(newPass), bcrypt.DefaultCost) + u := users[username] + u.Password = string(hash) + users[username] = u + if err := saveUserKeepalives(usersFile, users); err != nil { + _ = sendWarnToClient(sessionChan, "failed to save new password") + continue + } + _ = sendLogToClient(sessionChan, "Password updated.") + } + + if line == ".exit" { + _ = sessionChan.Close() + cancel() + return + } + continue + } + lineBuf = append(lineBuf, b) + } + }() + + if tunnelTimeout > 0 { + h := int(tunnelTimeout.Hours()) + m := int(tunnelTimeout.Minutes()) % 60 + s := int(tunnelTimeout.Seconds()) % 60 + _ = sendLogToClient(sessionChan, fmt.Sprintf("Tunnel will be active for %02d:%02d:%02d (hh:mm:ss).", h, m, s)) + } else { + _ = sendLogToClient(sessionChan, "Tunnel has no time limit and will stay active until disconnect.") + } + + go ssh.DiscardRequests(reqs) + + for req := range reqs { + if req.Type == "tcpip-forward" { + go handleForwardRequest(ctx, req, sshConn, tunnelTimeout, portStart, portEnd, sessionChan, host, cfEnabled, baseDir, cfCertPath, cfHostname, cfEmail, cfGlobalKey, extraArgs, customTunnelN) + } else { + req.Reply(false, nil) + } + } + + cancel() +} + +// ---------------- SERVER ---------------- + +func startServerCloudflared(privateBytes []byte, tunnelTimeout time.Duration, portStart, portEnd int, host *string, cfEnabled bool, baseDir, cfCertPath, cfHostname, cfEmail, cfGlobalKey, usersFile string) error { + signer, err := ssh.ParsePrivateKey(privateBytes) + if err != nil { + return err + } + + config := &ssh.ServerConfig{NoClientAuth: true} + config.AddHostKey(signer) + + listener, err := net.Listen("tcp", ":2222") + if err != nil { + return fmt.Errorf("failed to listen on :2222: %v", err) + } + defer listener.Close() + log.Println("SSH server listening on :2222") + + // Handle Ctrl+C / SIGTERM for graceful shutdown + sigc := make(chan os.Signal, 1) + signal.Notify(sigc, os.Interrupt, syscall.SIGTERM) + go func() { + <-sigc + log.Println("Shutting down server...") + serverRunning = false + + // Stop accepting new connections + listener.Close() + + // Cancel all active SSH sessions + sessionsMu.Lock() + for id, cancel := range sessionCancels { + log.Printf("[SHUTDOWN] cancelling session %s", id) + cancel() + } + for id, c := range sessionConns { + log.Printf("[SHUTDOWN] closing ssh connection %s", id) + _ = c.Close() + } + sessionsMu.Unlock() + + // Stop all active Cloudflare tunnels + cfTunnelsMu.Lock() + for _, t := range cfTunnels { + log.Printf("[SHUTDOWN] stopping tunnel %s (%s)...", t.Name, t.Hostname) + stopCloudflared(t.Name, t.Hostname, t.CfgPath, t.CredsPath, t.Cmd, cfEmail, cfGlobalKey, cfHostname) + } + cfTunnels = map[string]cfTunnelInfo{} + cfTunnelsMu.Unlock() + + log.Println("[SHUTDOWN] All tunnels and sessions cleaned up. Exiting...") + os.Exit(0) + }() + + for serverRunning { + conn, err := listener.Accept() + if err != nil { + if serverRunning { + log.Printf("Accept error: %v", err) + } + continue + } + + go handleSSH(conn, config, tunnelTimeout, portStart, portEnd, host, cfEnabled, baseDir, cfCertPath, cfHostname, cfEmail, cfGlobalKey, usersFile) + } + + return nil +} + +// ---------------- MAIN ---------------- + +func main() { + mrand.Seed(time.Now().UnixNano()) + + pemPath := flag.String("k", "", "Path to private key (PEM)") + tunnelTimeoutSec := flag.Int("t", 0, "Tunnel default timeout in seconds (0 = no limit)") + portStart := flag.Int("ps", 10000, "Start of public port range") + portEnd := flag.Int("pe", 60000, "End of public port range") + useUPnP := flag.Bool("upnp", true, "Enable UPnP port mapping (auto-disable if unavailable)") + + cfEnable := flag.Bool("cf", false, "Enable Cloudflare tunneling") + cfCert := flag.String("cf-cert", "", "Path to Cloudflare cert.pem (used as base dir for credentials)") + cfHost := flag.String("cf-hostname", "", "Base hostname for Cloudflare tunnels (example.com)") + + cfGlobalKey := flag.String("cf-gk", "", "Cloudflare Global API key (X-Auth-Key) for DNS deletion") + cfEmail := flag.String("cf-em", "", "Cloudflare account email (X-Auth-Email) for DNS deletion") + + usersFile := flag.String("users", "users.json", "Path to JSON file mapping usernames to keepalive + password") + + ip, err := GetPublicIP() + if err != nil { + ip = "0.0.0.0" + } + host := flag.String("host", ip, "Host to display") + + flag.Parse() + + if *portEnd <= *portStart { + log.Fatalf("invalid port range: pe must be > ps") + } + + // ensure usersFile exists and defaults filled + if _, err := os.Stat(*usersFile); os.IsNotExist(err) { + empty := map[string]UserKeepalive{} + data, _ := json.MarshalIndent(empty, "", " ") + _ = os.WriteFile(*usersFile, data, 0644) + } + // loadUserKeepalives will populate missing password fields with default hash and save + _, _ = loadUserKeepalives(*usersFile) + + // Determine base directory for all created files + var baseDir string + if *cfCert != "" { + // use directory of provided cert path + baseDir = filepath.Dir(*cfCert) + } else { + home := os.Getenv("HOME") + if home == "" { + // fallback to current working directory + cwd, _ := os.Getwd() + baseDir = filepath.Join(cwd, ".ssh-tunnels") + } else { + baseDir = filepath.Join(home, ".ssh-tunnels") + } + } + // ensure baseDir exists and is secure + if err := os.MkdirAll(baseDir, 0700); err != nil { + log.Fatalf("failed to create base dir %s: %v", baseDir, err) + } + + // set the port file path inside baseDir + portFilePath = filepath.Join(baseDir, "ports.json") + + // Clean old ports (uses portFilePath now set) + cleanupOldPorts() + + upnpMu.Lock() + upnpEnabled = *useUPnP + upnpMu.Unlock() + if *cfEnable { + // Cloudflare-only mode: disable UPnP/other external mapping + upnpMu.Lock() + upnpEnabled = false + upnpMu.Unlock() + } else { + if upnpEnabled && !tryInitUPnP() { + upnpMu.Lock() + upnpEnabled = false + upnpMu.Unlock() + log.Println("[UPnP] not available; disabled") + } + } + + if *cfEnable { + if *cfCert == "" || *cfHost == "" { + log.Fatal("When -cf is used you must provide -cf-cert and -cf-hostname") + } + // verify cloudflared exists + if _, err := exec.LookPath("cloudflared"); err != nil { + log.Fatalf("cloudflared binary not found in PATH: %v", err) + } + } + + keyBytes, err := loadOrCreateKey(*pemPath) + if err != nil { + log.Fatalf("Failed to load/generate key: %v", err) + } + tunnelTimeout := time.Duration(*tunnelTimeoutSec) * time.Second + + // start server (passes usersFile for password/keepalive handling) + if err := startServerCloudflared(keyBytes, tunnelTimeout, *portStart, *portEnd, host, *cfEnable, baseDir, *cfCert, *cfHost, *cfEmail, *cfGlobalKey, *usersFile); err != nil { + log.Fatalf("server error: %v", err) + } +}