Code-Hex / pget

The fastest, resumable file download client

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

sftp support?

yuzhichang opened this issue · comments

As far as I know, Lftp is the only one which download a file on SFTP server concurrently. I wish this pget tool support this feature, base on the sftp lib.

I've made one by myself - ycp copy a local file from/to a SFTP server with multiple threads.

package main

import (
	"bufio"
	"flag"
	"fmt"
	"io"
	"io/ioutil"
	"os"
	"os/user"
	"path/filepath"
	"strconv"
	"strings"
	"sync"
	"time"

	log "github.com/sirupsen/logrus"

	"github.com/pkg/errors"
	"github.com/pkg/sftp"
	"golang.org/x/crypto/ssh"
)

const (
	// There's a cold start cost inside pkg/sftp client Read/Write operation.
	// So BUFSIZE shall be large enough in order not to hurt the throughput.
	BUFSIZE int = 1 << 21 //2MB
	// Parrallel make sence when `file-size` is much larger than `bps-of-one-connection * sftp-establish-time`.
	// However it's hard to estimate `bps-of-one-connection` before transfer.
	ONE_CONN_LIMIT int = BUFSIZE
)

type Account struct {
	Ip       string
	Port     string
	Username string
	Password string
}

func parseAccounts() (accounts map[string]Account, err error) {
	var usr *user.User
	if usr, err = user.Current(); err != nil {
		err = errors.Wrap(err, "")
		return
	}
	fp := filepath.Join(usr.HomeDir, ".accounts.txt")
	var file *os.File
	if file, err = os.Open(fp); err != nil {
		err = errors.Wrap(err, "")
		return
	}
	defer file.Close()
	scanner := bufio.NewScanner(file)
	accounts = make(map[string]Account)
	for scanner.Scan() {
		line := strings.TrimSpace(scanner.Text())
		if line == "" || line[0] == '#' {
			continue
		}
		//ip, hostnames, port, username, password, note
		delm := func(c rune) bool {
			return (c == '\t')
		}
		fields := strings.FieldsFunc(line, delm)
		if len(fields) < 5 {
			err = errors.Errorf("invalid line: %s", line)
			return
		}
		ip, hostnames, port, username, password := strings.TrimSpace(fields[0]), strings.TrimSpace(fields[1]), strings.TrimSpace(fields[2]), strings.TrimSpace(fields[3]), strings.TrimSpace(fields[4])
		account := Account{Ip: ip, Port: port, Username: username, Password: password}
		accounts[ip] = account
		fields2 := strings.Split(hostnames, ",")
		for _, hostname := range fields2 {
			hostname = strings.TrimSpace(hostname)
			accounts[hostname] = account
		}
	}
	return
}

func getAccount(host string) (account Account, err error) {
	var accounts map[string]Account
	if accounts, err = parseAccounts(); err != nil {
		return
	}
	var found bool
	if account, found = accounts[host]; !found {
		err = errors.Errorf("host %s is not in the accounts database!", host)
		return
	}
	return
}

func min(a, b int64) int64 {
	if a > b {
		return b
	}
	return a
}

type WorkerResult struct {
	Seq int
	Err error
}

type Worker struct {
	seq   int
	ch    chan<- WorkerResult
	acc   Account
	rmtfp string
	lclfp string
	put   bool
	buf   []byte
	c     *sftp.Client
	rmt   *sftp.File
	lcl   *os.File
	lock  sync.Mutex //protect begin, end
	begin int64
	end   int64
}

func NewWorker(seq int, ch chan<- WorkerResult, account Account, rmtfp, lclfp string, put bool) (w *Worker) {
	return &Worker{
		seq:   seq,
		ch:    ch,
		acc:   account,
		rmtfp: rmtfp,
		lclfp: lclfp,
		put:   put,
		buf:   make([]byte, BUFSIZE),
	}
}

func (w *Worker) Share(c *sftp.Client, rmt *sftp.File, lcl *os.File) {
	w.c = c
	w.rmt = rmt
	w.lcl = lcl
}

