zet

Caching ssh connections is 68 times faster than not

Caching the TCP/IP dial-up connection (*ssh.Client) is 67.72 times faster than recreating a new connection every time:

Uncached:   4059  88907623 ns/op  109533 B/op  570 allocs/op
Cached:   277924   1312706 ns/op   70661 B/op  145 allocs/op
scale=2; 88907623 / 1312706
go test -bench . -benchtime=5m -benchmem
goos: linux
goarch: amd64
pkg: github.com/rwxrob/ssh
cpu: Intel(R) Core(TM) i7-9700 CPU @ 3.00GHz
BenchmarkRun-8              4059          88907623 ns/op          109533 B/op        570 allocs/op
PASS
go test -bench . -benchtime=5m -benchmem
goos: linux
goarch: amd64
pkg: github.com/rwxrob/ssh
cpu: Intel(R) Core(TM) i7-9700 CPU @ 3.00GHz
BenchmarkRun-8            277924           1312706 ns/op           70661 B/op        145 allocs/op
PASS

Here’s the benchmark code:

package ssh_test

import (
	"fmt"
	"log"
	"os"
	"testing"
	"time"

	"github.com/rwxrob/ssh"
)

var ukey = `
-----BEGIN OPENSSH PRIVATE KEY-----
...
-----END OPENSSH PRIVATE KEY-----
`

func BenchmarkRun(b *testing.B) {
	for i := 0; i < b.N; i++ {
		stdout, stderr, err := ssh.Run(`user@localhost:22`, []byte(ukey), nil, `echo hi there`, ``)
		if err != nil || stdout != "hi there\n" || stderr != `` {
			log.Print(stdout)
			return
		}
	}
}

(Make sure not to use a private key that actually matters.)

The cache in this case is just a simple map with absolutely minimal closed connection recovery (which returns an error when creating a new Session).

package ssh

import (
	"fmt"
	"log"
	"math/rand"
	"strings"
	"time"

	"golang.org/x/crypto/ssh"
)

// TCPTimeout is the default number of seconds to wait to complete a TCP
// connection.
var TCPTimeout = 300 * time.Second


var clients = map[string]*ssh.Client{}

// Run wraps the ssh.Session.Run command with sensible, stand-alone
// defaults. This function has no dependencies on any underlying ssh
// host installation making it idea for light-weight, remote ssh calls.
//
// Run combines several steps. First, a client secure shell connection
// is Dialed to the target (user@host:PORT) using the private key of the
// local user (ukey) and public host key in authorized_keys format (or nil
// to skip). Run then attempts to create a Session
// calling Run on it to execute the passed cmd feeding it any standard
// input (in) provided.  The standard output, standard error are then
// buffered and returned as strings. The exit value is captured in err
// for any exit code other than 0. See the ssh.Session.Run method for
// more information.
//
// Note that there are no limitations on the size of input and output
// meaning Run should only be used when calling remote commands that can
// be trusted not to produce too much output.
//
func Run(target string, ukey, hkey []byte, cmd, in string) (stdout, stderr string, err error) {

	t := strings.Split(target, "@")
	if len(t) != 2 {
		err = fmt.Errorf(`invalid target: %q`, target)
		return
	}
	user := t[0]
	addr := t[1]

	signer, err := ssh.ParsePrivateKey(ukey)
	if err != nil {
		return
	}

	var callback ssh.HostKeyCallback
	var hostkey, hostpub ssh.PublicKey

	if hkey != nil {

		hostkey, _, _, _, err = ssh.ParseAuthorizedKey(hkey)
		if err != nil {
			return
		}

		hostpub, err = ssh.ParsePublicKey(hostkey.Marshal())
		if err != nil {
			return
		}

		callback = ssh.FixedHostKey(hostpub)

	} else {

		callback = ssh.InsecureIgnoreHostKey()

	}

	var tried bool
	var client *ssh.Client

GETCLIENT:
	client, cached := clients[target]
	if !cached {
		client, err = ssh.Dial(`tcp`, addr, &ssh.ClientConfig{
			User:            user,
			Auth:            []ssh.AuthMethod{ssh.PublicKeys(signer)},
			HostKeyCallback: callback,
			Timeout:         TCPTimeout,
		})
		if err != nil {
			return
		}
		clients[target] = client
	}

	sess, err := client.NewSession()
	if sess != nil {
		defer sess.Close()
	}
	if err != nil {
		if !tried {
			delete(clients, target)
			tried = true
			goto GETCLIENT
		}
		return
	}

	if in != "" {
		sess.Stdin = strings.NewReader(in)
	}

	_out := new(strings.Builder)
	_err := new(strings.Builder)
	sess.Stdout = _out
	sess.Stderr = _err

	err = sess.Run(cmd)
	stdout = _out.String()
	stderr = _err.String()

	return

}

// ------------------------------- User -------------------------------

// we use an interface for flexibility and to allow the work to create
// a signer to only be needed once upon creation

type User struct {
	Name   string
	Key    []byte // original pemkey
	Signer ssh.Signer
}

func NewUser(name string, pemkey []byte) (*User, error) {
	var err error
	u := new(User)
	u.Name = name
	u.Key = pemkey
	u.Signer, err = ssh.ParsePrivateKey(pemkey)
	if err != nil {
		return u, err
	}
	return u, nil
}

// ------------------------------- Host -------------------------------

// we use an interface for flexibility and to allow the work to create
// a signer to only be needed once upon creation

