本文最后更新于40 天前,其中的信息可能已经过时,如有错误请发送邮件到2371964121@qq.com
环境准备
首先准备一台显存>=12G的服务器,我这里选用的是4090 24G云服务器
接着先拉取官方代码,然后创建一个虚拟环境,再安装其对应的依赖库
git clone https://github.com/stepfun-ai/Step-Audio.git
conda create -n stepaudio python=3.10
conda activate stepaudio
cd Step-Audio
pip install -r requirements.txt
pip install fastapi
pip install loguru
pip install aiohttp
接着下载部署所需的模型,由于我们仅为部署Step-Audio-TTS-3B模型,所以并不需要全部下载所有模型,仅需下载Step-Audio-Tokenizer和Step-Audio-TTS-3B模型即可。
模型下载上,我们可以从ModelScope上仅需下载,首先需要安装一下其下载器,然后前往对应模型页,即可复制并下载。
pip install modelscope
modelscope download --model stepfun-ai/Step-Audio-TTS-3B --local_dir ./model/Step-Audio-TTS-3B
modelscope download --model stepfun-ai/Step-Audio-Tokenizer --local_dir ./model/Step-Audio-Tokenizer
脚本部署
这里已经准备好了部署脚本,仅需要改变自己的密钥和ip、端口,即可直接部署,快速使用。
import os
import uuid
import time
import asyncio
from enum import Enum
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException, status, Depends, Request
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel, Field, validator, ValidationError, field_validator
from typing import Optional, Dict, List
from loguru import logger
import torchaudio
import aiohttp
import base64
from tts import StepAudioTTS
from tokenizer import StepAudioTokenizer
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError
import os
os.makedirs("static", exist_ok=True)
# 初始化认证密钥
security = HTTPBearer(auto_error=False)
valid_api_keys = {"你的密钥"} # 从环境变量读取更安全
# 任务状态枚举
class TaskStatus(str, Enum):
IN_QUEUE = "InQueue"
IN_PROGRESS = "InProgress"
SUCCEEDED = "Succeed"
FAILED = "Failed"
CANCELLED = "Cancelled"
# 任务存储结构
class TaskData(BaseModel):
status: TaskStatus = TaskStatus.IN_QUEUE
audio_type: str # "common", "music", "clone"
params: dict
created_at: int = Field(default_factory=lambda: int(time.time()))
started_at: Optional[int] = None
completed_at: Optional[int] = None
download_url: Optional[str] = None
reason: Optional[str] = None
# 生命周期管理
@asynccontextmanager
async def lifespan(app: FastAPI):
try:
# 初始化模型
model_path = os.getenv("MODEL_PATH", "/data/coding/model")
app.state.encoder = StepAudioTokenizer(os.path.join(model_path, "Step-Audio-Tokenizer"))
app.state.tts_engine = StepAudioTTS(os.path.join(model_path, "Step-Audio-TTS-3B"), app.state.encoder)
# 初始化任务系统
app.state.tasks: Dict[str, TaskData] = {}
app.state.pending_queue: List[str] = []
app.state.task_lock = asyncio.Lock()
app.state.semaphore = asyncio.Semaphore(5) # 并发数限制(增加到5)
app.state.base_url = "ip+端口"
# 启动后台处理器
asyncio.create_task(task_processor())
print("✅ 应用初始化完成")
yield
finally:
# 清理资源
app.state.encoder = None
app.state.tts_engine = None
torch.cuda.empty_cache()
app = FastAPI(lifespan=lifespan)
app.mount("/static", StaticFiles(directory="static"), name="static")
# 请求模型
class CommonTTSRequest(BaseModel):
text:str = Field(..., min_length=1, description="需要合成的文本内容")
speaker: str = Field(default="Tingting", description="仅支持 Tingting")
emotion: Optional[str] = Field(None, description="可选值: 高兴1, 高兴2, 生气1, 生气2, 悲伤1, 撒娇1")
language: Optional[str] = Field(None, description="可选值: 中文, 英文, 韩语, 日语, 四川话, 粤语, 广东话")
speed: Optional[str] = Field(None, description="可选值: 慢速1, 慢速2, 快速1, 快速2")
@field_validator("speaker")
def validate_speaker(cls, v):
if v not in ["Tingting"]:
raise ValueError(f"不支持的 speaker: {v}")
return v
@field_validator("emotion")
def validate_emotion(cls, v):
emotion_options = ["高兴1", "高兴2", "生气1", "生气2", "悲伤1", "撒娇1"]
if v and v not in emotion_options:
raise ValueError(f"不支持的 emotion: {v}")
return v
@field_validator("language")
def validate_language(cls, v):
language_options = ["中文", "英文", "韩语", "日语", "四川话", "粤语", "广东话"]
if v and v not in language_options:
raise ValueError(f"不支持的 language: {v}")
return v
@field_validator("speed")
def validate_speed(cls, v):
speed_options = ["慢速1", "慢速2", "快速1", "快速2"]
if v and v not in speed_options:
raise ValueError(f"不支持的 speed: {v}")
return v
# 可添加更友好的错误提示
# 可添加更友好的错误提示
@field_validator('text')
def check_text_not_empty(cls, v):
if not v.strip():
raise ValueError('文本内容不能为空')
return v
class MusicTTSRequest(BaseModel):
text:str = Field(..., min_length=1, description="需要合成的文本内容")
speaker: str = Field(default="Tingting", description="仅支持 Tingting")
mode: str = Field(default="Humming (哼唱)", description="可选值: RAP, Humming (哼唱)")
@field_validator("speaker")
def validate_speaker(cls, v):
if v not in ["Tingting"]:
raise ValueError(f"不支持的 speaker: {v}")
return v
@field_validator("mode")
def validate_mode(cls, v):
mode_options = ["RAP", "Humming (哼唱)"]
if v not in mode_options:
raise ValueError(f"不支持的 mode: {v}")
return v
# 可添加更友好的错误提示
@field_validator('text')
def check_text_not_empty(cls, v):
if not v.strip():
raise ValueError('文本内容不能为空')
return v
class CloneTTSRequest(BaseModel):
text:str = Field(..., min_length=1, description="需要合成的文本内容")
speaker_prompt: str
emotion: Optional[str] = Field(None, description="可选值: 高兴1, 高兴2, 生气1, 生气2, 悲伤1, 撒娇1")
language: Optional[str] = Field(None, description="可选值: 中文, 英文, 韩语, 日语, 四川话, 粤语, 广东话")
speed: Optional[str] = Field(None, description="可选值: 慢速1, 慢速2, 快速1, 快速2")
audio_source: str
@field_validator("emotion")
def validate_emotion(cls, v):
emotion_options = ["高兴1", "高兴2", "生气1", "生气2", "悲伤1", "撒娇1"]
if v and v not in emotion_options:
raise ValueError(f"不支持的 emotion: {v}")
return v
@field_validator("language")
def validate_language(cls, v):
language_options = ["中文", "英文", "韩语", "日语", "四川话", "粤语", "广东话"]
if v and v not in language_options:
raise ValueError(f"不支持的 language: {v}")
return v
@field_validator("speed")
def validate_speed(cls, v):
speed_options = ["慢速1", "慢速2", "快速1", "快速2"]
if v and v not in speed_options:
raise ValueError(f"不支持的 speed: {v}")
return v
# 可添加更友好的错误提示
@field_validator('text')
def check_text_not_empty(cls, v):
if not v.strip():
raise ValueError('文本内容不能为空')
return v
class TaskStatusRequest(BaseModel):
request_id: str
class TaskCancelRequest(BaseModel):
request_id: str
# 响应模型
class TaskSubmitResponse(BaseModel):
request_id: str
class TaskStatusResponse(BaseModel):
status: TaskStatus
reason: Optional[str] = None
results: Optional[dict] = None
queue_position: Optional[int] = None
# 认证验证
async def verify_auth(credentials: HTTPAuthorizationCredentials = Depends(security)):
if not credentials or credentials.scheme != "Bearer":
raise HTTPException(401, {"status": "Failed", "reason": "无效的认证凭证"})
if credentials.credentials not in valid_api_keys:
raise HTTPException(401, {"status": "Failed", "reason": "无效的API密钥"})
return True
@app.exception_handler(HTTPException)
async def http_exception_handler(request, exc: HTTPException):
return JSONResponse(
content=exc.detail
)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
error_msg = exc.errors()[0]['msg']
return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
content={"status": "Failed", "reason": f"参数校验失败: {error_msg}"}
)
# 提交任务接口
@app.post("/tts/submit/common", response_model=TaskSubmitResponse)
async def submit_common(request: CommonTTSRequest, auth: bool = Depends(verify_auth)):
request_id = str(uuid.uuid4())
async with app.state.task_lock:
app.state.tasks[request_id] = TaskData(
audio_type="common",
params=request.model_dump()
)
app.state.pending_queue.append(request_id)
return {"request_id": request_id}
@app.post("/tts/submit/music", response_model=TaskSubmitResponse)
async def submit_music(request: MusicTTSRequest, auth: bool = Depends(verify_auth)):
request_id = str(uuid.uuid4())
async with app.state.task_lock:
app.state.tasks[request_id] = TaskData(
audio_type="music",
params=request.model_dump()
)
app.state.pending_queue.append(request_id)
return {"request_id": request_id}
@app.post("/tts/submit/clone", response_model=TaskSubmitResponse)
async def submit_clone(request: CloneTTSRequest, auth: bool = Depends(verify_auth)):
request_id = str(uuid.uuid4())
async with app.state.task_lock:
app.state.tasks[request_id] = TaskData(
audio_type="clone",
params=request.model_dump()
)
app.state.pending_queue.append(request_id)
return {"request_id": request_id}
# 查询状态接口
@app.post("/tts/status", response_model=TaskStatusResponse)
async def get_status(request: TaskStatusRequest, auth: bool = Depends(verify_auth)):
task = app.state.tasks.get(request.request_id)
if not task:
raise HTTPException(
status_code=404,
detail={"status": "Failed", "reason": "无效的任务ID"}
)
# 计算队列位置(仅当在队列中时)
queue_pos = 0
if task.status == TaskStatus.IN_QUEUE and request.request_id in app.state.pending_queue:
try:
queue_pos = app.state.pending_queue.index(request.request_id) + 1
except ValueError:
queue_pos = 0
response = {
"status": task.status.value, # 确保返回的是字符串形式
"reason": task.reason,
"queue_position": queue_pos if task.status == TaskStatus.IN_QUEUE else None # 非排队状态返回null
}
# 成功状态的特殊处理
if task.status == TaskStatus.SUCCEEDED:
response["results"] = {
"audio": [{"url": task.download_url}], # 调整为数组形式
"timings": {
"inference": task.completed_at - task.started_at
},
# "seed": None
}
elif task.status == TaskStatus.CANCELLED:
response["reason"] = task.reason or "用户主动取消" # 确保原因字段存在
return response
# 取消任务接口
@app.post("/tts/cancel", response_model=dict)
async def cancel_task(request: TaskCancelRequest, auth: bool = Depends(verify_auth)):
request_id = request.request_id
async with app.state.task_lock:
task = app.state.tasks.get(request_id)
if not task:
raise HTTPException(404, {"status": "Failed", "reason": "无效的任务ID"})
current_status = task.status
if current_status != TaskStatus.IN_QUEUE:
raise HTTPException(400, {"status": "Failed", "reason": f"仅允许取消排队任务,当前状态: {current_status}"})
# 从队列中移除任务
try:
app.state.pending_queue.remove(request_id)
except ValueError:
pass # 可能已被处理
# 更新任务状态
task.status = TaskStatus.CANCELLED
task.reason = "用户主动取消"
return {"status": "Succeed", "reason": "取消排队任务成功"}
# 后台任务处理器
async def task_processor():
executor = ThreadPoolExecutor(max_workers=5) # 使用线程池避免阻塞事件循环
while True:
async with app.state.semaphore:
request_id = await get_next_task()
if request_id:
loop = asyncio.get_event_loop()
await loop.run_in_executor(executor, lambda: process_task_sync(request_id))
else:
await asyncio.sleep(0.5)
async def get_next_task():
async with app.state.task_lock:
if app.state.pending_queue:
return app.state.pending_queue.pop(0)
return None
def process_task_sync(request_id: str):
"""同步处理单个任务"""
task = app.state.tasks.get(request_id)
if not task:
return
try:
task.status = TaskStatus.IN_PROGRESS
task.started_at = int(time.time())
if task.audio_type == "common":
audio_data = process_common_sync(task.params)
elif task.audio_type == "music":
audio_data = process_music_sync(task.params)
elif task.audio_type == "clone":
audio_data = process_clone_sync(task.params)
audio_name = save_audio(task.audio_type, audio_data[0], audio_data[1])
task.download_url = f"{app.state.base_url}/static/audio/{task.audio_type}/{audio_name}"
task.status = TaskStatus.SUCCEEDED
task.completed_at = int(time.time())
except Exception as e:
task.status = TaskStatus.FAILED
task.reason = str(e)
task.completed_at = int(time.time())
logger.error(f"Task {request_id} failed: {str(e)}")
# 音频处理函数(同步版本)
def process_common_sync(params: dict):
control_tags = []
for attr in ["emotion", "language", "speed"]:
if params.get(attr):
control_tags.append(f"({params[attr]})")
formatted_text = "".join(control_tags) + params["text"]
return app.state.tts_engine(formatted_text, params["speaker"])
def process_music_sync(params: dict):
formatted_text = f"({params['mode']}){params['text']}"
return app.state.tts_engine(formatted_text, params["speaker"])
def process_clone_sync(params: dict):
temp_audio_path = None
try:
# 处理音频输入
if params["audio_source"].startswith(('http', 'https')):
with aiohttp.ClientSession() as session:
resp = session.get(params["audio_source"])
audio_contents = resp.read()
else:
audio_contents = base64.b64decode(params["audio_source"].split(',')[1])
# 保存临时文件
os.makedirs("tmp", exist_ok=True)
temp_audio_path = f"tmp/clone_{uuid.uuid4()}.wav"
with open(temp_audio_path, "wb") as f:
f.write(audio_contents)
# 构建克隆参数
clone_speaker = {
"wav_path": temp_audio_path,
"speaker": "custom_voice",
"prompt_text": params["speaker_prompt"],
}
# 处理控制标记
control_tags = []
for param in ["emotion", "language", "speed"]:
if params.get(param):
control_tags.append(f"({params[param]})")
formatted_text = "".join(control_tags) + params["text"]
# 生成音频
return app.state.tts_engine(formatted_text, "", clone_speaker)
finally:
if temp_audio_path and os.path.exists(temp_audio_path):
os.remove(temp_audio_path)
# 保存音频函数
def save_audio(audio_type: str, audio_data, sr: int) -> str:
current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
save_dir = os.path.join(os.getenv("TMP_DIR", "/data/coding/Step-Audio/static/audio"), audio_type)
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, f"{current_time}.wav")
torchaudio.save(save_path, audio_data, sr)
return f"{current_time}.wav"
# 自动清理任务
async def auto_cleanup(file_path: str, delay: int = 3600):
await asyncio.sleep(delay)
try:
if os.path.exists(file_path):
os.remove(file_path)
logger.info(f"已清理文件: {file_path}")
except Exception as e:
logger.error(f"文件清理失败: {str(e)}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=10341)
脚本说明
全局配置
{
"Authorization": "Bearer 你的密钥",
"Content-Type": "application/json"
}
接口列表
1、提交普通TTS任务
- URL:
/tts/submit/common
- Method: POST
- Body:{
“text”: “你好世界”,
“speaker”: “Tingting”,
“emotion”: “高兴1”,
“language”: “中文”,
“speed”: “慢速1”
}
成功则为如下响应
{
"request_id": "550e8400-e29b-41d4-a716-446655440000"
}
2、提交音乐TTS任务
- URL:
/tts/submit/music
- Method: POST
- Body:{
“text”: “音乐合成测试”,
“speaker”: “Tingting”,
“mode”: “Humming (哼唱)”
}
成功则为如下响应
{
"request_id": "550e8400-e29b-41d4-a716-446655440001"
}
3、提交克隆TTS任务
- URL:
/tts/submit/clone
- Method: POST
- Body:{
“text”: “语音克隆测试”,
“speaker_prompt”: “示例说话人”,
“audio_source”: “base64编码音频数据或URL”,
“language”: “粤语”
}
成功则为如下响应
{
"request_id": "550e8400-e29b-41d4-a716-446655440002"
}
4、查询任务状态
- URL:
/tts/status
- Method: POST
- Body:{
“request_id”: “550e8400-e29b-41d4-a716-446655440000”
}
成功则为如下响应
{
"status": "Succeed",
"reason": null,
"results": {
"audio": [
{
"url": "https://585os96gouen3i3talk4090.funhpc.com:30499/static/audio/clone/2025-04-18_10-43-55.wav"
}
],
"timings": {
"inference": 7
}
},
"queue_position": null
}
5、取消任务
- URL:
/tts/cancel
- Method: POST
- Body:{
“request_id”: “550e8400-e29b-41d4-a716-446655440000”
}
成功则为如下响应
{
"status": "Succeed",
"reason": "取消排队任务成功"
}
参数说明
通用参数
参数名 | 允许值 | 说明 |
---|---|---|
speaker | Tingting | 固定值 |
emotion | 高兴1/2, 生气1/2, 悲伤1, 撒娇1 | 情感参数 |
language | 中文/英文/韩语/日语/四川话/粤语/广东话 | 语言选择 |
speed | 慢速1/2, 快速1/2 | 语速控制 |
特殊参数
接口类型 | 特殊参数 | 允许值 |
---|---|---|
音乐合成 | mode | RAP, Humming (哼唱) |
语音克隆 | audio_source | 音频URL或Base64编码数据 |
错误码说明
HTTP状态码 | 原因 |
---|---|
401 | 无效的API密钥 |
404 | 无效的任务ID |
422 | 参数校验失败 |