feat: 优化 语音速度

This commit is contained in:
liqupan
2025-12-06 22:41:44 +08:00
parent c20aca3da0
commit c82d24ddae
20 changed files with 3942 additions and 33 deletions

View File

@@ -0,0 +1,311 @@
# 🔍 Grok API 调试指南
## 数据格式说明
### WebClient 处理后的数据格式
Spring WebFlux 的 `WebClient.bodyToFlux(String.class)` 会自动处理 SSE 流,每个元素直接就是一个完整的 JSON 字符串:
```json
{
"id": "8670de6a-75b3-97e2-fa5a-c2e3d7f1d0f2",
"object": "chat.completion.chunk",
"created": 1764858564,
"model": "grok-4-1-fast-non-reasoning",
"choices": [{
"index": 0,
"delta": {
"content": "你",
"role": "assistant"
},
"finish_reason": null
}],
"system_fingerprint": "fp_174298dd8e"
}
```
### 字段说明
| 字段 | 说明 | 示例值 |
|------|------|--------|
| `id` | 请求唯一ID | `"8670de6a-75b3-97e2-fa5a-c2e3d7f1d0f2"` |
| `object` | 对象类型 | `"chat.completion.chunk"` |
| `created` | 创建时间戳 | `1764858564` |
| `model` | 使用的模型 | `"grok-4-1-fast-non-reasoning"` |
| `choices[0].index` | 选择索引 | `0` |
| `choices[0].delta.content` | 当前token内容 | `"你"` |
| `choices[0].delta.role` | 角色(首次出现) | `"assistant"` |
| `choices[0].finish_reason` | 结束原因 | `null`, `"stop"`, `"length"`, `"content_filter"` |
| `system_fingerprint` | 系统指纹 | `"fp_174298dd8e"` |
## finish_reason 说明
| 值 | 含义 | 处理建议 |
|----|------|----------|
| `null` | 流还在继续 | 继续接收token |
| `"stop"` | 正常结束 | ✅ 回复完整 |
| `"length"` | 达到最大token限制 | ⚠️ 回复可能被截断,建议增加 `max_tokens` |
| `"content_filter"` | 内容被过滤 | ⚠️ 回复包含敏感内容 |
## 解析流程
```
收到数据 → 跳过空行 → 检查 [DONE] → 去除 "data: " 前缀(如果有)
解析 JSON → 验证 choices 字段 → 获取 choices[0]
检查 delta 字段 → 提取 content → 回调 onToken(content)
检查 finish_reason → 记录日志 → 完成
```
## 日志级别配置
`application.yml` 中配置:
```yaml
logging:
level:
com.xiaozhi.dialogue.llm.GrokStreamService: DEBUG # 或 TRACE
```
### 各级别输出内容
#### TRACE最详细
```
TRACE - 收到原始SSE数据: {"id":"8670de6a...
TRACE - 提取到token: 你
TRACE - 提取到token: 好
```
#### DEBUG
```
DEBUG - 请求体: {"model":"grok-4-1-fast-non-reasoning",...}
DEBUG - 收到角色信息: assistant
DEBUG - JSON中缺少choices字段如果有问题
```
#### INFO
```
INFO - 开始调用Grok API - Model: grok-4-1-fast-non-reasoning, UserMessage: 你好...
INFO - 流结束原因: length
INFO - Grok API流式调用完成
```
#### ERROR
```
ERROR - JSON解析失败 - 原始数据: xxx
ERROR - LLM调用失败: Connection refused
```
## 常见问题排查
### 问题1: 没有收到任何token
**可能原因**
- API Key 错误
- 网络连接问题
- API URL 配置错误
**检查步骤**
1. 查看日志是否有 `INFO - 开始调用Grok API`
2. 检查是否有错误日志
3. 验证 `application.yml` 中的配置:
```yaml
xiaozhi:
voice-stream:
grok:
api-key: ${GROK_API_KEY}
api-url: https://api.x.ai/v1/chat/completions
```
### 问题2: token乱码或格式错误
**检查点**
1. 启用 TRACE 日志查看原始数据
2. 确认 JSON 格式是否正确
3. 查看是否有 `ERROR - JSON解析失败` 日志
### 问题3: 回复被截断
**日志特征**
```
INFO - 流结束原因: length
```
**解决方案**
增加 `DEFAULT_MAX_TOKENS` 常量:
```java
private static final int DEFAULT_MAX_TOKENS = 4000; // 或更大
```
### 问题4: WebClient 连接超时
**错误日志**
```
ERROR - LLM调用失败: Connection timeout
```
**解决方案**
在 `initWebClient()` 中增加超时配置:
```java
this.webClient = WebClient.builder()
.baseUrl(apiUrl)
.codecs(configurer -> configurer.defaultCodecs().maxInMemorySize(10 * 1024 * 1024))
.defaultHeader(HttpHeaders.ACCEPT, "text/event-stream")
.build();
```
## 测试建议
### 1. 使用 curl 测试 API
```bash
curl -X POST https://api.x.ai/v1/chat/completions \
-H "Authorization: Bearer YOUR_API_KEY" \
-H "Content-Type: application/json" \
-d '{
"model": "grok-4-1-fast-non-reasoning",
"messages": [{"role": "user", "content": "你好"}],
"stream": true
}'
```
预期输出:
```
data: {"id":"xxx","object":"chat.completion.chunk",...}
data: {"id":"xxx","object":"chat.completion.chunk",...}
data: [DONE]
```
### 2. 检查配置
```java
// 在 GrokStreamService 中添加测试方法
@PostConstruct
public void testConfig() {
logger.info("Grok配置:");
logger.info(" API URL: {}", voiceStreamConfig.getGrok().getApiUrl());
logger.info(" Model: {}", voiceStreamConfig.getGrok().getModel());
logger.info(" API Key配置: {}", voiceStreamConfig.getGrok().getApiKey() != null ? "已配置" : "未配置");
}
```
### 3. 单元测试
```java
@Test
public void testParseSSE() {
String testJson = """
{
"id": "test",
"choices": [{
"delta": {
"content": "测试"
}
}]
}
""";
// 测试解析逻辑
}
```
## 性能监控
### 关键指标
1. **首 token 延迟**:从发送请求到收到第一个 token 的时间
2. **token 吞吐率**:每秒收到的 token 数量
3. **总响应时间**:从开始到收到 [DONE] 的时间
### 添加监控日志
```java
private long startTime;
private int tokenCount = 0;
public void streamChat(...) {
startTime = System.currentTimeMillis();
// ...
.doOnNext(token -> {
tokenCount++;
if (tokenCount == 1) {
long firstTokenLatency = System.currentTimeMillis() - startTime;
logger.info("首token延迟: {}ms", firstTokenLatency);
}
})
.doOnComplete(() -> {
long totalTime = System.currentTimeMillis() - startTime;
double tps = tokenCount / (totalTime / 1000.0);
logger.info("总计: {}个token, 耗时: {}ms, 吞吐率: {:.2f} tokens/s",
tokenCount, totalTime, tps);
})
}
```
## 调试技巧
### 1. 保存原始响应
```java
// 在parseSSE中
logger.trace("原始响应: {}", line);
```
### 2. 使用断点调试
在以下位置设置断点:
- `parseSSE()` 方法入口
- `sink.next(content)` 之前
- `callback.onToken()` 调用处
### 3. 模拟测试
创建测试类:
```java
@Test
public void testGrokStreamService() {
GrokStreamService service = new GrokStreamService();
service.streamChat("你好", null, new TokenCallback() {
@Override
public void onToken(String token) {
System.out.println("Token: " + token);
}
@Override
public void onComplete() {
System.out.println("完成");
}
@Override
public void onError(String error) {
System.err.println("错误: " + error);
}
});
// 等待完成
Thread.sleep(10000);
}
```
## 总结
正常工作流程:
1. ✅ 看到 `INFO - 开始调用Grok API`
2. ✅ 看到多个 `TRACE - 提取到token: xxx`
3. ✅ 看到 `INFO - 流结束原因: stop`(或 length
4. ✅ 看到 `INFO - Grok API流式调用完成`
如果出现问题:
1. 检查日志级别是否足够详细
2. 查找 ERROR 级别的日志
3. 验证配置是否正确
4. 使用 curl 直接测试 API
5. 检查网络连接
---
**需要更多帮助?** 提供完整的错误日志和配置信息。

View File

@@ -0,0 +1,251 @@
# TTS连接预热优化测试指南
## 测试目标
验证TTS WebSocket连接预热和复用功能确保
1. 用户连接时成功预热TTS连接
2. 多句话复用同一个TTS连接
3. 连接断开后自动重连
4. 用户断开时正确清理连接
5. 性能提升符合预期
## 测试环境准备
1. 启动后端服务
2. 确保MiniMax TTS API配置正确
3. 使用微信小程序或WebSocket客户端
## 测试用例
### 测试1连接预热验证
**目的**验证用户连接时是否成功预热TTS连接
**步骤**
1. 前端连接到语音流WebSocket`ws://your-server:8091/ws/voice-stream`
2. 观察后端日志
**预期日志**
```
[VoiceStreamHandler] 语音流WebSocket连接建立 - SessionId: xxx, UserId: null
[VoiceStreamService] 预热TTS连接 - SessionId: xxx
[MinimaxTtsStreamService] 开始预热TTS连接 - SessionId: xxx
[MinimaxTtsStreamService] MiniMax TTS连接已建立 - SessionId: xxx
[MinimaxTtsStreamService] 收到connected_success发送task_start - SessionId: xxx
[MinimaxTtsStreamService] 收到task_started连接就绪 - SessionId: xxx
[MinimaxTtsStreamService] TTS连接预热成功 - SessionId: xxx, 耗时: XXXms
[VoiceStreamHandler] TTS连接预热成功 - SessionId: xxx
```
**成功标准**
- ✅ 连接建立后立即开始预热
- ✅ 预热在1-2秒内完成
- ✅ 连接状态变为IDLE
---
### 测试2连接复用验证
**目的**验证多句话是否复用同一个TTS连接
**步骤**
1. 确保已建立连接并预热成功
2. 发送一段音频(用户说:"你好,今天天气怎么样?"
3. 等待AI回复多句话例如"你好!今天天气很好。阳光明媚。"
4. 观察后端日志
**预期日志**
```
[VoiceStreamService] STT识别结果 - SessionId: xxx, Text: 你好,今天天气怎么样?
[VoiceStreamService] 检测到完整句子 - SessionId: xxx, Sentence: 你好!
[MinimaxTtsStreamService] 使用已有TTS连接 - SessionId: xxx, Text: 你好!
[MinimaxTtsStreamService] 缓冲音频块 - SessionId: xxx: 2048 bytes, 总计: 2048 bytes
...
[MinimaxTtsStreamService] TTS完成 - SessionId: xxx, Text: 你好!, 音频总大小: 24576 bytes
[MinimaxTtsStreamService] TTS完成 - SessionId: xxx, 耗时: 300ms复用连接节省约1秒
[VoiceStreamService] 检测到完整句子 - SessionId: xxx, Sentence: 今天天气很好。
[MinimaxTtsStreamService] 使用已有TTS连接 - SessionId: xxx, Text: 今天天气很好。
[MinimaxTtsStreamService] TTS完成 - SessionId: xxx, 耗时: 250ms复用连接节省约1秒
```
**成功标准**
- ✅ 日志显示"使用已有TTS连接"而不是"创建MiniMax TTS连接"
- ✅ 每句TTS完成时间在200-500ms而非1-1.5s
- ✅ 日志显示"复用连接节省约1秒"
- ✅ 没有重复的连接建立日志
---
### 测试3自动重连验证
**目的**:验证连接断开后是否自动重连
**步骤**
1. 正常建立连接
2. 手动关闭MiniMax TTS WebSocket连接模拟网络断开
3. 发送一段音频触发TTS
4. 观察后端日志
**预期日志**
```
[MinimaxTtsStreamService] MiniMax TTS连接关闭 - SessionId: xxx, Code: 1006, Reason: ..., Remote: true
[MinimaxTtsStreamService] 尝试自动重连TTS - SessionId: xxx, 第1次重连
[MinimaxTtsStreamService] 开始预热TTS连接 - SessionId: xxx
[MinimaxTtsStreamService] TTS连接预热成功 - SessionId: xxx, 耗时: XXXms
[MinimaxTtsStreamService] 使用已有TTS连接 - SessionId: xxx, Text: ...
```
**成功标准**
- ✅ 检测到连接断开
- ✅ 自动触发重连最多3次
- ✅ 重连成功后可正常使用
- ✅ 重连失败3次后停止尝试
---
### 测试4连接清理验证
**目的**验证用户断开时是否正确清理TTS连接
**步骤**
1. 建立连接并预热成功
2. 前端主动断开WebSocket连接
3. 观察后端日志
**预期日志**
```
[VoiceStreamHandler] 语音流WebSocket连接关闭 - SessionId: xxx, Status: CloseStatus[code=1000, reason=null]
[VoiceStreamService] 关闭TTS连接 - SessionId: xxx
[MinimaxTtsStreamService] 关闭TTS连接 - SessionId: xxx
[MinimaxTtsStreamService] 发送task_finish - SessionId: xxx
[MinimaxTtsStreamService] MiniMax TTS连接关闭 - SessionId: xxx, Code: 1000, Reason: ...
```
**成功标准**
- ✅ 正确发送task_finish
- ✅ 正常关闭TTS连接
- ✅ 从连接池中移除
- ✅ 无内存泄漏
---
### 测试5性能对比验证
**目的**:验证性能提升是否符合预期
**测试场景**AI回复3句话
**修改前性能**
- 首句延迟1-1.5秒
- 第二句延迟1-1.5秒
- 第三句延迟1-1.5秒
- **总延迟**3-4.5秒
**修改后预期性能**
- 首句延迟0.5秒仅TTS合成
- 第二句延迟0.2-0.3秒
- 第三句延迟0.2-0.3秒
- **总延迟**0.9-1.1秒
**性能提升**:约 **65-70%**
**验证方法**
1. 在日志中记录每句TTS的开始和结束时间
2. 计算总耗时
3. 与预期值对比
---
### 测试6并发测试
**目的**:验证多用户同时使用时的连接管理
**步骤**
1. 创建3-5个WebSocket连接模拟多用户
2. 每个连接同时发送音频
3. 观察后端日志和性能
**成功标准**
- ✅ 每个用户有独立的TTS连接
- ✅ 连接之间互不影响
- ✅ 无连接混乱
- ✅ 性能稳定
---
## 性能监控指标
在日志中关注以下关键指标:
1. **连接建立时间**`TTS连接预热成功 - SessionId: xxx, 耗时: XXXms`
2. **连接复用次数**:统计"使用已有TTS连接"出现次数
3. **TTS合成时间**`TTS完成 - SessionId: xxx, 耗时: XXXms`
4. **重连次数**`尝试自动重连TTS - SessionId: xxx, 第X次重连`
5. **连接状态变化**观察状态从CONNECTING → CONNECTED → TASK_STARTED → IDLE的转换
---
## 常见问题排查
### 问题1预热失败
**现象**`TTS连接预热失败 - SessionId: xxx`
**可能原因**
- MiniMax API配置错误
- 网络连接问题
- API密钥过期
**解决方法**
- 检查application.yml中的MiniMax配置
- 验证API密钥是否有效
- 检查网络连接
### 问题2连接未复用
**现象**每次TTS都显示"创建MiniMax TTS连接"
**可能原因**
- 连接状态管理错误
- 连接在使用前被关闭
**解决方法**
- 检查连接状态日志
- 确认is_final后状态设为IDLE而非关闭
### 问题3自动重连失败
**现象**重连3次后仍无法恢复
**可能原因**
- MiniMax服务问题
- 网络持续不稳定
**解决方法**
- 检查MiniMax服务状态
- 降级为按需创建连接
---
## 测试清单
- [ ] 测试1连接预热验证
- [ ] 测试2连接复用验证
- [ ] 测试3自动重连验证
- [ ] 测试4连接清理验证
- [ ] 测试5性能对比验证
- [ ] 测试6并发测试
- [ ] 检查无内存泄漏
- [ ] 检查无连接泄漏
- [ ] 压力测试10+并发用户)
---
## 测试结论
完成所有测试后,记录:
1. 实际性能提升百分比
2. 连接复用成功率
3. 自动重连成功率
4. 发现的问题和改进建议

View File

@@ -0,0 +1,158 @@
# 语音对话接口交互文档
本文档详细描述了前端 `webUI` 与后端 `server` 之间关于语音对话功能(`/voice-chat`)的交互流程、参数传递及返回值结构。
## 1. 概述
语音对话功能允许用户录制一段语音,前端将其上传至后端。后端依次执行以下操作:
1. **STT (Speech-to-Text)**:将语音转换为文本。
2. **LLM (Large Language Model)**:将识别出的文本作为输入,获取 AI 的文本回复。
3. **TTS (Text-to-Speech)**:将 AI 的文本回复转换为语音。
最终后端将用户识别文本、AI 回复文本及合成的语音数据一次性返回给前端。
## 2. 前端调用 (WebUI)
前端主要涉及的文件为 `webUI/src/components/ChatBox.vue``webUI/src/utils/api.js`
### 2.1 调用方式
`ChatBox.vue` 中,当用户完成录音后,调用 `handleVoiceModeMessage` 方法,进而调用 `voiceAPI.voiceChat`
底层通过 `uni.uploadFile` 发起 `multipart/form-data` 类型的 POST 请求。
### 2.2 请求参数
前端向后端发送的请求包含 **文件****表单数据 (FormData)**
* **URL**: `/api/chat/voice-chat` (由 `config.js` 中的 `VOICE_CHAT` 常量定义)
* **Method**: `POST`
* **Header**: `Authorization: Bearer <token>`
* **文件部分**:
* `name`: `"audio"`
* `filePath`: 录音文件的本地路径 (e.g., `.aac``.wav`)
* **表单数据 (FormData)**:
| 参数名 | 类型 | 必填 | 说明 | 来源 |
| :--- | :--- | :--- | :--- | :--- |
| `sessionId` | String | 否 | 会话 ID用于保持上下文 | `conversationId.value` |
| `modelId` | Integer | 否 | 模型 ID | `characterConfig.modelId` |
| `templateId` | Integer | 否 | 模板 ID | `characterConfig.templateId` |
| `voiceStyle` | String | 否 | 语音风格 (用于 TTS) | `options.voiceStyle` |
| `ttsConfigId` | Integer | 否 | TTS 配置 ID | `aiConfig.ttsId` |
| `sttConfigId` | Integer | 否 | STT 配置 ID | `aiConfig.sttId` |
| `useFunctionCall` | Boolean | 否 | 是否使用函数调用 | 默认为 `false` |
### 2.3 响应处理
前端接收到后端返回的 JSON 数据后,进行如下解析:
1. **用户文本**: 从 `sttResult.text` 获取,显示在聊天界面右侧。
2. **AI 回复**: 从 `llmResult.response` 获取,显示在聊天界面左侧。
3. **语音播放**: 优先使用 `ttsResult.audioBase64` (Base64 编码音频),如果没有则使用 `ttsResult.audioPath` (音频 URL) 进行播放。
---
## 3. 后端处理 (Server)
后端入口为 `server/src/main/java/com/xiaozhi/controller/ChatController.java`,核心逻辑在 `ChatSessionServiceImpl.java`
### 3.1 接口定义
* **Controller**: `ChatController`
* **Path**: `/api/chat/voice-chat`
* **Consumes**: `multipart/form-data`
### 3.2 处理流程
1. **接收文件**: 后端支持字段名为 `audioFile`, `file`, 或 `audio` 的文件上传 (前端使用 `audio`)。
2. **参数解析**: 解析 `sessionId`, `modelId` 等参数。
3. **文件验证**: 检查文件大小 (1KB - 50MB)、格式 (支持 mp3, wav, m4a, aac 等) 和 MIME 类型。
4. **音频处理**:
* 将上传的音频文件保存为临时文件。
* 使用 `AudioUtils` 将音频转换为 **PCM 16k 单声道** 格式(适配 STT 引擎)。
5. **业务逻辑 (`ChatSessionService.voiceChat`)**:
* **STT**: 调用配置的 STT 服务识别语音,得到 `recognizedText`
* **LLM**: 如果识别到文本,调用 `syncChat` 获取 AI 回复 `chatResponse`
* **TTS**: 调用 TTS 服务将 AI 回复转换为语音,生成音频文件并读取为 Base64。
6. **结果封装**: 将 STT、LLM、TTS 的结果封装到 Map 中返回。
---
## 4. 接口规范
### 4.1 请求结构
**POST** `/api/chat/voice-chat`
**Content-Type**: `multipart/form-data`
**Body**:
* `audio`: [二进制文件数据]
* `sessionId`: "session_12345"
* `modelId`: 10
* `templateId`: 6
* ...
### 4.2 响应结构
**Content-Type**: `application/json`
```json
{
"code": 200,
"message": "语音对话成功",
"data": {
"sessionId": "session_12345",
"timestamp": 1717660000000,
// 1. STT 结果
"sttResult": {
"text": "你好,请介绍一下你自己。", // 用户语音识别结果
"audioSize": 32000,
"sttProvider": "vosk"
},
// 2. LLM 结果
"llmResult": {
"response": "你好我是蔚AI很高兴为你服务。", // AI 回复文本
"inputText": "你好,请介绍一下你自己。"
},
// 3. TTS 结果
"ttsResult": {
"audioBase64": "UklGRi...", // Base64 编码的音频数据 (用于直接播放)
"audioPath": "audio/output/...", // 服务器音频文件路径
"timestamp": 1717660005000
},
// 性能统计 (耗时: ms)
"sttDuration": 500,
"llmDuration": 1200,
"ttsDuration": 800,
// 文件元数据
"originalFileName": "temp_audio.aac",
"fileSize": 15000,
"contentType": "audio/aac",
"description": null
}
}
```
### 4.3 错误响应
```json
{
"code": 400, // 或 500
"message": "请求参数错误: 音频文件不能为空",
"data": null
}
```
## 5. 总结
* **交互模式**: 同步一次性交互。前端发送音频,等待后端完成所有处理(识别+对话+合成)后,一次性接收所有数据。
* **音频格式**: 前端通常录制 `aac``wav`,后端统一转码为 `pcm` 进行处理。
* **回退机制**: 如果后端处理失败,前端 `ChatBox.vue` 会捕获异常并提示用户,或使用本地模拟回复(在未登录等特定情况下)。

View File

@@ -0,0 +1,744 @@
# WebSocket 实时语音流式对话架构设计
## 1. 方案概述
本方案设计了一套**"伪实时"语音交互系统**,结合前端 VAD语音活动检测和后端流式处理实现低延迟的语音对话体验。
### 核心特点
- **前端保持简单**:复用现有 VAD 逻辑,用户说完话才上传完整音频
- **后端流式处理**STT 同步处理LLM+TTS 流式串联
- **极低感知延迟**:用户说完话后 1 秒内听到第一句回复
### 技术选型
- **前端通信**WebSocket (双向实时通信)
- **STT**:复用现有 Vosk/其他 STT 服务(同步处理完整音频)
- **LLM**Grok API (`https://api.x.ai/v1/chat/completions`) + SSE Stream
- **TTS**MiniMax WebSocket TTS (`wss://api.minimax.io/ws/v1/t2a_v2`) + Stream
---
## 2. 整体架构流程
```
┌─────────────┐
│ 前端 UI │
└──────┬──────┘
│ 1. 用户说话
┌─────────────────┐
│ RecorderManager │ (VAD 检测说完)
│ (PCM/AAC) │
└────────┬────────┘
│ 2. WebSocket 发送完整音频 (二进制)
┌────────────────────┐
│ 后端 WebSocket │
│ Handler │
└────────┬───────────┘
↓ 3. 调用 STT (同步)
┌────────────────┐
│ STT Service │ → "你好,今天天气怎么样?"
│ (复用现有) │
└────────┬───────┘
↓ 4. 流式调用 Grok LLM (HTTP SSE)
┌─────────────────────────┐
│ Grok API Stream │
│ https://api.x.ai/v1/... │
└────────┬────────────────┘
│ Token 流: "今", "天", "天", "气", "很", "好", "。", ...
┌────────────────────┐
│ 分句缓冲器 │ (检测标点)
└────────┬───────────┘
│ 完整句子: "今天天气很好。"
↓ 5. 为每句调用 MiniMax TTS (WebSocket Stream)
┌─────────────────────────────┐
│ MiniMax TTS WebSocket │
│ wss://api.minimax.io/ws/... │
└────────┬────────────────────┘
│ 音频流 (Hex → Bytes)
↓ 6. 实时转发给前端 (WebSocket 二进制)
┌────────────────────┐
│ 前端 WebSocket │
│ onMessage │
└────────┬───────────┘
│ 7. WebAudioContext 实时播放
┌────────────────────┐
│ 音频播放队列 │
│ (无缝连续播放) │
└────────────────────┘
```
---
## 3. 前端改造方案 (ChatBox.vue)
### 3.1 保留的部分
**现有 VAD 逻辑**
- `recorderManager` 配置
- `onStart``onFrameRecorded`(用于波形可视化)
- `onStop`(核心触发点)
- `calculateVolumeRMS``vadConfig`(静音检测)
### 3.2 需要修改的部分
#### A. 建立 WebSocket 连接
```javascript
// 新增WebSocket 连接管理
const voiceWebSocket = ref(null);
const connectVoiceWebSocket = () => {
const wsUrl = `ws://192.168.3.13:8091/ws/voice-stream`;
voiceWebSocket.value = uni.connectSocket({
url: wsUrl,
success: () => {
console.log('WebSocket 连接成功');
}
});
// 监听消息(接收音频流)
voiceWebSocket.value.onMessage((res) => {
if (res.data instanceof ArrayBuffer) {
// 收到音频数据,加入播放队列
playAudioChunk(res.data);
} else {
// JSON 控制消息
const msg = JSON.parse(res.data);
if (msg.event === 'stt_done') {
addMessage('user', msg.text);
voiceState.value = 'thinking';
} else if (msg.event === 'llm_start') {
voiceState.value = 'speaking';
}
}
});
voiceWebSocket.value.onError((err) => {
console.error('WebSocket 错误:', err);
});
};
// 进入语音模式时连接
const toggleVoiceMode = () => {
if (isVoiceMode.value) {
connectVoiceWebSocket();
} else {
voiceWebSocket.value?.close();
}
};
```
#### B. 修改音频上传逻辑
```javascript
// 修改 handleVoiceModeMessage
const handleVoiceModeMessage = async (filePath) => {
if (!isVoiceMode.value) return;
voiceState.value = 'thinking';
try {
// 读取录音文件为 ArrayBuffer
const fs = uni.getFileSystemManager();
const audioData = await new Promise((resolve, reject) => {
fs.readFile({
filePath: filePath,
success: (res) => resolve(res.data),
fail: reject
});
});
// 通过 WebSocket 发送音频
voiceWebSocket.value.send({
data: audioData,
success: () => {
console.log('音频已发送');
},
fail: (err) => {
console.error('发送失败:', err);
}
});
} catch (error) {
console.error('处理音频失败:', error);
voiceState.value = 'idle';
}
};
```
#### C. 实现流式音频播放
```javascript
// 音频播放队列
const audioQueue = ref([]);
const isPlayingAudio = ref(false);
const playAudioChunk = (arrayBuffer) => {
audioQueue.value.push(arrayBuffer);
if (!isPlayingAudio.value) {
processAudioQueue();
}
};
const processAudioQueue = async () => {
if (audioQueue.value.length === 0) {
isPlayingAudio.value = false;
// 播放完成,回到待机状态
if (isVoiceMode.value && isAutoVoiceMode.value) {
startVoiceRecording();
} else {
voiceState.value = 'idle';
}
return;
}
isPlayingAudio.value = true;
const chunk = audioQueue.value.shift();
// 使用 InnerAudioContext (需先转为临时文件)
const fs = uni.getFileSystemManager();
const tempPath = `${wx.env.USER_DATA_PATH}/temp_audio_${Date.now()}.mp3`;
fs.writeFile({
filePath: tempPath,
data: chunk,
encoding: 'binary',
success: () => {
const audio = uni.createInnerAudioContext();
audio.src = tempPath;
audio.onEnded(() => {
processAudioQueue(); // 播放下一块
});
audio.onError(() => {
processAudioQueue(); // 出错也继续
});
audio.play();
}
});
};
```
---
## 4. 后端实现方案 (Java)
### 4.1 WebSocket Handler
```java
@ServerEndpoint(value = "/ws/voice-stream")
@Component
public class VoiceStreamHandler {
private static final Logger logger = LoggerFactory.getLogger(VoiceStreamHandler.class);
@Autowired
private SttService sttService;
@OnOpen
public void onOpen(Session session) {
logger.info("WebSocket 连接建立: {}", session.getId());
}
@OnMessage
public void onBinaryMessage(ByteBuffer audioBuffer, Session session) {
logger.info("收到音频数据: {} bytes", audioBuffer.remaining());
// 转为 byte[]
byte[] audioData = new byte[audioBuffer.remaining()];
audioBuffer.get(audioData);
// 异步处理(避免阻塞 WebSocket 线程)
CompletableFuture.runAsync(() -> {
processVoiceStream(audioData, session);
});
}
@OnClose
public void onClose(Session session) {
logger.info("WebSocket 连接关闭: {}", session.getId());
}
@OnError
public void onError(Session session, Throwable error) {
logger.error("WebSocket 错误: {}", session.getId(), error);
}
}
```
### 4.2 核心处理流程
```java
private void processVoiceStream(byte[] audioData, Session session) {
try {
// ========== 1. STT (同步处理) ==========
long sttStart = System.currentTimeMillis();
String recognizedText = performStt(audioData);
long sttDuration = System.currentTimeMillis() - sttStart;
if (recognizedText == null || recognizedText.trim().isEmpty()) {
sendJsonMessage(session, Map.of("event", "error", "message", "未识别到语音"));
return;
}
logger.info("STT 完成 ({}ms): {}", sttDuration, recognizedText);
// 通知前端识别结果
sendJsonMessage(session, Map.of(
"event", "stt_done",
"text", recognizedText,
"duration", sttDuration
));
// ========== 2. LLM Stream + TTS Stream (串联) ==========
streamLlmAndTts(recognizedText, session);
} catch (Exception e) {
logger.error("处理语音流失败", e);
sendJsonMessage(session, Map.of("event", "error", "message", e.getMessage()));
}
}
```
### 4.3 STT 处理(复用现有)
```java
private String performStt(byte[] audioData) {
try {
// 转换为 PCM 16k 单声道
byte[] pcmData = AudioUtils.bytesToPcm(audioData);
// 调用 STT 服务(复用现有逻辑)
String text = sttService.recognition(pcmData);
return text != null ? text.trim() : "";
} catch (Exception e) {
logger.error("STT 处理失败", e);
return "";
}
}
```
### 4.4 Grok LLM Stream 调用
```java
private void streamLlmAndTts(String userText, Session frontendSession) {
// 分句缓冲器
StringBuilder sentenceBuffer = new StringBuilder();
// 使用 WebClient 调用 Grok SSE API
WebClient client = WebClient.builder()
.baseUrl("https://api.x.ai")
.defaultHeader("Authorization", "Bearer " + grokApiKey)
.build();
sendJsonMessage(frontendSession, Map.of("event", "llm_start"));
client.post()
.uri("/v1/chat/completions")
.contentType(MediaType.APPLICATION_JSON)
.bodyValue(Map.of(
"model", "grok-beta",
"messages", List.of(
Map.of("role", "system", "content", "你是一个友好的AI助手"),
Map.of("role", "user", "content", userText)
),
"stream", true,
"temperature", 0.7
))
.retrieve()
.bodyToFlux(String.class) // SSE 流
.subscribe(
// onNext: 处理每个 SSE chunk
sseChunk -> {
String token = parseGrokToken(sseChunk);
if (token != null && !token.isEmpty()) {
sentenceBuffer.append(token);
// 检测句子结束
if (isSentenceEnd(sentenceBuffer.toString())) {
String sentence = sentenceBuffer.toString().trim();
sentenceBuffer.setLength(0); // 清空
logger.info("检测到完整句子: {}", sentence);
// 调用 TTS 并流式发送
streamTtsToFrontend(sentence, frontendSession);
}
}
},
// onError
error -> {
logger.error("Grok LLM Stream 错误", error);
sendJsonMessage(frontendSession, Map.of("event", "error", "message", "LLM 处理失败"));
},
// onComplete
() -> {
// 如果还有剩余内容,也发送
if (sentenceBuffer.length() > 0) {
String lastSentence = sentenceBuffer.toString().trim();
if (!lastSentence.isEmpty()) {
streamTtsToFrontend(lastSentence, frontendSession);
}
}
sendJsonMessage(frontendSession, Map.of("event", "llm_complete"));
logger.info("LLM 处理完成");
}
);
}
```
### 4.5 Grok SSE 解析
```java
private String parseGrokToken(String sseChunk) {
// SSE 格式: data: {"choices":[{"delta":{"content":"你好"}}]}\n\n
if (!sseChunk.startsWith("data: ")) {
return null;
}
String jsonData = sseChunk.substring(6).trim();
if (jsonData.equals("[DONE]")) {
return null;
}
try {
JsonObject json = JsonParser.parseString(jsonData).getAsJsonObject();
JsonArray choices = json.getAsJsonArray("choices");
if (choices != null && choices.size() > 0) {
JsonObject delta = choices.get(0).getAsJsonObject().getAsJsonObject("delta");
if (delta != null && delta.has("content")) {
return delta.get("content").getAsString();
}
}
} catch (Exception e) {
logger.warn("解析 Grok token 失败: {}", sseChunk);
}
return null;
}
```
### 4.6 分句逻辑
```java
private boolean isSentenceEnd(String text) {
// 中英文标点检测
String trimmed = text.trim();
if (trimmed.isEmpty()) return false;
char lastChar = trimmed.charAt(trimmed.length() - 1);
// 中文标点
if (lastChar == '。' || lastChar == '' || lastChar == '' ||
lastChar == '' || lastChar == '') {
return true;
}
// 英文标点
if (lastChar == '.' || lastChar == '!' || lastChar == '?') {
return true;
}
// 防止句子过长超过50字强制分句
if (trimmed.length() > 50) {
return true;
}
return false;
}
```
### 4.7 MiniMax TTS WebSocket Stream
```java
private void streamTtsToFrontend(String text, Session frontendSession) {
try {
URI uri = new URI("wss://api.minimax.io/ws/v1/t2a_v2");
WebSocketClient ttsClient = new WebSocketClient(uri) {
private boolean taskStarted = false;
@Override
public void onOpen(ServerHandshake handshake) {
logger.info("MiniMax TTS 连接成功");
}
@Override
public void onMessage(String message) {
try {
JsonObject msg = JsonParser.parseString(message).getAsJsonObject();
String event = msg.get("event").getAsString();
if ("connected_success".equals(event)) {
// 发送 task_start
send(buildTaskStartJson());
} else if ("task_started".equals(event)) {
// 发送 task_continue
taskStarted = true;
send(buildTaskContinueJson(text));
} else if (msg.has("data")) {
JsonObject data = msg.getAsJsonObject("data");
if (data.has("audio")) {
String hexAudio = data.get("audio").getAsString();
// 转为字节数组
byte[] audioBytes = hexToBytes(hexAudio);
// 立即转发给前端
frontendSession.getAsyncRemote().sendBinary(
ByteBuffer.wrap(audioBytes)
);
logger.debug("转发音频数据: {} bytes", audioBytes.length);
}
// 检查是否结束
if (msg.has("is_final") && msg.get("is_final").getAsBoolean()) {
send("{\"event\":\"task_finish\"}");
close();
}
}
} catch (Exception e) {
logger.error("处理 TTS 消息失败", e);
}
}
@Override
public void onError(Exception ex) {
logger.error("MiniMax TTS 错误", ex);
}
@Override
public void onClose(int code, String reason, boolean remote) {
logger.info("MiniMax TTS 连接关闭");
}
};
// 添加认证头
Map<String, String> headers = new HashMap<>();
headers.put("Authorization", "Bearer " + minimaxApiKey);
ttsClient.setHttpHeaders(headers);
// 连接阻塞等待最多5秒
boolean connected = ttsClient.connectBlocking(5, TimeUnit.SECONDS);
if (!connected) {
throw new RuntimeException("连接 MiniMax TTS 超时");
}
} catch (Exception e) {
logger.error("调用 MiniMax TTS 失败", e);
sendJsonMessage(frontendSession, Map.of("event", "error", "message", "TTS 处理失败"));
}
}
```
### 4.8 工具方法
```java
// 构建 task_start JSON
private String buildTaskStartJson() {
return """
{
"event": "task_start",
"model": "speech-2.6-hd",
"voice_setting": {
"voice_id": "male-qn-qingse",
"speed": 1.0,
"vol": 1.0,
"pitch": 0
},
"audio_setting": {
"sample_rate": 32000,
"bitrate": 128000,
"format": "mp3",
"channel": 1
}
}
""";
}
// 构建 task_continue JSON
private String buildTaskContinueJson(String text) {
JsonObject json = new JsonObject();
json.addProperty("event", "task_continue");
json.addProperty("text", text);
return json.toString();
}
// Hex 转字节数组
private byte[] hexToBytes(String hex) {
int len = hex.length();
byte[] data = new byte[len / 2];
for (int i = 0; i < len; i += 2) {
data[i / 2] = (byte) ((Character.digit(hex.charAt(i), 16) << 4)
+ Character.digit(hex.charAt(i+1), 16));
}
return data;
}
// 发送 JSON 消息给前端
private void sendJsonMessage(Session session, Map<String, Object> data) {
try {
String json = new Gson().toJson(data);
session.getAsyncRemote().sendText(json);
} catch (Exception e) {
logger.error("发送消息失败", e);
}
}
```
---
## 5. Maven 依赖配置
```xml
<!-- Spring Boot WebSocket -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-websocket</artifactId>
</dependency>
<!-- WebClient (用于 Grok SSE) -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-webflux</artifactId>
</dependency>
<!-- Java WebSocket Client (用于 MiniMax) -->
<dependency>
<groupId>org.java-websocket</groupId>
<artifactId>Java-WebSocket</artifactId>
<version>1.5.3</version>
</dependency>
<!-- JSON 处理 -->
<dependency>
<groupId>com.google.code.gson</groupId>
<artifactId>gson</artifactId>
</dependency>
```
---
## 6. 关键技术难点与解决方案
### 6.1 并发控制
**问题**:多个句子的 TTS 如何保证顺序?
**解决方案**
- 使用 **信号量 (Semaphore)****串行队列**
- 一个句子的 TTS 完全结束后,才开始下一个句子
```java
private final Semaphore ttsSemaphore = new Semaphore(1);
private void streamTtsToFrontend(String text, Session frontendSession) {
try {
ttsSemaphore.acquire(); // 获取锁
// ... TTS 处理 ...
} finally {
ttsSemaphore.release(); // 释放锁
}
}
```
### 6.2 Grok SSE 连接稳定性
**问题**SSE 长连接可能中断。
**解决方案**
- 设置超时和重试机制
- 使用 `Flux.retry(3)` 自动重试
### 6.3 前端音频播放卡顿
**问题**:网络抖动导致音频断续。
**解决方案**
- 实现 **Jitter Buffer**(缓冲 3-5 个音频块再开始播放)
- 检测队列长度,动态调整播放速率
### 6.4 音频格式兼容性
**问题**:小程序对 PCM 格式支持不佳。
**解决方案**
- MiniMax 配置输出 `mp3` 格式(压缩率高,兼容性好)
- 前端直接播放 MP3 无需解码
---
## 7. 性能指标预估
| 阶段 | 预估耗时 | 说明 |
|---|---|---|
| **前端 VAD** | 实时 | 用户说话期间持续检测 |
| **音频上传** | 50-200ms | 取决于网络和文件大小 |
| **STT 处理** | 300-800ms | Vosk 本地处理较快 |
| **LLM 首 Token** | 500-1000ms | Grok 响应速度 |
| **TTS 首块音频** | 200-500ms | MiniMax WebSocket 延迟 |
| **首次播放延迟** | **1-2秒** | 用户说完到听到回复 |
---
## 8. 测试建议
### 8.1 单元测试
- STT 服务独立测试
- Grok API 调用测试Mock SSE 流)
- MiniMax TTS 调用测试Mock WebSocket
### 8.2 集成测试
- 完整语音流测试(录音 -> STT -> LLM -> TTS -> 播放)
- 并发测试(多用户同时对话)
- 异常场景网络中断、API 超时)
### 8.3 压力测试
- 模拟 100 并发用户
- 监控服务器 CPU、内存、网络 I/O
- 检查 WebSocket 连接泄漏
---
## 9. 优化方向
### 短期优化
1. **缓存 TTS 结果**:相同文本不重复合成
2. **自适应分句**:根据网络状况动态调整句子长度
3. **优雅降级**API 失败时使用备用服务
### 长期优化
1. **边缘计算**STT 迁移到设备端Whisper.cpp
2. **模型本地化**:部署私有 LLM 和 TTS
3. **多模态融合**:支持图片、视频输入
---
## 10. 总结
本方案通过 **WebSocket 全双工通信** 实现了高效的语音流式交互:
**前端简单**:保留现有 VAD只需改造上传和播放逻辑
**后端高效**LLM+TTS 流式串联,极低延迟
**用户体验**:说完话 1 秒内听到回复,接近真人对话
**技术成熟**Grok、MiniMax 官方支持流式 API
最终效果:**从"录音-等待-播放"进化为"流式对话"**,用户感知延迟降低 **60-80%**。🚀

183
docs/import asyncio Normal file
View File

@@ -0,0 +1,183 @@
import asyncio
import websockets
import json
import ssl
import subprocess
import os
model = "speech-2.6-hd"
file_format = "mp3"
class StreamAudioPlayer:
def __init__(self):
self.mpv_process = None
def start_mpv(self):
"""Start MPV player process"""
try:
mpv_command = ["mpv", "--no-cache", "--no-terminal", "--", "fd://0"]
self.mpv_process = subprocess.Popen(
mpv_command,
stdin=subprocess.PIPE,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
print("MPV player started")
return True
except FileNotFoundError:
print("Error: mpv not found. Please install mpv")
return False
except Exception as e:
print(f"Failed to start mpv: {e}")
return False
def play_audio_chunk(self, hex_audio):
"""Play audio chunk"""
try:
if self.mpv_process and self.mpv_process.stdin:
audio_bytes = bytes.fromhex(hex_audio)
self.mpv_process.stdin.write(audio_bytes)
self.mpv_process.stdin.flush()
return True
except Exception as e:
print(f"Play failed: {e}")
return False
return False
def stop(self):
"""Stop player"""
if self.mpv_process:
if self.mpv_process.stdin and not self.mpv_process.stdin.closed:
self.mpv_process.stdin.close()
try:
self.mpv_process.wait(timeout=20)
except subprocess.TimeoutExpired:
self.mpv_process.terminate()
async def establish_connection(api_key):
"""Establish WebSocket connection"""
url = "wss://api.minimax.io/ws/v1/t2a_v2"
headers = {"Authorization": f"Bearer {api_key}"}
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
try:
ws = await websockets.connect(url, additional_headers=headers, ssl=ssl_context)
connected = json.loads(await ws.recv())
if connected.get("event") == "connected_success":
print("Connection successful")
return ws
return None
except Exception as e:
print(f"Connection failed: {e}")
return None
async def start_task(websocket):
"""Send task start request"""
start_msg = {
"event": "task_start",
"model": model,
"voice_setting": {
"voice_id": "male-qn-qingse",
"speed": 1,
"vol": 1,
"pitch": 0,
"english_normalization": False
},
"audio_setting": {
"sample_rate": 32000,
"bitrate": 128000,
"format": file_format,
"channel": 1
}
}
await websocket.send(json.dumps(start_msg))
response = json.loads(await websocket.recv())
return response.get("event") == "task_started"
async def continue_task_with_stream_play(websocket, text, player):
"""Send continue request and stream play audio"""
await websocket.send(json.dumps({
"event": "task_continue",
"text": text
}))
chunk_counter = 1
total_audio_size = 0
audio_data = b""
while True:
try:
response = json.loads(await websocket.recv())
if "data" in response and "audio" in response["data"]:
audio = response["data"]["audio"]
if audio:
print(f"Playing chunk #{chunk_counter}")
audio_bytes = bytes.fromhex(audio)
if player.play_audio_chunk(audio):
total_audio_size += len(audio_bytes)
audio_data += audio_bytes
chunk_counter += 1
if response.get("is_final"):
print(f"Audio synthesis completed: {chunk_counter-1} chunks")
if player.mpv_process and player.mpv_process.stdin:
player.mpv_process.stdin.close()
# Save audio to file
with open(f"output.{file_format}", "wb") as f:
f.write(audio_data)
print(f"Audio saved as output.{file_format}")
estimated_duration = total_audio_size * 0.0625 / 1000
wait_time = max(estimated_duration + 5, 10)
return wait_time
except Exception as e:
print(f"Error: {e}")
break
return 10
async def close_connection(websocket):
"""Close connection"""
if websocket:
try:
await websocket.send(json.dumps({"event": "task_finish"}))
await websocket.close()
except Exception:
pass
async def main():
API_KEY = os.getenv("MINIMAX_API_KEY")
TEXT = "The real danger is not that computers start thinking like people, but that people start thinking like computers. Computers can only help us with simple tasks."
player = StreamAudioPlayer()
try:
if not player.start_mpv():
return
ws = await establish_connection(API_KEY)
if not ws:
return
if not await start_task(ws):
print("Task startup failed")
return
wait_time = await continue_task_with_stream_play(ws, TEXT, player)
await asyncio.sleep(wait_time)
except Exception as e:
print(f"Error: {e}")
finally:
player.stop()
if 'ws' in locals():
await close_connection(ws)
if __name__ == "__main__":
asyncio.run(main())

13
pom.xml
View File

@@ -41,6 +41,19 @@
<artifactId>spring-boot-starter-websocket</artifactId>
</dependency>
<!-- Spring WebFlux for Grok SSE -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-webflux</artifactId>
</dependency>
<!-- Java-WebSocket for MiniMax TTS -->
<dependency>
<groupId>org.java-websocket</groupId>
<artifactId>Java-WebSocket</artifactId>
<version>1.5.7</version>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-aop</artifactId>

View File

@@ -0,0 +1,327 @@
package com.xiaozhi.communication.server.websocket;
import com.google.gson.Gson;
import com.xiaozhi.service.VoiceStreamService;
import jakarta.annotation.Resource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.BinaryMessage;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.AbstractWebSocketHandler;
import java.io.IOException;
import java.net.URI;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
/**
* 语音流式对话 WebSocket 处理器
*/
@Component
public class VoiceStreamHandler extends AbstractWebSocketHandler {
private static final Logger logger = LoggerFactory.getLogger(VoiceStreamHandler.class);
private static final Gson gson = new Gson();
// 保存所有活跃的会话
private final Map<String, WebSocketSession> sessions = new ConcurrentHashMap<>();
@Resource
private VoiceStreamService voiceStreamService;
@Override
public void afterConnectionEstablished(WebSocketSession session) {
String sessionId = session.getId();
// 从请求头或查询参数获取用户认证信息
Map<String, String> params = getParamsFromSession(session);
String token = params.get("token");
String userId = params.get("userId");
// 获取聊天会话相关参数sessionId 用于历史记录查询和保存)
String chatSessionId = params.get("sessionId");
String templateId = params.get("templateId");
logger.info("语音流WebSocket连接建立 - SessionId: {}, UserId: {}, ChatSessionId: {}, TemplateId: {}",
sessionId, userId, chatSessionId, templateId);
// 保存会话和相关参数到session attributes
sessions.put(sessionId, session);
if (chatSessionId != null) {
session.getAttributes().put("chatSessionId", chatSessionId);
}
if (templateId != null) {
session.getAttributes().put("templateId", Integer.parseInt(templateId));
}
// 发送连接成功消息
sendTextMessage(session, createMessage("connected", "连接成功", Map.of("sessionId", sessionId)));
// 预热TTS连接
voiceStreamService.warmupTtsConnection(sessionId)
.thenAccept(v -> {
logger.info("TTS连接预热成功 - SessionId: {}", sessionId);
})
.exceptionally(ex -> {
logger.error("TTS连接预热失败 - SessionId: {}", sessionId, ex);
// 预热失败不影响主流程,降级为按需创建
return null;
});
}
@Override
protected void handleBinaryMessage(WebSocketSession session, BinaryMessage message) {
String sessionId = session.getId();
byte[] audioData = message.getPayload().array();
logger.debug("收到音频数据 - SessionId: {}, Size: {} bytes", sessionId, audioData.length);
// 从session attributes获取聊天会话相关参数
String chatSessionId = (String) session.getAttributes().get("chatSessionId");
Integer templateId = (Integer) session.getAttributes().get("templateId");
try {
// 判断是否使用带历史记录的处理方法
if (chatSessionId != null) {
// 调用带历史记录的语音流服务
voiceStreamService.processAudioStreamWithHistory(
sessionId, audioData, chatSessionId, templateId,
new VoiceStreamService.StreamCallback() {
@Override
public void onSttResult(String text) {
// STT识别结果
sendTextMessage(session, createMessage("stt_result", text, null));
}
@Override
public void onLlmToken(String token) {
// LLM输出token可选发送给前端显示
sendTextMessage(session, createMessage("llm_token", token, null));
}
@Override
public void onSentenceComplete(String sentence) {
// 完整句子(发送给前端显示)
sendTextMessage(session, createMessage("sentence", sentence, null));
}
@Override
public void onAudioChunk(byte[] audioChunk) {
// TTS音频数据
sendBinaryMessage(session, audioChunk);
}
@Override
public void onComplete() {
// 所有处理完成
sendTextMessage(session, createMessage("complete", "对话完成", null));
}
@Override
public void onError(String error) {
// 错误处理
sendTextMessage(session, createMessage("error", error, null));
}
});
} else {
// 调用原有的不带历史记录的方法(向后兼容)
voiceStreamService.processAudioStream(sessionId, audioData, new VoiceStreamService.StreamCallback() {
@Override
public void onSttResult(String text) {
// STT识别结果
sendTextMessage(session, createMessage("stt_result", text, null));
}
@Override
public void onLlmToken(String token) {
// LLM输出token可选发送给前端显示
sendTextMessage(session, createMessage("llm_token", token, null));
}
@Override
public void onSentenceComplete(String sentence) {
// 完整句子(发送给前端显示)
sendTextMessage(session, createMessage("sentence", sentence, null));
}
@Override
public void onAudioChunk(byte[] audioChunk) {
// TTS音频数据
sendBinaryMessage(session, audioChunk);
}
@Override
public void onComplete() {
// 所有处理完成
sendTextMessage(session, createMessage("complete", "对话完成", null));
}
@Override
public void onError(String error) {
// 错误处理
sendTextMessage(session, createMessage("error", error, null));
}
});
}
} catch (Exception e) {
logger.error("处理音频流失败 - SessionId: {}", sessionId, e);
sendTextMessage(session, createMessage("error", "处理音频失败: " + e.getMessage(), null));
}
}
@Override
protected void handleTextMessage(WebSocketSession session, TextMessage message) {
String sessionId = session.getId();
String payload = message.getPayload();
logger.debug("收到文本消息 - SessionId: {}, Message: {}", sessionId, payload);
try {
@SuppressWarnings("unchecked")
Map<String, Object> msgMap = gson.fromJson(payload, Map.class);
String type = (String) msgMap.get("type");
if ("ping".equals(type)) {
// 心跳响应
sendTextMessage(session, createMessage("pong", "pong", null));
} else if ("cancel".equals(type)) {
// 取消当前对话(打断)
voiceStreamService.cancelStream(sessionId);
sendTextMessage(session, createMessage("cancelled", "已取消", null));
}
} catch (Exception e) {
logger.error("处理文本消息失败 - SessionId: {}", sessionId, e);
}
}
@Override
public void afterConnectionClosed(WebSocketSession session, CloseStatus status) {
String sessionId = session.getId();
sessions.remove(sessionId);
// 取消该会话的所有处理
voiceStreamService.cancelStream(sessionId);
// 清理TTS连接
voiceStreamService.closeTtsConnection(sessionId);
logger.info("语音流WebSocket连接关闭 - SessionId: {}, Status: {}", sessionId, status);
}
@Override
public void handleTransportError(WebSocketSession session, Throwable exception) {
String sessionId = session.getId();
if (isClientCloseRequest(exception)) {
logger.info("WebSocket连接被客户端主动关闭 - SessionId: {}", sessionId);
} else {
logger.error("WebSocket传输错误 - SessionId: {}", sessionId, exception);
}
// 清理会话
sessions.remove(sessionId);
voiceStreamService.cancelStream(sessionId);
}
/**
* 判断异常是否由客户端主动关闭连接导致
*/
private boolean isClientCloseRequest(Throwable exception) {
if (exception instanceof IOException) {
String message = exception.getMessage();
if (message != null) {
return message.contains("Connection reset by peer") ||
message.contains("Broken pipe") ||
message.contains("Connection closed") ||
message.contains("远程主机强迫关闭了一个现有的连接");
}
return exception instanceof java.io.EOFException;
}
return false;
}
/**
* 从会话中获取参数从query string或header
*/
private Map<String, String> getParamsFromSession(WebSocketSession session) {
Map<String, String> params = new HashMap<>();
// 从header获取
String token = session.getHandshakeHeaders().getFirst("Authorization");
if (token != null) {
params.put("token", token.replace("Bearer ", ""));
}
String userId = session.getHandshakeHeaders().getFirst("User-Id");
if (userId != null) {
params.put("userId", userId);
}
// 从URI query参数获取
URI uri = session.getUri();
if (uri != null && uri.getQuery() != null) {
String query = uri.getQuery();
for (String param : query.split("&")) {
String[] kv = param.split("=", 2);
if (kv.length == 2) {
params.put(kv[0], kv[1]);
}
}
}
return params;
}
/**
* 创建JSON格式的消息
*/
private String createMessage(String type, String message, Map<String, Object> data) {
Map<String, Object> msg = new HashMap<>();
msg.put("type", type);
msg.put("message", message);
msg.put("timestamp", System.currentTimeMillis());
if (data != null) {
msg.put("data", data);
}
return gson.toJson(msg);
}
/**
* 发送文本消息
*/
private void sendTextMessage(WebSocketSession session, String message) {
try {
if (session.isOpen()) {
session.sendMessage(new TextMessage(message));
}
} catch (Exception e) {
logger.error("发送文本消息失败 - SessionId: {}", session.getId(), e);
}
}
/**
* 发送二进制消息
*/
private void sendBinaryMessage(WebSocketSession session, byte[] data) {
try {
if (session.isOpen()) {
session.sendMessage(new BinaryMessage(data));
}
} catch (Exception e) {
logger.error("发送二进制消息失败 - SessionId: {}", session.getId(), e);
}
}
/**
* 获取指定会话
*/
public WebSocketSession getSession(String sessionId) {
return sessions.get(sessionId);
}
}

View File

@@ -21,18 +21,27 @@ public class WebSocketConfig implements WebSocketConfigurer {
// 定义为public static以便其他类可以访问
public static final String WS_PATH = "/ws/xiaozhi/v1/";
public static final String VOICE_STREAM_PATH = "/ws/voice-stream";
@Resource
private WebSocketHandler webSocketHandler;
@Resource
private VoiceStreamHandler voiceStreamHandler;
@Resource
private CmsUtils cmsUtils;
@Override
public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
// 设备WebSocket端点
registry.addHandler(webSocketHandler, WS_PATH)
.setAllowedOrigins("*");
// 语音流式对话WebSocket端点
registry.addHandler(voiceStreamHandler, VOICE_STREAM_PATH)
.setAllowedOrigins("*");
logger.info("📡 WebSocket服务地址: {}", cmsUtils.getWebsocketAddress());
logger.info("==========================================================");
}

View File

@@ -0,0 +1,89 @@
package com.xiaozhi.config;
import lombok.Data;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.context.annotation.Configuration;
/**
* 语音流式对话配置类
*/
@Configuration
@ConfigurationProperties(prefix = "xiaozhi.voice-stream")
@Data
public class VoiceStreamConfig {
private GrokConfig grok = new GrokConfig();
private MinimaxConfig minimax = new MinimaxConfig();
@Data
public static class GrokConfig {
/**
* Grok API Key
*/
private String apiKey;
/**
* Grok API URL
*/
private String apiUrl;
/**
* Grok 模型名称
*/
private String model;
}
@Data
public static class MinimaxConfig {
/**
* MiniMax API Key
*/
private String apiKey;
/**
* MiniMax Group ID
*/
private String groupId;
/**
* MiniMax WebSocket URL
*/
private String wsUrl;
/**
* MiniMax TTS 模型名称
*/
private String model;
/**
* 音色ID
*/
private String voiceId;
/**
* 语速 (0.5-2.0)
*/
private Double speed;
/**
* 音量 (0.1-10.0)
*/
private Double vol;
/**
* 音调 (-12到12)
*/
private Integer pitch;
/**
* 音频采样率
*/
private Integer audioSampleRate;
/**
* 比特率
*/
private Integer bitrate;
}
}

View File

@@ -1,11 +1,13 @@
package com.xiaozhi.controller;
import com.xiaozhi.dialogue.stt.SttService;
import com.xiaozhi.dto.ChatRequest;
import com.xiaozhi.dto.TtsRequest;
import com.xiaozhi.dto.VoiceChatRequest;
import com.xiaozhi.dto.AudioUploadRequest;
import com.xiaozhi.dto.ChatResponse;
import com.xiaozhi.dto.ApiResponse;
import com.xiaozhi.entity.SysConfig;
import com.xiaozhi.service.ChatSessionService;
import com.xiaozhi.service.TtsIntegrationService;
import com.xiaozhi.config.ChatConstants;
@@ -182,33 +184,6 @@ public class ChatController {
ChatConstants.ERROR_INTERNAL + ": " + e.getMessage()));
}
}
/**
* 语音对话接口
* @param request 语音对话请求
* @return 包含识别文本、回复文本和音频Base64的响应
*/
@Operation(summary = "语音对话", description = "接收音频数据进行语音识别、对话生成和语音合成的完整流程。必填参数audioData可选参数sessionId、useFunctionCall、modelId、templateId、sttConfigId、ttsConfigId")
@PostMapping("/voice-chat")
public ResponseEntity<ApiResponse<Map<String, Object>>> voiceChat(
@Parameter(description = "语音对话请求参数", required = true,
content = @Content(examples = @ExampleObject(value = "{\"audioData\":\"UklGRiQAAABXQVZFZm10IBAAAAABAAEA...\",\"sessionId\":\"session123\",\"useFunctionCall\":false,\"modelId\":7,\"templateId\":1,\"sttConfigId\":9,\"ttsConfigId\":8}")))
@RequestBody VoiceChatRequest request) {
try {
logger.info("收到语音对话请求");
Map<String, Object> result = chatSessionService.voiceChat(request);
logger.info("语音对话响应成功");
return ResponseEntity.ok(ApiResponse.success(ChatConstants.VOICE_CHAT_SUCCESS_MESSAGE, result));
} catch (Exception e) {
logger.error("语音对话处理失败", e);
return ResponseEntity.internalServerError()
.body(ApiResponse.error(ChatConstants.INTERNAL_ERROR_CODE,
ChatConstants.ERROR_INTERNAL + ": " + e.getMessage()));
}
}
/**
* 语音对话接口multipart/form-data 版本)
* 为兼容客户端以表单方式上传音频数据,支持多个常见字段名。
@@ -273,6 +248,10 @@ public class ChatController {
}
}
@Resource
private com.xiaozhi.dialogue.stt.factory.SttServiceFactory sttServiceFactory;
/**
* 清除会话缓存
* @param sessionId 会话ID

View File

@@ -0,0 +1,336 @@
package com.xiaozhi.dialogue.llm;
import com.google.gson.Gson;
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;
import com.xiaozhi.config.VoiceStreamConfig;
import jakarta.annotation.Resource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Service;
import org.springframework.web.reactive.function.client.WebClient;
import reactor.core.publisher.Flux;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* Grok LLM 流式服务
*
* 支持通过 Server-Sent Events (SSE) 方式流式调用 Grok API
*
* API文档: https://docs.x.ai/api
*
* 响应格式:
* data: {"id":"xxx","object":"chat.completion.chunk","created":xxx,
* "model":"grok-4-1-fast-non-reasoning",
* "choices":[{"index":0,"delta":{"content":"token","role":"assistant"}}],
* "system_fingerprint":"xxx"}
* data: [DONE]
*/
@Service
public class GrokStreamService {
private static final Logger logger = LoggerFactory.getLogger(GrokStreamService.class);
private static final Gson gson = new Gson();
// SSE数据前缀
private static final String SSE_DATA_PREFIX = "data: ";
// SSE流结束标记
private static final String SSE_DONE_FLAG = "[DONE]";
// 默认温度参数
private static final double DEFAULT_TEMPERATURE = 0.7;
// 默认最大token数如果经常被截断可以增加这个值
private static final int DEFAULT_MAX_TOKENS = 4000;
@Resource
private VoiceStreamConfig voiceStreamConfig;
private WebClient webClient;
/**
* 流式调用 Grok API
*
* @param userMessage 用户消息
* @param systemPrompt 系统提示词(可选)
* @param callback 接收token的回调
*/
public void streamChat(String userMessage, String systemPrompt, TokenCallback callback) {
try {
// 参数验证
if (userMessage == null || userMessage.trim().isEmpty()) {
logger.warn("用户消息为空跳过API调用");
callback.onError("用户消息不能为空");
return;
}
if (webClient == null) {
initWebClient();
}
// 构建消息列表
List<Map<String, String>> messages = new ArrayList<>();
if (systemPrompt != null && !systemPrompt.isEmpty()) {
messages.add(Map.of("role", "system", "content", systemPrompt));
}
messages.add(Map.of("role", "user", "content", userMessage));
// 构建请求体
Map<String, Object> requestBody = new HashMap<>();
requestBody.put("model", voiceStreamConfig.getGrok().getModel());
requestBody.put("messages", messages);
requestBody.put("stream", true);
requestBody.put("temperature", DEFAULT_TEMPERATURE);
requestBody.put("max_tokens", DEFAULT_MAX_TOKENS); // 限制最大token数
String requestJson = gson.toJson(requestBody);
logger.info("开始调用Grok API - Model: {}, UserMessage: {}",
voiceStreamConfig.getGrok().getModel(),
userMessage.length() > 50 ? userMessage.substring(0, 50) + "..." : userMessage);
logger.debug("请求体: {}", requestJson);
// 发起流式请求
webClient.post()
.uri(voiceStreamConfig.getGrok().getApiUrl())
.header("Authorization", "Bearer " + voiceStreamConfig.getGrok().getApiKey())
.header("Content-Type", "application/json")
.bodyValue(requestJson)
.retrieve()
.bodyToFlux(String.class)
.flatMap(this::parseSSE)
.doOnNext(token -> {
if (token != null && !token.isEmpty()) {
callback.onToken(token);
}
})
.doOnComplete(() -> {
logger.info("Grok API流式调用完成");
callback.onComplete();
})
.doOnError(error -> {
String errorMsg = "LLM调用失败: " + error.getMessage();
logger.error(errorMsg, error);
callback.onError(errorMsg);
})
.subscribe();
} catch (Exception e) {
String errorMsg = "调用Grok API时发生异常: " + e.getMessage();
logger.error(errorMsg, e);
callback.onError(errorMsg);
}
}
/**
* 流式调用 Grok API带历史记录
*
* @param messages 完整的消息列表(包含系统提示、历史对话和当前用户消息)
* @param callback 接收token的回调
*/
public void streamChatWithHistory(List<Map<String, String>> messages, TokenCallback callback) {
try {
// 参数验证
if (messages == null || messages.isEmpty()) {
logger.warn("消息列表为空跳过API调用");
callback.onError("消息列表不能为空");
return;
}
if (webClient == null) {
initWebClient();
}
// 构建请求体
Map<String, Object> requestBody = new HashMap<>();
requestBody.put("model", voiceStreamConfig.getGrok().getModel());
requestBody.put("messages", messages);
requestBody.put("stream", true);
requestBody.put("temperature", DEFAULT_TEMPERATURE);
requestBody.put("max_tokens", DEFAULT_MAX_TOKENS); // 限制最大token数
String requestJson = gson.toJson(requestBody);
logger.info("开始调用Grok API带历史记录 - Model: {}, 消息数量: {}",
voiceStreamConfig.getGrok().getModel(), messages.size());
logger.debug("请求体: {}", requestJson);
// 发起流式请求
webClient.post()
.uri(voiceStreamConfig.getGrok().getApiUrl())
.header("Authorization", "Bearer " + voiceStreamConfig.getGrok().getApiKey())
.header("Content-Type", "application/json")
.bodyValue(requestJson)
.retrieve()
.bodyToFlux(String.class)
.flatMap(this::parseSSE)
.doOnNext(token -> {
if (token != null && !token.isEmpty()) {
callback.onToken(token);
}
})
.doOnComplete(() -> {
logger.info("Grok API流式调用完成带历史记录");
callback.onComplete();
})
.doOnError(error -> {
String errorMsg = "LLM调用失败: " + error.getMessage();
logger.error(errorMsg, error);
callback.onError(errorMsg);
})
.subscribe();
} catch (Exception e) {
String errorMsg = "调用Grok API时发生异常: " + e.getMessage();
logger.error(errorMsg, e);
callback.onError(errorMsg);
}
}
/**
* 解析SSE格式的响应
*
* 数据格式示例WebClient已自动去除 "data: " 前缀):
* {"id":"xxx","object":"chat.completion.chunk","created":xxx,"model":"grok-4-1-fast-non-reasoning",
* "choices":[{"index":0,"delta":{"content":"你","role":"assistant"}}],"system_fingerprint":"xxx"}
*/
private Flux<String> parseSSE(String line) {
return Flux.create(sink -> {
try {
// 跳过空行
if (line == null || line.trim().isEmpty()) {
sink.complete();
return;
}
String data = line.trim();
// 记录原始数据仅在trace级别
logger.trace("收到原始SSE数据: {}", data.length() > 200 ? data.substring(0, 200) + "..." : data);
// 跳过 [DONE] 标记(流结束标志)
// 可能是 "data: [DONE]" 或直接 "[DONE]"
if (SSE_DONE_FLAG.equals(data) || data.equals(SSE_DATA_PREFIX + SSE_DONE_FLAG)) {
logger.debug("收到流结束标记");
sink.complete();
return;
}
// 如果有 "data: " 前缀,去掉它
if (data.startsWith(SSE_DATA_PREFIX)) {
data = data.substring(SSE_DATA_PREFIX.length()).trim();
}
// 再次检查是否是 [DONE]
if (SSE_DONE_FLAG.equals(data)) {
logger.debug("收到流结束标记");
sink.complete();
return;
}
// 解析JSON现在data直接就是JSON字符串
JsonObject json = JsonParser.parseString(data).getAsJsonObject();
// 验证基本字段
if (!json.has("choices")) {
logger.debug("JSON中缺少choices字段");
sink.complete();
return;
}
// 获取choices数组
var choices = json.getAsJsonArray("choices");
if (choices.isEmpty()) {
logger.debug("choices数组为空");
sink.complete();
return;
}
// 获取第一个choice
var choice = choices.get(0).getAsJsonObject();
// 检查是否有delta字段
if (!choice.has("delta")) {
logger.debug("choice中缺少delta字段");
sink.complete();
return;
}
var delta = choice.getAsJsonObject("delta");
// 提取content内容
if (delta.has("content")) {
String content = delta.get("content").getAsString();
// 只有非空内容才发送
if (content != null && !content.isEmpty()) {
logger.trace("提取到token: {}", content);
sink.next(content);
}
}
// 如果有role字段可以记录日志首次消息会包含role
if (delta.has("role")) {
String role = delta.get("role").getAsString();
logger.debug("收到角色信息: {}", role);
}
// 检查finish_reason流结束原因
if (choice.has("finish_reason") && !choice.get("finish_reason").isJsonNull()) {
String finishReason = choice.get("finish_reason").getAsString();
logger.info("流结束原因: {}", finishReason);
// finish_reason 可能的值: "stop"(正常结束), "length"(达到最大长度), "content_filter"(内容过滤)
}
sink.complete();
} catch (com.google.gson.JsonSyntaxException e) {
logger.error("JSON解析失败 - 原始数据: {}", line, e);
sink.complete();
} catch (Exception e) {
logger.error("解析SSE数据时发生未知错误 - 原始数据: {}", line, e);
sink.complete();
}
});
}
/**
* 初始化WebClient
*/
private void initWebClient() {
String apiUrl = voiceStreamConfig.getGrok().getApiUrl();
if (apiUrl == null || apiUrl.trim().isEmpty()) {
throw new IllegalStateException("Grok API URL未配置");
}
logger.info("初始化Grok WebClient - URL: {}", apiUrl);
this.webClient = WebClient.builder()
.baseUrl(apiUrl)
// 增加缓冲区大小,支持更大的响应
.codecs(configurer -> configurer.defaultCodecs().maxInMemorySize(10 * 1024 * 1024)) // 10MB
.build();
logger.info("Grok WebClient初始化完成");
}
/**
* Token回调接口
*/
public interface TokenCallback {
/**
* 接收到新的token
*/
void onToken(String token);
/**
* 流式输出完成
*/
void onComplete();
/**
* 发生错误
*/
void onError(String error);
}
}

View File

@@ -0,0 +1,137 @@
package com.xiaozhi.dialogue.llm;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Service;
import java.util.regex.Pattern;
/**
* 分句缓冲服务
* 将LLM流式输出的token缓冲并按句子分割
*/
@Service
public class SentenceBufferService {
private static final Logger logger = LoggerFactory.getLogger(SentenceBufferService.class);
// 中英文句子结束标点符号
private static final Pattern SENTENCE_END_PATTERN = Pattern.compile("[。!?!?;]");
// 强制分句的最大长度
private static final int MAX_SENTENCE_LENGTH = 50;
/**
* 创建一个新的句子缓冲器
*/
public SentenceBuffer createBuffer(SentenceCallback callback) {
return new SentenceBuffer(callback);
}
/**
* 句子缓冲器
*/
public class SentenceBuffer {
private final StringBuilder buffer = new StringBuilder();
private final SentenceCallback callback;
public SentenceBuffer(SentenceCallback callback) {
this.callback = callback;
}
/**
* 添加token到缓冲区
*/
public synchronized void addToken(String token) {
buffer.append(token);
// 检查是否有完整的句子
String currentBuffer = buffer.toString();
int lastSentenceEnd = findLastSentenceEnd(currentBuffer);
if (lastSentenceEnd > 0) {
// 找到句子结束符,提取完整句子
String sentence = currentBuffer.substring(0, lastSentenceEnd + 1).trim();
if (!sentence.isEmpty()) {
logger.debug("检测到完整句子: {}", sentence);
callback.onSentence(sentence);
}
// 保留剩余部分
buffer.setLength(0);
if (lastSentenceEnd + 1 < currentBuffer.length()) {
buffer.append(currentBuffer.substring(lastSentenceEnd + 1));
}
} else if (buffer.length() > MAX_SENTENCE_LENGTH) {
// 超过最大长度,强制分句
String sentence = buffer.toString().trim();
if (!sentence.isEmpty()) {
logger.debug("强制分句(长度超限): {}", sentence);
callback.onSentence(sentence);
}
buffer.setLength(0);
}
}
/**
* 完成输入,输出剩余内容
*/
public synchronized void finish() {
if (buffer.length() > 0) {
String sentence = buffer.toString().trim();
if (!sentence.isEmpty()) {
logger.debug("输出最后剩余内容: {}", sentence);
callback.onSentence(sentence);
}
buffer.setLength(0);
}
callback.onComplete();
}
/**
* 清空缓冲区
*/
public synchronized void clear() {
buffer.setLength(0);
}
/**
* 查找最后一个句子结束符的位置
*/
private int findLastSentenceEnd(String text) {
int lastPos = -1;
for (int i = text.length() - 1; i >= 0; i--) {
char ch = text.charAt(i);
if (isSentenceEndChar(ch)) {
lastPos = i;
break;
}
}
return lastPos;
}
/**
* 判断是否为句子结束符
*/
private boolean isSentenceEndChar(char ch) {
return ch == '。' || ch == '' || ch == '' ||
ch == '.' || ch == '!' || ch == '?' ||
ch == '' || ch == ';';
}
}
/**
* 句子回调接口
*/
public interface SentenceCallback {
/**
* 检测到完整句子
*/
void onSentence(String sentence);
/**
* 所有句子处理完成
*/
void onComplete();
}
}

View File

@@ -30,7 +30,7 @@ public class SttServiceFactory {
private final Map<String, SttService> serviceCache = new ConcurrentHashMap<>();
// 默认服务提供商名称
private static final String DEFAULT_PROVIDER = "vosk";
private static final String DEFAULT_PROVIDER = "aliyun";
// 标记Vosk是否初始化成功
private boolean voskInitialized = false;

View File

@@ -40,7 +40,7 @@ public class AliyunSttService implements SttService {
// private final String modelName;
public AliyunSttService(SysConfig config) {
this.apiKey = config.getApiKey();
this.apiKey = "sk-f4c91752a2604845b2956265863f941d";
// // 从配置中获取模型名称,如果没有配置则使用默认值
// this.modelName = (config.getConfigName() != null && !config.getConfigName().trim().isEmpty())
// ? config.getConfigName().trim()

View File

@@ -0,0 +1,548 @@
package com.xiaozhi.dialogue.tts;
import com.google.gson.Gson;
import com.xiaozhi.config.VoiceStreamConfig;
import com.xiaozhi.dialogue.tts.TtsConnectionState.ConnectionStatus;
import jakarta.annotation.Resource;
import org.java_websocket.client.WebSocketClient;
import org.java_websocket.handshake.ServerHandshake;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Service;
import java.net.URI;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicBoolean;
/**
* MiniMax TTS 流式服务
* 基于WebSocket实现流式文本转语音支持连接预热和复用
*/
@Service
public class MinimaxTtsStreamService {
private static final Logger logger = LoggerFactory.getLogger(MinimaxTtsStreamService.class);
private static final Gson gson = new Gson();
private static final int CONNECTION_TIMEOUT_SECONDS = 5;
@Resource
private VoiceStreamConfig voiceStreamConfig;
// 管理活跃的WebSocket连接和状态
private final Map<String, SessionTtsConnection> connections = new ConcurrentHashMap<>();
/**
* 会话TTS连接封装
*/
private static class SessionTtsConnection {
final MinimaxTtsClient client;
final TtsConnectionState state;
final Semaphore processingLock = new Semaphore(1); // 确保串行处理
SessionTtsConnection(MinimaxTtsClient client, TtsConnectionState state) {
this.client = client;
this.state = state;
}
}
/**
* 预热TTS连接
* 建立WebSocket连接并发送task_start等待连接就绪
*
* @param sessionId 会话ID
* @return CompletableFuture 连接就绪后完成
*/
public CompletableFuture<Void> warmupConnection(String sessionId) {
long startTime = System.currentTimeMillis();
logger.info("开始预热TTS连接 - SessionId: {}", sessionId);
CompletableFuture<Void> warmupFuture = new CompletableFuture<>();
try {
// 检查是否已有连接
SessionTtsConnection existing = connections.get(sessionId);
if (existing != null && existing.state.isConnected()) {
logger.info("TTS连接已存在跳过预热 - SessionId: {}, 状态: {}",
sessionId, existing.state.getStatus());
warmupFuture.complete(null);
return warmupFuture;
}
// 创建新连接
String wsUrl = voiceStreamConfig.getMinimax().getWsUrl();
TtsConnectionState state = new TtsConnectionState();
state.setStatus(ConnectionStatus.CONNECTING);
MinimaxTtsClient client = new MinimaxTtsClient(
new URI(wsUrl),
sessionId,
state,
warmupFuture
);
SessionTtsConnection connection = new SessionTtsConnection(client, state);
connections.put(sessionId, connection);
// 异步连接
CompletableFuture.runAsync(() -> {
try {
boolean connected = client.connectBlocking(CONNECTION_TIMEOUT_SECONDS, TimeUnit.SECONDS);
if (!connected) {
throw new TimeoutException("连接超时");
}
long duration = System.currentTimeMillis() - startTime;
logger.info("TTS连接预热成功 - SessionId: {}, 耗时: {}ms", sessionId, duration);
} catch (Exception e) {
logger.error("TTS连接预热失败 - SessionId: {}", sessionId, e);
state.setStatus(ConnectionStatus.DISCONNECTED);
connections.remove(sessionId);
warmupFuture.completeExceptionally(e);
}
});
} catch (Exception e) {
logger.error("创建TTS连接失败 - SessionId: {}", sessionId, e);
warmupFuture.completeExceptionally(e);
}
return warmupFuture;
}
/**
* 流式合成语音(复用现有连接)
*
* @param sessionId 会话ID
* @param text 要合成的文本
* @param callback 音频数据回调
*/
public CompletableFuture<Void> streamTts(String sessionId, String text, AudioCallback callback) {
long startTime = System.currentTimeMillis();
CompletableFuture<Void> future = new CompletableFuture<>();
try {
SessionTtsConnection connection = connections.get(sessionId);
// 检查连接是否存在且就绪
if (connection == null || !connection.state.isReady()) {
logger.warn("TTS连接不存在或未就绪 - SessionId: {}, 状态: {}",
sessionId, connection != null ? connection.state.getStatus() : "NULL");
// 尝试建立新连接
return warmupConnection(sessionId)
.thenCompose(v -> streamTts(sessionId, text, callback));
}
// 获取处理锁,确保串行处理
connection.processingLock.acquire();
try {
logger.info("使用已有TTS连接 - SessionId: {}, Text: {}", sessionId, text);
connection.state.setStatus(ConnectionStatus.PROCESSING);
// 发送文本进行TTS
connection.client.sendText(text, callback, future);
// 等待TTS完成后释放锁
future.whenComplete((result, error) -> {
connection.processingLock.release();
if (error == null) {
long duration = System.currentTimeMillis() - startTime;
logger.info("TTS完成 - SessionId: {}, 耗时: {}ms复用连接节省约1秒",
sessionId, duration);
connection.state.setStatus(ConnectionStatus.IDLE);
} else {
logger.error("TTS失败 - SessionId: {}", sessionId, error);
connection.state.setStatus(ConnectionStatus.IDLE);
}
});
} catch (Exception e) {
connection.processingLock.release();
throw e;
}
} catch (Exception e) {
logger.error("TTS处理失败 - SessionId: {}", sessionId, e);
callback.onError("TTS处理失败: " + e.getMessage());
future.completeExceptionally(e);
}
return future;
}
/**
* 关闭指定会话的TTS连接
*/
public void closeConnection(String sessionId) {
SessionTtsConnection connection = connections.remove(sessionId);
if (connection != null) {
try {
connection.client.closeConnection();
logger.info("关闭TTS连接 - SessionId: {}", sessionId);
} catch (Exception e) {
logger.error("关闭TTS连接失败 - SessionId: {}", sessionId, e);
}
}
}
/**
* 取消指定会话的TTS兼容旧接口
*/
public void cancelTts(String sessionId) {
closeConnection(sessionId);
}
/**
* MiniMax TTS WebSocket 客户端(支持连接复用)
*/
private class MinimaxTtsClient extends WebSocketClient {
private final String sessionId;
private final TtsConnectionState state;
private final CompletableFuture<Void> warmupFuture;
// 当前任务相关
private volatile String currentText;
private volatile AudioCallback currentCallback;
private volatile CompletableFuture<Void> currentTaskFuture;
private final AtomicBoolean taskProcessing = new AtomicBoolean(false);
// 音频缓冲器 - 缓冲整句音频
private java.io.ByteArrayOutputStream audioBuffer;
public MinimaxTtsClient(URI serverUri, String sessionId, TtsConnectionState state,
CompletableFuture<Void> warmupFuture) {
super(serverUri);
this.sessionId = sessionId;
this.state = state;
this.warmupFuture = warmupFuture;
// 添加Authorization header
this.addHeader("Authorization", "Bearer " + voiceStreamConfig.getMinimax().getApiKey());
}
/**
* 发送文本进行TTS合成
* 注意调用此方法前外层应该已经检查过状态并设置为PROCESSING
*/
public void sendText(String text, AudioCallback callback, CompletableFuture<Void> taskFuture) {
this.currentText = text;
this.currentCallback = callback;
this.currentTaskFuture = taskFuture;
this.audioBuffer = new java.io.ByteArrayOutputStream();
this.taskProcessing.set(true);
// 直接发送task_continue
sendTaskContinue(text);
}
/**
* 关闭连接
*/
public void closeConnection() {
try {
if (isOpen()) {
sendTaskFinish();
}
close();
state.setStatus(ConnectionStatus.DISCONNECTED);
} catch (Exception e) {
logger.error("关闭连接失败", e);
}
}
@Override
public void onOpen(ServerHandshake handshake) {
logger.debug("MiniMax TTS连接已建立 - SessionId: {}", sessionId);
state.setStatus(ConnectionStatus.CONNECTED);
// 连接建立后,等待收到 connected_success 事件再发送 task_start
}
@Override
public void onMessage(String message) {
try {
logger.debug("收到消息 - SessionId: {}: {}", sessionId, message);
@SuppressWarnings("unchecked")
Map<String, Object> response = gson.fromJson(message, Map.class);
String event = (String) response.get("event");
if ("connected_success".equals(event)) {
// 连接成功,发送 task_start不包含text
logger.info("收到connected_success发送task_start - SessionId: {}", sessionId);
sendTaskStart();
} else if ("task_started".equals(event)) {
// 任务启动成功,连接就绪
logger.info("收到task_started连接就绪 - SessionId: {}", sessionId);
state.setStatus(ConnectionStatus.TASK_STARTED);
state.resetReconnectAttempts(); // 连接成功,重置重连计数
// 完成预热future
if (warmupFuture != null && !warmupFuture.isDone()) {
warmupFuture.complete(null);
}
// 如果是复用连接后立即发送的task_continue状态会自动设为IDLE
if (!taskProcessing.get()) {
state.setStatus(ConnectionStatus.IDLE);
}
} else if ("task_failed".equals(event)) {
// 任务失败
@SuppressWarnings("unchecked")
Map<String, Object> baseResp = (Map<String, Object>) response.get("base_resp");
String errorMsg = baseResp != null ? (String) baseResp.get("status_msg") : "unknown error";
logger.error("TTS任务失败 - SessionId: {}: {}", sessionId, errorMsg);
if (currentCallback != null) {
currentCallback.onError("TTS任务失败: " + errorMsg);
}
if (currentTaskFuture != null && !currentTaskFuture.isDone()) {
currentTaskFuture.completeExceptionally(new RuntimeException(errorMsg));
}
taskProcessing.set(false);
state.setStatus(ConnectionStatus.IDLE);
} else if (response.containsKey("data")) {
// 音频数据
@SuppressWarnings("unchecked")
Map<String, Object> data = (Map<String, Object>) response.get("data");
if (data != null && data.containsKey("audio")) {
String hexAudio = (String) data.get("audio");
if (hexAudio != null && !hexAudio.isEmpty() && audioBuffer != null) {
// 将hex字符串转为bytes并写入缓冲区
byte[] audioBytes = hexStringToByteArray(hexAudio);
try {
audioBuffer.write(audioBytes);
logger.debug("缓冲音频块 - SessionId: {}: {} bytes, 总计: {} bytes",
sessionId, audioBytes.length, audioBuffer.size());
} catch (java.io.IOException e) {
logger.error("写入音频缓冲区失败 - SessionId: {}", sessionId, e);
}
}
}
// 检查是否完成
Boolean isFinal = (Boolean) response.get("is_final");
if (Boolean.TRUE.equals(isFinal) && taskProcessing.compareAndSet(true, false)) {
logger.debug("TTS任务完成is_final=true- SessionId: {}", sessionId);
// 一次性发送完整的音频数据
if (audioBuffer != null && currentCallback != null) {
byte[] completeAudio = audioBuffer.toByteArray();
logger.info("TTS完成 - SessionId: {}, Text: {}, 音频总大小: {} bytes",
sessionId, currentText, completeAudio.length);
if (completeAudio.length > 0) {
currentCallback.onAudioChunk(completeAudio);
}
currentCallback.onComplete();
}
// 完成当前任务future
if (currentTaskFuture != null && !currentTaskFuture.isDone()) {
currentTaskFuture.complete(null);
}
// 不关闭连接设置为IDLE状态以便复用
state.setStatus(ConnectionStatus.IDLE);
// 清理当前任务
currentText = null;
currentCallback = null;
currentTaskFuture = null;
audioBuffer = null;
}
}
} catch (Exception e) {
logger.error("处理TTS消息失败 - SessionId: {}", sessionId, e);
if (currentCallback != null) {
currentCallback.onError("处理TTS响应失败: " + e.getMessage());
}
}
}
/**
* 发送 task_start 消息不包含text
*/
private void sendTaskStart() {
try {
Map<String, Object> message = new HashMap<>();
message.put("event", "task_start");
message.put("model", voiceStreamConfig.getMinimax().getModel());
// voice_setting
Map<String, Object> voiceSetting = new HashMap<>();
voiceSetting.put("voice_id", voiceStreamConfig.getMinimax().getVoiceId());
voiceSetting.put("speed", voiceStreamConfig.getMinimax().getSpeed());
voiceSetting.put("vol", voiceStreamConfig.getMinimax().getVol());
voiceSetting.put("pitch", voiceStreamConfig.getMinimax().getPitch());
message.put("voice_setting", voiceSetting);
// audio_setting
Map<String, Object> audioSetting = new HashMap<>();
audioSetting.put("sample_rate", voiceStreamConfig.getMinimax().getAudioSampleRate());
audioSetting.put("bitrate", voiceStreamConfig.getMinimax().getBitrate());
audioSetting.put("format", "mp3"); // 使用mp3格式兼容性更好
audioSetting.put("channel", 1);
message.put("audio_setting", audioSetting);
String jsonMessage = gson.toJson(message);
logger.debug("发送task_start - SessionId: {}: {}", sessionId, jsonMessage);
send(jsonMessage);
} catch (Exception e) {
logger.error("发送task_start失败 - SessionId: {}", sessionId, e);
if (warmupFuture != null && !warmupFuture.isDone()) {
warmupFuture.completeExceptionally(e);
}
}
}
/**
* 发送 task_continue 消息包含text
*/
private void sendTaskContinue(String text) {
try {
Map<String, Object> message = new HashMap<>();
message.put("event", "task_continue");
message.put("text", text);
String jsonMessage = gson.toJson(message);
logger.debug("发送task_continue - SessionId: {}: {}", sessionId, jsonMessage);
send(jsonMessage);
} catch (Exception e) {
logger.error("发送task_continue失败 - SessionId: {}", sessionId, e);
if (currentCallback != null) {
currentCallback.onError("发送文本失败: " + e.getMessage());
}
if (currentTaskFuture != null && !currentTaskFuture.isDone()) {
currentTaskFuture.completeExceptionally(e);
}
}
}
/**
* 发送 task_finish 消息
*/
private void sendTaskFinish() {
try {
Map<String, Object> message = new HashMap<>();
message.put("event", "task_finish");
String jsonMessage = gson.toJson(message);
logger.debug("发送task_finish - SessionId: {}: {}", sessionId, jsonMessage);
send(jsonMessage);
} catch (Exception e) {
logger.error("发送task_finish失败 - SessionId: {}", sessionId, e);
}
}
@Override
public void onClose(int code, String reason, boolean remote) {
logger.warn("MiniMax TTS连接关闭 - SessionId: {}, Code: {}, Reason: {}, Remote: {}",
sessionId, code, reason, remote);
ConnectionStatus prevStatus = state.getStatus();
state.setStatus(ConnectionStatus.DISCONNECTED);
// 判断是否需要重连
boolean needReconnect = remote && state.canReconnect() &&
(prevStatus == ConnectionStatus.CONNECTED ||
prevStatus == ConnectionStatus.TASK_STARTED ||
prevStatus == ConnectionStatus.IDLE);
if (needReconnect) {
// 自动重连
int attempts = state.incrementReconnectAttempts();
logger.info("尝试自动重连TTS - SessionId: {}, 第{}次重连", sessionId, attempts);
CompletableFuture.runAsync(() -> {
try {
Thread.sleep(1000 * attempts); // 延迟递增
warmupConnection(sessionId)
.exceptionally(ex -> {
logger.error("自动重连失败 - SessionId: {}", sessionId, ex);
return null;
});
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
});
} else {
// 不重连,通知错误
if (taskProcessing.get()) {
if (currentCallback != null) {
currentCallback.onComplete(); // 标记完成,避免阻塞
}
if (currentTaskFuture != null && !currentTaskFuture.isDone()) {
currentTaskFuture.complete(null);
}
taskProcessing.set(false);
}
// 如果是预热阶段失败
if (warmupFuture != null && !warmupFuture.isDone()) {
warmupFuture.completeExceptionally(new RuntimeException("连接关闭: " + reason));
}
}
}
@Override
public void onError(Exception ex) {
logger.error("MiniMax TTS连接错误 - SessionId: {}", sessionId, ex);
state.setStatus(ConnectionStatus.DISCONNECTED);
if (currentCallback != null) {
currentCallback.onError("TTS连接错误: " + ex.getMessage());
}
if (currentTaskFuture != null && !currentTaskFuture.isDone()) {
currentTaskFuture.completeExceptionally(ex);
}
if (warmupFuture != null && !warmupFuture.isDone()) {
warmupFuture.completeExceptionally(ex);
}
}
/**
* 将hex字符串转为byte数组
*/
private byte[] hexStringToByteArray(String hex) {
int len = hex.length();
byte[] data = new byte[len / 2];
for (int i = 0; i < len; i += 2) {
data[i / 2] = (byte) ((Character.digit(hex.charAt(i), 16) << 4)
+ Character.digit(hex.charAt(i + 1), 16));
}
return data;
}
}
/**
* 音频数据回调接口
*/
public interface AudioCallback {
/**
* 接收到音频数据块
*/
void onAudioChunk(byte[] audioData);
/**
* TTS完成
*/
void onComplete();
/**
* 发生错误
*/
void onError(String error);
}
}

View File

@@ -0,0 +1,136 @@
package com.xiaozhi.dialogue.tts;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
/**
* TTS连接状态管理
* 管理连接状态、重连计数等
*/
public class TtsConnectionState {
private static final int MAX_RECONNECT_ATTEMPTS = 3;
private final AtomicReference<ConnectionStatus> status;
private final AtomicInteger reconnectAttempts;
private volatile long lastActiveTime;
public TtsConnectionState() {
this.status = new AtomicReference<>(ConnectionStatus.DISCONNECTED);
this.reconnectAttempts = new AtomicInteger(0);
this.lastActiveTime = System.currentTimeMillis();
}
/**
* 连接状态枚举
*/
public enum ConnectionStatus {
DISCONNECTED, // 未连接
CONNECTING, // 连接中
CONNECTED, // 已连接未完成task_start
TASK_STARTED, // task_start已完成连接就绪
PROCESSING, // 正在处理TTS任务
IDLE // 空闲,可接受新任务
}
/**
* 获取当前状态
*/
public ConnectionStatus getStatus() {
return status.get();
}
/**
* 设置状态
*/
public void setStatus(ConnectionStatus newStatus) {
ConnectionStatus oldStatus = status.getAndSet(newStatus);
if (oldStatus != newStatus) {
updateLastActiveTime();
}
}
/**
* CAS方式更新状态
*/
public boolean compareAndSetStatus(ConnectionStatus expect, ConnectionStatus update) {
boolean updated = status.compareAndSet(expect, update);
if (updated) {
updateLastActiveTime();
}
return updated;
}
/**
* 检查是否可以处理新任务
*/
public boolean isReady() {
ConnectionStatus current = status.get();
return current == ConnectionStatus.IDLE || current == ConnectionStatus.TASK_STARTED;
}
/**
* 检查是否已连接
*/
public boolean isConnected() {
ConnectionStatus current = status.get();
return current != ConnectionStatus.DISCONNECTED && current != ConnectionStatus.CONNECTING;
}
/**
* 获取重连次数
*/
public int getReconnectAttempts() {
return reconnectAttempts.get();
}
/**
* 增加重连次数
*/
public int incrementReconnectAttempts() {
return reconnectAttempts.incrementAndGet();
}
/**
* 重置重连次数
*/
public void resetReconnectAttempts() {
reconnectAttempts.set(0);
}
/**
* 检查是否还可以重连
*/
public boolean canReconnect() {
return reconnectAttempts.get() < MAX_RECONNECT_ATTEMPTS;
}
/**
* 更新最后活跃时间
*/
public void updateLastActiveTime() {
this.lastActiveTime = System.currentTimeMillis();
}
/**
* 获取最后活跃时间
*/
public long getLastActiveTime() {
return lastActiveTime;
}
/**
* 重置状态
*/
public void reset() {
status.set(ConnectionStatus.DISCONNECTED);
reconnectAttempts.set(0);
updateLastActiveTime();
}
@Override
public String toString() {
return String.format("TtsConnectionState{status=%s, reconnectAttempts=%d, lastActiveTime=%d}",
status.get(), reconnectAttempts.get(), lastActiveTime);
}
}

View File

@@ -86,7 +86,6 @@ public class TtsServiceFactory {
case "volcengine" -> new VolcengineTtsService(config, voiceName, outputPath);
case "xfyun" -> new XfyunTtsService(config, voiceName, outputPath);
case "minimax" -> new MiniMaxTtsService(config, voiceName, outputPath);
case "minimax-ws" -> new MiniMaxTtsWebSocketService(config, voiceName, outputPath); // WebSocket 流式版本
default -> new EdgeTtsService(voiceName, outputPath);
};
}

View File

@@ -0,0 +1,649 @@
package com.xiaozhi.service;
import com.xiaozhi.dialogue.llm.GrokStreamService;
import com.xiaozhi.dialogue.llm.SentenceBufferService;
import com.xiaozhi.dialogue.llm.memory.DatabaseChatMemory;
import com.xiaozhi.dialogue.stt.SttService;
import com.xiaozhi.dialogue.stt.factory.SttServiceFactory;
import com.xiaozhi.dialogue.tts.MinimaxTtsStreamService;
import com.xiaozhi.entity.SysConfig;
import com.xiaozhi.entity.SysMessage;
import com.xiaozhi.utils.AudioUtils;
import jakarta.annotation.Resource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Service;
import java.io.IOException;
import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
/**
* 语音流式对话核心服务
* 串联 STT -> LLM Stream -> 分句 -> TTS Stream 流程
*/
@Service
public class VoiceStreamService {
private static final Logger logger = LoggerFactory.getLogger(VoiceStreamService.class);
@Resource
private SttServiceFactory sttServiceFactory;
@Resource
private GrokStreamService grokStreamService;
@Resource
private SentenceBufferService sentenceBufferService;
@Resource
private MinimaxTtsStreamService minimaxTtsStreamService;
@Resource
private SysConfigService configService;
@Resource
private DatabaseChatMemory chatMemory;
@Resource
private SysMessageService sysMessageService;
// 管理每个会话的状态
private final Map<String, SessionState> sessions = new ConcurrentHashMap<>();
// 线程池用于异步处理
private final ExecutorService executorService = Executors.newCachedThreadPool();
/**
* 处理音频流
*
* @param sessionId 会话ID
* @param audioData PCM音频数据
* @param callback 回调接口
*/
public void processAudioStream(String sessionId, byte[] audioData, StreamCallback callback) {
executorService.submit(() -> {
try {
// 获取或创建会话状态
SessionState state = sessions.computeIfAbsent(sessionId, k -> new SessionState());
// 取消之前的处理(打断机制)
state.cancel();
state.reset();
logger.info("开始处理音频流 - SessionId: {}, AudioSize: {}", sessionId, audioData.length);
// 1. STT - 语音转文字
String recognizedText = performStt(audioData);
if (recognizedText == null || recognizedText.trim().isEmpty()) {
callback.onError("语音识别失败或未识别到内容");
return;
}
logger.info("STT识别结果 - SessionId: {}, Text: {}", sessionId, recognizedText);
callback.onSttResult(recognizedText);
// 2. 创建分句缓冲器
SentenceBufferService.SentenceBuffer sentenceBuffer =
sentenceBufferService.createBuffer(new SentenceBufferService.SentenceCallback() {
private final Queue<String> sentenceQueue = new LinkedList<>();
private final Semaphore ttsPermit = new Semaphore(1); // 保证TTS顺序执行
private final AtomicBoolean processing = new AtomicBoolean(false);
@Override
public void onSentence(String sentence) {
if (state.isCancelled()) {
logger.info("会话已取消,跳过句子处理 - SessionId: {}", sessionId);
return;
}
logger.info("检测到完整句子 - SessionId: {}, Sentence: {}", sessionId, sentence);
callback.onSentenceComplete(sentence);
// 将句子加入队列并异步处理TTS
sentenceQueue.offer(sentence);
processSentenceQueue();
}
@Override
public void onComplete() {
logger.info("LLM输出完成 - SessionId: {}", sessionId);
// 等待所有TTS完成
waitForAllTtsComplete();
callback.onComplete();
}
private void processSentenceQueue() {
if (processing.compareAndSet(false, true)) {
executorService.submit(() -> {
try {
while (!sentenceQueue.isEmpty() && !state.isCancelled()) {
String sentence = sentenceQueue.poll();
if (sentence != null) {
processSentenceTts(sentence);
}
}
} finally {
processing.set(false);
// 检查是否还有新的句子
if (!sentenceQueue.isEmpty() && !state.isCancelled()) {
processSentenceQueue();
}
}
});
}
}
private void processSentenceTts(String sentence) {
if (state.isCancelled()) {
return;
}
try {
ttsPermit.acquire();
logger.info("开始TTS合成 - SessionId: {}, Sentence: {}", sessionId, sentence);
CompletableFuture<Void> ttsFuture = minimaxTtsStreamService.streamTts(
sessionId,
sentence,
new MinimaxTtsStreamService.AudioCallback() {
@Override
public void onAudioChunk(byte[] audioData) {
if (!state.isCancelled()) {
callback.onAudioChunk(audioData);
}
}
@Override
public void onComplete() {
logger.debug("句子TTS完成 - SessionId: {}", sessionId);
}
@Override
public void onError(String error) {
logger.error("TTS错误 - SessionId: {}, Error: {}", sessionId, error);
callback.onError(error);
}
}
);
// 等待TTS完成
ttsFuture.get(30, TimeUnit.SECONDS);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
logger.warn("TTS处理被中断 - SessionId: {}", sessionId);
state.markTtsFailed(); // 标记 TTS 失败
} catch (Exception e) {
logger.error("TTS处理失败 - SessionId: {}", sessionId, e);
state.markTtsFailed(); // 标记 TTS 失败
callback.onError("TTS处理失败: " + e.getMessage());
} finally {
ttsPermit.release();
}
}
private void waitForAllTtsComplete() {
try {
// 等待所有TTS完成
ttsPermit.acquire();
ttsPermit.release();
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
}
});
// 3. LLM流式调用
String systemPrompt = "你是一个友好的AI助手请用简洁、自然的语气回答用户的问题。";
grokStreamService.streamChat(
recognizedText,
systemPrompt,
new GrokStreamService.TokenCallback() {
@Override
public void onToken(String token) {
if (state.isCancelled()) {
return;
}
callback.onLlmToken(token);
sentenceBuffer.addToken(token);
}
@Override
public void onComplete() {
if (!state.isCancelled()) {
sentenceBuffer.finish();
}
}
@Override
public void onError(String error) {
logger.error("LLM调用失败 - SessionId: {}, Error: {}", sessionId, error);
callback.onError(error);
}
}
);
} catch (Exception e) {
logger.error("处理音频流失败 - SessionId: {}", sessionId, e);
callback.onError("处理失败: " + e.getMessage());
}
});
}
/**
* 处理音频流(带历史记录)
*
* @param sessionId WebSocket会话ID
* @param audioData PCM音频数据
* @param chatSessionId 聊天会话ID用于查询和保存历史记录
* @param templateId 模板ID可选用于获取角色信息
* @param callback 回调接口
*/
public void processAudioStreamWithHistory(String sessionId, byte[] audioData,
String chatSessionId, Integer templateId,
StreamCallback callback) {
executorService.submit(() -> {
try {
// 获取或创建会话状态
SessionState state = sessions.computeIfAbsent(sessionId, k -> new SessionState());
// 取消之前的处理(打断机制)
state.cancel();
state.reset();
logger.info("开始处理音频流(带历史记录) - SessionId: {}, ChatSessionId: {}, TemplateId: {}, AudioSize: {}",
sessionId, chatSessionId, templateId, audioData.length);
// 使用 chatSessionId 作为 deviceId保持与文字对话一致
String deviceId = chatSessionId;
// roleId 可以为 null数据库允许
Integer roleId = templateId;
// 1. STT - 语音转文字
String recognizedText = performStt(audioData);
if (recognizedText == null || recognizedText.trim().isEmpty()) {
callback.onError("语音识别失败或未识别到内容");
return;
}
logger.info("STT识别结果 - SessionId: {}, Text: {}", sessionId, recognizedText);
callback.onSttResult(recognizedText);
// 2. 异步保存用户消息到数据库(不等待,保留 future 引用)
CompletableFuture<Void> userMessageFuture = chatMemory.addMessage(
deviceId, chatSessionId, "user", recognizedText,
roleId, "NORMAL", System.currentTimeMillis());
logger.info("用户消息开始异步保存 - ChatSessionId: {}", chatSessionId);
// 3. 查询历史记录最近20条
// 等待用户消息保存完成,确保能查询到最新数据
List<SysMessage> historyMessages = new ArrayList<>();
try {
userMessageFuture.join(); // 等待用户消息保存完成
logger.info("用户消息保存完成 - ChatSessionId: {}", chatSessionId);
List<SysMessage> allMessages = sysMessageService.queryBySessionId(chatSessionId);
// 获取最近20条消息不包括刚刚保存的用户消息
int startIndex = Math.max(0, allMessages.size() - 21); // -21 因为包含了刚保存的用户消息
historyMessages = allMessages.subList(startIndex, allMessages.size() - 1); // -1 排除最后一条(刚保存的)
logger.info("查询到历史消息 {} 条 - ChatSessionId: {}", historyMessages.size(), chatSessionId);
} catch (Exception e) {
logger.error("保存或查询历史消息失败,将继续处理(无历史上下文) - ChatSessionId: {}", chatSessionId, e);
}
// 4. 累积LLM完整响应用于保存到数据库
AtomicReference<String> fullResponse = new AtomicReference<>("");
// 5. 创建分句缓冲器
SentenceBufferService.SentenceBuffer sentenceBuffer =
sentenceBufferService.createBuffer(new SentenceBufferService.SentenceCallback() {
private final Queue<String> sentenceQueue = new LinkedList<>();
private final Semaphore ttsPermit = new Semaphore(1); // 保证TTS顺序执行
private final AtomicBoolean processing = new AtomicBoolean(false);
@Override
public void onSentence(String sentence) {
if (state.isCancelled()) {
logger.info("会话已取消,跳过句子处理 - SessionId: {}", sessionId);
return;
}
logger.info("检测到完整句子 - SessionId: {}, Sentence: {}", sessionId, sentence);
callback.onSentenceComplete(sentence);
// 将句子加入队列并异步处理TTS
sentenceQueue.offer(sentence);
processSentenceQueue();
}
@Override
public void onComplete() {
logger.info("LLM输出完成 - SessionId: {}", sessionId);
// 等待所有TTS完成
waitForAllTtsComplete();
// TTS 全部成功后才保存 AI 回复(如果未取消且未失败)
String aiResponse = fullResponse.get();
if (aiResponse != null && !aiResponse.trim().isEmpty()
&& !state.isCancelled() && !state.isTtsFailed()) {
// 使用链式调用确保在用户消息之后保存,并增加 1ms 保证顺序
userMessageFuture.thenCompose(v -> {
long timestamp = System.currentTimeMillis() + 1;
return chatMemory.addMessage(deviceId, chatSessionId, "assistant",
aiResponse, roleId, "NORMAL", timestamp);
}).thenAccept(v -> {
logger.info("AI响应已保存 - ChatSessionId: {}, Length: {}",
chatSessionId, aiResponse.length());
}).exceptionally(ex -> {
logger.error("保存AI响应失败 - ChatSessionId: {}", chatSessionId, ex);
return null;
});
} else if (state.isTtsFailed()) {
logger.warn("TTS失败不保存AI响应 - ChatSessionId: {}", chatSessionId);
} else if (state.isCancelled()) {
logger.info("会话已取消不保存AI响应 - ChatSessionId: {}", chatSessionId);
}
callback.onComplete();
}
private void processSentenceQueue() {
if (processing.compareAndSet(false, true)) {
executorService.submit(() -> {
try {
while (!sentenceQueue.isEmpty() && !state.isCancelled()) {
String sentence = sentenceQueue.poll();
if (sentence != null) {
processSentenceTts(sentence);
}
}
} finally {
processing.set(false);
// 检查是否还有新的句子
if (!sentenceQueue.isEmpty() && !state.isCancelled()) {
processSentenceQueue();
}
}
});
}
}
private void processSentenceTts(String sentence) {
if (state.isCancelled()) {
return;
}
try {
ttsPermit.acquire();
logger.info("开始TTS合成 - SessionId: {}, Sentence: {}", sessionId, sentence);
CompletableFuture<Void> ttsFuture = minimaxTtsStreamService.streamTts(
sessionId,
sentence,
new MinimaxTtsStreamService.AudioCallback() {
@Override
public void onAudioChunk(byte[] audioData) {
if (!state.isCancelled()) {
callback.onAudioChunk(audioData);
}
}
@Override
public void onComplete() {
logger.debug("句子TTS完成 - SessionId: {}", sessionId);
}
@Override
public void onError(String error) {
logger.error("TTS错误 - SessionId: {}, Error: {}", sessionId, error);
callback.onError(error);
}
}
);
// 等待TTS完成
ttsFuture.get(30, TimeUnit.SECONDS);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
logger.warn("TTS处理被中断 - SessionId: {}", sessionId);
state.markTtsFailed(); // 标记 TTS 失败
} catch (Exception e) {
logger.error("TTS处理失败 - SessionId: {}", sessionId, e);
state.markTtsFailed(); // 标记 TTS 失败
callback.onError("TTS处理失败: " + e.getMessage());
} finally {
ttsPermit.release();
}
}
private void waitForAllTtsComplete() {
try {
// 等待所有TTS完成
ttsPermit.acquire();
ttsPermit.release();
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
}
});
// 6. 构建包含历史记录的消息列表
List<Map<String, String>> messages = buildMessagesWithHistory(historyMessages, recognizedText);
// 7. LLM流式调用带历史记录
grokStreamService.streamChatWithHistory(
messages,
new GrokStreamService.TokenCallback() {
@Override
public void onToken(String token) {
if (state.isCancelled()) {
return;
}
// 累积完整响应
fullResponse.updateAndGet(current -> current + token);
callback.onLlmToken(token);
sentenceBuffer.addToken(token);
}
@Override
public void onComplete() {
if (!state.isCancelled()) {
sentenceBuffer.finish();
}
}
@Override
public void onError(String error) {
logger.error("LLM调用失败 - SessionId: {}, Error: {}", sessionId, error);
callback.onError(error);
}
}
);
} catch (Exception e) {
logger.error("处理音频流失败(带历史记录) - SessionId: {}", sessionId, e);
callback.onError("处理失败: " + e.getMessage());
}
});
}
/**
* 构建包含历史记录的消息列表
*/
private List<Map<String, String>> buildMessagesWithHistory(List<SysMessage> historyMessages,
String currentUserMessage) {
List<Map<String, String>> messages = new ArrayList<>();
// 添加系统提示
Map<String, String> systemMessage = new HashMap<>();
systemMessage.put("role", "system");
systemMessage.put("content", "你是一个友好的AI助手请用简洁、自然的语气回答用户的问题。");
messages.add(systemMessage);
// 添加历史消息
for (SysMessage msg : historyMessages) {
Map<String, String> message = new HashMap<>();
// 将数据库中的 sender 转换为 Grok API 需要的角色
if ("user".equals(msg.getSender())) {
message.put("role", "user");
} else if ("assistant".equals(msg.getSender())) {
message.put("role", "assistant");
} else {
// 跳过不认识的角色
logger.warn("未知的消息发送方: {}", msg.getSender());
continue;
}
message.put("content", msg.getMessage());
messages.add(message);
}
// 添加当前用户消息
Map<String, String> currentMessage = new HashMap<>();
currentMessage.put("role", "user");
currentMessage.put("content", currentUserMessage);
messages.add(currentMessage);
logger.debug("构建消息列表完成,共 {} 条消息(含系统提示和历史)", messages.size());
return messages;
}
/**
* 取消指定会话的流处理
*/
public void cancelStream(String sessionId) {
SessionState state = sessions.get(sessionId);
if (state != null) {
state.cancel();
// 取消TTS
minimaxTtsStreamService.cancelTts(sessionId);
logger.info("取消流处理 - SessionId: {}", sessionId);
}
}
/**
* 预热TTS连接
*
* @param sessionId 会话ID
* @return CompletableFuture 连接就绪后完成
*/
public CompletableFuture<Void> warmupTtsConnection(String sessionId) {
logger.info("预热TTS连接 - SessionId: {}", sessionId);
return minimaxTtsStreamService.warmupConnection(sessionId);
}
/**
* 关闭TTS连接
*
* @param sessionId 会话ID
*/
public void closeTtsConnection(String sessionId) {
logger.info("关闭TTS连接 - SessionId: {}", sessionId);
minimaxTtsStreamService.closeConnection(sessionId);
}
/**
* 执行STT复用现有逻辑
*/
private String performStt(byte[] audioData) {
try {
// 标准化输入音频为 PCM 16k 单声道
byte[] pcmData = AudioUtils.bytesToPcm(audioData);
if (pcmData == null || pcmData.length == 0) {
throw new RuntimeException("音频数据为空或转换后为空");
}
// 获取STT服务使用默认配置
SttService sttService = sttServiceFactory.getSttService(null);
if (sttService == null) {
throw new RuntimeException("无法获取STT服务");
}
// 执行语音识别
String recognizedText = sttService.recognition(pcmData);
return recognizedText != null ? recognizedText.trim() : "";
} catch (IOException e) {
throw new RuntimeException("音频标准化失败: " + e.getMessage(), e);
} catch (Exception e) {
logger.error("STT处理失败: {}", e.getMessage(), e);
throw new RuntimeException("语音识别失败: " + e.getMessage(), e);
}
}
/**
* 会话状态
*/
private static class SessionState {
private final AtomicBoolean cancelled = new AtomicBoolean(false);
private final AtomicBoolean ttsFailed = new AtomicBoolean(false);
public void cancel() {
cancelled.set(true);
}
public void reset() {
cancelled.set(false);
ttsFailed.set(false); // 重置时也清除 TTS 失败标记
}
public boolean isCancelled() {
return cancelled.get();
}
public void markTtsFailed() {
ttsFailed.set(true);
}
public boolean isTtsFailed() {
return ttsFailed.get();
}
}
/**
* 流处理回调接口
*/
public interface StreamCallback {
/**
* STT识别结果
*/
void onSttResult(String text);
/**
* LLM输出token
*/
void onLlmToken(String token);
/**
* 完整句子
*/
void onSentenceComplete(String sentence);
/**
* TTS音频数据块
*/
void onAudioChunk(byte[] audioChunk);
/**
* 所有处理完成
*/
void onComplete();
/**
* 发生错误
*/
void onError(String error);
}
}

View File

@@ -182,7 +182,12 @@ public class ChatSessionServiceImpl implements ChatSessionService {
ChatSession session = getOrCreateSession(sessionId, request.getModelId(), request.getTemplateId());
// 1. STT - 语音转文本
long sttStartTime = System.currentTimeMillis();
Map<String, Object> sttResult = performStt(request);
long sttEndTime = System.currentTimeMillis();
long sttDuration = sttEndTime - sttStartTime;
logger.info("STT处理完成sessionId={}, 耗时={}s", sessionId, sttDuration / 1000.0);
String recognizedText = (String) sttResult.get("text");
if (recognizedText == null || recognizedText.trim().isEmpty()) {
@@ -191,6 +196,7 @@ public class ChatSessionServiceImpl implements ChatSessionService {
result.put("sttResult", sttResult);
result.put("sessionId", sessionId);
result.put("timestamp", System.currentTimeMillis());
result.put("sttDuration", sttDuration);
return result;
}
@@ -202,8 +208,12 @@ public class ChatSessionServiceImpl implements ChatSessionService {
chatRequest.setTemplateId(request.getTemplateId());
chatRequest.setUseFunctionCall(request.getUseFunctionCall());
long llmStartTime = System.currentTimeMillis();
ChatResponse chatResponse = syncChat(chatRequest);
long llmEndTime = System.currentTimeMillis();
long llmDuration = llmEndTime - llmStartTime;
logger.info("LLM处理完成sessionId={}, 耗时={}s", sessionId, llmDuration / 1000.0);
Map<String, Object> llmResult = new HashMap<>();
llmResult.put("response", chatResponse.getResponse());
llmResult.put("inputText", recognizedText);
@@ -211,13 +221,17 @@ public class ChatSessionServiceImpl implements ChatSessionService {
// 获取会话中角色的voiceName
String voiceName = null;
ChatSession chatSession = getOrCreateSession(sessionId, request.getModelId(), request.getTemplateId());
if (chatSession != null && chatSession.getConversation() != null &&
if (chatSession != null && chatSession.getConversation() != null &&
chatSession.getConversation().role() != null) {
voiceName = chatSession.getConversation().role().getVoiceName();
}
// 3. TTS - 文本转语音
long ttsStartTime = System.currentTimeMillis();
Map<String, Object> ttsResult = performTts(chatResponse.getResponse(), request.getTtsConfigId(), voiceName);
long ttsEndTime = System.currentTimeMillis();
long ttsDuration = ttsEndTime - ttsStartTime;
logger.info("TTS处理完成sessionId={}, 耗时={}s", sessionId, ttsDuration / 1000.0);
// 组装完整响应
Map<String, Object> result = new HashMap<>();
@@ -226,6 +240,13 @@ public class ChatSessionServiceImpl implements ChatSessionService {
result.put("ttsResult", ttsResult);
result.put("sessionId", sessionId);
result.put("timestamp", System.currentTimeMillis());
result.put("sttDuration", sttDuration);
result.put("llmDuration", llmDuration);
result.put("ttsDuration", ttsDuration);
logger.info("语音对话完成sessionId={}, STT耗时={}s, LLM耗时={}s, TTS耗时={}s, 总耗时={}s",
sessionId, sttDuration / 1000.0, llmDuration / 1000.0, ttsDuration / 1000.0,
(sttDuration + llmDuration + ttsDuration) / 1000.0);
return result;

View File

@@ -311,4 +311,23 @@ openapi:
contact-email: xiaozhi@qq.com
version: 1.0
external-description: xiaozhi API Docs
external-url: https://github.com/joey-zhou/xiaozhi-esp32-server-java
external-url: https://github.com/joey-zhou/xiaozhi-esp32-server-java
# 语音流式对话配置
xiaozhi:
voice-stream:
grok:
api-key: xai-KKU4O5WumrQowiPc2qSF223L7Nw2XAwDOgrBxQRtrhKAbAri7JBDJuiv6HNwBcvNTnO026YPUeijwGqq
api-url: https://caddy.liqupan.cn/v1/chat/completions
model: grok-4-1-fast-non-reasoning
minimax:
api-key: eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJHcm91cE5hbWUiOiJ2YmlvZGJkcCIsIlVzZXJOYW1lIjoidHNldCIsIkFjY291bnQiOiIiLCJTdWJqZWN0SUQiOiIxOTkyOTAyNTAzMzg5MjA1NDY3IiwiUGhvbmUiOiIiLCJHcm91cElEIjoiMTk5MjkwMjUwMzM4MDgyMDk1NSIsIlBhZ2VOYW1lIjoiIiwiTWFpbCI6InZiaW9kYmRwQGdtYWlsLmNvbSIsIkNyZWF0ZVRpbWUiOiIyMDI1LTEyLTA2IDE1OjQzOjUxIiwiVG9rZW5UeXBlIjoxLCJpc3MiOiJtaW5pbWF4In0.hf1M4cPe27Sz_QeSyYODqM6yrN8aQ68nRwYB7iQ3uO5nu0NSN7qHQRVxAt2tVuoOf503SEx5F-PfYyC85OFJFhWNNhhDuFuxPIz97LVz1oQUlIejZ_BmCMj4iWwGXTUmEugGK1lzcsI6eJz8eRjQHsxOgJJmxPLXWHTPs1gDqtnckAgjOBRQJSadP58Xe9EdI6n-2_SL_ni3Tqm3LuWq9tUPJa5WgDMZX9IDK7XXyZy0i1GoSXmp8P1O1JmIecBVUoCzyYFwWW787BNdYiyEV3UrFjC_4onJ8Tzh-eGq84-rtxBR5FKO2MpNU_I0xI-W3YJxOEl_JPXXGgX5ASTKNw
group-id: ${MINIMAX_GROUP_ID:your-group-id}
ws-url: wss://api.minimax.io/ws/v1/t2a_v2
model: speech-2.6-hd
voice-id: Chinese (Mandarin)_BashfulGirl
speed: 1.0
vol: 1.0
pitch: 0
audio-sample-rate: 32000
bitrate: 128000