func (w *Worker) Close() {
	if w.c != nil {
		w.c.Close()
	}
	if w.rmt != nil {
		w.rmt.Close()
	}
	if w.lcl != nil {
		w.lcl.Close()
	}
	w.c = nil
	w.rmt = nil
	w.lcl = nil
}

func (w *Worker) Stat() (begin, end int64) {
	w.lock.Lock()
	begin = w.begin
	end = w.end
	w.lock.Unlock()
	return
}

func (w *Worker) Start(begin, end int64) {
	w.lock.Lock()
	w.begin = begin
	w.end = end
	w.lock.Unlock()
	go func() {
		log.Infof("worker %d started, begin: %d, end: %d", w.seq, begin, end)
		var err error
		res := WorkerResult{Seq: w.seq, Err: nil}
		if w.c == nil {
			if w.c, w.rmt, w.lcl, err = sftpConnect(w.acc, w.rmtfp, w.lclfp, w.put); err != nil {
				w.ch <- res
				return
			}
			log.Infof("worker %d established sftp connection with %s", w.seq, w.acc.Ip)
		}

		if _, err = w.rmt.Seek(begin, io.SeekStart); err != nil {
			res.Err = errors.Wrap(err, "")
			w.ch <- res
			return
		}
		if _, err = w.lcl.Seek(begin, io.SeekStart); err != nil {
			res.Err = errors.Wrap(err, "")
			w.ch <- res
			return
		}
		var n int
		begin, end = w.Stat()
		toread := end - begin
		var reader io.Reader
		var writer io.Writer
		if w.put {
			reader = w.lcl
			writer = w.rmt
		} else {
			reader = w.rmt
			writer = w.lcl
		}
		for toread > 0 {
			if n, err = reader.Read(w.buf[:min(toread, int64(BUFSIZE))]); n > 0 {
				log.Debugf("worker %d got %d bytes", w.seq, n)
				if _, err = writer.Write(w.buf[:n]); err != nil {
					res.Err = errors.Wrap(err, "")
					w.ch <- res
					return
				}
				w.lock.Lock()
				begin += int64(n)
				w.begin = begin
				end = w.end
				w.lock.Unlock()
				toread = end - begin
			} else if err != nil {
				res.Err = errors.Wrap(err, "")
				w.ch <- res
				return
			}
		}
		log.Infof("worker %d done", w.seq)
		w.ch <- res
		return
	}()
}

func (w *Worker) ChangeEnd(end int64) (err error) {
	w.lock.Lock()
	if end > w.end {
		err = errors.Errorf("worker %d change end from %d to %d", w.seq, w.end, end)
	} else {
		log.Infof("worker %d change end from %d to %d", w.seq, w.end, end)
		w.end = end
	}
	w.lock.Unlock()
	return
}

type Status struct {
	Fp      string
	MaxConn int
	Begins  []int64
	Ends    []int64
}

func NewStatus(rmtfp, lclfp string, put bool) (status *Status) {
	status = &Status{}
	if put {
		usr, err := user.Current()
		if err != nil {
			log.Fatalf("got error: %+v", err)
		}
		status.Fp = filepath.Join(usr.HomeDir, filepath.Base(rmtfp)+".ycp-put-status")
	} else {
		status.Fp = lclfp + ".ycp-get-status"
	}
	return
}

func (status *Status) Load() (err error) {
	var file *os.File
	if file, err = os.Open(status.Fp); err != nil {
		err = errors.Wrapf(err, "")
		return
	}
	defer file.Close()

	scanner := bufio.NewScanner(file)
	var begin, end int64
	for scanner.Scan() {
		fields := strings.Split(scanner.Text(), "\t")
		if len(fields) != 2 {
			err = errors.Errorf("invalid line %s", scanner.Text())
			return
		}
		if begin, err = strconv.ParseInt(strings.TrimSpace(fields[0]), 10, 64); err != nil {
			err = errors.Errorf("invalid line %s", scanner.Text())
			return
		}
		if end, err = strconv.ParseInt(strings.TrimSpace(fields[1]), 10, 64); err != nil {
			err = errors.Errorf("invalid line %s", scanner.Text())
			return
		}
		status.Begins = append(status.Begins, begin)
		status.Ends = append(status.Ends, end)
	}
	status.MaxConn = len(status.Begins)
	return
}

