|
|
@@ -1,287 +0,0 @@
|
|
|
-#!/usr/bin/env python3
|
|
|
-"""
|
|
|
-安全的端口转发 - 只允许特定进程访问
|
|
|
-"""
|
|
|
-import subprocess
|
|
|
-import psutil
|
|
|
-import time
|
|
|
-import os
|
|
|
-import signal
|
|
|
-import sys
|
|
|
-import socket
|
|
|
-import threading
|
|
|
-import json
|
|
|
-from pathlib import Path
|
|
|
-from http.server import HTTPServer, BaseHTTPRequestHandler
|
|
|
-import requests
|
|
|
-
|
|
|
-class SecureProxyHandler(BaseHTTPRequestHandler):
|
|
|
- """安全代理处理器 - 只允许特定进程访问"""
|
|
|
-
|
|
|
- def __init__(self, allowed_pids, target_host, target_port, *args, **kwargs):
|
|
|
- self.allowed_pids = allowed_pids
|
|
|
- self.target_host = target_host
|
|
|
- self.target_port = target_port
|
|
|
- super().__init__(*args, **kwargs)
|
|
|
-
|
|
|
- def is_request_allowed(self):
|
|
|
- """检查请求是否来自允许的进程"""
|
|
|
- try:
|
|
|
- # 获取客户端连接信息
|
|
|
- client_ip = self.client_address[0]
|
|
|
-
|
|
|
- # 如果不是本地连接,直接拒绝
|
|
|
- if client_ip not in ['127.0.0.1', '::1']:
|
|
|
- return False, f"Non-local connection from {client_ip}"
|
|
|
-
|
|
|
- # 检查当前运行的进程
|
|
|
- current_pids = set()
|
|
|
- for proc in psutil.process_iter(['pid', 'name', 'cmdline']):
|
|
|
- try:
|
|
|
- proc_info = proc.info
|
|
|
- proc_name = proc_info['name']
|
|
|
- cmdline = ' '.join(proc_info['cmdline'] or [])
|
|
|
-
|
|
|
- # 检查是否是允许的进程
|
|
|
- for allowed_process in ['Code', 'github-copilot', 'vllm', 'python']:
|
|
|
- if (allowed_process.lower() in proc_name.lower() or
|
|
|
- allowed_process.lower() in cmdline.lower()):
|
|
|
- current_pids.add(proc.info['pid'])
|
|
|
-
|
|
|
- except (psutil.NoSuchProcess, psutil.AccessDenied):
|
|
|
- continue
|
|
|
-
|
|
|
- # 检查是否有允许的进程在运行
|
|
|
- if not current_pids:
|
|
|
- return False, "No allowed processes running"
|
|
|
-
|
|
|
- return True, "Request allowed"
|
|
|
-
|
|
|
- except Exception as e:
|
|
|
- return False, f"Error checking request: {e}"
|
|
|
-
|
|
|
- def proxy_request(self, method='GET', data=None):
|
|
|
- """代理请求到目标服务器"""
|
|
|
- allowed, reason = self.is_request_allowed()
|
|
|
-
|
|
|
- if not allowed:
|
|
|
- self.send_error(403, f"Access denied: {reason}")
|
|
|
- return
|
|
|
-
|
|
|
- try:
|
|
|
- # 构建目标 URL
|
|
|
- target_url = f"http://{self.target_host}:{self.target_port}{self.path}"
|
|
|
-
|
|
|
- # 准备请求头
|
|
|
- headers = {}
|
|
|
- for header, value in self.headers.items():
|
|
|
- if header.lower() not in ['host', 'connection']:
|
|
|
- headers[header] = value
|
|
|
-
|
|
|
- # 发送请求
|
|
|
- if method == 'GET':
|
|
|
- response = requests.get(target_url, headers=headers, timeout=30)
|
|
|
- elif method == 'POST':
|
|
|
- response = requests.post(target_url, headers=headers, data=data, timeout=30)
|
|
|
- else:
|
|
|
- self.send_error(405, "Method not allowed")
|
|
|
- return
|
|
|
-
|
|
|
- # 返回响应
|
|
|
- self.send_response(response.status_code)
|
|
|
- for header, value in response.headers.items():
|
|
|
- if header.lower() not in ['connection', 'transfer-encoding']:
|
|
|
- self.send_header(header, value)
|
|
|
- self.end_headers()
|
|
|
-
|
|
|
- self.wfile.write(response.content)
|
|
|
-
|
|
|
- # 记录访问日志
|
|
|
- self.log_message(f"Proxied {method} {self.path} -> {response.status_code}")
|
|
|
-
|
|
|
- except requests.RequestException as e:
|
|
|
- self.send_error(502, f"Proxy error: {e}")
|
|
|
- except Exception as e:
|
|
|
- self.send_error(500, f"Internal error: {e}")
|
|
|
-
|
|
|
- def do_GET(self):
|
|
|
- self.proxy_request('GET')
|
|
|
-
|
|
|
- def do_POST(self):
|
|
|
- content_length = int(self.headers.get('Content-Length', 0))
|
|
|
- post_data = self.rfile.read(content_length)
|
|
|
- self.proxy_request('POST', post_data)
|
|
|
-
|
|
|
- def log_message(self, format, *args):
|
|
|
- """自定义日志格式"""
|
|
|
- timestamp = time.strftime("%Y-%m-%d %H:%M:%S")
|
|
|
- sys.stderr.write(f"[{timestamp}] {format % args}\n")
|
|
|
-
|
|
|
-class SecureTunnel:
|
|
|
- def __init__(self):
|
|
|
- self.config_file = Path.home() / '.secure_tunnel_config.json'
|
|
|
- self.proxy_server = None
|
|
|
- self.tunnel_process = None
|
|
|
- self.proxy_thread = None
|
|
|
-
|
|
|
- def load_config(self):
|
|
|
- """加载配置"""
|
|
|
- default_config = {
|
|
|
- "remote_host": "10.192.72.11",
|
|
|
- "remote_user": "ubuntu",
|
|
|
- # "ssh_key": str(Path.home() / ".ssh/id_dotsocr_tunnel"),
|
|
|
- "local_forward_port": 8082,
|
|
|
- "remote_service_port": 8101,
|
|
|
- "secure_proxy_port": 7281, # 安全代理端口
|
|
|
- "target_proxy_host": "127.0.0.1",
|
|
|
- "target_proxy_port": 7890,
|
|
|
- "allowed_processes": ["Code", "github-copilot", "vllm"],
|
|
|
- "check_interval": 30
|
|
|
- }
|
|
|
-
|
|
|
- if self.config_file.exists():
|
|
|
- with open(self.config_file, 'r') as f:
|
|
|
- config = json.load(f)
|
|
|
- for key, value in default_config.items():
|
|
|
- if key not in config:
|
|
|
- config[key] = value
|
|
|
- return config
|
|
|
- else:
|
|
|
- with open(self.config_file, 'w') as f:
|
|
|
- json.dump(default_config, f, indent=2)
|
|
|
- return default_config
|
|
|
-
|
|
|
- def start_secure_proxy(self, config):
|
|
|
- """启动安全代理服务器"""
|
|
|
- def create_handler(*args, **kwargs):
|
|
|
- return SecureProxyHandler(
|
|
|
- config['allowed_processes'],
|
|
|
- config['target_proxy_host'],
|
|
|
- config['target_proxy_port'],
|
|
|
- *args, **kwargs
|
|
|
- )
|
|
|
-
|
|
|
- try:
|
|
|
- self.proxy_server = HTTPServer(('127.0.0.1', config['secure_proxy_port']), create_handler)
|
|
|
- print(f"🔒 Secure proxy started on port {config['secure_proxy_port']}")
|
|
|
-
|
|
|
- def run_server():
|
|
|
- self.proxy_server.serve_forever()
|
|
|
-
|
|
|
- self.proxy_thread = threading.Thread(target=run_server, daemon=True)
|
|
|
- self.proxy_thread.start()
|
|
|
- return True
|
|
|
-
|
|
|
- except Exception as e:
|
|
|
- print(f"❌ Failed to start secure proxy: {e}")
|
|
|
- return False
|
|
|
-
|
|
|
- def start_tunnel(self, config):
|
|
|
- """启动 SSH 隧道(使用安全代理端口)"""
|
|
|
- if self.tunnel_process and self.tunnel_process.poll() is None:
|
|
|
- return True
|
|
|
-
|
|
|
- ssh_cmd = [
|
|
|
- 'ssh',
|
|
|
- # '-i', config['ssh_key'],
|
|
|
- '-o', 'ExitOnForwardFailure=yes',
|
|
|
- '-o', 'ServerAliveInterval=30',
|
|
|
- '-o', 'ServerAliveCountMax=3',
|
|
|
- '-L', f"{config['local_forward_port']}:localhost:{config['remote_service_port']}",
|
|
|
- '-R', f"7281:localhost:{config['secure_proxy_port']}", # 转发到安全代理
|
|
|
- '-N',
|
|
|
- f"{config['remote_user']}@{config['remote_host']}"
|
|
|
- ]
|
|
|
-
|
|
|
- try:
|
|
|
- print(f"🚀 Starting secure SSH tunnel...")
|
|
|
- self.tunnel_process = subprocess.Popen(
|
|
|
- ssh_cmd,
|
|
|
- stdout=subprocess.PIPE,
|
|
|
- stderr=subprocess.PIPE,
|
|
|
- preexec_fn=os.setsid
|
|
|
- )
|
|
|
-
|
|
|
- time.sleep(3)
|
|
|
- if self.tunnel_process.poll() is None:
|
|
|
- print("✅ Secure SSH tunnel started")
|
|
|
- return True
|
|
|
- else:
|
|
|
- stdout, stderr = self.tunnel_process.communicate()
|
|
|
- print(f"❌ SSH tunnel failed: {stderr.decode()}")
|
|
|
- return False
|
|
|
-
|
|
|
- except Exception as e:
|
|
|
- print(f"❌ Error starting tunnel: {e}")
|
|
|
- return False
|
|
|
-
|
|
|
- def stop_all(self):
|
|
|
- """停止所有服务"""
|
|
|
- # 停止代理服务器
|
|
|
- if self.proxy_server:
|
|
|
- self.proxy_server.shutdown()
|
|
|
- self.proxy_server.server_close()
|
|
|
- print("🔒 Secure proxy stopped")
|
|
|
-
|
|
|
- # 停止 SSH 隧道
|
|
|
- if self.tunnel_process:
|
|
|
- try:
|
|
|
- os.killpg(os.getpgid(self.tunnel_process.pid), signal.SIGTERM)
|
|
|
- self.tunnel_process.wait(timeout=5)
|
|
|
- print("🚀 SSH tunnel stopped")
|
|
|
- except:
|
|
|
- try:
|
|
|
- os.killpg(os.getpgid(self.tunnel_process.pid), signal.SIGKILL)
|
|
|
- except:
|
|
|
- pass
|
|
|
-
|
|
|
- def run(self):
|
|
|
- """运行安全隧道"""
|
|
|
- config = self.load_config()
|
|
|
-
|
|
|
- print("🛡️ Starting Secure DotsOCR Tunnel")
|
|
|
- print("=" * 50)
|
|
|
- print(f"🔒 Secure proxy port: {config['secure_proxy_port']}")
|
|
|
- print(f"🎯 Target proxy: {config['target_proxy_host']}:{config['target_proxy_port']}")
|
|
|
- print(f"✅ Allowed processes: {', '.join(config['allowed_processes'])}")
|
|
|
- print("=" * 50)
|
|
|
-
|
|
|
- # 设置信号处理
|
|
|
- def signal_handler(signum, frame):
|
|
|
- print("\n🛑 Shutting down secure tunnel...")
|
|
|
- self.stop_all()
|
|
|
- sys.exit(0)
|
|
|
-
|
|
|
- signal.signal(signal.SIGTERM, signal_handler)
|
|
|
- signal.signal(signal.SIGINT, signal_handler)
|
|
|
-
|
|
|
- try:
|
|
|
- # 启动安全代理
|
|
|
- if not self.start_secure_proxy(config):
|
|
|
- return 1
|
|
|
-
|
|
|
- # 启动 SSH 隧道
|
|
|
- if not self.start_tunnel(config):
|
|
|
- return 1
|
|
|
-
|
|
|
- print("\n🎉 Secure tunnel is running!")
|
|
|
- print("📊 Access logs will show below:")
|
|
|
- print("-" * 50)
|
|
|
-
|
|
|
- # 保持运行
|
|
|
- while True:
|
|
|
- time.sleep(1)
|
|
|
-
|
|
|
- except KeyboardInterrupt:
|
|
|
- print("\n🛑 Interrupted by user")
|
|
|
- finally:
|
|
|
- self.stop_all()
|
|
|
-
|
|
|
- return 0
|
|
|
-
|
|
|
-def main():
|
|
|
- tunnel = SecureTunnel()
|
|
|
- return tunnel.run()
|
|
|
-
|
|
|
-if __name__ == "__main__":
|
|
|
- sys.exit(main())
|