first commit
This commit is contained in:
113
asr_api.py
Normal file
113
asr_api.py
Normal file
@@ -0,0 +1,113 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
# server.py
|
||||
from fastapi import FastAPI, WebSocket
|
||||
from fastapi import Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import HTMLResponse
|
||||
from funasr import AutoModel
|
||||
from funasr.utils.postprocess_utils import rich_transcription_postprocess
|
||||
|
||||
# 初始化FastAPI应用
|
||||
app = FastAPI(title="ASR API", description="语音识别API", version="1.0.0")
|
||||
|
||||
# 添加跨域支持
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
model_dir = "iic/SenseVoiceSmall"
|
||||
|
||||
# 检查是否有可用的CUDA设备
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
model = AutoModel(
|
||||
model=model_dir,
|
||||
trust_remote_code=True,
|
||||
vad_model="fsmn-vad",
|
||||
vad_kwargs={"max_single_segment_time": 30000},
|
||||
device=device, # 根据实际环境选择设备
|
||||
language="auto", # 自动检测语言
|
||||
use_itn=True, # 启用逆文本归一化
|
||||
)
|
||||
|
||||
|
||||
def render_template(template_name: str, context: dict = None):
|
||||
"""简化模板渲染函数,避免依赖Jinja2"""
|
||||
context = context or {}
|
||||
if template_name == "3.html":
|
||||
# 直接返回静态HTML内容
|
||||
with open("3.html", "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
return HTMLResponse(content=content)
|
||||
return HTMLResponse(content="Template not found")
|
||||
|
||||
|
||||
def bytes_to_float_array(audio_bytes: bytes) -> np.ndarray:
|
||||
"""Convert 16-bit PCM bytes to float32 array normalized to [-1, 1]"""
|
||||
# 从 bytes 创建 int16 数组
|
||||
audio_np = np.frombuffer(audio_bytes, dtype=np.int16)
|
||||
# 转为 float32 并归一化
|
||||
return audio_np.astype(np.float32) / 32768.0
|
||||
|
||||
|
||||
def convert_audio_data(audio_data: bytes) -> np.ndarray:
|
||||
# 首先尝试作为 PCM 数据处理
|
||||
try:
|
||||
if len(audio_data) % 2 == 0 and len(audio_data) > 0: # 16-bit数据长度应该是2的倍数
|
||||
audio_array = bytes_to_float_array(audio_data)
|
||||
# 确保数据是float32类型
|
||||
return audio_array.astype(np.float32)
|
||||
except Exception as e:
|
||||
print(f"⚠️ 16-bit PCM处理失败: {e}")
|
||||
|
||||
return np.array([], dtype=np.float32)
|
||||
|
||||
|
||||
# chatbot = QwenChatbot()
|
||||
@app.websocket("/ws/asr")
|
||||
async def websocket_asr(websocket: WebSocket):
|
||||
# 绕过origin检查,解决403 Forbidden问题
|
||||
await websocket.accept()
|
||||
|
||||
while True:
|
||||
# 🔁 接收前端发送的音频数据
|
||||
audio_data = await websocket.receive_bytes()
|
||||
|
||||
# 转换音频数据为模型可处理的格式
|
||||
audio_array = convert_audio_data(audio_data)
|
||||
|
||||
# 确保数据是正确的格式
|
||||
if audio_array.dtype != np.float32:
|
||||
audio_array = audio_array.astype(np.float32)
|
||||
|
||||
# ✅ 使用 SenseVoice 推理
|
||||
result = model.generate(
|
||||
audio_array,
|
||||
sampling_rate=16000,
|
||||
language="auto", # 自动检测语言
|
||||
use_itn=True, # 启用逆文本归一化
|
||||
batch_size_s=60
|
||||
)
|
||||
text = rich_transcription_postprocess(result[0]["text"])
|
||||
# print(chatbot.generate_response(text))
|
||||
print(f"🔍 Raw result type: {result}, content: {text}")
|
||||
await websocket.send_text(text)
|
||||
|
||||
|
||||
@app.get("/", response_class=HTMLResponse)
|
||||
async def index(request: Request):
|
||||
"""
|
||||
主页,提供Web界面
|
||||
"""
|
||||
return render_template("3.html", {"request": request})
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")
|
||||
Reference in New Issue
Block a user