import paramiko
from threading import Thread, Event
import time
import os
import stat
import logging
import socket
# 配置日志记录
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger('SSHManager')
class SSHManager:
def __init__(self, server_info):
self.server_info = server_info
self.ssh = None
self.transport = None
self.sftp = None
self.heartbeat_active = Event()
self.heartbeat_thread = None
self.file_transfer_timeout = 300 # 文件传输超时时间(秒)
self.chunk_size = 1024 * 1024 # 1MB 块大小
self.cancel_flag = False # 添加取消标志属性
def kill_processes_by_name(self, process_name, signal='-9'):
"""
通过进程名杀死所有匹配的进程
Args:
process_name (str): 进程名称
signal (str): 发送的信号,默认为-9(强制终止)
Returns:
tuple: (success, message)
"""
try:
# 使用pkill命令杀死进程
cmd = f"pkill {signal} -f '{process_name}'"
success, exit_code, stdout, stderr = self.execute_command(cmd)
if not success:
# pkill命令在没有找到匹配进程时会返回非零退出码,这在我们的场景中是正常的
if "No matching processes" in stderr or exit_code == 1:
return True, "没有找到需要终止的进程"
else:
return False, f"执行pkill命令失败: {stderr}"
return True, f"成功终止匹配的进程: {process_name}"
except Exception as e:
logger.error(f"终止进程失败: {str(e)}")
return False, f"终止进程失败: {str(e)}"
def set_cancel_flag(self, cancel=True):
"""设置取消标志"""
self.cancel_flag = cancel
def connect(self):
try:
# 如果有跳板机,先连接跳板机
if self.server_info.get('jump_ip'):
jump_ssh = paramiko.SSHClient()
jump_ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
# 设置更短的超时时间
jump_ssh.connect(
self.server_info['jump_ip'],
port=int(self.server_info.get('jump_port', 22)),
username=self.server_info.get('jump_name', 'root'),
password=self.server_info.get('jump_password', ''),
timeout=5, # 从10秒减少到5秒
banner_timeout=30 # 增加banner超时
)
self.transport = jump_ssh.get_transport()
dest_addr = (self.server_info['ip'], int(self.server_info.get('port', 22)))
local_addr = ('127.0.0.1', 22)
channel = self.transport.open_channel("direct-tcpip", dest_addr, local_addr)
self.ssh = paramiko.SSHClient()
self.ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
self.ssh.connect(
'127.0.0.1',
port=22,
username=self.server_info.get('name', 'root'),
password=self.server_info.get('password', ''),
sock=channel,
timeout=5, # 从10秒减少到5秒
banner_timeout=30 # 增加banner超时
)
else:
self.ssh = paramiko.SSHClient()
self.ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
self.ssh.connect(
self.server_info['ip'],
port=int(self.server_info.get('port', 22)),
username=self.server_info.get('name', 'root'),
password=self.server_info.get('password', ''),
timeout=5, # 从10秒减少到5秒
banner_timeout=30 # 增加banner超时
)
return True, "连接成功"
except Exception as e:
logger.error(f"连接失败: {str(e)}")
return False, str(e)
def disconnect(self):
self.stop_heartbeat()
if self.sftp:
try:
self.sftp.close()
logger.debug("SFTP连接已关闭")
except Exception as e:
logger.warning(f"关闭SFTP连接时出错: {str(e)}")
if self.ssh:
try:
self.ssh.close()
logger.debug("SSH连接已关闭")
except Exception as e:
logger.warning(f"关闭SSH连接时出错: {str(e)}")
if self.transport:
try:
self.transport.close()
logger.debug("传输通道已关闭")
except Exception as e:
logger.warning(f"关闭传输通道时出错: {str(e)}")
def reconnect(self):
"""专门的重新连接方法"""
try:
logger.info("尝试重新连接...")
# 先完全断开现有连接
self.disconnect()
time.sleep(1)
# 重新连接
success, message = self.connect()
if not success:
return False, message
# 重新打开SFTP连接
self.sftp = self.ssh.open_sftp()
self.sftp.get_channel().settimeout(30)
# 重新启动心跳
self.start_heartbeat()
logger.info("重连成功")
return True, "重连成功"
except Exception as e:
logger.error(f"重连失败: {str(e)}")
return False, f"重连失败: {str(e)}"
def start_heartbeat(self):
"""启动心跳检测(确保只启动一次)"""
if not self.heartbeat_active.is_set():
self.heartbeat_active.set()
# 如果已有线程在运行,先停止它
if self.heartbeat_thread and self.heartbeat_thread.is_alive():
self.stop_heartbeat()
self.heartbeat_thread = Thread(target=self._heartbeat_worker)
self.heartbeat_thread.daemon = True
self.heartbeat_thread.start()
logger.info("心跳检测已启动")
def _heartbeat_worker(self):
logger.debug("心跳线程开始运行")
heartbeat_interval = 60 # 改为60秒
while self.heartbeat_active.is_set():
if not self.heartbeat_active.wait(heartbeat_interval):
break
try:
# 使用事件等待替代固定睡眠,可及时响应停止信号
if not self.heartbeat_active.wait(15): # 从20秒减少到15秒
break
# 发送心跳检测命令 - 使用更可靠的方式
if self.ssh and self.ssh.get_transport() and self.ssh.get_transport().is_active():
# 创建新的通道执行命令
transport = self.ssh.get_transport()
channel = transport.open_session()
channel.exec_command("echo 'heartbeat'")
# 设置超时
channel.settimeout(10)
# 读取输出
output = b''
while not channel.exit_status_ready():
if channel.recv_ready():
output += channel.recv(1024)
# 检查退出状态
exit_status = channel.recv_exit_status()
if exit_status == 0:
logger.debug("心跳检测成功")
else:
logger.warning(f"心跳检测失败,退出状态: {exit_status}")
self.heartbeat_active.clear()
break
else:
logger.warning("心跳检测失败: SSH连接不可用")
self.heartbeat_active.clear()
break
except Exception as e:
logger.warning(f"心跳检测失败: {str(e)}")
self.heartbeat_active.clear()
break
logger.debug("心跳线程已停止")
def stop_heartbeat(self):
if self.heartbeat_active.is_set():
self.heartbeat_active.clear()
logger.info("正在停止心跳检测...")
if self.heartbeat_thread and self.heartbeat_thread.is_alive():
self.heartbeat_thread.join(timeout=2.0)
if self.heartbeat_thread.is_alive():
logger.warning("心跳线程未能正常终止")
else:
logger.info("心跳检测已停止")
else:
logger.info("心跳检测未运行")
def download_file(self, remote_path, local_path=None, overwrite=False):
"""
从远程服务器下载文件
:param remote_path: 远程文件路径
:param local_path: 本地保存路径(默认与远程文件名相同)
:param overwrite: 是否覆盖已存在文件
:return: (成功标志, 消息)
"""
if not self.ssh:
return False, "未建立SSH连接"
try:
if not self.sftp:
self.sftp = self.ssh.open_sftp()
self.sftp.get_channel().settimeout(self.file_transfer_timeout)
if not local_path:
local_path = os.path.basename(remote_path)
if os.path.exists(local_path) and not overwrite:
return False, f"本地文件 {local_path} 已存在"
# 获取远程文件大小用于进度显示
file_size = self.sftp.stat(remote_path).st_size
logger.info(f"开始下载文件: {remote_path} ({file_size} 字节) -> {local_path}")
# 使用分块下载
with self.sftp.open(remote_path, 'rb') as remote_file:
with open(local_path, 'wb') as local_file:
downloaded = 0
start_time = time.time()
while True:
chunk = remote_file.read(self.chunk_size)
if not chunk:
break
local_file.write(chunk)
downloaded += len(chunk)
# 每秒更新一次进度
if time.time() - start_time > 1:
percent = (downloaded / file_size) * 100
logger.info(f"下载进度: {percent:.1f}% ({downloaded}/{file_size} 字节)")
start_time = time.time()
logger.info(f"文件下载成功: {remote_path} -> {local_path}")
return True, f"文件下载成功: {remote_path} -> {local_path}"
except FileNotFoundError:
return False, f"远程文件不存在: {remote_path}"
except Exception as e:
logger.error(f"下载文件失败: {str(e)}")
return False, f"下载文件失败: {str(e)}"
def read_file(self, remote_path, encoding='utf-8', retries=2):
"""
增强版的读取文件方法,支持重试
:param remote_path: 远程文件路径
:param encoding: 文件编码
:param retries: 重试次数
:return: (成功标志, 文件内容或错误消息)
"""
if not self.ssh:
return False, "未建立SSH连接"
for attempt in range(retries + 1):
try:
# 检查连接状态
if not self.ssh.get_transport() or not self.ssh.get_transport().is_active():
logger.warning(f"SSH连接不活跃,尝试重连 (尝试 {attempt+1}/{retries+1})")
self.reconnect()
if not self.sftp:
self.sftp = self.ssh.open_sftp()
# 设置文件传输超时
self.sftp.get_channel().settimeout(30)
with self.sftp.file(remote_path, 'r') as f:
content = f.read().decode(encoding)
return True, content
except FileNotFoundError:
return False, f"远程文件不存在: {remote_path}"
except Exception as e:
if attempt < retries:
logger.warning(f"读取文件失败,重试中 ({attempt+1}/{retries}): {str(e)}")
time.sleep(1)
# 重置SFTP连接
try:
if self.sftp:
self.sftp.close()
except:
pass
self.sftp = None
else:
return False, f"读取文件失败: {str(e)}"
return False, "未知错误"
def upload_file(self, local_path, remote_path=None, overwrite=False, progress_callback=None):
"""优化后的上传方法,增强大文件上传能力"""
max_retries = 5
for attempt in range(max_retries):
try:
# 重置取消标志
self.cancel_flag = False
if not self.ssh:
return False, "未建立SSH连接"
# 处理本地路径
expanded_local_path = os.path.expanduser(local_path)
if not os.path.exists(expanded_local_path):
return False, f"本地文件不存在: {local_path} -> {expanded_local_path}"
if os.path.isdir(expanded_local_path):
return False, f"不支持上传目录: {local_path}"
# 获取本地文件大小
file_size = os.path.getsize(expanded_local_path)
logger.info(f"开始上传文件: {expanded_local_path} ({file_size} 字节)")
# 确定远程路径
if remote_path is None:
remote_path = os.path.basename(local_path)
else:
if remote_path.endswith('/'):
remote_path = remote_path.rstrip('/') + '/'
remote_path += os.path.basename(local_path)
# 设置超时时间
sftp_timeout = max(120, min(600, file_size // (1024 * 1024)))
# 初始化SFTP连接
if not self.sftp:
self.sftp = self.ssh.open_sftp()
self.sftp.get_channel().settimeout(sftp_timeout)
logger.info(f"设置SFTP超时为 {sftp_timeout} 秒")
# 创建远程目录
remote_dir = os.path.dirname(remote_path)
if remote_dir:
self._create_remote_directory(remote_dir)
# 获取规范化路径
remote_path = self.sftp.normalize(remote_path)
logger.info(f"准备上传文件: {expanded_local_path} -> {remote_path}")
# 检查远程路径类型
try:
remote_attr = self.sftp.stat(remote_path)
if stat.S_ISDIR(remote_attr.st_mode):
if not remote_path.endswith('/'):
remote_path += '/'
remote_path += os.path.basename(local_path)
logger.info(f"目标为目录,自动修正路径为: {remote_path}")
remote_attr = self.sftp.stat(remote_path)
if stat.S_ISDIR(remote_attr.st_mode):
return False, f"修正后的路径仍是目录: {remote_path}"
if stat.S_ISREG(remote_attr.st_mode) and not overwrite:
return False, f"远程文件已存在: {remote_path}"
except FileNotFoundError:
pass
# 使用分块上传
uploaded = 0
start_time = time.time()
last_callback_time = time.time()
last_activity_time = time.time()
chunk_retries = 0
max_chunk_retries = 10
with open(expanded_local_path, 'rb') as local_file:
remote_file = self.sftp.open(remote_path, 'wb')
remote_file.set_pipelined(True)
while not self.cancel_flag:
# 检查是否超时
if time.time() - last_activity_time > 60:
logger.warning("上传操作超时,60秒内无进展,尝试重连")
current_position = local_file.tell()
try:
remote_file.close()
except:
pass
reconnect_success, reconnect_msg = self.reconnect()
if not reconnect_success:
raise Exception(f"重连失败: {reconnect_msg}")
self.sftp.get_channel().settimeout(sftp_timeout)
remote_file = self.sftp.open(remote_path, 'r+b')
remote_file.seek(uploaded)
local_file.seek(current_position)
last_activity_time = time.time()
logger.info(f"重连成功,从 {uploaded} 字节继续上传")
chunk = local_file.read(self.chunk_size)
if not chunk:
break
# 尝试写入当前块
chunk_written = False
for retry in range(3):
try:
write_timeout = max(30, len(chunk) // (1024 * 10))
self.sftp.get_channel().settimeout(write_timeout)
remote_file.write(chunk)
chunk_written = True
break
except (socket.timeout, paramiko.SSHException, OSError) as e:
if self.cancel_flag:
break
logger.warning(f"写入操作失败 (尝试 {retry + 1}/3): {str(e)}")
if "closed" in str(e).lower() or "broken" in str(e).lower():
logger.info("检测到连接断开,尝试重连")
try:
self.reconnect()
remote_file.close()
remote_file = self.sftp.open(remote_path, 'r+b')
remote_file.seek(uploaded)
local_file.seek(uploaded)
except Exception as reconnect_err:
logger.error(f"重连失败: {str(reconnect_err)}")
time.sleep(2 ** retry)
if not chunk_written:
raise Exception(f"写入操作超时,重试3次失败")
# 更新上传进度
uploaded += len(chunk)
last_activity_time = time.time()
chunk_retries = 0
# 恢复默认超时
self.sftp.get_channel().settimeout(sftp_timeout)
# 更新进度回调 - 确保每次写入后都更新
current_time = time.time()
if progress_callback:
percent = (uploaded / file_size) * 100
try:
progress_callback(percent)
# 记录最后一次回调时间
last_callback_time = current_time
except:
logger.error("进度回调失败")
# 确保最后进度为100%
if progress_callback and uploaded == file_size:
try:
progress_callback(100)
except:
logger.error("最终进度回调失败")
# 关闭远程文件
remote_file.close()
if self.cancel_flag:
try:
self.sftp.remove(remote_path)
except:
pass
return False, "上传已取消"
# 验证文件大小
remote_size = self.sftp.stat(remote_path).st_size
if remote_size != file_size:
logger.error(f"文件大小验证失败: 本地 {file_size} 字节, 远程 {remote_size} 字节")
if attempt < max_retries - 1:
logger.info("文件大小不匹配,尝试重新上传")
continue
return False, f"文件大小验证失败: 本地 {file_size} 字节, 远程 {remote_size} 字节"
logger.info(f"文件上传成功: {expanded_local_path} -> {remote_path}")
return True, f"文件上传成功: {expanded_local_path} -> {remote_path}"
except Exception as e:
logger.error(f"上传文件失败 (尝试 {attempt + 1}/{max_retries}): {str(e)}", exc_info=True)
if "closed" in str(e).lower() or "broken" in str(e).lower() or "timeout" in str(e).lower():
logger.info("检测到连接问题,尝试重新连接")
self.reconnect()
wait_time = 2 ** attempt
time.sleep(wait_time)
if attempt == max_retries - 1:
return False, f"上传文件失败: {str(e)}"
return False, "上传失败,超过最大尝试次数"
def _create_remote_directory(self, remote_dir):
"""递归创建远程目录(修复路径创建问题)"""
if not remote_dir or remote_dir == '/':
return
try:
# 检查目录是否已存在
self.sftp.stat(remote_dir)
logger.debug(f"远程目录已存在: {remote_dir}")
return
except FileNotFoundError:
pass
# 递归创建父目录
parent_dir = os.path.dirname(remote_dir)
if parent_dir and parent_dir != '/':
self._create_remote_directory(parent_dir)
try:
self.sftp.mkdir(remote_dir)
logger.info(f"创建远程目录: {remote_dir}")
except OSError as e:
# 忽略目录已存在的错误(多线程可能同时创建)
if e.errno != 17: # Errno 17: File exists
logger.error(f"创建远程目录失败: {remote_dir}, 错误: {str(e)}")
raise
else:
logger.debug(f"远程目录已存在(忽略错误): {remote_dir}")
def execute_command(self, command, timeout=None, sudo=False):
"""
执行shell命令
:param command: 要执行的命令
:param timeout: 命令超时时间(秒)
:param sudo: 是否以sudo权限执行
:return: (成功标志, 退出状态码, 标准输出, 错误输出)
"""
if not self.ssh:
return False, -1, "", "未建立SSH连接"
try:
# 添加sudo前缀(如果需要)
if sudo and self.server_info.get('name') != 'root':
command = f"sudo -S -p '' {command}"
stdin, stdout, stderr = self.ssh.exec_command(command, timeout=timeout)
# 如果需要sudo且有密码,则提供密码
if sudo and self.server_info.get('name') != 'root' and self.server_info.get('password'):
stdin.write(self.server_info['password'] + '\n')
stdin.flush()
# 获取命令输出
out = stdout.read().decode('utf-8')
err = stderr.read().decode('utf-8')
exit_status = stdout.channel.recv_exit_status()
# 状态码为0表示成功
success = exit_status == 0
return success, exit_status, out, err
except Exception as e:
return False, -1, "", f"执行命令失败: {str(e)}"
请对上面代码优化
最新发布