本文将介绍如何使用 FastAPI 实现一个高效的 WebSocket 服务端,用于处理任务并推送结果,同时支持 MySQL 数据库持久化、Redis 广播与心跳检测等功能。我们将详细讲解数据库操作、WebSocket 服务端的实现、Redis 广播与订阅、以及心跳检测和断开连接后的重连机制。
1. 安装必要的包
首先,我们需要安装以下 Python 包:
pip install fastapi uvicorn sqlalchemy aioredis mysql-connector-python pydantic
2. 数据库操作
我们使用 SQLAlchemy 来与 MySQL 数据库进行交互。任务的执行结果将被持久化到数据库,以便后续查询与管理。
文件名:models.py
from sqlalchemy import Column, Integer, String
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session
Base = declarative_base()
class TaskResult(Base):
__tablename__ = "task_results"
id = Column(Integer, primary_key=True, index=True)
task_id = Column(String, index=True)
result = Column(String)
status = Column(String)
def save_task_result(db: Session, task_id: str, result: str) -> TaskResult:
db_task_result = TaskResult(task_id=task_id, result=result, status="completed")
db.add(db_task_result)
db.commit()
db.refresh(db_task_result)
return db_task_result
3. FastAPI 实例中的代码
FastAPI WebSocket 服务端将执行任务处理,并通过 Redis 广播任务结果。任务结果会先被持久化到 MySQL,然后发布到 Redis,最后由 WebSocket 客户端接收并显示结果。
文件名:main.py
import json
import asyncio
import aioredis
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, Depends
from sqlalchemy.orm import Session
from datetime import datetime, timedelta
from typing import Dict
from pydantic import BaseModel
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from models import TaskResult, save_task_result
# FastAPI 和 Redis 配置
app = FastAPI()
REDIS_HOST = "localhost"
REDIS_PORT = 6379
REDIS_TASK_RESULT_CHANNEL = "task_result_channel" # 用于广播任务结果的 Redis 频道
HEARTBEAT_TIMEOUT = 60 # 超过此时间没有心跳,则认为客户端断开
# 连接 Redis
async def get_redis_connection():
return await aioredis.create_redis_pool((REDIS_HOST, REDIS_PORT))
redis = None
@app.on_event("startup")
async def startup():
global redis
redis = await get_redis_connection()
asyncio.create_task(subscribe_to_task_results())
asyncio.create_task(check_heartbeats())
@app.on_event("shutdown")
async def shutdown():
redis.close()
await redis.wait_closed()
# MySQL 配置
SQLALCHEMY_DATABASE_URL = "mysql+mysqlconnector://root:password@localhost/task_db"
engine = create_engine(SQLALCHEMY_DATABASE_URL)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# 存储 WebSocket 连接信息
ws_connections: Dict[str, Dict] = {}
# 获取数据库 session
def get_db():
db = SessionLocal() # 创建数据库会话
try:
yield db
finally:
db.close()
# 提交任务结果并广播到 Redis
async def send_task_result(task_id: str, result: str, db: Session):
# 保存任务结果到 MySQL 数据库
try:
task_result = save_task_result(db, task_id, result) # 存储任务结果到 MySQL
print(f"Task result saved to MySQL with ID {task_result.id}")
except Exception as e:
db.rollback()
raise HTTPException(status_code=500, detail="Failed to save task result to MySQL")
# 广播任务结果到 Redis
message = json.dumps({"task_id": task_id, "result": result})
await redis.publish(REDIS_TASK_RESULT_CHANNEL, message)
print(f"Task result published to Redis: {message}")
# Redis 订阅任务结果频道并推送到 WebSocket 客户端
async def subscribe_to_task_results():
pubsub = redis.pubsub()
await pubsub.subscribe(REDIS_TASK_RESULT_CHANNEL)
print(f"Subscribed to Redis channel: {REDIS_TASK_RESULT_CHANNEL}")
while True:
message = await pubsub.get_message(ignore_subscribe_messages=True)
if message:
task_data = json.loads(message['data'].decode('utf-8'))
task_id = task_data['task_id']
result = task_data['result']
# 根据 task_id 查找 WebSocket 客户端并推送结果
client_ws = None
for ws_task_id, conn_info in ws_connections.items():
if ws_task_id == task_id:
client_ws = conn_info["ws"]
break
if client_ws:
result_message = json.dumps({"task_id": task_id, "result": result})
await client_ws.send_text(result_message)
print(f"Sent result to client for task {task_id}")
else:
print(f"No WebSocket client found for task {task_id}")
await asyncio.sleep(1)
# WebSocket 端点
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
client_type = await websocket.receive_text() # 获取客户端类型(worker 或 submitter)
connection_id = str(datetime.timestamp(datetime.now())) # 使用时间戳作为连接 ID
ws_connections[connection_id] = {"ws": websocket, "last_heartbeat": datetime.now(), "status": "idle"}
try:
if client_type == "worker":
print("Worker connected.")
while True:
message = await websocket.receive_text()
if message == "ping":
# 接收到心跳,更新最后心跳时间并报告状态
ws_connections[connection_id]["last_heartbeat"] = datetime.now()
await websocket.send_text("pong")
continue
elif message in ["idle", "busy"]:
ws_connections[connection_id]["status"] = message
continue
# 处理任务
# task_data = json.loads(message)
# ws_connections[connection_id]["status"] = "busy"
# # 模拟任务处理
# await asyncio.sleep(5) # 模拟任务处理时间
# task_result = {
# "task_id": task_data["task_id"],
# "result": "example_result"
# }
# db = next(get_db())
# await send_task_result(task_data["task_id"], task_result["result"], db)
# ws_connections[connection_id]["status"] = "idle"
elif client_type == "task_submitter":
print("Task submitter connected.")
while True:
task_data = await websocket.receive_text() # 接受任务数据
# 存储客户端信息及任务
ws_connections[connection_id]["last_heartbeat"] = datetime.now()
await websocket.send_text("Task submitted successfully.")
# 分发任务到空闲 worker
for worker_id, conn_info in ws_connections.items():
if conn_info["ws"] != websocket and conn_info["status"] == "idle":
await conn_info["ws"].send_text(task_data)
conn_info["status"] = "busy"
break
except WebSocketDisconnect:
print("Client disconnected.")
finally:
del ws_connections[connection_id]
# 检查心跳并断开超时连接
async def check_heartbeats():
while True:
now = datetime.now()
for connection_id, conn_info in list(ws_connections.items()):
if now - conn_info["last_heartbeat"] > timedelta(seconds=HEARTBEAT_TIMEOUT):
await conn_info["ws"].close()
del ws_connections[connection_id]
print(f"Disconnected client for connection {connection_id} due to heartbeat timeout")
await asyncio.sleep(HEARTBEAT_TIMEOUT)
# 提交任务结果的 REST 接口
class TaskResult(BaseModel):
task_id: str
result: str
@app.post("/submit-task-result/")
async def submit_task_result(task_result: TaskResult, db: Session = Depends(get_db)):
await send_task_result(task_result.task_id, task_result.result, db)
return {"message": "Result sent to client successfully."}
4. 油猴脚本实现的 worker 客户端
文件名:worker.user.js
// ==UserScript==
// @name WebSocket Worker
// @namespace http://tampermonkey.net/
// @version 0.1
// @description WebSocket worker client
// @author You
// @match *://*/*
// @grant none
// ==/UserScript==
(function() {
'use strict';
let ws;
let status = "idle";
function connectWebSocket() {
ws = new WebSocket("ws://localhost:8000/ws");
ws.onopen = function() {
ws.send("worker");
setInterval(() => {
ws.send("ping");
ws.send(status); // 定期发送状态
}, 30000); // 每 30 秒发送一次 ping 和状态
};
ws.onmessage = function(event) {
console.log("Message from server: ", event.data);
// 接收到任务数据后解析并处理任务
const taskData = JSON.parse(event.data);
if (taskData.task_id) {
status = "busy";
// 模拟任务处理
setTimeout(() => {
const taskResult = {
task_id: taskData.task_id,
result: "example_result"
};
fetch("http://localhost:8000/submit-task-result/", {
method: "POST",
headers: {
"Content-Type": "application/json"
},
body: JSON.stringify(taskResult)
}).then(response => response.json())
.then(data => {
console.log(data);
status = "idle";
});
}, 5000); // 模拟任务处理时间
}
};
ws.onclose = function() {
console.log("WebSocket connection closed, reconnecting...");
setTimeout(connectWebSocket, 5000); // 5 秒后重连
};
ws.onerror = function(error) {
console.error("WebSocket error: ", error);
ws.close(); // 关闭连接以触发重连
};
}
connectWebSocket();
})();
5. 网页客户端
文件名:index.html
<!DOCTYPE html>
<html>
<head>
<title>WebSocket Task Submitter</title>
</head>
<body>
<h1>Task Submitter</h1>
<input type="text" id="taskInput" placeholder="Enter task data">
<button onclick="submitTask()">Submit Task</button>
<div id="result"></div>
<script>
let ws;
function connectWebSocket() {
ws = new WebSocket("ws://localhost:8000/ws");
ws.onopen = function() {
ws.send("task_submitter");
setInterval(() => {
ws.send("ping");
}, 30000); // 每 30 秒发送一次 ping
};
ws.onmessage = function(event) {
console.log("Message from server: ", event.data);
document.getElementById("result").innerText = event.data;
};
ws.onclose = function() {
console.log("WebSocket connection closed, reconnecting...");
setTimeout(connectWebSocket, 5000); // 5 秒后重连
};
ws.onerror = function(error) {
console.error("WebSocket error: ", error);
ws.close(); // 关闭连接以触发重连
};
}
connectWebSocket();
function submitTask() {
const taskData = document.getElementById("taskInput").value;
ws.send(taskData);
}
</script>
</body>
</html>
总结
通过上述代码,我们实现了一个高效的 WebSocket 服务端,能够处理任务并推送结果,同时支持 MySQL 数据库持久化、Redis 广播与心跳检测等功能。我们确保了 WebSocket 客户端在断开连接后能够自动重连,并且在发生错误时也会关闭连接以触发重连。