| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287 |
- #!/usr/bin/env python3
- """
- 安全的端口转发 - 只允许特定进程访问
- """
- import subprocess
- import psutil
- import time
- import os
- import signal
- import sys
- import socket
- import threading
- import json
- from pathlib import Path
- from http.server import HTTPServer, BaseHTTPRequestHandler
- import requests
- class SecureProxyHandler(BaseHTTPRequestHandler):
- """安全代理处理器 - 只允许特定进程访问"""
-
- def __init__(self, allowed_pids, target_host, target_port, *args, **kwargs):
- self.allowed_pids = allowed_pids
- self.target_host = target_host
- self.target_port = target_port
- super().__init__(*args, **kwargs)
-
- def is_request_allowed(self):
- """检查请求是否来自允许的进程"""
- try:
- # 获取客户端连接信息
- client_ip = self.client_address[0]
-
- # 如果不是本地连接,直接拒绝
- if client_ip not in ['127.0.0.1', '::1']:
- return False, f"Non-local connection from {client_ip}"
-
- # 检查当前运行的进程
- current_pids = set()
- for proc in psutil.process_iter(['pid', 'name', 'cmdline']):
- try:
- proc_info = proc.info
- proc_name = proc_info['name']
- cmdline = ' '.join(proc_info['cmdline'] or [])
-
- # 检查是否是允许的进程
- for allowed_process in ['Code', 'github-copilot', 'vllm', 'python']:
- if (allowed_process.lower() in proc_name.lower() or
- allowed_process.lower() in cmdline.lower()):
- current_pids.add(proc.info['pid'])
-
- except (psutil.NoSuchProcess, psutil.AccessDenied):
- continue
-
- # 检查是否有允许的进程在运行
- if not current_pids:
- return False, "No allowed processes running"
-
- return True, "Request allowed"
-
- except Exception as e:
- return False, f"Error checking request: {e}"
-
- def proxy_request(self, method='GET', data=None):
- """代理请求到目标服务器"""
- allowed, reason = self.is_request_allowed()
-
- if not allowed:
- self.send_error(403, f"Access denied: {reason}")
- return
-
- try:
- # 构建目标 URL
- target_url = f"http://{self.target_host}:{self.target_port}{self.path}"
-
- # 准备请求头
- headers = {}
- for header, value in self.headers.items():
- if header.lower() not in ['host', 'connection']:
- headers[header] = value
-
- # 发送请求
- if method == 'GET':
- response = requests.get(target_url, headers=headers, timeout=30)
- elif method == 'POST':
- response = requests.post(target_url, headers=headers, data=data, timeout=30)
- else:
- self.send_error(405, "Method not allowed")
- return
-
- # 返回响应
- self.send_response(response.status_code)
- for header, value in response.headers.items():
- if header.lower() not in ['connection', 'transfer-encoding']:
- self.send_header(header, value)
- self.end_headers()
-
- self.wfile.write(response.content)
-
- # 记录访问日志
- self.log_message(f"Proxied {method} {self.path} -> {response.status_code}")
-
- except requests.RequestException as e:
- self.send_error(502, f"Proxy error: {e}")
- except Exception as e:
- self.send_error(500, f"Internal error: {e}")
-
- def do_GET(self):
- self.proxy_request('GET')
-
- def do_POST(self):
- content_length = int(self.headers.get('Content-Length', 0))
- post_data = self.rfile.read(content_length)
- self.proxy_request('POST', post_data)
-
- def log_message(self, format, *args):
- """自定义日志格式"""
- timestamp = time.strftime("%Y-%m-%d %H:%M:%S")
- sys.stderr.write(f"[{timestamp}] {format % args}\n")
- class SecureTunnel:
- def __init__(self):
- self.config_file = Path.home() / '.secure_tunnel_config.json'
- self.proxy_server = None
- self.tunnel_process = None
- self.proxy_thread = None
-
- def load_config(self):
- """加载配置"""
- default_config = {
- "remote_host": "10.192.72.11",
- "remote_user": "ubuntu",
- # "ssh_key": str(Path.home() / ".ssh/id_dotsocr_tunnel"),
- "local_forward_port": 8082,
- "remote_service_port": 8101,
- "secure_proxy_port": 7281, # 安全代理端口
- "target_proxy_host": "127.0.0.1",
- "target_proxy_port": 7890,
- "allowed_processes": ["Code", "github-copilot", "vllm"],
- "check_interval": 30
- }
-
- if self.config_file.exists():
- with open(self.config_file, 'r') as f:
- config = json.load(f)
- for key, value in default_config.items():
- if key not in config:
- config[key] = value
- return config
- else:
- with open(self.config_file, 'w') as f:
- json.dump(default_config, f, indent=2)
- return default_config
-
- def start_secure_proxy(self, config):
- """启动安全代理服务器"""
- def create_handler(*args, **kwargs):
- return SecureProxyHandler(
- config['allowed_processes'],
- config['target_proxy_host'],
- config['target_proxy_port'],
- *args, **kwargs
- )
-
- try:
- self.proxy_server = HTTPServer(('127.0.0.1', config['secure_proxy_port']), create_handler)
- print(f"🔒 Secure proxy started on port {config['secure_proxy_port']}")
-
- def run_server():
- self.proxy_server.serve_forever()
-
- self.proxy_thread = threading.Thread(target=run_server, daemon=True)
- self.proxy_thread.start()
- return True
-
- except Exception as e:
- print(f"❌ Failed to start secure proxy: {e}")
- return False
-
- def start_tunnel(self, config):
- """启动 SSH 隧道(使用安全代理端口)"""
- if self.tunnel_process and self.tunnel_process.poll() is None:
- return True
-
- ssh_cmd = [
- 'ssh',
- # '-i', config['ssh_key'],
- '-o', 'ExitOnForwardFailure=yes',
- '-o', 'ServerAliveInterval=30',
- '-o', 'ServerAliveCountMax=3',
- '-L', f"{config['local_forward_port']}:localhost:{config['remote_service_port']}",
- '-R', f"7280:localhost:{config['secure_proxy_port']}", # 转发到安全代理
- '-N',
- f"{config['remote_user']}@{config['remote_host']}"
- ]
-
- try:
- print(f"🚀 Starting secure SSH tunnel...")
- self.tunnel_process = subprocess.Popen(
- ssh_cmd,
- stdout=subprocess.PIPE,
- stderr=subprocess.PIPE,
- preexec_fn=os.setsid
- )
-
- time.sleep(3)
- if self.tunnel_process.poll() is None:
- print("✅ Secure SSH tunnel started")
- return True
- else:
- stdout, stderr = self.tunnel_process.communicate()
- print(f"❌ SSH tunnel failed: {stderr.decode()}")
- return False
-
- except Exception as e:
- print(f"❌ Error starting tunnel: {e}")
- return False
-
- def stop_all(self):
- """停止所有服务"""
- # 停止代理服务器
- if self.proxy_server:
- self.proxy_server.shutdown()
- self.proxy_server.server_close()
- print("🔒 Secure proxy stopped")
-
- # 停止 SSH 隧道
- if self.tunnel_process:
- try:
- os.killpg(os.getpgid(self.tunnel_process.pid), signal.SIGTERM)
- self.tunnel_process.wait(timeout=5)
- print("🚀 SSH tunnel stopped")
- except:
- try:
- os.killpg(os.getpgid(self.tunnel_process.pid), signal.SIGKILL)
- except:
- pass
-
- def run(self):
- """运行安全隧道"""
- config = self.load_config()
-
- print("🛡️ Starting Secure DotsOCR Tunnel")
- print("=" * 50)
- print(f"🔒 Secure proxy port: {config['secure_proxy_port']}")
- print(f"🎯 Target proxy: {config['target_proxy_host']}:{config['target_proxy_port']}")
- print(f"✅ Allowed processes: {', '.join(config['allowed_processes'])}")
- print("=" * 50)
-
- # 设置信号处理
- def signal_handler(signum, frame):
- print("\n🛑 Shutting down secure tunnel...")
- self.stop_all()
- sys.exit(0)
-
- signal.signal(signal.SIGTERM, signal_handler)
- signal.signal(signal.SIGINT, signal_handler)
-
- try:
- # 启动安全代理
- if not self.start_secure_proxy(config):
- return 1
-
- # 启动 SSH 隧道
- if not self.start_tunnel(config):
- return 1
-
- print("\n🎉 Secure tunnel is running!")
- print("📊 Access logs will show below:")
- print("-" * 50)
-
- # 保持运行
- while True:
- time.sleep(1)
-
- except KeyboardInterrupt:
- print("\n🛑 Interrupted by user")
- finally:
- self.stop_all()
-
- return 0
- def main():
- tunnel = SecureTunnel()
- return tunnel.run()
- if __name__ == "__main__":
- sys.exit(main())
|