func (status *Status) Store() (err error) {
	var file *os.File
	flag := os.O_WRONLY | os.O_CREATE | os.O_TRUNC
	if file, err = os.OpenFile(status.Fp, flag, 0600); err != nil {
		err = errors.Wrap(err, "")
		return
	}
	defer file.Close()
	for i := 0; i < status.MaxConn; i++ {
		if _, err = file.WriteString(fmt.Sprintf("%d\t%d\n", status.Begins[i], status.Ends[i])); err != nil {
			err = errors.Wrap(err, "")
			return
		}
	}
	return
}

func publicKey(path string) ssh.AuthMethod {
	key, err := ioutil.ReadFile(path)
	if err != nil {
		panic(err)
	}
	signer, err := ssh.ParsePrivateKey(key)
	if err != nil {
		panic(err)
	}
	return ssh.PublicKeys(signer)
}

func sftpConnect(account Account, rmtfp, lclfp string, put bool) (c *sftp.Client, rmt *sftp.File, lcl *os.File, err error) {
	var auths []ssh.AuthMethod
	if account.Password != "publickey" {
		auths = append(auths, ssh.Password(account.Password))
	} else {
		usr, err := user.Current()
		if err != nil {
			log.Fatal(err)
		}
		auths = append(auths, publicKey(filepath.Join(usr.HomeDir, ".ssh/id_rsa")))
	}
	config := ssh.ClientConfig{
		User:            account.Username,
		Auth:            auths,
		HostKeyCallback: ssh.InsecureIgnoreHostKey(),
	}
	addr := fmt.Sprintf("%s:%s", account.Ip, account.Port)
	var conn *ssh.Client
	if conn, err = ssh.Dial("tcp", addr, &config); err != nil {
		err = errors.Wrap(err, "")
		return
	}

	if c, err = sftp.NewClient(conn); err != nil {
		err = errors.Wrap(err, "")
		return
	}
	flag := os.O_WRONLY | os.O_CREATE
	if put {
		if rmt, err = c.OpenFile(rmtfp, flag); err != nil {
			err = errors.Wrap(err, "")
			return
		}
		if lcl, err = os.Open(lclfp); err != nil {
			err = errors.Wrap(err, "")
			return
		}
	} else {
		if rmt, err = c.Open(rmtfp); err != nil {
			err = errors.Wrap(err, "")
			return
		}
		if lcl, err = os.OpenFile(lclfp, flag, 0600); err != nil {
			err = errors.Wrap(err, "")
			return
		}
	}
	return
}

func parseUrls(arg1, arg2 string) (host, rmtfp, lclfp string, put bool, err error) {
	fields1 := strings.Split(arg1, ":")
	fields2 := strings.Split(arg2, ":")
	if len(fields1) == 1 && len(fields2) == 2 {
		put = true
		host = fields2[0]
		rmtfp = fields2[1]
		lclfp = arg1
		base := filepath.Base(rmtfp)
		baseCh := base[len(base)-1]
		if baseCh == '.' || baseCh == '/' {
			rmtfp = filepath.Join(rmtfp, filepath.Base(lclfp))
		}
	} else if len(fields1) == 2 && len(fields2) == 1 {
		put = false
		host = fields1[0]
		rmtfp = fields1[1]
		lclfp = arg2
		var fi os.FileInfo
		if fi, err = os.Stat(lclfp); err != nil {
			err = errors.Wrap(err, "")
			return
		}
		if fi.IsDir() {
			lclfp = filepath.Join(lclfp, filepath.Base(rmtfp))
		}
	} else {
		err = errors.Errorf("one url shall be remote, and the other one shall be local")
		return
	}
	return
}

