-
Notifications
You must be signed in to change notification settings - Fork 34
/
Copy pathsftp.py
172 lines (143 loc) · 5.73 KB
/
sftp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import logging
import os
import posixpath
import pysftp
import tempfile
import warnings
from io import BytesIO
from paramiko import SSHException
from assemblyline.common.exceptions import ChainAll
from assemblyline.common.uid import get_random_id
from assemblyline.filestore.transport.base import Transport, TransportException, normalize_srl_path
def reconnect_retry_on_fail(func):
def new_func(self, *args, **kwargs):
with warnings.catch_warnings():
warnings.simplefilter("ignore")
if not self.validate_host:
cnopts = pysftp.CnOpts()
cnopts.hostkeys = None
else:
cnopts = None
try:
if not self.sftp:
self.sftp = pysftp.Connection(self.host,
username=self.user,
password=self.password,
private_key=self.private_key,
private_key_pass=self.private_key_pass,
cnopts=cnopts)
return func(self, *args, **kwargs)
except SSHException:
pass
# The previous attempt at calling original func failed.
# Reset the connection and try again (one time).
if self.sftp:
self.sftp.close() # Just best effort.
# The original func will reconnect automatically.
self.sftp = pysftp.Connection(self.host,
username=self.user,
password=self.password,
private_key=self.private_key,
private_key_pass=self.private_key_pass,
cnopts=cnopts)
return func(self, *args, **kwargs)
new_func.__name__ = func.__name__
new_func.__doc__ = func.__doc__
return new_func
@ChainAll(TransportException)
class TransportSFTP(Transport):
"""
SFTP Transport class.
"""
def __init__(self, base=None, host=None, password=None, user=None, private_key=None, private_key_pass=None,
validate_host=False):
self.log = logging.getLogger('assemblyline.transport.sftp')
self.base = base
self.sftp = None
self.host = host
self.password = password
self.user = user
self.private_key = private_key
self.private_key_pass = private_key_pass
self.validate_host = validate_host
def sftp_normalize(path):
# If they've provided an absolute path. Leave it a is.
if path.startswith('/'):
s = path
# Relative paths
elif '/' in path or len(path) != 64:
s = posixpath.join(self.base, path)
else:
s = posixpath.join(self.base, normalize_srl_path(path))
self.log.debug('sftp normalized: %s -> %s', path, s)
return s
super(TransportSFTP, self).__init__(normalize=sftp_normalize)
def __str__(self):
return 'sftp://{}@{}{}'.format(self.user, self.host, self.base)
def close(self):
if self.sftp:
self.sftp.close()
@reconnect_retry_on_fail
def delete(self, path):
path = self.normalize(path)
self.sftp.remove(path)
@reconnect_retry_on_fail
def exists(self, path):
path = self.normalize(path)
return self.sftp.exists(path)
@reconnect_retry_on_fail
def makedirs(self, path):
path = self.normalize(path)
self.sftp.makedirs(path)
# File based functions
@reconnect_retry_on_fail
def download(self, src_path, dst_path):
dir_path = os.path.dirname(dst_path)
if not os.path.exists(dir_path):
os.makedirs(dir_path)
src_path = self.normalize(src_path)
self.sftp.get(src_path, dst_path)
@reconnect_retry_on_fail
def upload(self, src_path, dst_path):
dst_path = self.normalize(dst_path)
dirname = posixpath.dirname(dst_path)
filename = posixpath.basename(dst_path)
tempname = get_random_id()
temppath = posixpath.join(dirname, tempname)
finalpath = posixpath.join(dirname, filename)
assert (finalpath == dst_path)
self.makedirs(dirname)
self.sftp.put(src_path, temppath)
self.sftp.rename(temppath, finalpath)
assert (self.exists(dst_path))
@reconnect_retry_on_fail
def upload_batch(self, local_remote_tuples):
return super(TransportSFTP, self).upload_batch(local_remote_tuples)
# Buffer based functions
@reconnect_retry_on_fail
def get(self, path: str) -> bytes:
path = self.normalize(path)
bio = BytesIO()
with self.sftp.open(path) as sftp_handle:
bio.write(sftp_handle.read())
return bio.getvalue()
@reconnect_retry_on_fail
def put(self, dst_path, content):
dst_path = self.normalize(dst_path)
dirname = posixpath.dirname(dst_path)
filename = posixpath.basename(dst_path)
tempname = get_random_id()
temppath = posixpath.join(dirname, tempname)
finalpath = posixpath.join(dirname, filename)
assert (finalpath == dst_path)
# Write content to a tempfile
fd, src_path = tempfile.mkstemp(prefix="filestore.local_path")
with open(fd, "wb") as f:
f.write(content)
# Upload the tempfile
self.makedirs(dirname)
self.sftp.put(src_path, temppath)
self.sftp.rename(temppath, finalpath)
assert (self.exists(dst_path))
# Cleanup
os.unlink(src_path)