diff --git a/.gitignore b/.gitignore index 5b405c6..d1e4909 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ .idea +*~ *.exe diff --git a/app.go b/app.go index ba79c19..35fd1af 100644 --- a/app.go +++ b/app.go @@ -1,194 +1,68 @@ package sshtunnel import ( "flag" - "strings" - "os" - "fmt" - "io" - "net" - "encoding/base64" - "golang.org/x/crypto/ssh" - "golang.org/x/crypto/ssh/agent" "log" - "strconv" -) - - -type Client struct { - -} - -type Endpoint struct { - Host string - Port int -} - -func KeyPrint(dialAddr string, addr net.Addr, key ssh.PublicKey) error { - fmt.Printf("%s %s %s\n", strings.Split(dialAddr, ":")[0], key.Type(), base64.StdEncoding.EncodeToString(key.Marshal())) - return nil -} - -func (endpoint *Endpoint) String() string { - return fmt.Sprintf("%s:%d", endpoint.Host, endpoint.Port) -} - -func ParseEndpointString(endpointString string) (localEp Endpoint, remoteEp Endpoint){ - var err error - localEp.Host = "localhost" - - buffer := strings.Split(endpointString, ":") - localEp.Port, err = strconv.Atoi(buffer[0]) - _ = err - - - remoteEp.Host = buffer[1] - remoteEp.Port, err = strconv.Atoi(buffer[2]) - _ = err - return localEp, remoteEp -} - - -type SSHtunnel struct { - Local *Endpoint - Server *Endpoint - Remote *Endpoint - - Config *ssh.ClientConfig - ServerConnection *ssh.ServerConfig -} - -func (tunnel *SSHtunnel) Start() error { - listener, err := net.Listen("tcp", tunnel.Local.String()) - if err != nil { - return err - } - defer listener.Close() - - for { - conn, err := listener.Accept() - if err != nil { - return err - } - go tunnel.forward(conn) - } -} - -func (tunnel *SSHtunnel) forward(localConn net.Conn) { - serverConn, err := ssh.Dial("tcp", tunnel.Server.String(), tunnel.Config) - if err != nil { - fmt.Printf("Server dial error: %s\n", err) - return - } - - remoteConn, err := serverConn.Dial("tcp", tunnel.Remote.String()) - if err != nil { - fmt.Printf("Remote dial error: %s\n", err) - return - } - - copyConn:=func(writer, reader net.Conn) { - _, err:= io.Copy(writer, reader) - if err != nil { - fmt.Printf("io.Copy error: %s", err) - } - } - - go copyConn(localConn, remoteConn) - go copyConn(remoteConn, localConn) -} - -func SSHAgent() ssh.AuthMethod { - if sshAgent, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")); err == nil { - return ssh.PublicKeysCallback(agent.NewClient(sshAgent).Signers) - } - return nil -} - -func check(e error) { - if e != nil { - panic(e) - } -} - + "golang.org/x/crypto/ssh" +) func(c *Client) Start(){ - //Get the command line arguments - remoteHostname := flag.String("remote-hostname", "172.28.128.230", - "Remote hostname.") - remotePort := flag.Int("remote-port", 27013, - "The remote port bound through the server.") //The intermediary server for port binding serverHostname := flag.String("server", "ubuntu.cse.unr.edu", "a string") + //TODO Add logfile + //TODO Add exec - Executes an application. + var endpoints EndpointsFlag + flag.Var(&endpoints,"endpoints", "Local and remote endpoints") flag.Parse() - // 27013 - serverEndpoint := &Endpoint{ - Host: *serverHostname, - Port: 22, + if len(endpoints) < 1 { + log.Printf("No endpoints defined: %d\n", len(endpoints)) - } + endpoints = make(EndpointsFlag, 0) + endpoints = append(endpoints, "27013:172.28.128.230:27013") + endpoints = append(endpoints, "8080:134.197.20.220:80") + endpoints = append(endpoints, "5555:172.28.128.230:5555") - localEndpoint := &Endpoint{ - Host: "localhost", - Port: 27013, } - remoteEndpoint := &Endpoint{ - Host: *remoteHostname, - Port: *remotePort, - } - - //5555 + localE, remoteE:= ParseEndpointsFromArray(endpoints) - localEndpoint2 := &Endpoint{ - Host: "localhost", - Port: 5555, - } + log.Printf("localE: %v remoteE: %v", localE, remoteE) - remoteEndpoint2 := &Endpoint{ - Host: *remoteHostname, - Port: 5555, + // 27013 + serverEndpoint := &Endpoint{ + Host: *serverHostname, + Port: 22, } + log.Printf("Endpoints: %v \nserverEndpoint: %v\n", endpoints, serverEndpoint) + //log.Printf("%v %v", serverEndpoint) //credChan := make(chan Credentials) - var passwordForm= NewPasswordForm() //passwordForm.SetChan(credChan) passwordForm.Show(); credentials := passwordForm.Credentials - log.Printf("Connecting to %s, User: %s ", *serverHostname, credentials.Username) - sshConfig := &ssh.ClientConfig{ - User: credentials.Username, - HostKeyCallback: KeyPrint, - Auth: []ssh.AuthMethod{ - ssh.Password(credentials.Password), - }, - } + User: credentials.Username, + HostKeyCallback: KeyPrint, + Auth: []ssh.AuthMethod{ + ssh.Password(credentials.Password), + }, + } tunnel := &SSHtunnel{ - Config: sshConfig, - Local: localEndpoint, - Server: serverEndpoint, - Remote: remoteEndpoint, - } - - tunnel2 := &SSHtunnel{ - Config: sshConfig, - Local: localEndpoint2, - Server: serverEndpoint, - Remote: remoteEndpoint2, - } - - go tunnel2.Start() - tunnel.Start() + Config: sshConfig, + Local: localE, + Server: serverEndpoint, + Remote: remoteE, + } + tunnel.Start() } diff --git a/app_test.go b/app_test.go index 051b2ed..6932a97 100644 --- a/app_test.go +++ b/app_test.go @@ -1,75 +1,110 @@ package sshtunnel import ( "testing" "fmt" ) //function func TestEndPointString(t *testing.T) { var testend Endpoint testend.Host = "localhost" testend.Port = 5555 // make a string and test what the string should be stringWant := "localhost:5555" returnedString := testend.String() fmt.Printf("\nstringWant: %s\n\n", stringWant) fmt.Printf("returnedString : %s\n\n", returnedString) // return string == stringwant // success } //function func TestEndPointParseString(t *testing.T) { var testRemote, testRemote1 Endpoint var testLocal, testLocal1 Endpoint // make a string and test what the string should be testLocal, testRemote = ParseEndpointString("5555:motherbrain.unr.edu:5555") testLocal1, testRemote1 = ParseEndpointString("27013:motherbrain.unr.edu:27013") fmt.Printf("Local: %s %d\n\n", testLocal.Host, testLocal.Port) fmt.Printf("Remote: %s %d \n\n", testRemote.Host, testRemote.Port) fmt.Printf("Local: %s %d\n\n", testLocal1.Host, testLocal1.Port) fmt.Printf("Remote: %s %d \n\n", testRemote1.Host, testRemote1.Port) if testLocal.Host != "localhost"{ t.Error("Invalid Local EP Hostname") } if testLocal.Port != 5555 { t.Error("Invalid Local EP Port") } if testRemote.Host != "motherbrain.unr.edu"{ t.Error("Invalid remote EP Hostname") } if testRemote.Port != 5555 { t.Error("Invalid remote EP Port") } //////////////////////////////////////// if testLocal1.Host != "localhost"{ t.Error("Invalid Local EP Hostname") } if testLocal1.Port != 27013 { t.Error("Invalid Local EP Port") } if testRemote1.Host != "motherbrain.unr.edu"{ t.Error("Invalid remote EP Hostname") } if testRemote1.Port != 27013 { t.Error("Invalid remote EP Port") } // return string == stringwant // success } + +func TestEndPointsArray(t *testing.T) { + + //Input + testInput := []string{ + "5555:motherbrain.unr.edu:5555", + "27013:motherbrain.unr.edu:27013", + } + fmt.Printf("Input %v \n\n", testInput) + + //Anticipated output + testLocalEndpoints := [2]Endpoint{ + {Host: "localhost", Port: 5555}, + {Host: "localhost", Port: 27013}, + }; + testRemoteEndpoints := [2]Endpoint{ + {Host: "motherbrain.unr.edu", Port: 5555}, + {Host: "motherbrain.unr.edu", Port: 27013}, + }; + + for r,_ := range(testLocalEndpoints){ + fmt.Printf("Test Local: %v\n\n", testLocalEndpoints[r]) + fmt.Printf("Test Remote: %v\n\n", testRemoteEndpoints[r]) + } + + // try to match the testInput with the anticipated output by passing the test, + // append the input to array + myLocalEndpoints, myRemoteEndpoints := ParseEndpointsFromArray(testInput); + + //Real output + for r,_ := range(myLocalEndpoints){ + fmt.Printf("Local: %v\n\n", myLocalEndpoints[r]) + fmt.Printf("Remote: %v\n\n", myRemoteEndpoints[r]) + } +} diff --git a/endpoint.go b/endpoint.go new file mode 100644 index 0000000..1e2acd8 --- /dev/null +++ b/endpoint.go @@ -0,0 +1,55 @@ +package sshtunnel + +import ( + "strings" + "net" + "golang.org/x/crypto/ssh" + "fmt" + "strconv" + "encoding/base64" +) + +type Client struct { + +} + +type Endpoint struct { + Host string + Port int +} + +func KeyPrint(dialAddr string, addr net.Addr, key ssh.PublicKey) error { + fmt.Printf("%s %s %s\n", strings.Split(dialAddr, ":")[0], key.Type(), base64.StdEncoding.EncodeToString(key.Marshal())) + return nil +} + +func (endpoint *Endpoint) String() string { + return fmt.Sprintf("%s:%d", endpoint.Host, endpoint.Port) +} + +func ParseEndpointString(endpointString string) (localEp Endpoint, remoteEp Endpoint){ + + localEp.Host = "localhost" + + buffer := strings.Split(endpointString, ":") + localEp.Port, _ = strconv.Atoi(buffer[0]) + + remoteEp.Host = buffer[1] + remoteEp.Port, _ = strconv.Atoi(buffer[2]) + + return localEp, remoteEp +} + + +func ParseEndpointsFromArray(endpointStruct []string) ([]Endpoint, []Endpoint) { + localEndpoints := make([]Endpoint, 0) + remoteEndpoints := make([]Endpoint, 0) + + for _, e := range endpointStruct { + localE, remoteE := ParseEndpointString(e) + localEndpoints = append(localEndpoints, localE) + remoteEndpoints = append(remoteEndpoints, remoteE) + } + + return localEndpoints, remoteEndpoints +} diff --git a/endpoints_flag.go b/endpoints_flag.go new file mode 100644 index 0000000..7b6628f --- /dev/null +++ b/endpoints_flag.go @@ -0,0 +1,12 @@ +package sshtunnel + +type EndpointsFlag []string + +func (i *EndpointsFlag) String() string { + return "my string representation" +} + +func (i *EndpointsFlag) Set(value string) error { + *i = append(*i, value) + return nil +} diff --git a/tunnel.go b/tunnel.go new file mode 100644 index 0000000..e512f46 --- /dev/null +++ b/tunnel.go @@ -0,0 +1,138 @@ +package sshtunnel + +import ( + "net" + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/agent" + "fmt" + "io" + "os" + "log" + "sync" + "time" +) + +type SSHtunnel struct { + Local []Endpoint + Server *Endpoint + Remote []Endpoint + ErrorChans []chan error + ServerClient *ssh.Client + Config *ssh.ClientConfig + ServerConnection *ssh.ServerConfig + ExecPath string + sync.Mutex +} + + +func(tunnel *SSHtunnel) Listen( localE Endpoint, remoteE Endpoint, errorChan chan error) { + + + log.Printf("LocalEndpoint: %v, RemoteEndpoint: %v", localE, remoteE) + listener, err := net.Listen("tcp", localE.String()) + if err != nil { + return + } + + for { + conn, err := listener.Accept() + + if err != nil { + log.Println("Cannot listen.") + errorChan <- err + } + + defer listener.Close() + + log.Printf("Localhost listen %s, Remote Connection: %s \n", localE.String(), remoteE.String()) + go tunnel.forward(conn, remoteE) + } + +} + +func (tunnel *SSHtunnel) Start() error { + var err error + + tunnel.ErrorChans = make([]chan error, len(tunnel.Local)) + // + + tunnel.ServerClient, err = ssh.Dial("tcp", tunnel.Server.String(), tunnel.Config) + if err != nil { + fmt.Printf("Server dial error: %s\n", err) + //log.Println("Sever dial error occurred. No Internet connection. Program exited.") + // if err == "ssh: handshake failed: ssh: unable to authenticate, attempted methods [password none], no supported methods remain" { + // fmt.Println(" Incorrect Password entered.") + + os.Exit(1) + return err + } + + + + // + for i, _ := range tunnel.Local { + + //defer listener.Close() + log.Printf("Main - LocalEndpoint: %v, RemoteEndpoint: %v", tunnel.Local[i], tunnel.Remote[i]) + go tunnel.Listen(tunnel.Local[i], tunnel.Remote[i], tunnel.ErrorChans[i]) + } + + for { + for _, c := range tunnel.ErrorChans { + select { + case msg := <-c: + log.Printf("\n in tunnel.errorschans : %v", msg) + // case msg2:= <- done: + // log.Printf("%v", msg2) + // os.Exit(1) + default: + //fmt.Println("No message received") + } + } + time.Sleep(time.Second * 1) + //TODO run ExecPath + } + return err + +} + +//TODO refactor to accept remoteEndpoint +func (tunnel *SSHtunnel) forward(localConn net.Conn, remoteEndpoint Endpoint) { + var err error + + + + log.Printf("Connect with client. %s \n", remoteEndpoint.String()) + remoteConn, err := tunnel.ServerClient.Dial("tcp", remoteEndpoint.String()) + if err != nil { + fmt.Printf("Remote dial error: %s\n", err) + log.Println("Remote dial error occurred. Trying to establish a new server connection.") + tunnel.ServerClient, err = ssh.Dial("tcp", tunnel.Server.String(), tunnel.Config) + return + } + log.Printf("Remote connection: %v", remoteConn) + + copyConn:=func(writer, reader net.Conn) { + _, err:= io.Copy(writer, reader) + if err != nil { + log.Printf("%s io.Copy error: %s \n", remoteEndpoint.String(), err) + } + } + + go copyConn(localConn, remoteConn) + go copyConn(remoteConn, localConn) + +} + +func SSHAgent() ssh.AuthMethod { + if sshAgent, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")); err == nil { + return ssh.PublicKeysCallback(agent.NewClient(sshAgent).Signers) + } + return nil +} + +func check(e error) { + if e != nil { + panic(e) + } +}