func transfer() (err error) {
	var verbose bool
	var test bool
	var maxconn int
	var txmsg, srcurl string
	flag.BoolVar(&verbose, "v", false, "verbose mode")
	flag.BoolVar(&test, "t", false, "test mode, no continue")
	flag.IntVar(&maxconn, "n", 10, "set maximum number of connections")
	rationale := `This util copy a local file from/to a SFTP server with multiple threads. It's inspired by and a bit faster than lftp's pget under the same setup(maxconn=10). Why I invent another pget-like tool?
 - lftp pget parallel("-n") is not good. Threads perform variously. Some done their job much earlier then others without helping them.
 - lftp pget continue("-c") is buggy. Rerunning "pget -c" often begins a fresh download.
 - lftp doesn't support upload a file parallelly.
 - lftp is 20+ years old C++ code.`
	flag.Usage = func() {
		fmt.Fprintf(os.Stderr, "%s\n%s [-h] [-v] [-c] [-n maxconn] [host:]file1 [host:]file2\nThis util assumes $HOME/.accounts.txt stores accounts.\n", rationale, os.Args[0])
		flag.PrintDefaults()
	}
	flag.Parse()
	if verbose {
		log.SetLevel(log.DebugLevel)
	}
	if flag.NArg() != 2 {
		return errors.Errorf("got %d arguments while expect two", flag.NArg())
	}
	var host, rmtfp, lclfp string
	var put bool
	if host, rmtfp, lclfp, put, err = parseUrls(flag.Args()[0], flag.Args()[1]); err != nil {
		return
	}
	if put {
		srcurl = lclfp
		txmsg = fmt.Sprintf("%s -> %s:%s", lclfp, host, rmtfp)
	} else {
		srcurl = fmt.Sprintf("%s:%s", host, rmtfp)
		txmsg = fmt.Sprintf("%s:%s -> %s", host, rmtfp, lclfp)
	}
	var account Account
	if account, err = getAccount(host); err != nil {
		return
	}

	var xfer_total, xfer_remained, xfer_remained_prev, xfer_bps int64
	xfer_begin := time.Now()
	cont := !test
	if test && maxconn <= 1 {
		var c *sftp.Client
		var rmt *sftp.File
		var lcl *os.File
		if c, rmt, lcl, err = sftpConnect(account, rmtfp, lclfp, put); err != nil {
			return
		}
		log.Infof("established sftp connection with %s", account.Ip)
		defer c.Close()
		defer rmt.Close()
		defer lcl.Close()
		if put {
			xfer_total, err = io.Copy(lcl, rmt)

		} else {
			xfer_total, err = io.Copy(rmt, lcl)
		}
		if err != nil {
			err = errors.Wrap(err, "")
			return
		}
	} else {
		status := NewStatus(rmtfp, lclfp, put)
		if cont {
			if err = status.Load(); err != nil {
				if os.IsNotExist(errors.Cause(err)) {
					err = nil
				} else {
					return
				}
			}
			log.Infof("status file: %s", status.Fp)
		}

		var workers []*Worker
		ch := make(chan WorkerResult)
		if cont && status.MaxConn > 0 {
			maxconn = status.MaxConn
			workers = make([]*Worker, maxconn)
			for i := 0; i < maxconn; i++ {
				workers[i] = NewWorker(i, ch, account, rmtfp, lclfp, put)
				workers[i].Start(status.Begins[i], status.Ends[i])
				xfer_total += status.Ends[i] - status.Begins[i]
			}
		} else {
			var c *sftp.Client
			var rmt *sftp.File
			var lcl *os.File
			var fi os.FileInfo
			if c, rmt, lcl, err = sftpConnect(account, rmtfp, lclfp, put); err != nil {
				return
			}
			log.Infof("worker %d established sftp connection with %s", maxconn-1, account.Ip)
			if put {
				fi, err = lcl.Stat()
			} else {
				fi, err = rmt.Stat()
			}
			if err != nil {
				err = errors.Wrap(err, "")
				return
			}
			fileSize := fi.Size()
			xfer_total = fileSize
			msg := fmt.Sprintf("file size of %s: %d", srcurl, fileSize)
			if fileSize <= int64(ONE_CONN_LIMIT) {
				maxconn = 1
				msg += ", will transfer with one thread"
			}
			log.Infof(msg)

			if cont {
				status.MaxConn = maxconn
				status.Begins = make([]int64, maxconn)
				status.Ends = make([]int64, maxconn)
			} else {
				if put {
					if err = c.Remove(rmtfp); err != nil {
						err = errors.Wrap(err, "")
						return
					}
				} else {
					if err = os.Remove(lclfp); err != nil {
						err = errors.Wrap(err, "")
						return
					}
				}
			}

			chunkSize := fileSize / int64(maxconn)
			workers = make([]*Worker, maxconn)
			for i := 0; i < maxconn-1; i++ {
				workers[i] = NewWorker(i, ch, account, rmtfp, lclfp, put)
				begin := chunkSize * int64(i)
				end := chunkSize * int64((i + 1))
				workers[i].Start(begin, end)
			}
			last_worker := NewWorker(maxconn-1, ch, account, rmtfp, lclfp, put)
			workers[maxconn-1] = last_worker
			last_worker.Share(c, rmt, lcl)
			begin := chunkSize * int64(maxconn-1)
			end := fileSize
			last_worker.Start(begin, end)
		}

		defer func() {
			for i := 0; i < maxconn; i++ {
				workers[i].Close()
			}
		}()

		xfer_remained_prev = xfer_total
		interval := time.Duration(10) * time.Second
		ticker := time.NewTicker(interval)
		defer ticker.Stop()
		running := maxconn
		for running != 0 {
			select {
			case res := <-ch:
				if res.Err != nil {
					return res.Err
				}
				running--
				var lastSeq int
				var lastBegin, lastEnd int64
				for i := 0; i < maxconn; i++ {
					if i == res.Seq {
						continue
					}
					begin, end := workers[i].Stat()
					if begin >= end {
						continue
					}
					log.Infof("[%d] worker %d remained %d bytes", res.Seq, i, end-begin)
					if end-begin > int64(BUFSIZE) && end-begin > lastEnd-lastBegin {
						lastSeq = i
						lastBegin = begin
						lastEnd = end
					}
				}
				if lastBegin < lastEnd {
					split := (lastEnd + lastBegin + int64(BUFSIZE)) / int64(2)
					if err = workers[lastSeq].ChangeEnd(split); err != nil {
						return
					}
					workers[res.Seq].Start(split, lastEnd)
					running++
				}
			case <-ticker.C:
				xfer_remained = 0
				for i := 0; i < maxconn; i++ {
					begin, end := workers[i].Stat()
					xfer_remained += end - begin
					if cont {
						status.Begins[i] = begin
						status.Ends[i] = end
					}
				}
				xfer_percent := 100 - xfer_remained*100/xfer_total
				xfer_bps = (xfer_remained_prev-xfer_remained)/int64(2*interval.Seconds()) + xfer_bps/2
				xfer_eta := "--"
				if xfer_bps != 0 {
					xfer_eta = fmt.Sprintf("%d", xfer_remained/xfer_bps)
				}
				xfer_remained_prev = xfer_remained
				log.Infof("%d%% %dB %dB/s %ss ETA", xfer_percent, xfer_total-xfer_remained, xfer_bps, xfer_eta)
				if cont {
					if err = status.Store(); err != nil {
						return
					}
				}
			}
		}
		if cont {
			if err = os.Remove(status.Fp); err != nil {
				if os.IsNotExist(err) {
					err = nil
				} else {
					err = errors.Wrap(err, "")
					return
				}
			}
		}
	}
	xfer_duration := int64(time.Since(xfer_begin).Seconds())
	xfer_bps = xfer_total / xfer_duration
	log.Infof("%s transfer completed(%dB %dB/s %ds)", txmsg, xfer_total, xfer_bps, xfer_duration)
	return
}

func main() {
	if err := transfer(); err != nil {
		log.Fatalf("got error %+v\n", err)
	}
}