package main

import (
	"strings"
	"os"
	"fmt"
	"io"
	"net"
	"log"
  "encoding/base64"
	"golang.org/x/crypto/ssh"
	"golang.org/x/crypto/ssh/agent"
)

type Endpoint struct {
	Host string
	Port int
}

func KeyPrint(dialAddr string, addr net.Addr, key ssh.PublicKey) error {
        fmt.Printf("%s %s %s\n", strings.Split(dialAddr, ":")[0], key.Type(), base64.StdEncoding.EncodeToString(key.Marshal()))
        return nil
}

func (endpoint *Endpoint) String() string {
	return fmt.Sprintf("%s:%d", endpoint.Host, endpoint.Port)
}

type SSHtunnel struct {
	Local  *Endpoint
	Server *Endpoint
	Remote *Endpoint

	Config *ssh.ClientConfig
}

func (tunnel *SSHtunnel) Start() error {
	listener, err := net.Listen("tcp", tunnel.Local.String())
	if err != nil {
		return err
	}
	defer listener.Close()

	for {
		conn, err := listener.Accept()
		if err != nil {
			return err
		}
		go tunnel.forward(conn)
	}
}

func (tunnel *SSHtunnel) forward(localConn net.Conn) {
	serverConn, err := ssh.Dial("tcp", tunnel.Server.String(), tunnel.Config)
	if err != nil {
		fmt.Printf("Server dial error: %s\n", err)
		return
	}

	remoteConn, err := serverConn.Dial("tcp", tunnel.Remote.String())
	if err != nil {
		fmt.Printf("Remote dial error: %s\n", err)
		return
	}

	copyConn:=func(writer, reader net.Conn) {
		_, err:= io.Copy(writer, reader)
		if err != nil {
			fmt.Printf("io.Copy error: %s", err)
		}
	}

	go copyConn(localConn, remoteConn)
	go copyConn(remoteConn, localConn)
}

func SSHAgent() ssh.AuthMethod {
	if sshAgent, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")); err == nil {
		return ssh.PublicKeysCallback(agent.NewClient(sshAgent).Signers)
	}
	return nil
}

func check(e error) {
  if e != nil {
    panic(e)
  }
}

func main() {

	fout, err := os.OpenFile("C:\\Projects\\sshtunnel\\sshtunnel_log.log", os.O_RDWR | os.O_CREATE | os.O_APPEND, 0666)
	check(err)
	log.SetOutput(fout)

	localHostName, err := os.Hostname()
	check(err)

	localEndpoint := &Endpoint{
		Host: localHostName,
		Port: 27013,
	}
	log.Println( "Host:", localEndpoint.Host, "Port:" , localEndpoint.Port)

	serverEndpoint := &Endpoint{
		Host: "ubuntu.cse.unr.edu",
		Port: 22,

	}
	log.Println( "Host:", serverEndpoint.Host, "Port:" , serverEndpoint.Port)

	remoteEndpoint := &Endpoint{
		Host: "motherbrain.unr.edu",
		Port: 27013,
	}
	log.Println( "Host:", remoteEndpoint.Host, "Port:" , remoteEndpoint.Port)

	sshConfig := &ssh.ClientConfig{
		User: "",
		HostKeyCallback: KeyPrint,
		Auth: []ssh.AuthMethod{
		ssh.Password(""),
		},
	}

	tunnel := &SSHtunnel{
		Config: sshConfig,
		Local:  localEndpoint,
		Server: serverEndpoint,
		Remote: remoteEndpoint,
	}

		var currentNetworkHardwareName string

		interfaces, _ := net.Interfaces()
		for _, interf := range interfaces {
		       if addrs, err := interf.Addrs(); err == nil {
		               for index, addr := range addrs {
                   			log.Println("[", index, "]", interf.Name, ">", addr)

			                  log.Println("Use name : ", interf.Name)
			                  currentNetworkHardwareName = interf.Name
		               }
		       }
		}

		// extract the hardware information base on the interface name
		// capture above
		netInterface, err := net.InterfaceByName(currentNetworkHardwareName)
		check(err)
		log.Println(netInterface)

		name := netInterface.Name
		macAddress := netInterface.HardwareAddr

		log.Println("Hardware name : ", name)
		log.Println("MAC address : ", macAddress)

	tunnel.Start()
	fout.Close()
}