type Host struct {
	Addr    string        // name or IP
	Auth    []byte        // authorized_hosts format
	Netkey  ssh.PublicKey // RFC 4235, section 6.6
	Pubkey  ssh.PublicKey // suitable for ssh.FixedHostkey
	Comment string        // authorized_hosts comment
	Options []string      // authorized_hosts options
	Client  *ssh.Client   // last (cached) client connection
}

func NewHost(addr string, authkey []byte) (*Host, error) {
	var err error
	host := new(Host)
	host.Addr = addr
	host.Auth = authkey

	if authkey == nil {
		return host, nil
	}

	host.Netkey, host.Comment, host.Options, _, err = ssh.ParseAuthorizedKey(authkey)
	if err != nil {
		return host, err
	}

	// required since host.net (also ssh.PublicKey) is in RFC format
	// (which fails for ssh.FixedHostKey)

	host.Pubkey, err = ssh.ParsePublicKey(host.Netkey.Marshal())
	if err != nil {
		return host, err
	}

	return host, nil
}

// -------------------------- MultiHostClient -------------------------

type MultiHostClient struct {
	User      *User         // user credentials to use
	Hosts     []*Host       // hosts to Dial
	Timeout   time.Duration // TCP/IP timeout (not session)
	Sleep     time.Duration // time to sleep between Dial calls
	Attempts  int           // number of attempts (0 same as 1)
	SafeDelim string        // RunSafe delimiter (default pkg SafeDelim)

	last int
}

func (c MultiHostClient) assert() {
	switch {
	case c.User == nil:
		panic(`undefined User`)
	case c.User.Name == "":
		panic(`undefined User.Name`)
	case c.User.Signer == nil:
		panic(`undefined User.Signer`)
	case c.Hosts == nil:
		panic(`undefined Hosts`)
	case c.Timeout == 0:
		panic(`Timeout cannot be 0`)
	case c.Attempts == 0:
		panic(`Attempts cannot be 0`)
	}
}

// Dial attempts to dial a random host from Hosts and rotates through
// all hosts until one responds or all hosts have been tried. Reuses a
// cached Host.Client if available to prevent creating new TCP/IP
// connections if not needed. The first host to respond to Dial is
// used.  Panics if any User, Hosts, Timeout, or Attempts is undefined.
func (c *MultiHostClient) Dial() (*ssh.Client, error) {
	c.assert()

	rand.Seed(time.Now().UnixNano())
	c.last = rand.Intn(len(c.Hosts))
	host := c.Hosts[c.last]

	var err error
	var client *ssh.Client
	var callback ssh.HostKeyCallback

	for n := 0; n < len(c.Hosts); n++ {

		if host.Auth == nil {
			callback = ssh.InsecureIgnoreHostKey()
		} else {
			callback = ssh.FixedHostKey(host.Pubkey)
		}

		client, err = ssh.Dial(`tcp`, host.Addr, &ssh.ClientConfig{
			User:            c.User.Name,
			Auth:            []ssh.AuthMethod{ssh.PublicKeys(c.User.Signer)},
			HostKeyCallback: callback,
			Timeout:         c.Timeout,
		})

		if err == nil {
			return client, nil
		}

		// error during dial
		log.Print(err)

		if c.last == len(c.Hosts)-1 {
			c.last = 0
		} else {
			c.last++
		}

		host = c.Hosts[c.last]
	}

	return nil, err

}

// DialUntil attempts to Dial Attempts number of times waiting in
// between for Sleep seconds in between each attempt.
func (c *MultiHostClient) DialUntil() (client *ssh.Client, err error) {
	n := 1
	for {
		client, err = c.Dial()
		if client != nil || n >= c.Attempts {
			break
		}
		n++
		time.Sleep(c.Sleep)
	}
	return
}

// Run gets a client connection with DialUntil and then runs the command
// (with optional stdin) as a Session on the remote host capturing the
// stdout, stderr as strings and returning the exit value in the error
// (see ssh.Session.Run). Note that no sanity checking is performed on
// the command passed and that most SSH servers will pass the cmd to the
// shell assigned to the remote user. This means that if semicolon where
// passed within the cmd string unexpected behavior could result.
// Therefore, it is critical that the cmd string passed be rigorously
// validated (usually through a very strict regular expression match) to
// prevent shell injection vulnerabilities. Another preventative
// measure is to provide a rudimentary shell for the remote user that
// disallows any shell expansion of any kind (effectively limiting all
// remote commands to their exec syscall equivalents).
func (c *MultiHostClient) Run(cmd, stdin string) (stdout, stderr string, err error) {

	client, err := c.DialUntil()
	if client == nil {
		err = fmt.Errorf(`failed to get client connection`)
		return
	}

	var sess *ssh.Session
	sess, err = client.NewSession()
	if err != nil {
		return
	}

	if stdin != "" {
		sess.Stdin = strings.NewReader(stdin)
	}

	_out := new(strings.Builder)
	_err := new(strings.Builder)
	sess.Stdout = _out
	sess.Stderr = _err

	err = sess.Run(cmd)
	stdout = _out.String()
	stderr = _err.String()

	return
}

func (c MultiHostClient) RunSafe(cmd string, args ...string) (stdout, stderr string, err error) {
	safelist := append([]string{cmd}, args...)
	if c.SafeDelim == "" {
		c.SafeDelim = SafeDelim
	}
	_cmd := strings.Join(safelist, c.SafeDelim)
	return c.Run(_cmd, ``)
}