sftp support?
yuzhichang opened this issue · comments
Zhichang Yu commented
As far as I know, Lftp is the only one which download a file on SFTP server concurrently. I wish this pget tool support this feature, base on the sftp lib.
Zhichang Yu commented
I've made one by myself - ycp copy a local file from/to a SFTP server with multiple threads.
package main
import (
"bufio"
"flag"
"fmt"
"io"
"io/ioutil"
"os"
"os/user"
"path/filepath"
"strconv"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/pkg/errors"
"github.com/pkg/sftp"
"golang.org/x/crypto/ssh"
)
const (
// There's a cold start cost inside pkg/sftp client Read/Write operation.
// So BUFSIZE shall be large enough in order not to hurt the throughput.
BUFSIZE int = 1 << 21 //2MB
// Parrallel make sence when `file-size` is much larger than `bps-of-one-connection * sftp-establish-time`.
// However it's hard to estimate `bps-of-one-connection` before transfer.
ONE_CONN_LIMIT int = BUFSIZE
)
type Account struct {
Ip string
Port string
Username string
Password string
}
func parseAccounts() (accounts map[string]Account, err error) {
var usr *user.User
if usr, err = user.Current(); err != nil {
err = errors.Wrap(err, "")
return
}
fp := filepath.Join(usr.HomeDir, ".accounts.txt")
var file *os.File
if file, err = os.Open(fp); err != nil {
err = errors.Wrap(err, "")
return
}
defer file.Close()
scanner := bufio.NewScanner(file)
accounts = make(map[string]Account)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || line[0] == '#' {
continue
}
//ip, hostnames, port, username, password, note
delm := func(c rune) bool {
return (c == '\t')
}
fields := strings.FieldsFunc(line, delm)
if len(fields) < 5 {
err = errors.Errorf("invalid line: %s", line)
return
}
ip, hostnames, port, username, password := strings.TrimSpace(fields[0]), strings.TrimSpace(fields[1]), strings.TrimSpace(fields[2]), strings.TrimSpace(fields[3]), strings.TrimSpace(fields[4])
account := Account{Ip: ip, Port: port, Username: username, Password: password}
accounts[ip] = account
fields2 := strings.Split(hostnames, ",")
for _, hostname := range fields2 {
hostname = strings.TrimSpace(hostname)
accounts[hostname] = account
}
}
return
}
func getAccount(host string) (account Account, err error) {
var accounts map[string]Account
if accounts, err = parseAccounts(); err != nil {
return
}
var found bool
if account, found = accounts[host]; !found {
err = errors.Errorf("host %s is not in the accounts database!", host)
return
}
return
}
func min(a, b int64) int64 {
if a > b {
return b
}
return a
}
type WorkerResult struct {
Seq int
Err error
}
type Worker struct {
seq int
ch chan<- WorkerResult
acc Account
rmtfp string
lclfp string
put bool
buf []byte
c *sftp.Client
rmt *sftp.File
lcl *os.File
lock sync.Mutex //protect begin, end
begin int64
end int64
}
func NewWorker(seq int, ch chan<- WorkerResult, account Account, rmtfp, lclfp string, put bool) (w *Worker) {
return &Worker{
seq: seq,
ch: ch,
acc: account,
rmtfp: rmtfp,
lclfp: lclfp,
put: put,
buf: make([]byte, BUFSIZE),
}
}
func (w *Worker) Share(c *sftp.Client, rmt *sftp.File, lcl *os.File) {
w.c = c
w.rmt = rmt
w.lcl = lcl
}
func (w *Worker) Close() {
if w.c != nil {
w.c.Close()
}
if w.rmt != nil {
w.rmt.Close()
}
if w.lcl != nil {
w.lcl.Close()
}
w.c = nil
w.rmt = nil
w.lcl = nil
}
func (w *Worker) Stat() (begin, end int64) {
w.lock.Lock()
begin = w.begin
end = w.end
w.lock.Unlock()
return
}
func (w *Worker) Start(begin, end int64) {
w.lock.Lock()
w.begin = begin
w.end = end
w.lock.Unlock()
go func() {
log.Infof("worker %d started, begin: %d, end: %d", w.seq, begin, end)
var err error
res := WorkerResult{Seq: w.seq, Err: nil}
if w.c == nil {
if w.c, w.rmt, w.lcl, err = sftpConnect(w.acc, w.rmtfp, w.lclfp, w.put); err != nil {
w.ch <- res
return
}
log.Infof("worker %d established sftp connection with %s", w.seq, w.acc.Ip)
}
if _, err = w.rmt.Seek(begin, io.SeekStart); err != nil {
res.Err = errors.Wrap(err, "")
w.ch <- res
return
}
if _, err = w.lcl.Seek(begin, io.SeekStart); err != nil {
res.Err = errors.Wrap(err, "")
w.ch <- res
return
}
var n int
begin, end = w.Stat()
toread := end - begin
var reader io.Reader
var writer io.Writer
if w.put {
reader = w.lcl
writer = w.rmt
} else {
reader = w.rmt
writer = w.lcl
}
for toread > 0 {
if n, err = reader.Read(w.buf[:min(toread, int64(BUFSIZE))]); n > 0 {
log.Debugf("worker %d got %d bytes", w.seq, n)
if _, err = writer.Write(w.buf[:n]); err != nil {
res.Err = errors.Wrap(err, "")
w.ch <- res
return
}
w.lock.Lock()
begin += int64(n)
w.begin = begin
end = w.end
w.lock.Unlock()
toread = end - begin
} else if err != nil {
res.Err = errors.Wrap(err, "")
w.ch <- res
return
}
}
log.Infof("worker %d done", w.seq)
w.ch <- res
return
}()
}
func (w *Worker) ChangeEnd(end int64) (err error) {
w.lock.Lock()
if end > w.end {
err = errors.Errorf("worker %d change end from %d to %d", w.seq, w.end, end)
} else {
log.Infof("worker %d change end from %d to %d", w.seq, w.end, end)
w.end = end
}
w.lock.Unlock()
return
}
type Status struct {
Fp string
MaxConn int
Begins []int64
Ends []int64
}
func NewStatus(rmtfp, lclfp string, put bool) (status *Status) {
status = &Status{}
if put {
usr, err := user.Current()
if err != nil {
log.Fatalf("got error: %+v", err)
}
status.Fp = filepath.Join(usr.HomeDir, filepath.Base(rmtfp)+".ycp-put-status")
} else {
status.Fp = lclfp + ".ycp-get-status"
}
return
}
func (status *Status) Load() (err error) {
var file *os.File
if file, err = os.Open(status.Fp); err != nil {
err = errors.Wrapf(err, "")
return
}
defer file.Close()
scanner := bufio.NewScanner(file)
var begin, end int64
for scanner.Scan() {
fields := strings.Split(scanner.Text(), "\t")
if len(fields) != 2 {
err = errors.Errorf("invalid line %s", scanner.Text())
return
}
if begin, err = strconv.ParseInt(strings.TrimSpace(fields[0]), 10, 64); err != nil {
err = errors.Errorf("invalid line %s", scanner.Text())
return
}
if end, err = strconv.ParseInt(strings.TrimSpace(fields[1]), 10, 64); err != nil {
err = errors.Errorf("invalid line %s", scanner.Text())
return
}
status.Begins = append(status.Begins, begin)
status.Ends = append(status.Ends, end)
}
status.MaxConn = len(status.Begins)
return
}
func (status *Status) Store() (err error) {
var file *os.File
flag := os.O_WRONLY | os.O_CREATE | os.O_TRUNC
if file, err = os.OpenFile(status.Fp, flag, 0600); err != nil {
err = errors.Wrap(err, "")
return
}
defer file.Close()
for i := 0; i < status.MaxConn; i++ {
if _, err = file.WriteString(fmt.Sprintf("%d\t%d\n", status.Begins[i], status.Ends[i])); err != nil {
err = errors.Wrap(err, "")
return
}
}
return
}
func publicKey(path string) ssh.AuthMethod {
key, err := ioutil.ReadFile(path)
if err != nil {
panic(err)
}
signer, err := ssh.ParsePrivateKey(key)
if err != nil {
panic(err)
}
return ssh.PublicKeys(signer)
}
func sftpConnect(account Account, rmtfp, lclfp string, put bool) (c *sftp.Client, rmt *sftp.File, lcl *os.File, err error) {
var auths []ssh.AuthMethod
if account.Password != "publickey" {
auths = append(auths, ssh.Password(account.Password))
} else {
usr, err := user.Current()
if err != nil {
log.Fatal(err)
}
auths = append(auths, publicKey(filepath.Join(usr.HomeDir, ".ssh/id_rsa")))
}
config := ssh.ClientConfig{
User: account.Username,
Auth: auths,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
addr := fmt.Sprintf("%s:%s", account.Ip, account.Port)
var conn *ssh.Client
if conn, err = ssh.Dial("tcp", addr, &config); err != nil {
err = errors.Wrap(err, "")
return
}
if c, err = sftp.NewClient(conn); err != nil {
err = errors.Wrap(err, "")
return
}
flag := os.O_WRONLY | os.O_CREATE
if put {
if rmt, err = c.OpenFile(rmtfp, flag); err != nil {
err = errors.Wrap(err, "")
return
}
if lcl, err = os.Open(lclfp); err != nil {
err = errors.Wrap(err, "")
return
}
} else {
if rmt, err = c.Open(rmtfp); err != nil {
err = errors.Wrap(err, "")
return
}
if lcl, err = os.OpenFile(lclfp, flag, 0600); err != nil {
err = errors.Wrap(err, "")
return
}
}
return
}
func parseUrls(arg1, arg2 string) (host, rmtfp, lclfp string, put bool, err error) {
fields1 := strings.Split(arg1, ":")
fields2 := strings.Split(arg2, ":")
if len(fields1) == 1 && len(fields2) == 2 {
put = true
host = fields2[0]
rmtfp = fields2[1]
lclfp = arg1
base := filepath.Base(rmtfp)
baseCh := base[len(base)-1]
if baseCh == '.' || baseCh == '/' {
rmtfp = filepath.Join(rmtfp, filepath.Base(lclfp))
}
} else if len(fields1) == 2 && len(fields2) == 1 {
put = false
host = fields1[0]
rmtfp = fields1[1]
lclfp = arg2
var fi os.FileInfo
if fi, err = os.Stat(lclfp); err != nil {
err = errors.Wrap(err, "")
return
}
if fi.IsDir() {
lclfp = filepath.Join(lclfp, filepath.Base(rmtfp))
}
} else {
err = errors.Errorf("one url shall be remote, and the other one shall be local")
return
}
return
}
func transfer() (err error) {
var verbose bool
var test bool
var maxconn int
var txmsg, srcurl string
flag.BoolVar(&verbose, "v", false, "verbose mode")
flag.BoolVar(&test, "t", false, "test mode, no continue")
flag.IntVar(&maxconn, "n", 10, "set maximum number of connections")
rationale := `This util copy a local file from/to a SFTP server with multiple threads. It's inspired by and a bit faster than lftp's pget under the same setup(maxconn=10). Why I invent another pget-like tool?
- lftp pget parallel("-n") is not good. Threads perform variously. Some done their job much earlier then others without helping them.
- lftp pget continue("-c") is buggy. Rerunning "pget -c" often begins a fresh download.
- lftp doesn't support upload a file parallelly.
- lftp is 20+ years old C++ code.`
flag.Usage = func() {
fmt.Fprintf(os.Stderr, "%s\n%s [-h] [-v] [-c] [-n maxconn] [host:]file1 [host:]file2\nThis util assumes $HOME/.accounts.txt stores accounts.\n", rationale, os.Args[0])
flag.PrintDefaults()
}
flag.Parse()
if verbose {
log.SetLevel(log.DebugLevel)
}
if flag.NArg() != 2 {
return errors.Errorf("got %d arguments while expect two", flag.NArg())
}
var host, rmtfp, lclfp string
var put bool
if host, rmtfp, lclfp, put, err = parseUrls(flag.Args()[0], flag.Args()[1]); err != nil {
return
}
if put {
srcurl = lclfp
txmsg = fmt.Sprintf("%s -> %s:%s", lclfp, host, rmtfp)
} else {
srcurl = fmt.Sprintf("%s:%s", host, rmtfp)
txmsg = fmt.Sprintf("%s:%s -> %s", host, rmtfp, lclfp)
}
var account Account
if account, err = getAccount(host); err != nil {
return
}
var xfer_total, xfer_remained, xfer_remained_prev, xfer_bps int64
xfer_begin := time.Now()
cont := !test
if test && maxconn <= 1 {
var c *sftp.Client
var rmt *sftp.File
var lcl *os.File
if c, rmt, lcl, err = sftpConnect(account, rmtfp, lclfp, put); err != nil {
return
}
log.Infof("established sftp connection with %s", account.Ip)
defer c.Close()
defer rmt.Close()
defer lcl.Close()
if put {
xfer_total, err = io.Copy(lcl, rmt)
} else {
xfer_total, err = io.Copy(rmt, lcl)
}
if err != nil {
err = errors.Wrap(err, "")
return
}
} else {
status := NewStatus(rmtfp, lclfp, put)
if cont {
if err = status.Load(); err != nil {
if os.IsNotExist(errors.Cause(err)) {
err = nil
} else {
return
}
}
log.Infof("status file: %s", status.Fp)
}
var workers []*Worker
ch := make(chan WorkerResult)
if cont && status.MaxConn > 0 {
maxconn = status.MaxConn
workers = make([]*Worker, maxconn)
for i := 0; i < maxconn; i++ {
workers[i] = NewWorker(i, ch, account, rmtfp, lclfp, put)
workers[i].Start(status.Begins[i], status.Ends[i])
xfer_total += status.Ends[i] - status.Begins[i]
}
} else {
var c *sftp.Client
var rmt *sftp.File
var lcl *os.File
var fi os.FileInfo
if c, rmt, lcl, err = sftpConnect(account, rmtfp, lclfp, put); err != nil {
return
}
log.Infof("worker %d established sftp connection with %s", maxconn-1, account.Ip)
if put {
fi, err = lcl.Stat()
} else {
fi, err = rmt.Stat()
}
if err != nil {
err = errors.Wrap(err, "")
return
}
fileSize := fi.Size()
xfer_total = fileSize
msg := fmt.Sprintf("file size of %s: %d", srcurl, fileSize)
if fileSize <= int64(ONE_CONN_LIMIT) {
maxconn = 1
msg += ", will transfer with one thread"
}
log.Infof(msg)
if cont {
status.MaxConn = maxconn
status.Begins = make([]int64, maxconn)
status.Ends = make([]int64, maxconn)
} else {
if put {
if err = c.Remove(rmtfp); err != nil {
err = errors.Wrap(err, "")
return
}
} else {
if err = os.Remove(lclfp); err != nil {
err = errors.Wrap(err, "")
return
}
}
}
chunkSize := fileSize / int64(maxconn)
workers = make([]*Worker, maxconn)
for i := 0; i < maxconn-1; i++ {
workers[i] = NewWorker(i, ch, account, rmtfp, lclfp, put)
begin := chunkSize * int64(i)
end := chunkSize * int64((i + 1))
workers[i].Start(begin, end)
}
last_worker := NewWorker(maxconn-1, ch, account, rmtfp, lclfp, put)
workers[maxconn-1] = last_worker
last_worker.Share(c, rmt, lcl)
begin := chunkSize * int64(maxconn-1)
end := fileSize
last_worker.Start(begin, end)
}
defer func() {
for i := 0; i < maxconn; i++ {
workers[i].Close()
}
}()
xfer_remained_prev = xfer_total
interval := time.Duration(10) * time.Second
ticker := time.NewTicker(interval)
defer ticker.Stop()
running := maxconn
for running != 0 {
select {
case res := <-ch:
if res.Err != nil {
return res.Err
}
running--
var lastSeq int
var lastBegin, lastEnd int64
for i := 0; i < maxconn; i++ {
if i == res.Seq {
continue
}
begin, end := workers[i].Stat()
if begin >= end {
continue
}
log.Infof("[%d] worker %d remained %d bytes", res.Seq, i, end-begin)
if end-begin > int64(BUFSIZE) && end-begin > lastEnd-lastBegin {
lastSeq = i
lastBegin = begin
lastEnd = end
}
}
if lastBegin < lastEnd {
split := (lastEnd + lastBegin + int64(BUFSIZE)) / int64(2)
if err = workers[lastSeq].ChangeEnd(split); err != nil {
return
}
workers[res.Seq].Start(split, lastEnd)
running++
}
case <-ticker.C:
xfer_remained = 0
for i := 0; i < maxconn; i++ {
begin, end := workers[i].Stat()
xfer_remained += end - begin
if cont {
status.Begins[i] = begin
status.Ends[i] = end
}
}
xfer_percent := 100 - xfer_remained*100/xfer_total
xfer_bps = (xfer_remained_prev-xfer_remained)/int64(2*interval.Seconds()) + xfer_bps/2
xfer_eta := "--"
if xfer_bps != 0 {
xfer_eta = fmt.Sprintf("%d", xfer_remained/xfer_bps)
}
xfer_remained_prev = xfer_remained
log.Infof("%d%% %dB %dB/s %ss ETA", xfer_percent, xfer_total-xfer_remained, xfer_bps, xfer_eta)
if cont {
if err = status.Store(); err != nil {
return
}
}
}
}
if cont {
if err = os.Remove(status.Fp); err != nil {
if os.IsNotExist(err) {
err = nil
} else {
err = errors.Wrap(err, "")
return
}
}
}
}
xfer_duration := int64(time.Since(xfer_begin).Seconds())
xfer_bps = xfer_total / xfer_duration
log.Infof("%s transfer completed(%dB %dB/s %ds)", txmsg, xfer_total, xfer_bps, xfer_duration)
return
}
func main() {
if err := transfer(); err != nil {
log.Fatalf("got error %+v\n", err)
}
}