Files
SSH-Tunnel/main.go
2026-04-17 12:29:00 +00:00

1207 lines
34 KiB
Go

// main.go (Cloudflare + UPnP + per-user keepalive + passwords + graceful shutdown)
package main
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/binary"
"encoding/json"
"encoding/pem"
"flag"
"fmt"
"io"
"log"
mrand "math/rand"
"net"
"net/http"
"os"
"os/exec"
"os/signal"
"path/filepath"
"runtime"
"strings"
"sync"
"syscall"
"time"
"errors"
"strconv"
"github.com/huin/goupnp/dcps/internetgateway1"
"golang.org/x/crypto/bcrypt"
"golang.org/x/crypto/ssh"
)
var (
serverRunning = true
upnpEnabled bool
upnpMu sync.Mutex
portFilePath string
portFileMu sync.Mutex
httpClient = &http.Client{Timeout: 15 * time.Second}
// session tracking for graceful shutdown
sessionsMu sync.Mutex
sessionCancels = map[string]context.CancelFunc{} // key: sessionID
sessionConns = map[string]*ssh.ServerConn{} // so we can force-close connections
cfTunnelsMu sync.Mutex
cfTunnels = map[string]cfTunnelInfo{}
)
type cfTunnelInfo struct {
Name, Hostname, CfgPath, CredsPath string
Cmd *exec.Cmd
}
// ---------------- USER KEEPALIVE + PASSWORD ----------------
type UserKeepalive struct {
Keepalive int `json:"keepalive"`
Password string `json:"password,omitempty"`
CustomTunnel bool `json:"custom_tunnel,omitempty"`
}
func loadUserKeepalives(path string) (map[string]UserKeepalive, error) {
data, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
empty := map[string]UserKeepalive{}
data, _ := json.MarshalIndent(empty, "", " ")
_ = os.WriteFile(path, data, 0644)
return empty, nil
}
return nil, err
}
var users map[string]UserKeepalive
if err := json.Unmarshal(data, &users); err != nil {
return nil, err
}
// ensure default password hash for any user with empty password
changed := false
for name, u := range users {
if strings.TrimSpace(u.Password) == "" {
hash, _ := bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost)
u.Password = string(hash)
users[name] = u
changed = true
}
}
if changed {
_ = saveUserKeepalives(path, users)
}
return users, nil
}
func saveUserKeepalives(path string, users map[string]UserKeepalive) error {
data, _ := json.MarshalIndent(users, "", " ")
return os.WriteFile(path, data, 0644)
}
// ---------------- PORT JSON MANAGEMENT ----------------
func loadPorts() ([]int, error) {
portFileMu.Lock()
defer portFileMu.Unlock()
data, err := os.ReadFile(portFilePath)
if err != nil {
if os.IsNotExist(err) {
return []int{}, nil
}
return nil, err
}
var ports []int
if err := json.Unmarshal(data, &ports); err != nil {
return []int{}, nil
}
return ports, nil
}
func savePorts(ports []int) error {
portFileMu.Lock()
defer portFileMu.Unlock()
data, _ := json.MarshalIndent(ports, "", " ")
return os.WriteFile(portFilePath, data, 0644)
}
func addPort(p int) {
ports, _ := loadPorts()
for _, existing := range ports {
if existing == p {
return
}
}
ports = append(ports, p)
_ = savePorts(ports)
}
func removePort(p int) {
ports, _ := loadPorts()
newPorts := []int{}
for _, existing := range ports {
if existing != p {
newPorts = append(newPorts, existing)
}
}
_ = savePorts(newPorts)
}
func clearAllPorts() {
_ = savePorts([]int{})
}
func cleanupOldPorts() {
ports, _ := loadPorts()
if len(ports) == 0 {
return
}
log.Printf("[CLEANUP] Removing %d old open ports...", len(ports))
for _, port := range ports {
addr := fmt.Sprintf("0.0.0.0:%d", port)
conn, err := net.DialTimeout("tcp", addr, 200*time.Millisecond)
if err == nil {
conn.Close()
log.Printf("[CLEANUP] Closed lingering listener on %d", port)
}
}
clearAllPorts()
}
// ---------------- NETWORK HELPERS ----------------
func GetPublicIP() (string, error) {
conn, err := net.Dial("udp", "8.8.8.8:80")
if err != nil {
return "", err
}
defer conn.Close()
localAddr := conn.LocalAddr().(*net.UDPAddr)
return localAddr.IP.String(), nil
}
func getLocalIP() (string, error) {
if runtime.GOOS == "windows" {
addrs, err := net.InterfaceAddrs()
if err != nil {
return "", err
}
for _, a := range addrs {
if ipnet, ok := a.(*net.IPNet); ok && !ipnet.IP.IsLoopback() && ipnet.IP.To4() != nil {
return ipnet.IP.String(), nil
}
}
return "127.0.0.1", nil
}
out, err := exec.Command("sh", "-c", `hostname -I | awk '{print $1}'`).Output()
if err != nil {
addrs, err2 := net.InterfaceAddrs()
if err2 != nil {
return "127.0.0.1", nil
}
for _, a := range addrs {
if ipnet, ok := a.(*net.IPNet); ok && !ipnet.IP.IsLoopback() && ipnet.IP.To4() != nil {
return ipnet.IP.String(), nil
}
}
return "127.0.0.1", nil
}
return strings.TrimSpace(string(out)), nil
}
// ---------------- UPNP HELPERS ----------------
func tryInitUPnP() bool {
clients, _, err := internetgateway1.NewWANIPConnection1Clients()
if err != nil || len(clients) == 0 {
_, _, err2 := internetgateway1.NewWANPPPConnection1Clients()
if err2 != nil {
return false
}
return true
}
return true
}
func addUPnPPortForward(localIP string, port int, lease time.Duration) (func(), error) {
upnpMu.Lock()
enabled := upnpEnabled
upnpMu.Unlock()
if !enabled {
return func() {}, nil
}
wanIPClients, _, errIP := internetgateway1.NewWANIPConnection1Clients()
wanPPPClients, _, errPPP := internetgateway1.NewWANPPPConnection1Clients()
if (errIP != nil || len(wanIPClients) == 0) && (errPPP != nil || len(wanPPPClients) == 0) {
upnpMu.Lock()
upnpEnabled = false
upnpMu.Unlock()
return func() {}, fmt.Errorf("no IGD clients discovered")
}
leaseSeconds := uint32(0)
if lease > 0 {
leaseSeconds = uint32(lease.Seconds())
}
tryClient := func(c interface{}) (string, error) {
switch cl := c.(type) {
case *internetgateway1.WANIPConnection1:
ip, _ := cl.GetExternalIPAddress()
err := cl.AddPortMapping("", uint16(port), "TCP", uint16(port), localIP, true, "ssh-tunnel", leaseSeconds)
return ip, err
case *internetgateway1.WANPPPConnection1:
ip, _ := cl.GetExternalIPAddress()
err := cl.AddPortMapping("", uint16(port), "TCP", uint16(port), localIP, true, "ssh-tunnel", leaseSeconds)
return ip, err
default:
return "", fmt.Errorf("unsupported client type")
}
}
for _, c := range wanIPClients {
ip, err := tryClient(c)
if err == nil {
cleanup := func() { _ = c.DeletePortMapping("", uint16(port), "TCP") }
log.Printf("[UPnP] Mapped external port %d -> %s:%d (external IP %s)", port, localIP, port, ip)
return cleanup, nil
}
}
for _, c := range wanPPPClients {
ip, err := tryClient(c)
if err == nil {
cleanup := func() { _ = c.DeletePortMapping("", uint16(port), "TCP") }
log.Printf("[UPnP] Mapped external port %d -> %s:%d (external IP %s)", port, localIP, port, ip)
return cleanup, nil
}
}
upnpMu.Lock()
upnpEnabled = false
upnpMu.Unlock()
return func() {}, fmt.Errorf("UPnP mapping failed on all discovered IGD clients")
}
// ---------------- LOG HELPERS ----------------
func sendLogToClient(channel ssh.Channel, msg string) error {
_, err := channel.Write([]byte(fmt.Sprintf("[LOG] %s\n", msg)))
return err
}
func sendWarnToClient(channel ssh.Channel, msg string) error {
_, err := channel.Write([]byte(fmt.Sprintf("[WARN] %s\n", msg)))
return err
}
func sendErrorToClient(channel ssh.Channel, msg string) error {
_, err := channel.Write([]byte(fmt.Sprintf("[ERR] %s\n", msg)))
return err
}
// ---------------- SSH ENV ARGS PARSING ----------------
func parseSSHString(b []byte) (string, []byte, error) {
if len(b) < 4 {
return "", nil, fmt.Errorf("payload too short")
}
l := int(binary.BigEndian.Uint32(b[:4]))
if len(b) < 4+l {
return "", nil, fmt.Errorf("payload shorter than length")
}
return string(b[4 : 4+l]), b[4+l:], nil
}
func parseArgsEnv(payload []byte) map[string]string {
name, rest, err := parseSSHString(payload)
if err != nil || !strings.EqualFold(name, "ARGS") {
return nil
}
val, _, err := parseSSHString(rest)
if err != nil {
return nil
}
val = strings.Trim(val, `"'`)
argsMap := make(map[string]string)
pairs := strings.Split(val, ";")
for _, pair := range pairs {
pair = strings.TrimSpace(pair)
if pair == "" {
continue
}
parts := strings.SplitN(pair, "=", 2)
if len(parts) != 2 {
continue
}
key := strings.ToLower(strings.TrimSpace(parts[0]))
value := strings.TrimSpace(parts[1])
if key != "" && value != "" {
argsMap[key] = value
}
}
return argsMap
}
// ---------------- KEY HELPERS ----------------
func generateKey() ([]byte, error) {
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, err
}
privateKeyPEM := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
})
return privateKeyPEM, nil
}
func loadOrCreateKey(pemPath string) ([]byte, error) {
if pemPath != "" {
return os.ReadFile(pemPath)
}
return generateKey()
}
// ---------------- Cloudflared helpers ----------------
// randomAlphaNum12 returns 12-char lowercase a-z0-9 string
func randomAlphaNum12() (string, error) {
b := make([]byte, 9)
if _, err := rand.Read(b); err != nil {
return "", err
}
chars := "abcdefghijklmnopqrstuvwxyz0123456789"
out := make([]byte, 12)
for i := range out {
out[i] = chars[int(b[i%len(b)])%len(chars)]
}
return string(out), nil
}
func fileExists(filePath string) bool {
_, err := os.Stat(filePath)
return !errors.Is(err, os.ErrNotExist)
}
// startCloudflaredForPort: now takes baseDir (directory where credentials+cfg+ports.json are stored).
// Creates credentials file under baseDir, temp config under baseDir, starts tunnel process and optionally creates dns record.
func startCloudflaredForPort(ctx context.Context, baseDir, cfCertPath, cfHost string, localPort int, protocol string, id string) (tunnelName, hostname string, cmdRun *exec.Cmd, cfgPath string, credsPath string, err error) {
if id == "" {
newid, err := randomAlphaNum12()
if err != nil {
return "", "", nil, "", "", err
}
id = newid
}
tunnelName = "ssh-" + id
hostname = fmt.Sprintf("%s.%s", id, cfHost)
// ensure baseDir exists
if err := os.MkdirAll(baseDir, 0700); err != nil {
return "", "", nil, "", "", fmt.Errorf("failed to create base dir: %v", err)
}
// credentials path: baseDir/<tunnelName>.json
credsPath = filepath.Join(baseDir, fmt.Sprintf("%s.json", tunnelName))
//if fileExists(credsPath) {
// return "", "", nil, "", "", fmt.Errorf("Tunnel exists already!")
//}
// 1) create tunnel with specified credentials file
cmdCreate := exec.CommandContext(ctx, "cloudflared", "tunnel", "--credentials-file", credsPath, "create", tunnelName)
out, err := cmdCreate.CombinedOutput()
if err != nil {
// include output for debugging
return "", "", nil, "", "", fmt.Errorf("cloudflared tunnel create failed: %v: %s", err, string(out))
}
// 2) write config YAML into baseDir
cfgPath = filepath.Join(baseDir, fmt.Sprintf("%s.yml", tunnelName))
service := fmt.Sprintf("%s://localhost:%d", protocol, localPort)
yaml := fmt.Sprintf(`tunnel: %s
credentials-file: %s
ingress:
- hostname: %s
service: %s
- service: http_status:404
`, tunnelName, credsPath, hostname, service)
if err := os.WriteFile(cfgPath, []byte(yaml), 0600); err != nil {
return "", "", nil, "", "", fmt.Errorf("failed writing cfg: %v", err)
}
// 3) run tunnel
cmdRun = exec.CommandContext(ctx, "cloudflared", "--config", cfgPath, "tunnel", "run", tunnelName)
cmdRun.Stdout = nil
cmdRun.Stderr = nil
if err := cmdRun.Start(); err != nil {
return "", "", nil, "", "", fmt.Errorf("cloudflared run failed: %v", err)
}
// 4) create DNS route (best-effort)
cmdDNS := exec.Command("cloudflared", "tunnel", "route", "dns", tunnelName, hostname)
if out2, err2 := cmdDNS.CombinedOutput(); err2 != nil {
log.Printf("[CF] route dns returned error (continuing): %v — output: %s", err2, string(out2))
}
return tunnelName, hostname, cmdRun, cfgPath, credsPath, nil
}
// ---------------- Cloudflare API helpers ----------------
// getZoneID uses Global API Key auth (X-Auth-Email + X-Auth-Key) to fetch the zone id for zoneName.
func getZoneID(ctx context.Context, authEmail, authKey, zoneName string) (string, error) {
req, err := http.NewRequestWithContext(ctx, "GET", "https://api.cloudflare.com/client/v4/zones?name="+zoneName, nil)
if err != nil {
return "", err
}
req.Header.Set("X-Auth-Email", strings.TrimSpace(authEmail))
req.Header.Set("X-Auth-Key", strings.TrimSpace(authKey))
req.Header.Set("Content-Type", "application/json")
resp, err := httpClient.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return "", fmt.Errorf("getZoneID: non-2xx status %d: %s", resp.StatusCode, string(body))
}
var parsed struct {
Success bool `json:"success"`
Errors json.RawMessage `json:"errors"`
Result []struct {
ID string `json:"id"`
} `json:"result"`
}
if err := json.Unmarshal(body, &parsed); err != nil {
return "", fmt.Errorf("getZoneID: decode error: %w (body: %s)", err, string(body))
}
if !parsed.Success || len(parsed.Result) == 0 {
return "", fmt.Errorf("getZoneID: no zone found or api failure: %s", string(parsed.Errors))
}
return parsed.Result[0].ID, nil
}
// getDNSRecordID finds a DNS record ID by name in a zone
func getDNSRecordID(ctx context.Context, authEmail, authKey, zoneID, recordName string) (string, error) {
url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records?type=CNAME&name=%s", zoneID, recordName)
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return "", err
}
req.Header.Set("X-Auth-Email", strings.TrimSpace(authEmail))
req.Header.Set("X-Auth-Key", strings.TrimSpace(authKey))
req.Header.Set("Content-Type", "application/json")
resp, err := httpClient.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return "", fmt.Errorf("getDNSRecordID: non-2xx status %d: %s", resp.StatusCode, string(body))
}
var parsed struct {
Success bool `json:"success"`
Errors json.RawMessage `json:"errors"`
Result []struct {
ID string `json:"id"`
Name string `json:"name"`
} `json:"result"`
}
if err := json.Unmarshal(body, &parsed); err != nil {
return "", fmt.Errorf("getDNSRecordID: decode error: %w (body: %s)", err, string(body))
}
if !parsed.Success {
return "", fmt.Errorf("getDNSRecordID: api error: %s", string(parsed.Errors))
}
if len(parsed.Result) == 0 {
return "", fmt.Errorf("getDNSRecordID: record %s not found in zone %s", recordName, zoneID)
}
return parsed.Result[0].ID, nil
}
// deleteDNSRecord deletes a DNS record by ID
func deleteDNSRecord(ctx context.Context, authEmail, authKey, zoneID, recordID string) error {
url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records/%s", zoneID, recordID)
req, err := http.NewRequestWithContext(ctx, "DELETE", url, nil)
if err != nil {
return err
}
req.Header.Set("X-Auth-Email", strings.TrimSpace(authEmail))
req.Header.Set("X-Auth-Key", strings.TrimSpace(authKey))
req.Header.Set("Content-Type", "application/json")
resp, err := httpClient.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return fmt.Errorf("deleteDNSRecord: non-2xx status %d: %s", resp.StatusCode, string(body))
}
var parsed struct {
Success bool `json:"success"`
Errors json.RawMessage `json:"errors"`
}
if err := json.Unmarshal(body, &parsed); err != nil {
return fmt.Errorf("deleteDNSRecord: decode error: %w (body: %s)", err, string(body))
}
if !parsed.Success {
return fmt.Errorf("deleteDNSRecord: api error: %s", string(parsed.Errors))
}
return nil
}
// ---------------- stopCloudflared (clean-up) ----------------
// stopCloudflared: stops process, deletes DNS via API (if creds provided), deletes tunnel object, and removes files
func stopCloudflared(tunnelName, hostname, cfgPath, credsPath string, cmdRun *exec.Cmd, cfEmail, cfGlobalKey, cfHost string) {
log.Printf("[CLEANUP] Stopping Cloudflare tunnel %s (%s)...", tunnelName, hostname)
// --- 1) Kill process if still running ---
if cmdRun != nil && cmdRun.Process != nil {
_ = cmdRun.Process.Kill()
done := make(chan error, 1)
go func() { done <- cmdRun.Wait() }()
select {
case <-done:
log.Printf("[CLEANUP] Process for %s exited", tunnelName)
case <-time.After(5 * time.Second):
log.Printf("[CLEANUP] Timeout waiting for process %s to exit", tunnelName)
}
} else {
log.Printf("[CLEANUP] No running process for %s", tunnelName)
}
// --- 2) Always remove config/creds files first ---
removeFile := func(path string, label string) {
if path == "" {
return
}
if err := os.Remove(path); err != nil {
if !os.IsNotExist(err) {
log.Printf("[CLEANUP] Failed to remove %s (%s): %v", label, path, err)
}
} else {
log.Printf("[CLEANUP] Deleted %s file: %s", label, path)
}
}
removeFile(cfgPath, "config")
removeFile(credsPath, "credentials")
// --- 3) Derive root zone (strip subdomains) ---
parts := strings.Split(cfHost, ".")
zoneName := cfHost
if len(parts) > 2 {
zoneName = strings.Join(parts[len(parts)-2:], ".") // e.g. tunnels.linuxdummy.win → linuxdummy.win
}
// --- 4) Delete DNS record via Cloudflare API (best effort) ---
if cfEmail != "" && cfGlobalKey != "" && zoneName != "" {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
zoneID, err := getZoneID(ctx, cfEmail, cfGlobalKey, zoneName)
if err != nil {
log.Printf("[CF-API] getZoneID failed: %v", err)
} else {
recID, err2 := getDNSRecordID(ctx, cfEmail, cfGlobalKey, zoneID, hostname)
if err2 != nil {
log.Printf("[CF-API] getDNSRecordID failed (likely already removed): %v", err2)
} else if err3 := deleteDNSRecord(ctx, cfEmail, cfGlobalKey, zoneID, recID); err3 != nil {
log.Printf("[CF-API] deleteDNSRecord failed: %v", err3)
} else {
log.Printf("[CF-API] Deleted DNS record %s (id=%s)", hostname, recID)
}
}
} else {
log.Printf("[CF-API] CF credentials or zone not provided; skipping DNS cleanup for %s", hostname)
}
// --- 5) Delete tunnel object via cloudflared CLI (optional) ---
if tunnelName != "" {
if err := exec.Command("cloudflared", "tunnel", "delete", "-f", tunnelName).Run(); err != nil {
log.Printf("[CLEANUP] cloudflared tunnel delete failed for %s: %v", tunnelName, err)
}
}
log.Printf("[CLEANUP] Tunnel %s fully cleaned up.", tunnelName)
}
// ---------------- SSH ENV ARGS PARSING & HANDLERS ----------------
// handleForwardRequest now accepts ctx and baseDir so tunnels/files are created inside baseDir and can be canceled by ctx.
func handleForwardRequest(ctx context.Context, req *ssh.Request, conn *ssh.ServerConn, tunnelTimeout time.Duration, portStart, portEnd int, sessionChannel ssh.Channel, host *string, cfEnabled bool, baseDir, cfCertPath, cfHostname, cfEmail, cfGlobalKey string, extraArgs map[string]string, customTunnelN bool) {
type ForwardPayload struct {
Addr string
Port uint32
}
var payload ForwardPayload
if err := ssh.Unmarshal(req.Payload, &payload); err != nil {
req.Reply(false, nil)
return
}
var pubListener net.Listener
var publicPort int
if customTunnelN {
if p, ok := extraArgs["port"]; ok && p != "" {
publicPort, _ = strconv.Atoi(p)
} else {
for i := 0; i < 10; i++ {
if p, ok := extraArgs["port"]; ok && p != "" {
publicPort, _ = strconv.Atoi(p)
break
}
time.Sleep(50 * time.Millisecond)
}
}
} else {
publicPort = mrand.Intn(portEnd-portStart+1) + portStart
}
for {
addr := fmt.Sprintf("0.0.0.0:%d", publicPort)
l, err := net.Listen("tcp", addr)
if err != nil {
publicPort = mrand.Intn(portEnd-portStart+1) + portStart
continue
}
pubListener = l
break
}
addPort(publicPort)
resp := struct{ Port uint32 }{Port: uint32(publicPort)}
if err := req.Reply(true, ssh.Marshal(resp)); err != nil {
pubListener.Close()
removePort(publicPort)
return
}
var cleanupMu sync.Mutex
cleanup := func() {}
done := make(chan struct{})
// ensure cleanup when ctx is done, or when done channel closes
go func() {
localIP, _ := getLocalIP()
c, err := addUPnPPortForward(localIP, publicPort, tunnelTimeout)
if err == nil {
cleanupMu.Lock()
cleanup = c
cleanupMu.Unlock()
}
select {
case <-done:
cleanupMu.Lock()
cleanup()
cleanupMu.Unlock()
case <-ctx.Done():
// session ended by server or SSH connection closed -> close listener to force teardown
_ = pubListener.Close()
cleanupMu.Lock()
cleanup()
cleanupMu.Unlock()
}
}()
// Cloudflared per-forward
var cfName, cfHostName, cfgPath, credsPath string
var cfCmd *exec.Cmd
if cfEnabled {
protocol := "tcp"
name := ""
if extraArgs != nil {
if p, ok := extraArgs["protocol"]; ok && p != "" {
protocol = p
} else {
for i := 0; i < 10; i++ {
if p, ok := extraArgs["protocol"]; ok && p != "" {
protocol = p
break
}
time.Sleep(50 * time.Millisecond)
}
}
if customTunnelN {
if n, ok := extraArgs["tunnelname"]; ok && n != "" {
name = n
} else {
for i := 0; i < 10; i++ {
if n, ok := extraArgs["tunnelname"]; ok && n != "" {
name = n
break
}
time.Sleep(50 * time.Millisecond)
}
}
}
}
tn, hn, cmdRun, cfg, creds, err := startCloudflaredForPort(ctx, baseDir, cfCertPath, cfHostname, publicPort, protocol, name)
if err != nil {
msg := fmt.Sprintf("Tunnel error on port %d: %v, exiting...", publicPort, err)
log.Println("[CF]", msg)
_ = sendErrorToClient(sessionChannel, msg)
req.Reply(false, nil)
_ = pubListener.Close()
removePort(publicPort)
close(done)
sessionChannel.Close()
return
}
cfName = tn
cfHostName = hn
cfgPath = cfg
credsPath = creds
cfCmd = cmdRun
_ = sendLogToClient(sessionChannel, fmt.Sprintf("Tunnel ready: %s", cfHostName))
}
// Register the new Cloudflare tunnel for cleanup tracking
cfTunnelsMu.Lock()
cfTunnels[cfName] = cfTunnelInfo{
Name: cfName, Hostname: cfHostName,
CfgPath: cfgPath, CredsPath: credsPath, Cmd: cfCmd,
}
cfTunnelsMu.Unlock()
log.Printf("Created public listener %s:%d", *host, publicPort)
if tunnelTimeout > 0 {
go func() {
select {
case <-time.After(tunnelTimeout):
pubListener.Close()
case <-done:
case <-ctx.Done():
}
}()
}
// This goroutine watches the session channel to detect session-level EOF and close the listener.
go func() {
buf := make([]byte, 512)
for {
if _, err := sessionChannel.Read(buf); err != nil {
_ = pubListener.Close()
return
}
}
}()
// Main accept loop. When pubListener closes, Accept returns error and we reach cleanup.
go func() {
defer func() {
select {
case <-done:
default:
close(done)
}
cleanupMu.Lock()
cleanup()
cleanupMu.Unlock()
if cfName != "" {
// attempt to stop the cloudflared tunnel and remove DNS + files
stopCloudflared(cfName, cfHostName, cfgPath, credsPath, cfCmd, cfEmail, cfGlobalKey, cfHostname)
}
removePort(publicPort)
_ = sendWarnToClient(sessionChannel, fmt.Sprintf("Tunnel closed (port %d)", publicPort))
}()
for {
publicConn, err := pubListener.Accept()
if err != nil {
return
}
go func(pc net.Conn) {
ch, requests, err := conn.OpenChannel("forwarded-tcpip", ssh.Marshal(struct {
Addr string
Port uint32
OriginAddr string
OriginPort uint32
}{
Addr: payload.Addr,
Port: payload.Port,
OriginAddr: pc.RemoteAddr().(*net.TCPAddr).IP.String(),
OriginPort: uint32(pc.RemoteAddr().(*net.TCPAddr).Port),
}))
if err != nil {
pc.Close()
return
}
go ssh.DiscardRequests(requests)
go io.Copy(ch, pc)
go io.Copy(pc, ch)
}(publicConn)
}
}()
}
// handleSSH accepts an ssh connection and processes channels/requests
// now takes baseDir so that per-session tunnels use that directory and can be cleaned up on session end.
func handleSSH(nConn net.Conn, config *ssh.ServerConfig, tunnelTimeout time.Duration, portStart, portEnd int, host *string, cfEnabled bool, baseDir, cfCertPath, cfHostname, cfEmail, cfGlobalKey string, usersFile string) {
sshConn, chans, reqs, err := ssh.NewServerConn(nConn, config)
if err != nil {
log.Printf("SSH handshake failed: %v", err)
return
}
// register session for global shutdown handling
sessionID := fmt.Sprintf("%s-%d", sshConn.RemoteAddr().String(), time.Now().UnixNano())
sessionsMu.Lock()
sessionConns[sessionID] = sshConn
ctx, cancel := context.WithCancel(context.Background())
sessionCancels[sessionID] = cancel
sessionsMu.Unlock()
defer func() {
sessionsMu.Lock()
delete(sessionConns, sessionID)
delete(sessionCancels, sessionID)
sessionsMu.Unlock()
sshConn.Close()
}()
log.Printf("New connection from %s (user: %s)", sshConn.RemoteAddr(), sshConn.User())
userKeepalives, _ := loadUserKeepalives(usersFile)
username := sshConn.User()
customTunnelN := false
if username != "" {
if uk, ok := userKeepalives[username]; ok {
tunnelTimeout = time.Duration(uk.Keepalive) * time.Second
customTunnelN = uk.CustomTunnel
}
}
var sessionChan ssh.Channel
var sessionReqs <-chan *ssh.Request
extraArgs := make(map[string]string)
for newCh := range chans {
if newCh.ChannelType() != "session" {
newCh.Reject(ssh.UnknownChannelType, "only session supported")
continue
}
ch, reqs2, err := newCh.Accept()
if err != nil {
return
}
sessionChan = ch
sessionReqs = reqs2
break
}
if sessionChan == nil {
cancel()
return
}
// --- PASSWORD PROMPT ---
if username != "" {
if uk, ok := userKeepalives[username]; ok {
_, _ = sessionChan.Write([]byte("Password: "))
var passBuf []byte
tmp := make([]byte, 1)
for {
n, err := sessionChan.Read(tmp)
if err != nil || n == 0 {
_ = sessionChan.Close()
cancel()
return
}
if tmp[0] == '\n' || tmp[0] == '\r' {
break
}
passBuf = append(passBuf, tmp[0])
}
passInput := strings.TrimSpace(string(passBuf))
if bcrypt.CompareHashAndPassword([]byte(uk.Password), []byte(passInput)) != nil {
_, _ = sessionChan.Write([]byte("Invalid password. Disconnecting.\n"))
_ = sessionChan.Close()
cancel()
return
}
sendLogToClient(sessionChan, "Authenticated.")
}
}
go func() {
for req := range sessionReqs {
switch req.Type {
case "shell", "exec":
req.Reply(true, nil)
case "env":
if args := parseArgsEnv(req.Payload); args != nil {
for k, v := range args {
extraArgs[k] = v
}
}
req.Reply(true, nil)
default:
req.Reply(false, nil)
}
}
}()
// --- SESSION INPUT LOOP (password change / exit) ---
go func() {
buf := make([]byte, 1)
var lineBuf []byte
for {
n, err := sessionChan.Read(buf)
if err != nil || n == 0 {
cancel()
return
}
b := buf[0]
if b == '\r' || b == '\n' {
line := strings.TrimSpace(string(lineBuf))
lineBuf = nil
// .setpass command
if strings.HasPrefix(line, ".setpass ") && username != "" {
newPass := strings.TrimSpace(strings.TrimPrefix(line, ".setpass "))
if newPass == "" {
_ = sendLogToClient(sessionChan, "Usage: .setpass <newpassword>")
continue
}
users, err := loadUserKeepalives(usersFile)
if err != nil {
_ = sendWarnToClient(sessionChan, "failed to load users file")
continue
}
if _, ok := users[username]; !ok {
_ = sendWarnToClient(sessionChan, "username not found in users.json; cannot set password")
continue
}
hash, _ := bcrypt.GenerateFromPassword([]byte(newPass), bcrypt.DefaultCost)
u := users[username]
u.Password = string(hash)
users[username] = u
if err := saveUserKeepalives(usersFile, users); err != nil {
_ = sendWarnToClient(sessionChan, "failed to save new password")
continue
}
_ = sendLogToClient(sessionChan, "Password updated.")
}
if line == ".exit" {
_ = sessionChan.Close()
cancel()
return
}
continue
}
lineBuf = append(lineBuf, b)
}
}()
if tunnelTimeout > 0 {
h := int(tunnelTimeout.Hours())
m := int(tunnelTimeout.Minutes()) % 60
s := int(tunnelTimeout.Seconds()) % 60
_ = sendLogToClient(sessionChan, fmt.Sprintf("Tunnel will be active for %02d:%02d:%02d (hh:mm:ss).", h, m, s))
} else {
_ = sendLogToClient(sessionChan, "Tunnel has no time limit and will stay active until disconnect.")
}
go ssh.DiscardRequests(reqs)
for req := range reqs {
if req.Type == "tcpip-forward" {
go handleForwardRequest(ctx, req, sshConn, tunnelTimeout, portStart, portEnd, sessionChan, host, cfEnabled, baseDir, cfCertPath, cfHostname, cfEmail, cfGlobalKey, extraArgs, customTunnelN)
} else {
req.Reply(false, nil)
}
}
cancel()
}
// ---------------- SERVER ----------------
func startServerCloudflared(privateBytes []byte, tunnelTimeout time.Duration, portStart, portEnd int, host *string, cfEnabled bool, baseDir, cfCertPath, cfHostname, cfEmail, cfGlobalKey, usersFile string) error {
signer, err := ssh.ParsePrivateKey(privateBytes)
if err != nil {
return err
}
config := &ssh.ServerConfig{NoClientAuth: true}
config.AddHostKey(signer)
listener, err := net.Listen("tcp", ":2222")
if err != nil {
return fmt.Errorf("failed to listen on :2222: %v", err)
}
defer listener.Close()
log.Println("SSH server listening on :2222")
// Handle Ctrl+C / SIGTERM for graceful shutdown
sigc := make(chan os.Signal, 1)
signal.Notify(sigc, os.Interrupt, syscall.SIGTERM)
go func() {
<-sigc
log.Println("Shutting down server...")
serverRunning = false
// Stop accepting new connections
listener.Close()
// Cancel all active SSH sessions
sessionsMu.Lock()
for id, cancel := range sessionCancels {
log.Printf("[SHUTDOWN] cancelling session %s", id)
cancel()
}
for id, c := range sessionConns {
log.Printf("[SHUTDOWN] closing ssh connection %s", id)
_ = c.Close()
}
sessionsMu.Unlock()
// Stop all active Cloudflare tunnels
cfTunnelsMu.Lock()
for _, t := range cfTunnels {
log.Printf("[SHUTDOWN] stopping tunnel %s (%s)...", t.Name, t.Hostname)
stopCloudflared(t.Name, t.Hostname, t.CfgPath, t.CredsPath, t.Cmd, cfEmail, cfGlobalKey, cfHostname)
}
cfTunnels = map[string]cfTunnelInfo{}
cfTunnelsMu.Unlock()
log.Println("[SHUTDOWN] All tunnels and sessions cleaned up. Exiting...")
os.Exit(0)
}()
for serverRunning {
conn, err := listener.Accept()
if err != nil {
if serverRunning {
log.Printf("Accept error: %v", err)
}
continue
}
go handleSSH(conn, config, tunnelTimeout, portStart, portEnd, host, cfEnabled, baseDir, cfCertPath, cfHostname, cfEmail, cfGlobalKey, usersFile)
}
return nil
}
// ---------------- MAIN ----------------
func main() {
mrand.Seed(time.Now().UnixNano())
pemPath := flag.String("k", "", "Path to private key (PEM)")
tunnelTimeoutSec := flag.Int("t", 0, "Tunnel default timeout in seconds (0 = no limit)")
portStart := flag.Int("ps", 10000, "Start of public port range")
portEnd := flag.Int("pe", 60000, "End of public port range")
useUPnP := flag.Bool("upnp", true, "Enable UPnP port mapping (auto-disable if unavailable)")
cfEnable := flag.Bool("cf", false, "Enable Cloudflare tunneling")
cfCert := flag.String("cf-cert", "", "Path to Cloudflare cert.pem (used as base dir for credentials)")
cfHost := flag.String("cf-hostname", "", "Base hostname for Cloudflare tunnels (example.com)")
cfGlobalKey := flag.String("cf-gk", "", "Cloudflare Global API key (X-Auth-Key) for DNS deletion")
cfEmail := flag.String("cf-em", "", "Cloudflare account email (X-Auth-Email) for DNS deletion")
usersFile := flag.String("users", "users.json", "Path to JSON file mapping usernames to keepalive + password")
ip, err := GetPublicIP()
if err != nil {
ip = "0.0.0.0"
}
host := flag.String("host", ip, "Host to display")
flag.Parse()
if *portEnd <= *portStart {
log.Fatalf("invalid port range: pe must be > ps")
}
// ensure usersFile exists and defaults filled
if _, err := os.Stat(*usersFile); os.IsNotExist(err) {
empty := map[string]UserKeepalive{}
data, _ := json.MarshalIndent(empty, "", " ")
_ = os.WriteFile(*usersFile, data, 0644)
}
// loadUserKeepalives will populate missing password fields with default hash and save
_, _ = loadUserKeepalives(*usersFile)
// Determine base directory for all created files
var baseDir string
if *cfCert != "" {
// use directory of provided cert path
baseDir = filepath.Dir(*cfCert)
} else {
home := os.Getenv("HOME")
if home == "" {
// fallback to current working directory
cwd, _ := os.Getwd()
baseDir = filepath.Join(cwd, ".ssh-tunnels")
} else {
baseDir = filepath.Join(home, ".ssh-tunnels")
}
}
// ensure baseDir exists and is secure
if err := os.MkdirAll(baseDir, 0700); err != nil {
log.Fatalf("failed to create base dir %s: %v", baseDir, err)
}
// set the port file path inside baseDir
portFilePath = filepath.Join(baseDir, "ports.json")
// Clean old ports (uses portFilePath now set)
cleanupOldPorts()
upnpMu.Lock()
upnpEnabled = *useUPnP
upnpMu.Unlock()
if *cfEnable {
// Cloudflare-only mode: disable UPnP/other external mapping
upnpMu.Lock()
upnpEnabled = false
upnpMu.Unlock()
} else {
if upnpEnabled && !tryInitUPnP() {
upnpMu.Lock()
upnpEnabled = false
upnpMu.Unlock()
log.Println("[UPnP] not available; disabled")
}
}
if *cfEnable {
if *cfCert == "" || *cfHost == "" {
log.Fatal("When -cf is used you must provide -cf-cert and -cf-hostname")
}
// verify cloudflared exists
if _, err := exec.LookPath("cloudflared"); err != nil {
log.Fatalf("cloudflared binary not found in PATH: %v", err)
}
}
keyBytes, err := loadOrCreateKey(*pemPath)
if err != nil {
log.Fatalf("Failed to load/generate key: %v", err)
}
tunnelTimeout := time.Duration(*tunnelTimeoutSec) * time.Second
// start server (passes usersFile for password/keepalive handling)
if err := startServerCloudflared(keyBytes, tunnelTimeout, *portStart, *portEnd, host, *cfEnable, baseDir, *cfCert, *cfHost, *cfEmail, *cfGlobalKey, *usersFile); err != nil {
log.Fatalf("server error: %v", err)
}
}