- 版本1
# -*- coding: utf-8 -*-
import time
import paramiko
import sys
import re
from tenacity import retry, stop_after_attempt, wait_fixed
import functools
from concurrent import futures
from threading import Lock
import socket
from paramiko.ssh_exception import SSHException
from paramiko.ssh_exception import AuthenticationException
class NetmikoTimeoutException(SSHException):
"""SSH session timed trying to connect to the device."""
pass
class NetmikoAuthenticationException(AuthenticationException):
"""SSH authentication exception based on Paramiko AuthenticationException."""
pass
MAX_BUFFER = 65535
global_delay_factor=1
NetMikoTimeoutException = NetmikoTimeoutException
NetMikoAuthenticationException = NetmikoAuthenticationException
executor = futures.ThreadPoolExecutor(1)
def timeout(timeout):
def decorator(func):
functools.wraps(func)
def wrapper(*args, **kw):
return executor.submit(func, *args, **kw).result(timeout=timeout)
return wrapper
return decorator
class ParaSession(object):
# will init invoke_shell
def __init__(self, hostname, password, port=22, username='root', timeout=60):
self.t = None # paramiko.Transport
self.sftp = None
self._closed = True
self._channel_closed = True
self._sftp_closed = True
self.hostname = hostname
self.password = password
self.port = port
self.username = username
self.timeout = timeout
print('- start to create SSH connection -')
self.client = paramiko.SSHClient()
self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
# noinspection PyBroadException
try:
self.client.connect(hostname=hostname,
port=port,
username=username,
password=password,
timeout=timeout)
self._closed = False
except Exception as e:
print(f'Connection Error And please check the ENV: {str(e)}')
else:
try:
self.channel = self.client.invoke_shell()
self._channel_closed = False
print('- Connection created successfully -')
except Exception as e:
try:
self.close()
except:
pass
print(f'Open channel failed And please check the ENV: {str(e)}')
def sftp_client(self):
"""
:return: sftp obj
"""
try:
self.t = paramiko.Transport((self.hostname, self.port))
self.t.connect(username=self.username, password=self.password)
self._sftp_closed = False
self.sftp = paramiko.SFTPClient.from_transport(self.t)
print('- SFTP created successfully -')
return self.sftp
except Exception as e:
try:
self.close()
except:
pass
print(f'Open sftp failed And please check the ENV: {str(e)}')
def close(self):
if not self._closed:
if not self._channel_closed:
self.channel.close()
self._channel_closed = True
self.client.close()
self._closed = True
if not self._sftp_closed:
self.sftp.close()
self._sftp_closed = True
def __del__(self):
if not self._closed:
self.close()
class BaseSSHSession(object):
def __init__(self, hostname, username, password, su_password, port=22,timeout=100):
self.hostname = hostname
self.password = password
self.username = username
self.su_password = su_password
self.port = port
self.global_delay_factor = global_delay_factor
self.base_prompt = None
self.root_prompt = None
self.timeout = timeout
self.RETURN = "
"
self.ssh_obj = ParaSession(hostname=hostname, username=username, password=password, port=port)
self.remote_conn = self.ssh_obj.channel
self.session_preparation()
self._session_locker = Lock()
# 使用的时候不一定是超时才会到这里,调试慎重
# @retry(stop=stop_after_attempt(2), wait=wait_fixed(5))
# @timeout(6)
def session_preparation(self, delay_factor=1):
EXIT = 'exit'+ self.RETURN
self.set_base_prompt()
# self.set_root_prompt()
# self.write_channel(EXIT)
self.clear_buffer()
# self.exec_ssh_cmd('whoami')
# print(self.base_prompt)
# print(self.root_prompt)
def exec_ssh_cmd(self, cmd, delay_factor=2):
sleep_time = delay_factor * 0.1
time.sleep(sleep_time)
if hasattr(self.ssh_obj, 'channel'):
try:
self.write_channel(cmd + self.RETURN)
except Exception as e:
print(f'CMD send errors: {str(e)}')
return None
time.sleep(delay_factor)
print('Exec ssh command: %s' % str(cmd))
try:
buff = ''
while not (self.base_prompt in buff):
resp = self.read_channel()
buff += resp
except Exception as e:
print(f'CMD receive errors: {str(e)}')
return resp
else:
self.close_all_session()
print(f'session error and please check the ENV.')
return None
def open_sftp_session(self, hostname, password, port=22, username='root'):
ssh_sftp = ParaSession(hostname=hostname, password=password, port=port, username=username).sftp_client()
return ssh_sftp
def get_remotefile(self, remote_path, local_path, session='default'):
if hasattr(self.ssh_obj, 'sftp'):
try:
self.ssh_obj.sftp.get(remote_path, local_path)
return True
except Exception as e:
print(f'Get remote file errors: {str(e)}')
return False
else:
self.close_all_session()
print(f'session: {session} is error and please check the ENV.')
return None
def put_localfile(self, local_path, remote_path, session='default'):
if hasattr(self.ssh_obj, 'sftp'):
try:
self.ssh_obj.sftp.put(local_path, remote_path)
return True
except Exception as e:
print(f'Put local file errors: {str(e)}')
return False
else:
self.close_all_session()
print(f'session: {session} is error and please check the ENV.')
return None
def close_session(self):
self.ssh_obj.close()
def close_all_session(self):
self.close_session()
def __del__(self):
self.close_all_session()
def _read_channel(self):
output = ""
while True:
if self.remote_conn.recv_ready():
outbuf = self.remote_conn.recv(MAX_BUFFER) # 会挂住,需要recv_ready()判断
if len(outbuf) == 0:
raise EOFError("Channel stream closed by remote device.")
output += outbuf.decode("utf-8", "ignore")
else:
break
return output
def normalize_linefeeds(self, a_string):
newline = re.compile("(
|
|
|
)")
a_string = newline.sub("
", a_string)
return re.sub("
", "
", a_string)
def read_channel(self):
return self._read_channel()
def write_bytes(self, out_data, encoding="ascii"):
"""Legacy for Python2 and Python3 compatible byte stream."""
if sys.version_info[0] >= 3:
if isinstance(out_data, type("")):
if encoding == "utf-8":
return out_data.encode("utf-8")
else:
return out_data.encode("ascii", "ignore")
elif isinstance(out_data, type(b"")):
return out_data
msg = "Invalid value for out_data neither unicode nor byte string: {}".format(
out_data
)
raise ValueError(msg)
def _write_channel(self, out_data):
self.remote_conn.sendall(self.write_bytes(out_data))
def write_channel(self, out_data):
self._write_channel(out_data)
def find_prompt(self, delay_factor=1):
# RETURN = "
"
sleep_time = delay_factor * 0.1
time.sleep(sleep_time)
prompt = self.read_channel().strip()
# Check if the only thing you received was a newline
count = 0
while count <= 12 and not prompt:
prompt = self.read_channel().strip()
if not prompt:
self.write_channel(self.RETURN)
time.sleep(sleep_time)
if sleep_time <= 3:
# Double the sleep_time when it is small
sleep_time *= 2
else:
sleep_time += 1
count += 1
# If multiple lines in the output take the last line
prompt = self.normalize_linefeeds(prompt)
prompt = prompt.split("
")[-1]
prompt = prompt.strip()
if not prompt:
raise ValueError(f"Unable to find prompt: {prompt}")
time.sleep(delay_factor * 0.1)
return prompt
def clear_buffer(self, backoff=True, delay_factor=1):
"""Read any data available in the channel."""
sleep_time = 0.1 * delay_factor
for _ in range(10):
time.sleep(sleep_time)
data = self.read_channel()
if not data:
break
if backoff:
sleep_time *= 2
sleep_time = 3 if sleep_time >= 3 else sleep_time
# 待废弃
def root_su(self, password, delay_factor=1):
sleep_time = delay_factor * 0.1
RETURN = "
"
waite_for_password = re.compile("Password:")
prompt = self.read_channel().strip()
count = 0
while count <= 13 and not prompt:
prompt = self.read_channel().strip()
if not prompt:
self.write_channel(RETURN)
time.sleep(sleep_time)
if sleep_time <= 3:
sleep_time *= 2
else:
sleep_time += 1
else:
prompt = self.normalize_linefeeds(prompt)
prompt = prompt.split("
")[-1]
prompt = prompt.strip()
if prompt.endswith('$'):
self.write_channel('su' + RETURN)
time.sleep(sleep_time)
prompt = self.read_channel().strip()
if waite_for_password.search(prompt):
self.write_channel(password + RETURN)
time.sleep(sleep_time)
prompt = self.read_channel().strip()
if prompt.endswith('#'):
return prompt
count += 1
if not prompt:
raise ValueError(f"Unable to find prompt: {prompt}")
def set_base_prompt(self, delay_factor=1, prompt_terminator="$", ):
prompt = self.find_prompt(delay_factor=delay_factor)
if not prompt[-1] in prompt_terminator:
raise ValueError(f"Prompt not found: {repr(prompt)}")
self.base_prompt = prompt[:-1]
return self.base_prompt
# 待废弃
def set_root_prompt(self, delay_factor=1, prompt_terminator="#"):
prompt = self.root_su(self.su_password, delay_factor=delay_factor)
if not prompt[-1] in prompt_terminator:
raise ValueError(f"Router prompt not found: {repr(prompt)}")
self.root_prompt = prompt[:-1]
return self.root_prompt
def check_base_prompt(self, check_sre, prompt_terminator="$"):
return self.base_prompt + prompt_terminator in check_sre
def check_root_prompt(self, check_sre, prompt_terminator="#"):
return self.root_prompt + prompt_terminator in check_sre
def enable(self, cmd="", pattern="ssword", secret="", re_flags=re.IGNORECASE):
output = ""
msg = (
"Failed to enter su mode. Please ensure you pass "
"the 'secret' argument to ConnectHandler."
)
if not self.check_enable_mode():
self.write_channel(self.normalize_cmd(cmd))
try:
output += self.read_until_prompt_or_pattern(
pattern=pattern, re_flags=re_flags
)
self.write_channel(self.normalize_cmd(secret))
# output += self.read_until_prompt(pattern="#")
output += self.read_until_prompt()
except NetmikoTimeoutException:
raise ValueError(msg)
if not self.check_enable_mode():
raise ValueError(msg)
return output
def exit_enable_mode(self, exit_command=""):
output = ""
if self.check_enable_mode():
self.write_channel(self.normalize_cmd(exit_command))
output += self.read_until_prompt()
if self.check_enable_mode():
raise ValueError("Failed to exit enable mode.")
return output
def normalize_cmd(self, command):
command = command.rstrip()
command += self.RETURN
return command
def check_enable_mode(self, check_string=""):
self.write_channel(self.RETURN)
output = self.read_until_prompt()
return check_string in output
def read_until_prompt_or_pattern(self, pattern="", re_flags=0):
combined_pattern = re.escape(self.base_prompt)
if pattern:
combined_pattern = r"({}|{})".format(combined_pattern, pattern)
return self._read_channel_expect(combined_pattern, re_flags=re_flags)
def _read_channel_expect(self, pattern="", re_flags=0, max_loops=150):
output = ""
if not pattern:
# 这里设置问题导致hang,需要重写子类set_base_prompt
pattern = re.escape(self.base_prompt)
# pattern = re.escape('int4-Standard-PC-i440FX-PIIX-1996')
i = 1
loop_delay = 0.1
# Default to making loop time be roughly equivalent to self.timeout
if max_loops == 150:
max_loops = int(self.timeout / loop_delay)
while i < max_loops:
try:
self._lock_netmiko_session()
new_data = self.remote_conn.recv(MAX_BUFFER)
if len(new_data) == 0:
raise EOFError("Channel stream closed by remote device.")
new_data = new_data.decode("utf-8", "ignore")
output += new_data
except socket.timeout:
raise NetmikoTimeoutException(
"Timed-out reading channel, data not available."
)
finally:
self._unlock_netmiko_session()
if re.search(pattern, output, flags=re_flags):
return output
time.sleep(loop_delay * self.global_delay_factor)
i += 1
# print('_read_channel_expect:',i,': ',loop_delay * self.global_delay_factor)
raise NetmikoTimeoutException(
f"Timed-out reading channel, pattern not found in output: {pattern}"
)
def read_until_prompt(self, *args, **kwargs):
return self._read_channel_expect(*args, **kwargs)
def _lock_netmiko_session(self, start=None):
if not start:
start = time.time()
# Wait here until the SSH channel lock is acquired or until session_timeout exceeded
while not self._session_locker.acquire(False) and not self._timeout_exceeded(
start, "The netmiko channel is not available!"
):
time.sleep(0.1)
return True
def _unlock_netmiko_session(self):
if self._session_locker.locked():
self._session_locker.release()
def _timeout_exceeded(self, start, msg="Timeout exceeded!"):
if not start:
# Must provide a comparison time
return False
if time.time() - start > self.session_timeout:
# session_timeout exceeded
raise NetmikoTimeoutException(msg)
return False
class LinuxBaseConnection(BaseSSHSession):
"""Base Class for cisco-like behavior."""
def check_enable_mode(self, check_string="#"):
"""Check if in enable mode. Return boolean."""
return super().check_enable_mode(check_string=check_string)
def enable(self, cmd="su", pattern="ssword", secret="nokia123", re_flags=re.IGNORECASE):
"""Enter enable mode."""
return super().enable(cmd=cmd, pattern=pattern, secret=secret, re_flags=re_flags)
def exit_enable_mode(self, exit_command="disable"):
"""Exits enable (privileged exec) mode."""
return super().exit_enable_mode(exit_command=exit_command)
def set_base_prompt(self):
base_prompt_re = re.compile("@(w.*):")
prompt = super().set_base_prompt()
prompt = base_prompt_re.search(prompt)
self.base_prompt = prompt[1]
return self.base_prompt
if __name__ == '__main__':
host_ip = '10.101.35.249'
user_name = 'int4'
pass_word = 'nokia123'
su_password = pass_word
ssh_session_obj = LinuxBaseConnection(host_ip, user_name, pass_word, su_password)
result_su = ssh_session_obj.exec_ssh_cmd('ifconfig')
print(result_su)
# 方式1,待废弃
# ssh_session_obj.root_su(password=su_password)
# 方式2
ssh_session_obj.enable()
result_su = ssh_session_obj.exec_ssh_cmd('whoami')
print(result_su)
result_su = ssh_session_obj.exec_ssh_cmd('ifconfig')
print(result_su)