|
|
@@ -0,0 +1,287 @@
|
|
|
+#!/usr/bin/env python3
|
|
|
+"""
|
|
|
+安全的端口转发 - 只允许特定进程访问
|
|
|
+"""
|
|
|
+import subprocess
|
|
|
+import psutil
|
|
|
+import time
|
|
|
+import os
|
|
|
+import signal
|
|
|
+import sys
|
|
|
+import socket
|
|
|
+import threading
|
|
|
+import json
|
|
|
+from pathlib import Path
|
|
|
+from http.server import HTTPServer, BaseHTTPRequestHandler
|
|
|
+import requests
|
|
|
+
|
|
|
+class SecureProxyHandler(BaseHTTPRequestHandler):
|
|
|
+ """安全代理处理器 - 只允许特定进程访问"""
|
|
|
+
|
|
|
+ def __init__(self, allowed_pids, target_host, target_port, *args, **kwargs):
|
|
|
+ self.allowed_pids = allowed_pids
|
|
|
+ self.target_host = target_host
|
|
|
+ self.target_port = target_port
|
|
|
+ super().__init__(*args, **kwargs)
|
|
|
+
|
|
|
+ def is_request_allowed(self):
|
|
|
+ """检查请求是否来自允许的进程"""
|
|
|
+ try:
|
|
|
+ # 获取客户端连接信息
|
|
|
+ client_ip = self.client_address[0]
|
|
|
+
|
|
|
+ # 如果不是本地连接,直接拒绝
|
|
|
+ if client_ip not in ['127.0.0.1', '::1']:
|
|
|
+ return False, f"Non-local connection from {client_ip}"
|
|
|
+
|
|
|
+ # 检查当前运行的进程
|
|
|
+ current_pids = set()
|
|
|
+ for proc in psutil.process_iter(['pid', 'name', 'cmdline']):
|
|
|
+ try:
|
|
|
+ proc_info = proc.info
|
|
|
+ proc_name = proc_info['name']
|
|
|
+ cmdline = ' '.join(proc_info['cmdline'] or [])
|
|
|
+
|
|
|
+ # 检查是否是允许的进程
|
|
|
+ for allowed_process in ['Code', 'github-copilot', 'vllm', 'python']:
|
|
|
+ if (allowed_process.lower() in proc_name.lower() or
|
|
|
+ allowed_process.lower() in cmdline.lower()):
|
|
|
+ current_pids.add(proc.info['pid'])
|
|
|
+
|
|
|
+ except (psutil.NoSuchProcess, psutil.AccessDenied):
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 检查是否有允许的进程在运行
|
|
|
+ if not current_pids:
|
|
|
+ return False, "No allowed processes running"
|
|
|
+
|
|
|
+ return True, "Request allowed"
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ return False, f"Error checking request: {e}"
|
|
|
+
|
|
|
+ def proxy_request(self, method='GET', data=None):
|
|
|
+ """代理请求到目标服务器"""
|
|
|
+ allowed, reason = self.is_request_allowed()
|
|
|
+
|
|
|
+ if not allowed:
|
|
|
+ self.send_error(403, f"Access denied: {reason}")
|
|
|
+ return
|
|
|
+
|
|
|
+ try:
|
|
|
+ # 构建目标 URL
|
|
|
+ target_url = f"http://{self.target_host}:{self.target_port}{self.path}"
|
|
|
+
|
|
|
+ # 准备请求头
|
|
|
+ headers = {}
|
|
|
+ for header, value in self.headers.items():
|
|
|
+ if header.lower() not in ['host', 'connection']:
|
|
|
+ headers[header] = value
|
|
|
+
|
|
|
+ # 发送请求
|
|
|
+ if method == 'GET':
|
|
|
+ response = requests.get(target_url, headers=headers, timeout=30)
|
|
|
+ elif method == 'POST':
|
|
|
+ response = requests.post(target_url, headers=headers, data=data, timeout=30)
|
|
|
+ else:
|
|
|
+ self.send_error(405, "Method not allowed")
|
|
|
+ return
|
|
|
+
|
|
|
+ # 返回响应
|
|
|
+ self.send_response(response.status_code)
|
|
|
+ for header, value in response.headers.items():
|
|
|
+ if header.lower() not in ['connection', 'transfer-encoding']:
|
|
|
+ self.send_header(header, value)
|
|
|
+ self.end_headers()
|
|
|
+
|
|
|
+ self.wfile.write(response.content)
|
|
|
+
|
|
|
+ # 记录访问日志
|
|
|
+ self.log_message(f"Proxied {method} {self.path} -> {response.status_code}")
|
|
|
+
|
|
|
+ except requests.RequestException as e:
|
|
|
+ self.send_error(502, f"Proxy error: {e}")
|
|
|
+ except Exception as e:
|
|
|
+ self.send_error(500, f"Internal error: {e}")
|
|
|
+
|
|
|
+ def do_GET(self):
|
|
|
+ self.proxy_request('GET')
|
|
|
+
|
|
|
+ def do_POST(self):
|
|
|
+ content_length = int(self.headers.get('Content-Length', 0))
|
|
|
+ post_data = self.rfile.read(content_length)
|
|
|
+ self.proxy_request('POST', post_data)
|
|
|
+
|
|
|
+ def log_message(self, format, *args):
|
|
|
+ """自定义日志格式"""
|
|
|
+ timestamp = time.strftime("%Y-%m-%d %H:%M:%S")
|
|
|
+ sys.stderr.write(f"[{timestamp}] {format % args}\n")
|
|
|
+
|
|
|
+class SecureTunnel:
|
|
|
+ def __init__(self):
|
|
|
+ self.config_file = Path.home() / '.secure_tunnel_config.json'
|
|
|
+ self.proxy_server = None
|
|
|
+ self.tunnel_process = None
|
|
|
+ self.proxy_thread = None
|
|
|
+
|
|
|
+ def load_config(self):
|
|
|
+ """加载配置"""
|
|
|
+ default_config = {
|
|
|
+ "remote_host": "10.192.72.11",
|
|
|
+ "remote_user": "ubuntu",
|
|
|
+ # "ssh_key": str(Path.home() / ".ssh/id_dotsocr_tunnel"),
|
|
|
+ "local_forward_port": 8082,
|
|
|
+ "remote_service_port": 8101,
|
|
|
+ "secure_proxy_port": 7281, # 安全代理端口
|
|
|
+ "target_proxy_host": "127.0.0.1",
|
|
|
+ "target_proxy_port": 7890,
|
|
|
+ "allowed_processes": ["Code", "github-copilot", "vllm"],
|
|
|
+ "check_interval": 30
|
|
|
+ }
|
|
|
+
|
|
|
+ if self.config_file.exists():
|
|
|
+ with open(self.config_file, 'r') as f:
|
|
|
+ config = json.load(f)
|
|
|
+ for key, value in default_config.items():
|
|
|
+ if key not in config:
|
|
|
+ config[key] = value
|
|
|
+ return config
|
|
|
+ else:
|
|
|
+ with open(self.config_file, 'w') as f:
|
|
|
+ json.dump(default_config, f, indent=2)
|
|
|
+ return default_config
|
|
|
+
|
|
|
+ def start_secure_proxy(self, config):
|
|
|
+ """启动安全代理服务器"""
|
|
|
+ def create_handler(*args, **kwargs):
|
|
|
+ return SecureProxyHandler(
|
|
|
+ config['allowed_processes'],
|
|
|
+ config['target_proxy_host'],
|
|
|
+ config['target_proxy_port'],
|
|
|
+ *args, **kwargs
|
|
|
+ )
|
|
|
+
|
|
|
+ try:
|
|
|
+ self.proxy_server = HTTPServer(('127.0.0.1', config['secure_proxy_port']), create_handler)
|
|
|
+ print(f"🔒 Secure proxy started on port {config['secure_proxy_port']}")
|
|
|
+
|
|
|
+ def run_server():
|
|
|
+ self.proxy_server.serve_forever()
|
|
|
+
|
|
|
+ self.proxy_thread = threading.Thread(target=run_server, daemon=True)
|
|
|
+ self.proxy_thread.start()
|
|
|
+ return True
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"❌ Failed to start secure proxy: {e}")
|
|
|
+ return False
|
|
|
+
|
|
|
+ def start_tunnel(self, config):
|
|
|
+ """启动 SSH 隧道(使用安全代理端口)"""
|
|
|
+ if self.tunnel_process and self.tunnel_process.poll() is None:
|
|
|
+ return True
|
|
|
+
|
|
|
+ ssh_cmd = [
|
|
|
+ 'ssh',
|
|
|
+ # '-i', config['ssh_key'],
|
|
|
+ '-o', 'ExitOnForwardFailure=yes',
|
|
|
+ '-o', 'ServerAliveInterval=30',
|
|
|
+ '-o', 'ServerAliveCountMax=3',
|
|
|
+ '-L', f"{config['local_forward_port']}:localhost:{config['remote_service_port']}",
|
|
|
+ '-R', f"7280:localhost:{config['secure_proxy_port']}", # 转发到安全代理
|
|
|
+ '-N',
|
|
|
+ f"{config['remote_user']}@{config['remote_host']}"
|
|
|
+ ]
|
|
|
+
|
|
|
+ try:
|
|
|
+ print(f"🚀 Starting secure SSH tunnel...")
|
|
|
+ self.tunnel_process = subprocess.Popen(
|
|
|
+ ssh_cmd,
|
|
|
+ stdout=subprocess.PIPE,
|
|
|
+ stderr=subprocess.PIPE,
|
|
|
+ preexec_fn=os.setsid
|
|
|
+ )
|
|
|
+
|
|
|
+ time.sleep(3)
|
|
|
+ if self.tunnel_process.poll() is None:
|
|
|
+ print("✅ Secure SSH tunnel started")
|
|
|
+ return True
|
|
|
+ else:
|
|
|
+ stdout, stderr = self.tunnel_process.communicate()
|
|
|
+ print(f"❌ SSH tunnel failed: {stderr.decode()}")
|
|
|
+ return False
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"❌ Error starting tunnel: {e}")
|
|
|
+ return False
|
|
|
+
|
|
|
+ def stop_all(self):
|
|
|
+ """停止所有服务"""
|
|
|
+ # 停止代理服务器
|
|
|
+ if self.proxy_server:
|
|
|
+ self.proxy_server.shutdown()
|
|
|
+ self.proxy_server.server_close()
|
|
|
+ print("🔒 Secure proxy stopped")
|
|
|
+
|
|
|
+ # 停止 SSH 隧道
|
|
|
+ if self.tunnel_process:
|
|
|
+ try:
|
|
|
+ os.killpg(os.getpgid(self.tunnel_process.pid), signal.SIGTERM)
|
|
|
+ self.tunnel_process.wait(timeout=5)
|
|
|
+ print("🚀 SSH tunnel stopped")
|
|
|
+ except:
|
|
|
+ try:
|
|
|
+ os.killpg(os.getpgid(self.tunnel_process.pid), signal.SIGKILL)
|
|
|
+ except:
|
|
|
+ pass
|
|
|
+
|
|
|
+ def run(self):
|
|
|
+ """运行安全隧道"""
|
|
|
+ config = self.load_config()
|
|
|
+
|
|
|
+ print("🛡️ Starting Secure DotsOCR Tunnel")
|
|
|
+ print("=" * 50)
|
|
|
+ print(f"🔒 Secure proxy port: {config['secure_proxy_port']}")
|
|
|
+ print(f"🎯 Target proxy: {config['target_proxy_host']}:{config['target_proxy_port']}")
|
|
|
+ print(f"✅ Allowed processes: {', '.join(config['allowed_processes'])}")
|
|
|
+ print("=" * 50)
|
|
|
+
|
|
|
+ # 设置信号处理
|
|
|
+ def signal_handler(signum, frame):
|
|
|
+ print("\n🛑 Shutting down secure tunnel...")
|
|
|
+ self.stop_all()
|
|
|
+ sys.exit(0)
|
|
|
+
|
|
|
+ signal.signal(signal.SIGTERM, signal_handler)
|
|
|
+ signal.signal(signal.SIGINT, signal_handler)
|
|
|
+
|
|
|
+ try:
|
|
|
+ # 启动安全代理
|
|
|
+ if not self.start_secure_proxy(config):
|
|
|
+ return 1
|
|
|
+
|
|
|
+ # 启动 SSH 隧道
|
|
|
+ if not self.start_tunnel(config):
|
|
|
+ return 1
|
|
|
+
|
|
|
+ print("\n🎉 Secure tunnel is running!")
|
|
|
+ print("📊 Access logs will show below:")
|
|
|
+ print("-" * 50)
|
|
|
+
|
|
|
+ # 保持运行
|
|
|
+ while True:
|
|
|
+ time.sleep(1)
|
|
|
+
|
|
|
+ except KeyboardInterrupt:
|
|
|
+ print("\n🛑 Interrupted by user")
|
|
|
+ finally:
|
|
|
+ self.stop_all()
|
|
|
+
|
|
|
+ return 0
|
|
|
+
|
|
|
+def main():
|
|
|
+ tunnel = SecureTunnel()
|
|
|
+ return tunnel.run()
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ sys.exit(main())
|