diff --git a/app.go b/app.go index c8a3ec1..606a942 100644 --- a/app.go +++ b/app.go @@ -1,82 +1,91 @@ package sshtunnel import ( "flag" "log" "golang.org/x/crypto/ssh" "sync" "os/exec" + "os" + "io" ) - -func(c *Client) Start(){ +func (c *Client) Start() { var wg sync.WaitGroup //The intermediary server for port binding serverHostname := flag.String("server", "ubuntu.cse.unr.edu", "a string") execPath := flag.String("exec", "", "a binary to execute after connecting.") - //TODO Add logfile - //TODO Add exec - Executes an application. + logFile := flag.String("log", "log.txt", "logfile location") + + lf, err := os.OpenFile(*logFile, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0666) + if err != nil { + log.Fatalf("error opening file: %v", err) + } + defer lf.Close() + + mw := io.MultiWriter(os.Stdout, lf) + log.SetOutput(mw) - var endpoints EndpointsFlag - flag.Var(&endpoints,"endpoints", "Local and remote endpoints") + var endpoints EndpointsFlag + flag.Var(&endpoints, "endpoints", "Local and remote endpoints") flag.Parse() - if len(endpoints) < 1 { + if len(endpoints) < 1 { log.Printf("No endpoints defined: %d\n", len(endpoints)) endpoints = make(EndpointsFlag, 0) - endpoints = append(endpoints, "5555:172.28.128.230:5555") // MATLAB - endpoints = append(endpoints, "25734:172.28.128.230:25734") // SolidWorks - endpoints = append(endpoints, "25735:172.28.128.230:25735") // SolidWorks - endpoints = append(endpoints, "7788:134.197.20.21:7788") // MATHCAD + endpoints = append(endpoints, "5555:172.28.128.230:5555") // MATLAB + endpoints = append(endpoints, "25734:172.28.128.230:25734") // SolidWorks + endpoints = append(endpoints, "25735:172.28.128.230:25735") // SolidWorks + endpoints = append(endpoints, "7788:134.197.20.21:7788") // MATHCAD } - localE, remoteE:= ParseEndpointsFromArray(endpoints) + localE, remoteE := ParseEndpointsFromArray(endpoints) log.Printf("localE: %v remoteE: %v", localE, remoteE) // 27013 serverEndpoint := &Endpoint{ Host: *serverHostname, Port: 22, } - log.Printf("Endpoints: %v \nserverEndpoint: %v\n", endpoints, serverEndpoint) + log.Printf("Endpoints: %v serverEndpoint: %v\n", endpoints, serverEndpoint) //log.Printf("%v %v", serverEndpoint) //credChan := make(chan Credentials) - var passwordForm= NewPasswordForm() + var passwordForm = NewPasswordForm() //passwordForm.SetChan(credChan) passwordForm.Show(); - credentials := passwordForm.Credentials log.Printf("Connecting to %s, User: %s ", *serverHostname, credentials.Username) sshConfig := &ssh.ClientConfig{ - User: credentials.Username, - HostKeyCallback: KeyPrint, - Auth: []ssh.AuthMethod{ - ssh.Password(credentials.Password), - }, - } + User: credentials.Username, + HostKeyCallback: KeyPrint, + Auth: []ssh.AuthMethod{ + ssh.Password(credentials.Password), + }, + } tunnel := &SSHtunnel{ - Config: sshConfig, - Local: localE, - Server: serverEndpoint, - Remote: remoteE, - } + Config: sshConfig, + Local: localE, + Server: serverEndpoint, + Remote: remoteE, + } wg.Add(1) - go func(){ + go func() { defer wg.Done() tunnel.Start() }(); - - cmd := exec.Command(*execPath) - err := cmd.Start() - if err != nil { - log.Fatal(err) + if len(*execPath) > 0 { + cmd := exec.Command(*execPath) + err := cmd.Start() + if err != nil { + log.Fatal(err) + } } wg.Wait() } diff --git a/bin/main.go b/bin/main.go index 4de2ca6..cb0ea05 100644 --- a/bin/main.go +++ b/bin/main.go @@ -1,15 +1,13 @@ package main import ( "sshtunnel" "fmt" ) - func main() { var ep = &sshtunnel.Endpoint{}; var sshClient = &sshtunnel.Client{}; - sshClient.Start() fmt.Printf("%v", ep) } diff --git a/endpoint.go b/endpoint.go index 1e2acd8..16b5388 100644 --- a/endpoint.go +++ b/endpoint.go @@ -1,55 +1,53 @@ package sshtunnel import ( - "strings" - "net" - "golang.org/x/crypto/ssh" - "fmt" - "strconv" - "encoding/base64" + "strings" + "net" + "golang.org/x/crypto/ssh" + "fmt" + "strconv" + "encoding/base64" ) type Client struct { - } type Endpoint struct { Host string Port int } func KeyPrint(dialAddr string, addr net.Addr, key ssh.PublicKey) error { fmt.Printf("%s %s %s\n", strings.Split(dialAddr, ":")[0], key.Type(), base64.StdEncoding.EncodeToString(key.Marshal())) return nil } func (endpoint *Endpoint) String() string { return fmt.Sprintf("%s:%d", endpoint.Host, endpoint.Port) } -func ParseEndpointString(endpointString string) (localEp Endpoint, remoteEp Endpoint){ +func ParseEndpointString(endpointString string) (localEp Endpoint, remoteEp Endpoint) { localEp.Host = "localhost" buffer := strings.Split(endpointString, ":") localEp.Port, _ = strconv.Atoi(buffer[0]) remoteEp.Host = buffer[1] remoteEp.Port, _ = strconv.Atoi(buffer[2]) return localEp, remoteEp } - -func ParseEndpointsFromArray(endpointStruct []string) ([]Endpoint, []Endpoint) { +func ParseEndpointsFromArray(endpointStruct []string) ([]Endpoint, []Endpoint) { localEndpoints := make([]Endpoint, 0) remoteEndpoints := make([]Endpoint, 0) - for _, e := range endpointStruct { - localE, remoteE := ParseEndpointString(e) - localEndpoints = append(localEndpoints, localE) - remoteEndpoints = append(remoteEndpoints, remoteE) - } + for _, e := range endpointStruct { + localE, remoteE := ParseEndpointString(e) + localEndpoints = append(localEndpoints, localE) + remoteEndpoints = append(remoteEndpoints, remoteE) + } - return localEndpoints, remoteEndpoints + return localEndpoints, remoteEndpoints } diff --git a/tunnel.go b/tunnel.go index e512f46..5e384eb 100644 --- a/tunnel.go +++ b/tunnel.go @@ -1,138 +1,158 @@ package sshtunnel import ( "net" - "golang.org/x/crypto/ssh" - "golang.org/x/crypto/ssh/agent" - "fmt" - "io" - "os" - "log" + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/agent" + "fmt" + "io" + "os" + "log" "sync" "time" ) type SSHtunnel struct { - Local []Endpoint - Server *Endpoint - Remote []Endpoint - ErrorChans []chan error - ServerClient *ssh.Client - Config *ssh.ClientConfig + Local []Endpoint + Server *Endpoint + Remote []Endpoint + ErrorChans []chan error + ServerClient *ssh.Client + Config *ssh.ClientConfig ServerConnection *ssh.ServerConfig - ExecPath string + ExecPath string sync.Mutex } +func (tunnel *SSHtunnel) Listen(localE Endpoint, remoteE Endpoint, errorChan chan error) { + log.Printf("LocalEndpoint: %v, RemoteEndpoint: %v", localE, remoteE) + listener, err := net.Listen("tcp", localE.String()) + if err != nil { + return + } -func(tunnel *SSHtunnel) Listen( localE Endpoint, remoteE Endpoint, errorChan chan error) { - - - log.Printf("LocalEndpoint: %v, RemoteEndpoint: %v", localE, remoteE) - listener, err := net.Listen("tcp", localE.String()) - if err != nil { - return - } - - for { - conn, err := listener.Accept() + for { + conn, err := listener.Accept() - if err != nil { - log.Println("Cannot listen.") - errorChan <- err - } + if err != nil { + log.Println("Cannot listen.") + errorChan <- err + } - defer listener.Close() + defer listener.Close() - log.Printf("Localhost listen %s, Remote Connection: %s \n", localE.String(), remoteE.String()) - go tunnel.forward(conn, remoteE) - } + log.Printf("Localhost listen %s, Remote Connection: %s \n", localE.String(), remoteE.String()) + go tunnel.forward(conn, remoteE) + } } func (tunnel *SSHtunnel) Start() error { var err error - tunnel.ErrorChans = make([]chan error, len(tunnel.Local)) + tunnel.ErrorChans = make([]chan error, len(tunnel.Local)) // tunnel.ServerClient, err = ssh.Dial("tcp", tunnel.Server.String(), tunnel.Config) if err != nil { fmt.Printf("Server dial error: %s\n", err) //log.Println("Sever dial error occurred. No Internet connection. Program exited.") - // if err == "ssh: handshake failed: ssh: unable to authenticate, attempted methods [password none], no supported methods remain" { - // fmt.Println(" Incorrect Password entered.") + // if err == "ssh: handshake failed: ssh: unable to authenticate, attempted methods [password none], no supported methods remain" { + // fmt.Println(" Incorrect Password entered.") os.Exit(1) return err } - - + //Keep alive + done := make(chan bool) + go tunnel.keepalive(done) // - for i, _ := range tunnel.Local { + for i, _ := range tunnel.Local { - //defer listener.Close() - log.Printf("Main - LocalEndpoint: %v, RemoteEndpoint: %v", tunnel.Local[i], tunnel.Remote[i]) - go tunnel.Listen(tunnel.Local[i], tunnel.Remote[i], tunnel.ErrorChans[i]) - } + //defer listener.Close() + log.Printf("Main - LocalEndpoint: %v, RemoteEndpoint: %v", tunnel.Local[i], tunnel.Remote[i]) + go tunnel.Listen(tunnel.Local[i], tunnel.Remote[i], tunnel.ErrorChans[i]) + } - for { - for _, c := range tunnel.ErrorChans { - select { + for { + for _, c := range tunnel.ErrorChans { + select { case msg := <-c: log.Printf("\n in tunnel.errorschans : %v", msg) - // case msg2:= <- done: - // log.Printf("%v", msg2) - // os.Exit(1) - default: - //fmt.Println("No message received") + // case msg2:= <- done: + // log.Printf("%v", msg2) + // os.Exit(1) + default: + //fmt.Println("No message received") } } - time.Sleep(time.Second * 1) + time.Sleep(time.Second * 1) //TODO run ExecPath } - return err + return err + +} + +func (tunnel *SSHtunnel) keepalive(done chan bool) { + for { + // Create a session. It is one session per command. + session, err := tunnel.ServerClient.NewSession() + if err != nil { + log.Print("Unable to open keep alive session.") + } else { + log.Print("Keep alive session opened.") + } + + session.Close() + + select { + case message := <-done: + if message == true { + break + } + default: + } + time.Sleep(20000 * time.Millisecond) + } } //TODO refactor to accept remoteEndpoint func (tunnel *SSHtunnel) forward(localConn net.Conn, remoteEndpoint Endpoint) { var err error + log.Printf("Connect with client. %s \n", remoteEndpoint.String()) + remoteConn, err := tunnel.ServerClient.Dial("tcp", remoteEndpoint.String()) + if err != nil { + fmt.Printf("Remote dial error: %s\n", err) + log.Println("Remote dial error occurred. Trying to establish a new server connection.") + tunnel.ServerClient, err = ssh.Dial("tcp", tunnel.Server.String(), tunnel.Config) + return + } + log.Printf("Remote connection: %v", remoteConn) - - log.Printf("Connect with client. %s \n", remoteEndpoint.String()) - remoteConn, err := tunnel.ServerClient.Dial("tcp", remoteEndpoint.String()) + copyConn := func(writer, reader net.Conn) { + _, err := io.Copy(writer, reader) if err != nil { - fmt.Printf("Remote dial error: %s\n", err) - log.Println("Remote dial error occurred. Trying to establish a new server connection.") - tunnel.ServerClient, err = ssh.Dial("tcp", tunnel.Server.String(), tunnel.Config) - return - } - log.Printf("Remote connection: %v", remoteConn) - - copyConn:=func(writer, reader net.Conn) { - _, err:= io.Copy(writer, reader) - if err != nil { - log.Printf("%s io.Copy error: %s \n", remoteEndpoint.String(), err) - } + log.Printf("%s io.Copy error: %s \n", remoteEndpoint.String(), err) } + } - go copyConn(localConn, remoteConn) - go copyConn(remoteConn, localConn) + go copyConn(localConn, remoteConn) + go copyConn(remoteConn, localConn) } func SSHAgent() ssh.AuthMethod { if sshAgent, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")); err == nil { return ssh.PublicKeysCallback(agent.NewClient(sshAgent).Signers) } return nil } func check(e error) { if e != nil { panic(e) } } diff --git a/windows_ui.go b/windows_ui.go index 0037dcb..742969d 100644 --- a/windows_ui.go +++ b/windows_ui.go @@ -1,71 +1,67 @@ // +build windows,!linux package sshtunnel import ( . "github.com/lxn/walk/declarative" "github.com/lxn/walk" ) type PasswordForm struct { Credentials *Credentials - } func NewPasswordForm() *PasswordForm { return &PasswordForm{} } -func (p *PasswordForm) Show(){ +func (p *PasswordForm) Show() { var mainWindow *walk.MainWindow var usernameLE, passwordLE *walk.LineEdit - p.Credentials = &Credentials{Username:"Test"} + p.Credentials = &Credentials{Username: "Test"} MainWindow{ AssignTo: &mainWindow, - Title: "Login", - MinSize: Size{250, 150}, - Layout: VBox{}, + Title: "Login", + MinSize: Size{250, 150}, + Layout: VBox{}, Children: []Widget{ HSplitter{ Children: []Widget{ - Label{ Text: "Username"}, - LineEdit{ AssignTo: &usernameLE }, - + Label{Text: "Username"}, + LineEdit{AssignTo: &usernameLE}, }, }, HSplitter{ Children: []Widget{ - Label{ Text: "Password"}, + Label{Text: "Password"}, LineEdit{PasswordMode: true, AssignTo: &passwordLE, OnKeyDown: func(key walk.Key) { if key == walk.KeyReturn { p.Credentials.Username = usernameLE.Text() p.Credentials.Password = passwordLE.Text() mainWindow.Close() } }, }, - }, }, PushButton{ Text: "Login", OnClicked: func() { p.Credentials.Username = usernameLE.Text() p.Credentials.Password = passwordLE.Text() mainWindow.Close() }, }, PushButton{ Text: "Cancel", OnClicked: func() { mainWindow.Close() }, }, - }, }.Run() }