浏览代码

feat: 添加安全的端口转发功能,限制特定进程访问

zhch158_admin 3 月之前
父节点
当前提交
09bff33352
共有 1 个文件被更改,包括 287 次插入0 次删除
  1. 287 0
      zhch/secure_tunnel.py

+ 287 - 0
zhch/secure_tunnel.py

@@ -0,0 +1,287 @@
+#!/usr/bin/env python3
+"""
+安全的端口转发 - 只允许特定进程访问
+"""
+import subprocess
+import psutil
+import time
+import os
+import signal
+import sys
+import socket
+import threading
+import json
+from pathlib import Path
+from http.server import HTTPServer, BaseHTTPRequestHandler
+import requests
+
+class SecureProxyHandler(BaseHTTPRequestHandler):
+    """安全代理处理器 - 只允许特定进程访问"""
+    
+    def __init__(self, allowed_pids, target_host, target_port, *args, **kwargs):
+        self.allowed_pids = allowed_pids
+        self.target_host = target_host
+        self.target_port = target_port
+        super().__init__(*args, **kwargs)
+    
+    def is_request_allowed(self):
+        """检查请求是否来自允许的进程"""
+        try:
+            # 获取客户端连接信息
+            client_ip = self.client_address[0]
+            
+            # 如果不是本地连接,直接拒绝
+            if client_ip not in ['127.0.0.1', '::1']:
+                return False, f"Non-local connection from {client_ip}"
+            
+            # 检查当前运行的进程
+            current_pids = set()
+            for proc in psutil.process_iter(['pid', 'name', 'cmdline']):
+                try:
+                    proc_info = proc.info
+                    proc_name = proc_info['name']
+                    cmdline = ' '.join(proc_info['cmdline'] or [])
+                    
+                    # 检查是否是允许的进程
+                    for allowed_process in ['Code', 'github-copilot', 'vllm', 'python']:
+                        if (allowed_process.lower() in proc_name.lower() or 
+                            allowed_process.lower() in cmdline.lower()):
+                            current_pids.add(proc.info['pid'])
+                            
+                except (psutil.NoSuchProcess, psutil.AccessDenied):
+                    continue
+            
+            # 检查是否有允许的进程在运行
+            if not current_pids:
+                return False, "No allowed processes running"
+            
+            return True, "Request allowed"
+            
+        except Exception as e:
+            return False, f"Error checking request: {e}"
+    
+    def proxy_request(self, method='GET', data=None):
+        """代理请求到目标服务器"""
+        allowed, reason = self.is_request_allowed()
+        
+        if not allowed:
+            self.send_error(403, f"Access denied: {reason}")
+            return
+        
+        try:
+            # 构建目标 URL
+            target_url = f"http://{self.target_host}:{self.target_port}{self.path}"
+            
+            # 准备请求头
+            headers = {}
+            for header, value in self.headers.items():
+                if header.lower() not in ['host', 'connection']:
+                    headers[header] = value
+            
+            # 发送请求
+            if method == 'GET':
+                response = requests.get(target_url, headers=headers, timeout=30)
+            elif method == 'POST':
+                response = requests.post(target_url, headers=headers, data=data, timeout=30)
+            else:
+                self.send_error(405, "Method not allowed")
+                return
+            
+            # 返回响应
+            self.send_response(response.status_code)
+            for header, value in response.headers.items():
+                if header.lower() not in ['connection', 'transfer-encoding']:
+                    self.send_header(header, value)
+            self.end_headers()
+            
+            self.wfile.write(response.content)
+            
+            # 记录访问日志
+            self.log_message(f"Proxied {method} {self.path} -> {response.status_code}")
+            
+        except requests.RequestException as e:
+            self.send_error(502, f"Proxy error: {e}")
+        except Exception as e:
+            self.send_error(500, f"Internal error: {e}")
+    
+    def do_GET(self):
+        self.proxy_request('GET')
+    
+    def do_POST(self):
+        content_length = int(self.headers.get('Content-Length', 0))
+        post_data = self.rfile.read(content_length)
+        self.proxy_request('POST', post_data)
+    
+    def log_message(self, format, *args):
+        """自定义日志格式"""
+        timestamp = time.strftime("%Y-%m-%d %H:%M:%S")
+        sys.stderr.write(f"[{timestamp}] {format % args}\n")
+
+class SecureTunnel:
+    def __init__(self):
+        self.config_file = Path.home() / '.secure_tunnel_config.json'
+        self.proxy_server = None
+        self.tunnel_process = None
+        self.proxy_thread = None
+        
+    def load_config(self):
+        """加载配置"""
+        default_config = {
+            "remote_host": "10.192.72.11",
+            "remote_user": "ubuntu",
+            # "ssh_key": str(Path.home() / ".ssh/id_dotsocr_tunnel"),
+            "local_forward_port": 8082,
+            "remote_service_port": 8101,
+            "secure_proxy_port": 7281,  # 安全代理端口
+            "target_proxy_host": "127.0.0.1",
+            "target_proxy_port": 7890,
+            "allowed_processes": ["Code", "github-copilot", "vllm"],
+            "check_interval": 30
+        }
+        
+        if self.config_file.exists():
+            with open(self.config_file, 'r') as f:
+                config = json.load(f)
+                for key, value in default_config.items():
+                    if key not in config:
+                        config[key] = value
+                return config
+        else:
+            with open(self.config_file, 'w') as f:
+                json.dump(default_config, f, indent=2)
+            return default_config
+    
+    def start_secure_proxy(self, config):
+        """启动安全代理服务器"""
+        def create_handler(*args, **kwargs):
+            return SecureProxyHandler(
+                config['allowed_processes'],
+                config['target_proxy_host'],
+                config['target_proxy_port'],
+                *args, **kwargs
+            )
+        
+        try:
+            self.proxy_server = HTTPServer(('127.0.0.1', config['secure_proxy_port']), create_handler)
+            print(f"🔒 Secure proxy started on port {config['secure_proxy_port']}")
+            
+            def run_server():
+                self.proxy_server.serve_forever()
+            
+            self.proxy_thread = threading.Thread(target=run_server, daemon=True)
+            self.proxy_thread.start()
+            return True
+            
+        except Exception as e:
+            print(f"❌ Failed to start secure proxy: {e}")
+            return False
+    
+    def start_tunnel(self, config):
+        """启动 SSH 隧道(使用安全代理端口)"""
+        if self.tunnel_process and self.tunnel_process.poll() is None:
+            return True
+        
+        ssh_cmd = [
+            'ssh',
+            # '-i', config['ssh_key'],
+            '-o', 'ExitOnForwardFailure=yes',
+            '-o', 'ServerAliveInterval=30',
+            '-o', 'ServerAliveCountMax=3',
+            '-L', f"{config['local_forward_port']}:localhost:{config['remote_service_port']}",
+            '-R', f"7280:localhost:{config['secure_proxy_port']}",  # 转发到安全代理
+            '-N',
+            f"{config['remote_user']}@{config['remote_host']}"
+        ]
+        
+        try:
+            print(f"🚀 Starting secure SSH tunnel...")
+            self.tunnel_process = subprocess.Popen(
+                ssh_cmd,
+                stdout=subprocess.PIPE,
+                stderr=subprocess.PIPE,
+                preexec_fn=os.setsid
+            )
+            
+            time.sleep(3)
+            if self.tunnel_process.poll() is None:
+                print("✅ Secure SSH tunnel started")
+                return True
+            else:
+                stdout, stderr = self.tunnel_process.communicate()
+                print(f"❌ SSH tunnel failed: {stderr.decode()}")
+                return False
+                
+        except Exception as e:
+            print(f"❌ Error starting tunnel: {e}")
+            return False
+    
+    def stop_all(self):
+        """停止所有服务"""
+        # 停止代理服务器
+        if self.proxy_server:
+            self.proxy_server.shutdown()
+            self.proxy_server.server_close()
+            print("🔒 Secure proxy stopped")
+        
+        # 停止 SSH 隧道
+        if self.tunnel_process:
+            try:
+                os.killpg(os.getpgid(self.tunnel_process.pid), signal.SIGTERM)
+                self.tunnel_process.wait(timeout=5)
+                print("🚀 SSH tunnel stopped")
+            except:
+                try:
+                    os.killpg(os.getpgid(self.tunnel_process.pid), signal.SIGKILL)
+                except:
+                    pass
+    
+    def run(self):
+        """运行安全隧道"""
+        config = self.load_config()
+        
+        print("🛡️ Starting Secure DotsOCR Tunnel")
+        print("=" * 50)
+        print(f"🔒 Secure proxy port: {config['secure_proxy_port']}")
+        print(f"🎯 Target proxy: {config['target_proxy_host']}:{config['target_proxy_port']}")
+        print(f"✅ Allowed processes: {', '.join(config['allowed_processes'])}")
+        print("=" * 50)
+        
+        # 设置信号处理
+        def signal_handler(signum, frame):
+            print("\n🛑 Shutting down secure tunnel...")
+            self.stop_all()
+            sys.exit(0)
+        
+        signal.signal(signal.SIGTERM, signal_handler)
+        signal.signal(signal.SIGINT, signal_handler)
+        
+        try:
+            # 启动安全代理
+            if not self.start_secure_proxy(config):
+                return 1
+            
+            # 启动 SSH 隧道
+            if not self.start_tunnel(config):
+                return 1
+            
+            print("\n🎉 Secure tunnel is running!")
+            print("📊 Access logs will show below:")
+            print("-" * 50)
+            
+            # 保持运行
+            while True:
+                time.sleep(1)
+                
+        except KeyboardInterrupt:
+            print("\n🛑 Interrupted by user")
+        finally:
+            self.stop_all()
+        
+        return 0
+
+def main():
+    tunnel = SecureTunnel()
+    return tunnel.run()
+
+if __name__ == "__main__":
+    sys.exit(main())