114 lines
3.5 KiB
Python
114 lines
3.5 KiB
Python
import numpy as np
|
||
import torch
|
||
# server.py
|
||
from fastapi import FastAPI, WebSocket
|
||
from fastapi import Request
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from fastapi.responses import HTMLResponse
|
||
from funasr import AutoModel
|
||
from funasr.utils.postprocess_utils import rich_transcription_postprocess
|
||
|
||
# 初始化FastAPI应用
|
||
app = FastAPI(title="ASR API", description="语音识别API", version="1.0.0")
|
||
|
||
# 添加跨域支持
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"],
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
model_dir = "iic/SenseVoiceSmall"
|
||
|
||
# 检查是否有可用的CUDA设备
|
||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||
|
||
model = AutoModel(
|
||
model=model_dir,
|
||
trust_remote_code=True,
|
||
vad_model="fsmn-vad",
|
||
vad_kwargs={"max_single_segment_time": 30000},
|
||
device=device, # 根据实际环境选择设备
|
||
language="auto", # 自动检测语言
|
||
use_itn=True, # 启用逆文本归一化
|
||
)
|
||
|
||
|
||
def render_template(template_name: str, context: dict = None):
|
||
"""简化模板渲染函数,避免依赖Jinja2"""
|
||
context = context or {}
|
||
if template_name == "3.html":
|
||
# 直接返回静态HTML内容
|
||
with open("3.html", "r", encoding="utf-8") as f:
|
||
content = f.read()
|
||
return HTMLResponse(content=content)
|
||
return HTMLResponse(content="Template not found")
|
||
|
||
|
||
def bytes_to_float_array(audio_bytes: bytes) -> np.ndarray:
|
||
"""Convert 16-bit PCM bytes to float32 array normalized to [-1, 1]"""
|
||
# 从 bytes 创建 int16 数组
|
||
audio_np = np.frombuffer(audio_bytes, dtype=np.int16)
|
||
# 转为 float32 并归一化
|
||
return audio_np.astype(np.float32) / 32768.0
|
||
|
||
|
||
def convert_audio_data(audio_data: bytes) -> np.ndarray:
|
||
# 首先尝试作为 PCM 数据处理
|
||
try:
|
||
if len(audio_data) % 2 == 0 and len(audio_data) > 0: # 16-bit数据长度应该是2的倍数
|
||
audio_array = bytes_to_float_array(audio_data)
|
||
# 确保数据是float32类型
|
||
return audio_array.astype(np.float32)
|
||
except Exception as e:
|
||
print(f"⚠️ 16-bit PCM处理失败: {e}")
|
||
|
||
return np.array([], dtype=np.float32)
|
||
|
||
|
||
# chatbot = QwenChatbot()
|
||
@app.websocket("/ws/asr")
|
||
async def websocket_asr(websocket: WebSocket):
|
||
# 绕过origin检查,解决403 Forbidden问题
|
||
await websocket.accept()
|
||
|
||
while True:
|
||
# 🔁 接收前端发送的音频数据
|
||
audio_data = await websocket.receive_bytes()
|
||
|
||
# 转换音频数据为模型可处理的格式
|
||
audio_array = convert_audio_data(audio_data)
|
||
|
||
# 确保数据是正确的格式
|
||
if audio_array.dtype != np.float32:
|
||
audio_array = audio_array.astype(np.float32)
|
||
|
||
# ✅ 使用 SenseVoice 推理
|
||
result = model.generate(
|
||
audio_array,
|
||
sampling_rate=16000,
|
||
language="auto", # 自动检测语言
|
||
use_itn=True, # 启用逆文本归一化
|
||
batch_size_s=60
|
||
)
|
||
text = rich_transcription_postprocess(result[0]["text"])
|
||
# print(chatbot.generate_response(text))
|
||
print(f"🔍 Raw result type: {result}, content: {text}")
|
||
await websocket.send_text(text)
|
||
|
||
|
||
@app.get("/", response_class=HTMLResponse)
|
||
async def index(request: Request):
|
||
"""
|
||
主页,提供Web界面
|
||
"""
|
||
return render_template("3.html", {"request": request})
|
||
|
||
|
||
if __name__ == '__main__':
|
||
import uvicorn
|
||
|
||
uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")
|