secure_tunnel.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. #!/usr/bin/env python3
  2. """
  3. 安全的端口转发 - 只允许特定进程访问
  4. """
  5. import subprocess
  6. import psutil
  7. import time
  8. import os
  9. import signal
  10. import sys
  11. import socket
  12. import threading
  13. import json
  14. from pathlib import Path
  15. from http.server import HTTPServer, BaseHTTPRequestHandler
  16. import requests
  17. class SecureProxyHandler(BaseHTTPRequestHandler):
  18. """安全代理处理器 - 只允许特定进程访问"""
  19. def __init__(self, allowed_pids, target_host, target_port, *args, **kwargs):
  20. self.allowed_pids = allowed_pids
  21. self.target_host = target_host
  22. self.target_port = target_port
  23. super().__init__(*args, **kwargs)
  24. def is_request_allowed(self):
  25. """检查请求是否来自允许的进程"""
  26. try:
  27. # 获取客户端连接信息
  28. client_ip = self.client_address[0]
  29. # 如果不是本地连接,直接拒绝
  30. if client_ip not in ['127.0.0.1', '::1']:
  31. return False, f"Non-local connection from {client_ip}"
  32. # 检查当前运行的进程
  33. current_pids = set()
  34. for proc in psutil.process_iter(['pid', 'name', 'cmdline']):
  35. try:
  36. proc_info = proc.info
  37. proc_name = proc_info['name']
  38. cmdline = ' '.join(proc_info['cmdline'] or [])
  39. # 检查是否是允许的进程
  40. for allowed_process in ['Code', 'github-copilot', 'vllm', 'python']:
  41. if (allowed_process.lower() in proc_name.lower() or
  42. allowed_process.lower() in cmdline.lower()):
  43. current_pids.add(proc.info['pid'])
  44. except (psutil.NoSuchProcess, psutil.AccessDenied):
  45. continue
  46. # 检查是否有允许的进程在运行
  47. if not current_pids:
  48. return False, "No allowed processes running"
  49. return True, "Request allowed"
  50. except Exception as e:
  51. return False, f"Error checking request: {e}"
  52. def proxy_request(self, method='GET', data=None):
  53. """代理请求到目标服务器"""
  54. allowed, reason = self.is_request_allowed()
  55. if not allowed:
  56. self.send_error(403, f"Access denied: {reason}")
  57. return
  58. try:
  59. # 构建目标 URL
  60. target_url = f"http://{self.target_host}:{self.target_port}{self.path}"
  61. # 准备请求头
  62. headers = {}
  63. for header, value in self.headers.items():
  64. if header.lower() not in ['host', 'connection']:
  65. headers[header] = value
  66. # 发送请求
  67. if method == 'GET':
  68. response = requests.get(target_url, headers=headers, timeout=30)
  69. elif method == 'POST':
  70. response = requests.post(target_url, headers=headers, data=data, timeout=30)
  71. else:
  72. self.send_error(405, "Method not allowed")
  73. return
  74. # 返回响应
  75. self.send_response(response.status_code)
  76. for header, value in response.headers.items():
  77. if header.lower() not in ['connection', 'transfer-encoding']:
  78. self.send_header(header, value)
  79. self.end_headers()
  80. self.wfile.write(response.content)
  81. # 记录访问日志
  82. self.log_message(f"Proxied {method} {self.path} -> {response.status_code}")
  83. except requests.RequestException as e:
  84. self.send_error(502, f"Proxy error: {e}")
  85. except Exception as e:
  86. self.send_error(500, f"Internal error: {e}")
  87. def do_GET(self):
  88. self.proxy_request('GET')
  89. def do_POST(self):
  90. content_length = int(self.headers.get('Content-Length', 0))
  91. post_data = self.rfile.read(content_length)
  92. self.proxy_request('POST', post_data)
  93. def log_message(self, format, *args):
  94. """自定义日志格式"""
  95. timestamp = time.strftime("%Y-%m-%d %H:%M:%S")
  96. sys.stderr.write(f"[{timestamp}] {format % args}\n")
  97. class SecureTunnel:
  98. def __init__(self):
  99. self.config_file = Path.home() / '.secure_tunnel_config.json'
  100. self.proxy_server = None
  101. self.tunnel_process = None
  102. self.proxy_thread = None
  103. def load_config(self):
  104. """加载配置"""
  105. default_config = {
  106. "remote_host": "10.192.72.11",
  107. "remote_user": "ubuntu",
  108. # "ssh_key": str(Path.home() / ".ssh/id_dotsocr_tunnel"),
  109. "local_forward_port": 8082,
  110. "remote_service_port": 8101,
  111. "secure_proxy_port": 7281, # 安全代理端口
  112. "target_proxy_host": "127.0.0.1",
  113. "target_proxy_port": 7890,
  114. "allowed_processes": ["Code", "github-copilot", "vllm"],
  115. "check_interval": 30
  116. }
  117. if self.config_file.exists():
  118. with open(self.config_file, 'r') as f:
  119. config = json.load(f)
  120. for key, value in default_config.items():
  121. if key not in config:
  122. config[key] = value
  123. return config
  124. else:
  125. with open(self.config_file, 'w') as f:
  126. json.dump(default_config, f, indent=2)
  127. return default_config
  128. def start_secure_proxy(self, config):
  129. """启动安全代理服务器"""
  130. def create_handler(*args, **kwargs):
  131. return SecureProxyHandler(
  132. config['allowed_processes'],
  133. config['target_proxy_host'],
  134. config['target_proxy_port'],
  135. *args, **kwargs
  136. )
  137. try:
  138. self.proxy_server = HTTPServer(('127.0.0.1', config['secure_proxy_port']), create_handler)
  139. print(f"🔒 Secure proxy started on port {config['secure_proxy_port']}")
  140. def run_server():
  141. self.proxy_server.serve_forever()
  142. self.proxy_thread = threading.Thread(target=run_server, daemon=True)
  143. self.proxy_thread.start()
  144. return True
  145. except Exception as e:
  146. print(f"❌ Failed to start secure proxy: {e}")
  147. return False
  148. def start_tunnel(self, config):
  149. """启动 SSH 隧道(使用安全代理端口)"""
  150. if self.tunnel_process and self.tunnel_process.poll() is None:
  151. return True
  152. ssh_cmd = [
  153. 'ssh',
  154. # '-i', config['ssh_key'],
  155. '-o', 'ExitOnForwardFailure=yes',
  156. '-o', 'ServerAliveInterval=30',
  157. '-o', 'ServerAliveCountMax=3',
  158. '-L', f"{config['local_forward_port']}:localhost:{config['remote_service_port']}",
  159. '-R', f"7280:localhost:{config['secure_proxy_port']}", # 转发到安全代理
  160. '-N',
  161. f"{config['remote_user']}@{config['remote_host']}"
  162. ]
  163. try:
  164. print(f"🚀 Starting secure SSH tunnel...")
  165. self.tunnel_process = subprocess.Popen(
  166. ssh_cmd,
  167. stdout=subprocess.PIPE,
  168. stderr=subprocess.PIPE,
  169. preexec_fn=os.setsid
  170. )
  171. time.sleep(3)
  172. if self.tunnel_process.poll() is None:
  173. print("✅ Secure SSH tunnel started")
  174. return True
  175. else:
  176. stdout, stderr = self.tunnel_process.communicate()
  177. print(f"❌ SSH tunnel failed: {stderr.decode()}")
  178. return False
  179. except Exception as e:
  180. print(f"❌ Error starting tunnel: {e}")
  181. return False
  182. def stop_all(self):
  183. """停止所有服务"""
  184. # 停止代理服务器
  185. if self.proxy_server:
  186. self.proxy_server.shutdown()
  187. self.proxy_server.server_close()
  188. print("🔒 Secure proxy stopped")
  189. # 停止 SSH 隧道
  190. if self.tunnel_process:
  191. try:
  192. os.killpg(os.getpgid(self.tunnel_process.pid), signal.SIGTERM)
  193. self.tunnel_process.wait(timeout=5)
  194. print("🚀 SSH tunnel stopped")
  195. except:
  196. try:
  197. os.killpg(os.getpgid(self.tunnel_process.pid), signal.SIGKILL)
  198. except:
  199. pass
  200. def run(self):
  201. """运行安全隧道"""
  202. config = self.load_config()
  203. print("🛡️ Starting Secure DotsOCR Tunnel")
  204. print("=" * 50)
  205. print(f"🔒 Secure proxy port: {config['secure_proxy_port']}")
  206. print(f"🎯 Target proxy: {config['target_proxy_host']}:{config['target_proxy_port']}")
  207. print(f"✅ Allowed processes: {', '.join(config['allowed_processes'])}")
  208. print("=" * 50)
  209. # 设置信号处理
  210. def signal_handler(signum, frame):
  211. print("\n🛑 Shutting down secure tunnel...")
  212. self.stop_all()
  213. sys.exit(0)
  214. signal.signal(signal.SIGTERM, signal_handler)
  215. signal.signal(signal.SIGINT, signal_handler)
  216. try:
  217. # 启动安全代理
  218. if not self.start_secure_proxy(config):
  219. return 1
  220. # 启动 SSH 隧道
  221. if not self.start_tunnel(config):
  222. return 1
  223. print("\n🎉 Secure tunnel is running!")
  224. print("📊 Access logs will show below:")
  225. print("-" * 50)
  226. # 保持运行
  227. while True:
  228. time.sleep(1)
  229. except KeyboardInterrupt:
  230. print("\n🛑 Interrupted by user")
  231. finally:
  232. self.stop_all()
  233. return 0
  234. def main():
  235. tunnel = SecureTunnel()
  236. return tunnel.run()
  237. if __name__ == "__main__":
  238. sys.exit(main())