feat: 优化 语音速度
This commit is contained in:
311
docs/GROK_API_DEBUG_GUIDE.md
Normal file
311
docs/GROK_API_DEBUG_GUIDE.md
Normal file
@@ -0,0 +1,311 @@
|
||||
# 🔍 Grok API 调试指南
|
||||
|
||||
## 数据格式说明
|
||||
|
||||
### WebClient 处理后的数据格式
|
||||
|
||||
Spring WebFlux 的 `WebClient.bodyToFlux(String.class)` 会自动处理 SSE 流,每个元素直接就是一个完整的 JSON 字符串:
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "8670de6a-75b3-97e2-fa5a-c2e3d7f1d0f2",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": 1764858564,
|
||||
"model": "grok-4-1-fast-non-reasoning",
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"content": "你",
|
||||
"role": "assistant"
|
||||
},
|
||||
"finish_reason": null
|
||||
}],
|
||||
"system_fingerprint": "fp_174298dd8e"
|
||||
}
|
||||
```
|
||||
|
||||
### 字段说明
|
||||
|
||||
| 字段 | 说明 | 示例值 |
|
||||
|------|------|--------|
|
||||
| `id` | 请求唯一ID | `"8670de6a-75b3-97e2-fa5a-c2e3d7f1d0f2"` |
|
||||
| `object` | 对象类型 | `"chat.completion.chunk"` |
|
||||
| `created` | 创建时间戳 | `1764858564` |
|
||||
| `model` | 使用的模型 | `"grok-4-1-fast-non-reasoning"` |
|
||||
| `choices[0].index` | 选择索引 | `0` |
|
||||
| `choices[0].delta.content` | 当前token内容 | `"你"` |
|
||||
| `choices[0].delta.role` | 角色(首次出现) | `"assistant"` |
|
||||
| `choices[0].finish_reason` | 结束原因 | `null`, `"stop"`, `"length"`, `"content_filter"` |
|
||||
| `system_fingerprint` | 系统指纹 | `"fp_174298dd8e"` |
|
||||
|
||||
## finish_reason 说明
|
||||
|
||||
| 值 | 含义 | 处理建议 |
|
||||
|----|------|----------|
|
||||
| `null` | 流还在继续 | 继续接收token |
|
||||
| `"stop"` | 正常结束 | ✅ 回复完整 |
|
||||
| `"length"` | 达到最大token限制 | ⚠️ 回复可能被截断,建议增加 `max_tokens` |
|
||||
| `"content_filter"` | 内容被过滤 | ⚠️ 回复包含敏感内容 |
|
||||
|
||||
## 解析流程
|
||||
|
||||
```
|
||||
收到数据 → 跳过空行 → 检查 [DONE] → 去除 "data: " 前缀(如果有)
|
||||
↓
|
||||
解析 JSON → 验证 choices 字段 → 获取 choices[0]
|
||||
↓
|
||||
检查 delta 字段 → 提取 content → 回调 onToken(content)
|
||||
↓
|
||||
检查 finish_reason → 记录日志 → 完成
|
||||
```
|
||||
|
||||
## 日志级别配置
|
||||
|
||||
在 `application.yml` 中配置:
|
||||
|
||||
```yaml
|
||||
logging:
|
||||
level:
|
||||
com.xiaozhi.dialogue.llm.GrokStreamService: DEBUG # 或 TRACE
|
||||
```
|
||||
|
||||
### 各级别输出内容
|
||||
|
||||
#### TRACE(最详细)
|
||||
```
|
||||
TRACE - 收到原始SSE数据: {"id":"8670de6a...
|
||||
TRACE - 提取到token: 你
|
||||
TRACE - 提取到token: 好
|
||||
```
|
||||
|
||||
#### DEBUG
|
||||
```
|
||||
DEBUG - 请求体: {"model":"grok-4-1-fast-non-reasoning",...}
|
||||
DEBUG - 收到角色信息: assistant
|
||||
DEBUG - JSON中缺少choices字段(如果有问题)
|
||||
```
|
||||
|
||||
#### INFO
|
||||
```
|
||||
INFO - 开始调用Grok API - Model: grok-4-1-fast-non-reasoning, UserMessage: 你好...
|
||||
INFO - 流结束原因: length
|
||||
INFO - Grok API流式调用完成
|
||||
```
|
||||
|
||||
#### ERROR
|
||||
```
|
||||
ERROR - JSON解析失败 - 原始数据: xxx
|
||||
ERROR - LLM调用失败: Connection refused
|
||||
```
|
||||
|
||||
## 常见问题排查
|
||||
|
||||
### 问题1: 没有收到任何token
|
||||
|
||||
**可能原因**:
|
||||
- API Key 错误
|
||||
- 网络连接问题
|
||||
- API URL 配置错误
|
||||
|
||||
**检查步骤**:
|
||||
1. 查看日志是否有 `INFO - 开始调用Grok API`
|
||||
2. 检查是否有错误日志
|
||||
3. 验证 `application.yml` 中的配置:
|
||||
```yaml
|
||||
xiaozhi:
|
||||
voice-stream:
|
||||
grok:
|
||||
api-key: ${GROK_API_KEY}
|
||||
api-url: https://api.x.ai/v1/chat/completions
|
||||
```
|
||||
|
||||
### 问题2: token乱码或格式错误
|
||||
|
||||
**检查点**:
|
||||
1. 启用 TRACE 日志查看原始数据
|
||||
2. 确认 JSON 格式是否正确
|
||||
3. 查看是否有 `ERROR - JSON解析失败` 日志
|
||||
|
||||
### 问题3: 回复被截断
|
||||
|
||||
**日志特征**:
|
||||
```
|
||||
INFO - 流结束原因: length
|
||||
```
|
||||
|
||||
**解决方案**:
|
||||
增加 `DEFAULT_MAX_TOKENS` 常量:
|
||||
```java
|
||||
private static final int DEFAULT_MAX_TOKENS = 4000; // 或更大
|
||||
```
|
||||
|
||||
### 问题4: WebClient 连接超时
|
||||
|
||||
**错误日志**:
|
||||
```
|
||||
ERROR - LLM调用失败: Connection timeout
|
||||
```
|
||||
|
||||
**解决方案**:
|
||||
在 `initWebClient()` 中增加超时配置:
|
||||
```java
|
||||
this.webClient = WebClient.builder()
|
||||
.baseUrl(apiUrl)
|
||||
.codecs(configurer -> configurer.defaultCodecs().maxInMemorySize(10 * 1024 * 1024))
|
||||
.defaultHeader(HttpHeaders.ACCEPT, "text/event-stream")
|
||||
.build();
|
||||
```
|
||||
|
||||
## 测试建议
|
||||
|
||||
### 1. 使用 curl 测试 API
|
||||
|
||||
```bash
|
||||
curl -X POST https://api.x.ai/v1/chat/completions \
|
||||
-H "Authorization: Bearer YOUR_API_KEY" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "grok-4-1-fast-non-reasoning",
|
||||
"messages": [{"role": "user", "content": "你好"}],
|
||||
"stream": true
|
||||
}'
|
||||
```
|
||||
|
||||
预期输出:
|
||||
```
|
||||
data: {"id":"xxx","object":"chat.completion.chunk",...}
|
||||
data: {"id":"xxx","object":"chat.completion.chunk",...}
|
||||
data: [DONE]
|
||||
```
|
||||
|
||||
### 2. 检查配置
|
||||
|
||||
```java
|
||||
// 在 GrokStreamService 中添加测试方法
|
||||
@PostConstruct
|
||||
public void testConfig() {
|
||||
logger.info("Grok配置:");
|
||||
logger.info(" API URL: {}", voiceStreamConfig.getGrok().getApiUrl());
|
||||
logger.info(" Model: {}", voiceStreamConfig.getGrok().getModel());
|
||||
logger.info(" API Key配置: {}", voiceStreamConfig.getGrok().getApiKey() != null ? "已配置" : "未配置");
|
||||
}
|
||||
```
|
||||
|
||||
### 3. 单元测试
|
||||
|
||||
```java
|
||||
@Test
|
||||
public void testParseSSE() {
|
||||
String testJson = """
|
||||
{
|
||||
"id": "test",
|
||||
"choices": [{
|
||||
"delta": {
|
||||
"content": "测试"
|
||||
}
|
||||
}]
|
||||
}
|
||||
""";
|
||||
|
||||
// 测试解析逻辑
|
||||
}
|
||||
```
|
||||
|
||||
## 性能监控
|
||||
|
||||
### 关键指标
|
||||
|
||||
1. **首 token 延迟**:从发送请求到收到第一个 token 的时间
|
||||
2. **token 吞吐率**:每秒收到的 token 数量
|
||||
3. **总响应时间**:从开始到收到 [DONE] 的时间
|
||||
|
||||
### 添加监控日志
|
||||
|
||||
```java
|
||||
private long startTime;
|
||||
private int tokenCount = 0;
|
||||
|
||||
public void streamChat(...) {
|
||||
startTime = System.currentTimeMillis();
|
||||
// ...
|
||||
.doOnNext(token -> {
|
||||
tokenCount++;
|
||||
if (tokenCount == 1) {
|
||||
long firstTokenLatency = System.currentTimeMillis() - startTime;
|
||||
logger.info("首token延迟: {}ms", firstTokenLatency);
|
||||
}
|
||||
})
|
||||
.doOnComplete(() -> {
|
||||
long totalTime = System.currentTimeMillis() - startTime;
|
||||
double tps = tokenCount / (totalTime / 1000.0);
|
||||
logger.info("总计: {}个token, 耗时: {}ms, 吞吐率: {:.2f} tokens/s",
|
||||
tokenCount, totalTime, tps);
|
||||
})
|
||||
}
|
||||
```
|
||||
|
||||
## 调试技巧
|
||||
|
||||
### 1. 保存原始响应
|
||||
|
||||
```java
|
||||
// 在parseSSE中
|
||||
logger.trace("原始响应: {}", line);
|
||||
```
|
||||
|
||||
### 2. 使用断点调试
|
||||
|
||||
在以下位置设置断点:
|
||||
- `parseSSE()` 方法入口
|
||||
- `sink.next(content)` 之前
|
||||
- `callback.onToken()` 调用处
|
||||
|
||||
### 3. 模拟测试
|
||||
|
||||
创建测试类:
|
||||
```java
|
||||
@Test
|
||||
public void testGrokStreamService() {
|
||||
GrokStreamService service = new GrokStreamService();
|
||||
service.streamChat("你好", null, new TokenCallback() {
|
||||
@Override
|
||||
public void onToken(String token) {
|
||||
System.out.println("Token: " + token);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onComplete() {
|
||||
System.out.println("完成");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onError(String error) {
|
||||
System.err.println("错误: " + error);
|
||||
}
|
||||
});
|
||||
|
||||
// 等待完成
|
||||
Thread.sleep(10000);
|
||||
}
|
||||
```
|
||||
|
||||
## 总结
|
||||
|
||||
正常工作流程:
|
||||
1. ✅ 看到 `INFO - 开始调用Grok API`
|
||||
2. ✅ 看到多个 `TRACE - 提取到token: xxx`
|
||||
3. ✅ 看到 `INFO - 流结束原因: stop`(或 length)
|
||||
4. ✅ 看到 `INFO - Grok API流式调用完成`
|
||||
|
||||
如果出现问题:
|
||||
1. 检查日志级别是否足够详细
|
||||
2. 查找 ERROR 级别的日志
|
||||
3. 验证配置是否正确
|
||||
4. 使用 curl 直接测试 API
|
||||
5. 检查网络连接
|
||||
|
||||
---
|
||||
|
||||
**需要更多帮助?** 提供完整的错误日志和配置信息。
|
||||
|
||||
|
||||
251
docs/TTS_CONNECTION_WARMUP_TEST.md
Normal file
251
docs/TTS_CONNECTION_WARMUP_TEST.md
Normal file
@@ -0,0 +1,251 @@
|
||||
# TTS连接预热优化测试指南
|
||||
|
||||
## 测试目标
|
||||
|
||||
验证TTS WebSocket连接预热和复用功能,确保:
|
||||
1. 用户连接时成功预热TTS连接
|
||||
2. 多句话复用同一个TTS连接
|
||||
3. 连接断开后自动重连
|
||||
4. 用户断开时正确清理连接
|
||||
5. 性能提升符合预期
|
||||
|
||||
## 测试环境准备
|
||||
|
||||
1. 启动后端服务
|
||||
2. 确保MiniMax TTS API配置正确
|
||||
3. 使用微信小程序或WebSocket客户端
|
||||
|
||||
## 测试用例
|
||||
|
||||
### 测试1:连接预热验证
|
||||
|
||||
**目的**:验证用户连接时是否成功预热TTS连接
|
||||
|
||||
**步骤**:
|
||||
1. 前端连接到语音流WebSocket:`ws://your-server:8091/ws/voice-stream`
|
||||
2. 观察后端日志
|
||||
|
||||
**预期日志**:
|
||||
```
|
||||
[VoiceStreamHandler] 语音流WebSocket连接建立 - SessionId: xxx, UserId: null
|
||||
[VoiceStreamService] 预热TTS连接 - SessionId: xxx
|
||||
[MinimaxTtsStreamService] 开始预热TTS连接 - SessionId: xxx
|
||||
[MinimaxTtsStreamService] MiniMax TTS连接已建立 - SessionId: xxx
|
||||
[MinimaxTtsStreamService] 收到connected_success,发送task_start - SessionId: xxx
|
||||
[MinimaxTtsStreamService] 收到task_started,连接就绪 - SessionId: xxx
|
||||
[MinimaxTtsStreamService] TTS连接预热成功 - SessionId: xxx, 耗时: XXXms
|
||||
[VoiceStreamHandler] TTS连接预热成功 - SessionId: xxx
|
||||
```
|
||||
|
||||
**成功标准**:
|
||||
- ✅ 连接建立后立即开始预热
|
||||
- ✅ 预热在1-2秒内完成
|
||||
- ✅ 连接状态变为IDLE
|
||||
|
||||
---
|
||||
|
||||
### 测试2:连接复用验证
|
||||
|
||||
**目的**:验证多句话是否复用同一个TTS连接
|
||||
|
||||
**步骤**:
|
||||
1. 确保已建立连接并预热成功
|
||||
2. 发送一段音频(用户说:"你好,今天天气怎么样?")
|
||||
3. 等待AI回复多句话(例如:"你好!今天天气很好。阳光明媚。")
|
||||
4. 观察后端日志
|
||||
|
||||
**预期日志**:
|
||||
```
|
||||
[VoiceStreamService] STT识别结果 - SessionId: xxx, Text: 你好,今天天气怎么样?
|
||||
[VoiceStreamService] 检测到完整句子 - SessionId: xxx, Sentence: 你好!
|
||||
[MinimaxTtsStreamService] 使用已有TTS连接 - SessionId: xxx, Text: 你好!
|
||||
[MinimaxTtsStreamService] 缓冲音频块 - SessionId: xxx: 2048 bytes, 总计: 2048 bytes
|
||||
...
|
||||
[MinimaxTtsStreamService] TTS完成 - SessionId: xxx, Text: 你好!, 音频总大小: 24576 bytes
|
||||
[MinimaxTtsStreamService] TTS完成 - SessionId: xxx, 耗时: 300ms(复用连接节省约1秒)
|
||||
[VoiceStreamService] 检测到完整句子 - SessionId: xxx, Sentence: 今天天气很好。
|
||||
[MinimaxTtsStreamService] 使用已有TTS连接 - SessionId: xxx, Text: 今天天气很好。
|
||||
[MinimaxTtsStreamService] TTS完成 - SessionId: xxx, 耗时: 250ms(复用连接节省约1秒)
|
||||
```
|
||||
|
||||
**成功标准**:
|
||||
- ✅ 日志显示"使用已有TTS连接"而不是"创建MiniMax TTS连接"
|
||||
- ✅ 每句TTS完成时间在200-500ms(而非1-1.5s)
|
||||
- ✅ 日志显示"复用连接节省约1秒"
|
||||
- ✅ 没有重复的连接建立日志
|
||||
|
||||
---
|
||||
|
||||
### 测试3:自动重连验证
|
||||
|
||||
**目的**:验证连接断开后是否自动重连
|
||||
|
||||
**步骤**:
|
||||
1. 正常建立连接
|
||||
2. 手动关闭MiniMax TTS WebSocket连接(模拟网络断开)
|
||||
3. 发送一段音频触发TTS
|
||||
4. 观察后端日志
|
||||
|
||||
**预期日志**:
|
||||
```
|
||||
[MinimaxTtsStreamService] MiniMax TTS连接关闭 - SessionId: xxx, Code: 1006, Reason: ..., Remote: true
|
||||
[MinimaxTtsStreamService] 尝试自动重连TTS - SessionId: xxx, 第1次重连
|
||||
[MinimaxTtsStreamService] 开始预热TTS连接 - SessionId: xxx
|
||||
[MinimaxTtsStreamService] TTS连接预热成功 - SessionId: xxx, 耗时: XXXms
|
||||
[MinimaxTtsStreamService] 使用已有TTS连接 - SessionId: xxx, Text: ...
|
||||
```
|
||||
|
||||
**成功标准**:
|
||||
- ✅ 检测到连接断开
|
||||
- ✅ 自动触发重连(最多3次)
|
||||
- ✅ 重连成功后可正常使用
|
||||
- ✅ 重连失败3次后停止尝试
|
||||
|
||||
---
|
||||
|
||||
### 测试4:连接清理验证
|
||||
|
||||
**目的**:验证用户断开时是否正确清理TTS连接
|
||||
|
||||
**步骤**:
|
||||
1. 建立连接并预热成功
|
||||
2. 前端主动断开WebSocket连接
|
||||
3. 观察后端日志
|
||||
|
||||
**预期日志**:
|
||||
```
|
||||
[VoiceStreamHandler] 语音流WebSocket连接关闭 - SessionId: xxx, Status: CloseStatus[code=1000, reason=null]
|
||||
[VoiceStreamService] 关闭TTS连接 - SessionId: xxx
|
||||
[MinimaxTtsStreamService] 关闭TTS连接 - SessionId: xxx
|
||||
[MinimaxTtsStreamService] 发送task_finish - SessionId: xxx
|
||||
[MinimaxTtsStreamService] MiniMax TTS连接关闭 - SessionId: xxx, Code: 1000, Reason: ...
|
||||
```
|
||||
|
||||
**成功标准**:
|
||||
- ✅ 正确发送task_finish
|
||||
- ✅ 正常关闭TTS连接
|
||||
- ✅ 从连接池中移除
|
||||
- ✅ 无内存泄漏
|
||||
|
||||
---
|
||||
|
||||
### 测试5:性能对比验证
|
||||
|
||||
**目的**:验证性能提升是否符合预期
|
||||
|
||||
**测试场景**:AI回复3句话
|
||||
|
||||
**修改前性能**:
|
||||
- 首句延迟:1-1.5秒
|
||||
- 第二句延迟:1-1.5秒
|
||||
- 第三句延迟:1-1.5秒
|
||||
- **总延迟**:3-4.5秒
|
||||
|
||||
**修改后预期性能**:
|
||||
- 首句延迟:0.5秒(仅TTS合成)
|
||||
- 第二句延迟:0.2-0.3秒
|
||||
- 第三句延迟:0.2-0.3秒
|
||||
- **总延迟**:0.9-1.1秒
|
||||
|
||||
**性能提升**:约 **65-70%**
|
||||
|
||||
**验证方法**:
|
||||
1. 在日志中记录每句TTS的开始和结束时间
|
||||
2. 计算总耗时
|
||||
3. 与预期值对比
|
||||
|
||||
---
|
||||
|
||||
### 测试6:并发测试
|
||||
|
||||
**目的**:验证多用户同时使用时的连接管理
|
||||
|
||||
**步骤**:
|
||||
1. 创建3-5个WebSocket连接(模拟多用户)
|
||||
2. 每个连接同时发送音频
|
||||
3. 观察后端日志和性能
|
||||
|
||||
**成功标准**:
|
||||
- ✅ 每个用户有独立的TTS连接
|
||||
- ✅ 连接之间互不影响
|
||||
- ✅ 无连接混乱
|
||||
- ✅ 性能稳定
|
||||
|
||||
---
|
||||
|
||||
## 性能监控指标
|
||||
|
||||
在日志中关注以下关键指标:
|
||||
|
||||
1. **连接建立时间**:`TTS连接预热成功 - SessionId: xxx, 耗时: XXXms`
|
||||
2. **连接复用次数**:统计"使用已有TTS连接"出现次数
|
||||
3. **TTS合成时间**:`TTS完成 - SessionId: xxx, 耗时: XXXms`
|
||||
4. **重连次数**:`尝试自动重连TTS - SessionId: xxx, 第X次重连`
|
||||
5. **连接状态变化**:观察状态从CONNECTING → CONNECTED → TASK_STARTED → IDLE的转换
|
||||
|
||||
---
|
||||
|
||||
## 常见问题排查
|
||||
|
||||
### 问题1:预热失败
|
||||
|
||||
**现象**:`TTS连接预热失败 - SessionId: xxx`
|
||||
|
||||
**可能原因**:
|
||||
- MiniMax API配置错误
|
||||
- 网络连接问题
|
||||
- API密钥过期
|
||||
|
||||
**解决方法**:
|
||||
- 检查application.yml中的MiniMax配置
|
||||
- 验证API密钥是否有效
|
||||
- 检查网络连接
|
||||
|
||||
### 问题2:连接未复用
|
||||
|
||||
**现象**:每次TTS都显示"创建MiniMax TTS连接"
|
||||
|
||||
**可能原因**:
|
||||
- 连接状态管理错误
|
||||
- 连接在使用前被关闭
|
||||
|
||||
**解决方法**:
|
||||
- 检查连接状态日志
|
||||
- 确认is_final后状态设为IDLE而非关闭
|
||||
|
||||
### 问题3:自动重连失败
|
||||
|
||||
**现象**:重连3次后仍无法恢复
|
||||
|
||||
**可能原因**:
|
||||
- MiniMax服务问题
|
||||
- 网络持续不稳定
|
||||
|
||||
**解决方法**:
|
||||
- 检查MiniMax服务状态
|
||||
- 降级为按需创建连接
|
||||
|
||||
---
|
||||
|
||||
## 测试清单
|
||||
|
||||
- [ ] 测试1:连接预热验证
|
||||
- [ ] 测试2:连接复用验证
|
||||
- [ ] 测试3:自动重连验证
|
||||
- [ ] 测试4:连接清理验证
|
||||
- [ ] 测试5:性能对比验证
|
||||
- [ ] 测试6:并发测试
|
||||
- [ ] 检查无内存泄漏
|
||||
- [ ] 检查无连接泄漏
|
||||
- [ ] 压力测试(10+并发用户)
|
||||
|
||||
---
|
||||
|
||||
## 测试结论
|
||||
|
||||
完成所有测试后,记录:
|
||||
1. 实际性能提升百分比
|
||||
2. 连接复用成功率
|
||||
3. 自动重连成功率
|
||||
4. 发现的问题和改进建议
|
||||
|
||||
158
docs/VOICE_CHAT_INTERACTION.md
Normal file
158
docs/VOICE_CHAT_INTERACTION.md
Normal file
@@ -0,0 +1,158 @@
|
||||
# 语音对话接口交互文档
|
||||
|
||||
本文档详细描述了前端 `webUI` 与后端 `server` 之间关于语音对话功能(`/voice-chat`)的交互流程、参数传递及返回值结构。
|
||||
|
||||
## 1. 概述
|
||||
|
||||
语音对话功能允许用户录制一段语音,前端将其上传至后端。后端依次执行以下操作:
|
||||
1. **STT (Speech-to-Text)**:将语音转换为文本。
|
||||
2. **LLM (Large Language Model)**:将识别出的文本作为输入,获取 AI 的文本回复。
|
||||
3. **TTS (Text-to-Speech)**:将 AI 的文本回复转换为语音。
|
||||
|
||||
最终,后端将用户识别文本、AI 回复文本及合成的语音数据一次性返回给前端。
|
||||
|
||||
## 2. 前端调用 (WebUI)
|
||||
|
||||
前端主要涉及的文件为 `webUI/src/components/ChatBox.vue` 和 `webUI/src/utils/api.js`。
|
||||
|
||||
### 2.1 调用方式
|
||||
|
||||
在 `ChatBox.vue` 中,当用户完成录音后,调用 `handleVoiceModeMessage` 方法,进而调用 `voiceAPI.voiceChat`。
|
||||
|
||||
底层通过 `uni.uploadFile` 发起 `multipart/form-data` 类型的 POST 请求。
|
||||
|
||||
### 2.2 请求参数
|
||||
|
||||
前端向后端发送的请求包含 **文件** 和 **表单数据 (FormData)**。
|
||||
|
||||
* **URL**: `/api/chat/voice-chat` (由 `config.js` 中的 `VOICE_CHAT` 常量定义)
|
||||
* **Method**: `POST`
|
||||
* **Header**: `Authorization: Bearer <token>`
|
||||
* **文件部分**:
|
||||
* `name`: `"audio"`
|
||||
* `filePath`: 录音文件的本地路径 (e.g., `.aac` 或 `.wav`)
|
||||
|
||||
* **表单数据 (FormData)**:
|
||||
|
||||
| 参数名 | 类型 | 必填 | 说明 | 来源 |
|
||||
| :--- | :--- | :--- | :--- | :--- |
|
||||
| `sessionId` | String | 否 | 会话 ID,用于保持上下文 | `conversationId.value` |
|
||||
| `modelId` | Integer | 否 | 模型 ID | `characterConfig.modelId` |
|
||||
| `templateId` | Integer | 否 | 模板 ID | `characterConfig.templateId` |
|
||||
| `voiceStyle` | String | 否 | 语音风格 (用于 TTS) | `options.voiceStyle` |
|
||||
| `ttsConfigId` | Integer | 否 | TTS 配置 ID | `aiConfig.ttsId` |
|
||||
| `sttConfigId` | Integer | 否 | STT 配置 ID | `aiConfig.sttId` |
|
||||
| `useFunctionCall` | Boolean | 否 | 是否使用函数调用 | 默认为 `false` |
|
||||
|
||||
### 2.3 响应处理
|
||||
|
||||
前端接收到后端返回的 JSON 数据后,进行如下解析:
|
||||
|
||||
1. **用户文本**: 从 `sttResult.text` 获取,显示在聊天界面右侧。
|
||||
2. **AI 回复**: 从 `llmResult.response` 获取,显示在聊天界面左侧。
|
||||
3. **语音播放**: 优先使用 `ttsResult.audioBase64` (Base64 编码音频),如果没有则使用 `ttsResult.audioPath` (音频 URL) 进行播放。
|
||||
|
||||
---
|
||||
|
||||
## 3. 后端处理 (Server)
|
||||
|
||||
后端入口为 `server/src/main/java/com/xiaozhi/controller/ChatController.java`,核心逻辑在 `ChatSessionServiceImpl.java`。
|
||||
|
||||
### 3.1 接口定义
|
||||
|
||||
* **Controller**: `ChatController`
|
||||
* **Path**: `/api/chat/voice-chat`
|
||||
* **Consumes**: `multipart/form-data`
|
||||
|
||||
### 3.2 处理流程
|
||||
|
||||
1. **接收文件**: 后端支持字段名为 `audioFile`, `file`, 或 `audio` 的文件上传 (前端使用 `audio`)。
|
||||
2. **参数解析**: 解析 `sessionId`, `modelId` 等参数。
|
||||
3. **文件验证**: 检查文件大小 (1KB - 50MB)、格式 (支持 mp3, wav, m4a, aac 等) 和 MIME 类型。
|
||||
4. **音频处理**:
|
||||
* 将上传的音频文件保存为临时文件。
|
||||
* 使用 `AudioUtils` 将音频转换为 **PCM 16k 单声道** 格式(适配 STT 引擎)。
|
||||
5. **业务逻辑 (`ChatSessionService.voiceChat`)**:
|
||||
* **STT**: 调用配置的 STT 服务识别语音,得到 `recognizedText`。
|
||||
* **LLM**: 如果识别到文本,调用 `syncChat` 获取 AI 回复 `chatResponse`。
|
||||
* **TTS**: 调用 TTS 服务将 AI 回复转换为语音,生成音频文件并读取为 Base64。
|
||||
6. **结果封装**: 将 STT、LLM、TTS 的结果封装到 Map 中返回。
|
||||
|
||||
---
|
||||
|
||||
## 4. 接口规范
|
||||
|
||||
### 4.1 请求结构
|
||||
|
||||
**POST** `/api/chat/voice-chat`
|
||||
**Content-Type**: `multipart/form-data`
|
||||
|
||||
**Body**:
|
||||
* `audio`: [二进制文件数据]
|
||||
* `sessionId`: "session_12345"
|
||||
* `modelId`: 10
|
||||
* `templateId`: 6
|
||||
* ...
|
||||
|
||||
### 4.2 响应结构
|
||||
|
||||
**Content-Type**: `application/json`
|
||||
|
||||
```json
|
||||
{
|
||||
"code": 200,
|
||||
"message": "语音对话成功",
|
||||
"data": {
|
||||
"sessionId": "session_12345",
|
||||
"timestamp": 1717660000000,
|
||||
|
||||
// 1. STT 结果
|
||||
"sttResult": {
|
||||
"text": "你好,请介绍一下你自己。", // 用户语音识别结果
|
||||
"audioSize": 32000,
|
||||
"sttProvider": "vosk"
|
||||
},
|
||||
|
||||
// 2. LLM 结果
|
||||
"llmResult": {
|
||||
"response": "你好!我是蔚AI,很高兴为你服务。", // AI 回复文本
|
||||
"inputText": "你好,请介绍一下你自己。"
|
||||
},
|
||||
|
||||
// 3. TTS 结果
|
||||
"ttsResult": {
|
||||
"audioBase64": "UklGRi...", // Base64 编码的音频数据 (用于直接播放)
|
||||
"audioPath": "audio/output/...", // 服务器音频文件路径
|
||||
"timestamp": 1717660005000
|
||||
},
|
||||
|
||||
// 性能统计 (耗时: ms)
|
||||
"sttDuration": 500,
|
||||
"llmDuration": 1200,
|
||||
"ttsDuration": 800,
|
||||
|
||||
// 文件元数据
|
||||
"originalFileName": "temp_audio.aac",
|
||||
"fileSize": 15000,
|
||||
"contentType": "audio/aac",
|
||||
"description": null
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 4.3 错误响应
|
||||
|
||||
```json
|
||||
{
|
||||
"code": 400, // 或 500
|
||||
"message": "请求参数错误: 音频文件不能为空",
|
||||
"data": null
|
||||
}
|
||||
```
|
||||
|
||||
## 5. 总结
|
||||
|
||||
* **交互模式**: 同步一次性交互。前端发送音频,等待后端完成所有处理(识别+对话+合成)后,一次性接收所有数据。
|
||||
* **音频格式**: 前端通常录制 `aac` 或 `wav`,后端统一转码为 `pcm` 进行处理。
|
||||
* **回退机制**: 如果后端处理失败,前端 `ChatBox.vue` 会捕获异常并提示用户,或使用本地模拟回复(在未登录等特定情况下)。
|
||||
|
||||
744
docs/WEBSOCKET_VOICE_STREAM_DESIGN.md
Normal file
744
docs/WEBSOCKET_VOICE_STREAM_DESIGN.md
Normal file
@@ -0,0 +1,744 @@
|
||||
# WebSocket 实时语音流式对话架构设计
|
||||
|
||||
## 1. 方案概述
|
||||
|
||||
本方案设计了一套**"伪实时"语音交互系统**,结合前端 VAD(语音活动检测)和后端流式处理,实现低延迟的语音对话体验。
|
||||
|
||||
### 核心特点
|
||||
|
||||
- **前端保持简单**:复用现有 VAD 逻辑,用户说完话才上传完整音频
|
||||
- **后端流式处理**:STT 同步处理,LLM+TTS 流式串联
|
||||
- **极低感知延迟**:用户说完话后 1 秒内听到第一句回复
|
||||
|
||||
### 技术选型
|
||||
|
||||
- **前端通信**:WebSocket (双向实时通信)
|
||||
- **STT**:复用现有 Vosk/其他 STT 服务(同步处理完整音频)
|
||||
- **LLM**:Grok API (`https://api.x.ai/v1/chat/completions`) + SSE Stream
|
||||
- **TTS**:MiniMax WebSocket TTS (`wss://api.minimax.io/ws/v1/t2a_v2`) + Stream
|
||||
|
||||
---
|
||||
|
||||
## 2. 整体架构流程
|
||||
|
||||
```
|
||||
┌─────────────┐
|
||||
│ 前端 UI │
|
||||
└──────┬──────┘
|
||||
│ 1. 用户说话
|
||||
↓
|
||||
┌─────────────────┐
|
||||
│ RecorderManager │ (VAD 检测说完)
|
||||
│ (PCM/AAC) │
|
||||
└────────┬────────┘
|
||||
│ 2. WebSocket 发送完整音频 (二进制)
|
||||
↓
|
||||
┌────────────────────┐
|
||||
│ 后端 WebSocket │
|
||||
│ Handler │
|
||||
└────────┬───────────┘
|
||||
│
|
||||
↓ 3. 调用 STT (同步)
|
||||
┌────────────────┐
|
||||
│ STT Service │ → "你好,今天天气怎么样?"
|
||||
│ (复用现有) │
|
||||
└────────┬───────┘
|
||||
│
|
||||
↓ 4. 流式调用 Grok LLM (HTTP SSE)
|
||||
┌─────────────────────────┐
|
||||
│ Grok API Stream │
|
||||
│ https://api.x.ai/v1/... │
|
||||
└────────┬────────────────┘
|
||||
│ Token 流: "今", "天", "天", "气", "很", "好", "。", ...
|
||||
↓
|
||||
┌────────────────────┐
|
||||
│ 分句缓冲器 │ (检测标点)
|
||||
└────────┬───────────┘
|
||||
│ 完整句子: "今天天气很好。"
|
||||
↓ 5. 为每句调用 MiniMax TTS (WebSocket Stream)
|
||||
┌─────────────────────────────┐
|
||||
│ MiniMax TTS WebSocket │
|
||||
│ wss://api.minimax.io/ws/... │
|
||||
└────────┬────────────────────┘
|
||||
│ 音频流 (Hex → Bytes)
|
||||
↓ 6. 实时转发给前端 (WebSocket 二进制)
|
||||
┌────────────────────┐
|
||||
│ 前端 WebSocket │
|
||||
│ onMessage │
|
||||
└────────┬───────────┘
|
||||
│ 7. WebAudioContext 实时播放
|
||||
↓
|
||||
┌────────────────────┐
|
||||
│ 音频播放队列 │
|
||||
│ (无缝连续播放) │
|
||||
└────────────────────┘
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 3. 前端改造方案 (ChatBox.vue)
|
||||
|
||||
### 3.1 保留的部分
|
||||
|
||||
✅ **现有 VAD 逻辑**:
|
||||
- `recorderManager` 配置
|
||||
- `onStart`、`onFrameRecorded`(用于波形可视化)
|
||||
- `onStop`(核心触发点)
|
||||
- `calculateVolumeRMS`、`vadConfig`(静音检测)
|
||||
|
||||
### 3.2 需要修改的部分
|
||||
|
||||
#### A. 建立 WebSocket 连接
|
||||
|
||||
```javascript
|
||||
// 新增:WebSocket 连接管理
|
||||
const voiceWebSocket = ref(null);
|
||||
|
||||
const connectVoiceWebSocket = () => {
|
||||
const wsUrl = `ws://192.168.3.13:8091/ws/voice-stream`;
|
||||
|
||||
voiceWebSocket.value = uni.connectSocket({
|
||||
url: wsUrl,
|
||||
success: () => {
|
||||
console.log('WebSocket 连接成功');
|
||||
}
|
||||
});
|
||||
|
||||
// 监听消息(接收音频流)
|
||||
voiceWebSocket.value.onMessage((res) => {
|
||||
if (res.data instanceof ArrayBuffer) {
|
||||
// 收到音频数据,加入播放队列
|
||||
playAudioChunk(res.data);
|
||||
} else {
|
||||
// JSON 控制消息
|
||||
const msg = JSON.parse(res.data);
|
||||
if (msg.event === 'stt_done') {
|
||||
addMessage('user', msg.text);
|
||||
voiceState.value = 'thinking';
|
||||
} else if (msg.event === 'llm_start') {
|
||||
voiceState.value = 'speaking';
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
voiceWebSocket.value.onError((err) => {
|
||||
console.error('WebSocket 错误:', err);
|
||||
});
|
||||
};
|
||||
|
||||
// 进入语音模式时连接
|
||||
const toggleVoiceMode = () => {
|
||||
if (isVoiceMode.value) {
|
||||
connectVoiceWebSocket();
|
||||
} else {
|
||||
voiceWebSocket.value?.close();
|
||||
}
|
||||
};
|
||||
```
|
||||
|
||||
#### B. 修改音频上传逻辑
|
||||
|
||||
```javascript
|
||||
// 修改 handleVoiceModeMessage
|
||||
const handleVoiceModeMessage = async (filePath) => {
|
||||
if (!isVoiceMode.value) return;
|
||||
|
||||
voiceState.value = 'thinking';
|
||||
|
||||
try {
|
||||
// 读取录音文件为 ArrayBuffer
|
||||
const fs = uni.getFileSystemManager();
|
||||
const audioData = await new Promise((resolve, reject) => {
|
||||
fs.readFile({
|
||||
filePath: filePath,
|
||||
success: (res) => resolve(res.data),
|
||||
fail: reject
|
||||
});
|
||||
});
|
||||
|
||||
// 通过 WebSocket 发送音频
|
||||
voiceWebSocket.value.send({
|
||||
data: audioData,
|
||||
success: () => {
|
||||
console.log('音频已发送');
|
||||
},
|
||||
fail: (err) => {
|
||||
console.error('发送失败:', err);
|
||||
}
|
||||
});
|
||||
|
||||
} catch (error) {
|
||||
console.error('处理音频失败:', error);
|
||||
voiceState.value = 'idle';
|
||||
}
|
||||
};
|
||||
```
|
||||
|
||||
#### C. 实现流式音频播放
|
||||
|
||||
```javascript
|
||||
// 音频播放队列
|
||||
const audioQueue = ref([]);
|
||||
const isPlayingAudio = ref(false);
|
||||
|
||||
const playAudioChunk = (arrayBuffer) => {
|
||||
audioQueue.value.push(arrayBuffer);
|
||||
if (!isPlayingAudio.value) {
|
||||
processAudioQueue();
|
||||
}
|
||||
};
|
||||
|
||||
const processAudioQueue = async () => {
|
||||
if (audioQueue.value.length === 0) {
|
||||
isPlayingAudio.value = false;
|
||||
// 播放完成,回到待机状态
|
||||
if (isVoiceMode.value && isAutoVoiceMode.value) {
|
||||
startVoiceRecording();
|
||||
} else {
|
||||
voiceState.value = 'idle';
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
isPlayingAudio.value = true;
|
||||
const chunk = audioQueue.value.shift();
|
||||
|
||||
// 使用 InnerAudioContext (需先转为临时文件)
|
||||
const fs = uni.getFileSystemManager();
|
||||
const tempPath = `${wx.env.USER_DATA_PATH}/temp_audio_${Date.now()}.mp3`;
|
||||
|
||||
fs.writeFile({
|
||||
filePath: tempPath,
|
||||
data: chunk,
|
||||
encoding: 'binary',
|
||||
success: () => {
|
||||
const audio = uni.createInnerAudioContext();
|
||||
audio.src = tempPath;
|
||||
audio.onEnded(() => {
|
||||
processAudioQueue(); // 播放下一块
|
||||
});
|
||||
audio.onError(() => {
|
||||
processAudioQueue(); // 出错也继续
|
||||
});
|
||||
audio.play();
|
||||
}
|
||||
});
|
||||
};
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 4. 后端实现方案 (Java)
|
||||
|
||||
### 4.1 WebSocket Handler
|
||||
|
||||
```java
|
||||
@ServerEndpoint(value = "/ws/voice-stream")
|
||||
@Component
|
||||
public class VoiceStreamHandler {
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(VoiceStreamHandler.class);
|
||||
|
||||
@Autowired
|
||||
private SttService sttService;
|
||||
|
||||
@OnOpen
|
||||
public void onOpen(Session session) {
|
||||
logger.info("WebSocket 连接建立: {}", session.getId());
|
||||
}
|
||||
|
||||
@OnMessage
|
||||
public void onBinaryMessage(ByteBuffer audioBuffer, Session session) {
|
||||
logger.info("收到音频数据: {} bytes", audioBuffer.remaining());
|
||||
|
||||
// 转为 byte[]
|
||||
byte[] audioData = new byte[audioBuffer.remaining()];
|
||||
audioBuffer.get(audioData);
|
||||
|
||||
// 异步处理(避免阻塞 WebSocket 线程)
|
||||
CompletableFuture.runAsync(() -> {
|
||||
processVoiceStream(audioData, session);
|
||||
});
|
||||
}
|
||||
|
||||
@OnClose
|
||||
public void onClose(Session session) {
|
||||
logger.info("WebSocket 连接关闭: {}", session.getId());
|
||||
}
|
||||
|
||||
@OnError
|
||||
public void onError(Session session, Throwable error) {
|
||||
logger.error("WebSocket 错误: {}", session.getId(), error);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 4.2 核心处理流程
|
||||
|
||||
```java
|
||||
private void processVoiceStream(byte[] audioData, Session session) {
|
||||
try {
|
||||
// ========== 1. STT (同步处理) ==========
|
||||
long sttStart = System.currentTimeMillis();
|
||||
String recognizedText = performStt(audioData);
|
||||
long sttDuration = System.currentTimeMillis() - sttStart;
|
||||
|
||||
if (recognizedText == null || recognizedText.trim().isEmpty()) {
|
||||
sendJsonMessage(session, Map.of("event", "error", "message", "未识别到语音"));
|
||||
return;
|
||||
}
|
||||
|
||||
logger.info("STT 完成 ({}ms): {}", sttDuration, recognizedText);
|
||||
|
||||
// 通知前端识别结果
|
||||
sendJsonMessage(session, Map.of(
|
||||
"event", "stt_done",
|
||||
"text", recognizedText,
|
||||
"duration", sttDuration
|
||||
));
|
||||
|
||||
// ========== 2. LLM Stream + TTS Stream (串联) ==========
|
||||
streamLlmAndTts(recognizedText, session);
|
||||
|
||||
} catch (Exception e) {
|
||||
logger.error("处理语音流失败", e);
|
||||
sendJsonMessage(session, Map.of("event", "error", "message", e.getMessage()));
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 4.3 STT 处理(复用现有)
|
||||
|
||||
```java
|
||||
private String performStt(byte[] audioData) {
|
||||
try {
|
||||
// 转换为 PCM 16k 单声道
|
||||
byte[] pcmData = AudioUtils.bytesToPcm(audioData);
|
||||
|
||||
// 调用 STT 服务(复用现有逻辑)
|
||||
String text = sttService.recognition(pcmData);
|
||||
|
||||
return text != null ? text.trim() : "";
|
||||
|
||||
} catch (Exception e) {
|
||||
logger.error("STT 处理失败", e);
|
||||
return "";
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 4.4 Grok LLM Stream 调用
|
||||
|
||||
```java
|
||||
private void streamLlmAndTts(String userText, Session frontendSession) {
|
||||
// 分句缓冲器
|
||||
StringBuilder sentenceBuffer = new StringBuilder();
|
||||
|
||||
// 使用 WebClient 调用 Grok SSE API
|
||||
WebClient client = WebClient.builder()
|
||||
.baseUrl("https://api.x.ai")
|
||||
.defaultHeader("Authorization", "Bearer " + grokApiKey)
|
||||
.build();
|
||||
|
||||
sendJsonMessage(frontendSession, Map.of("event", "llm_start"));
|
||||
|
||||
client.post()
|
||||
.uri("/v1/chat/completions")
|
||||
.contentType(MediaType.APPLICATION_JSON)
|
||||
.bodyValue(Map.of(
|
||||
"model", "grok-beta",
|
||||
"messages", List.of(
|
||||
Map.of("role", "system", "content", "你是一个友好的AI助手"),
|
||||
Map.of("role", "user", "content", userText)
|
||||
),
|
||||
"stream", true,
|
||||
"temperature", 0.7
|
||||
))
|
||||
.retrieve()
|
||||
.bodyToFlux(String.class) // SSE 流
|
||||
.subscribe(
|
||||
// onNext: 处理每个 SSE chunk
|
||||
sseChunk -> {
|
||||
String token = parseGrokToken(sseChunk);
|
||||
if (token != null && !token.isEmpty()) {
|
||||
sentenceBuffer.append(token);
|
||||
|
||||
// 检测句子结束
|
||||
if (isSentenceEnd(sentenceBuffer.toString())) {
|
||||
String sentence = sentenceBuffer.toString().trim();
|
||||
sentenceBuffer.setLength(0); // 清空
|
||||
|
||||
logger.info("检测到完整句子: {}", sentence);
|
||||
|
||||
// 调用 TTS 并流式发送
|
||||
streamTtsToFrontend(sentence, frontendSession);
|
||||
}
|
||||
}
|
||||
},
|
||||
// onError
|
||||
error -> {
|
||||
logger.error("Grok LLM Stream 错误", error);
|
||||
sendJsonMessage(frontendSession, Map.of("event", "error", "message", "LLM 处理失败"));
|
||||
},
|
||||
// onComplete
|
||||
() -> {
|
||||
// 如果还有剩余内容,也发送
|
||||
if (sentenceBuffer.length() > 0) {
|
||||
String lastSentence = sentenceBuffer.toString().trim();
|
||||
if (!lastSentence.isEmpty()) {
|
||||
streamTtsToFrontend(lastSentence, frontendSession);
|
||||
}
|
||||
}
|
||||
|
||||
sendJsonMessage(frontendSession, Map.of("event", "llm_complete"));
|
||||
logger.info("LLM 处理完成");
|
||||
}
|
||||
);
|
||||
}
|
||||
```
|
||||
|
||||
### 4.5 Grok SSE 解析
|
||||
|
||||
```java
|
||||
private String parseGrokToken(String sseChunk) {
|
||||
// SSE 格式: data: {"choices":[{"delta":{"content":"你好"}}]}\n\n
|
||||
if (!sseChunk.startsWith("data: ")) {
|
||||
return null;
|
||||
}
|
||||
|
||||
String jsonData = sseChunk.substring(6).trim();
|
||||
if (jsonData.equals("[DONE]")) {
|
||||
return null;
|
||||
}
|
||||
|
||||
try {
|
||||
JsonObject json = JsonParser.parseString(jsonData).getAsJsonObject();
|
||||
JsonArray choices = json.getAsJsonArray("choices");
|
||||
if (choices != null && choices.size() > 0) {
|
||||
JsonObject delta = choices.get(0).getAsJsonObject().getAsJsonObject("delta");
|
||||
if (delta != null && delta.has("content")) {
|
||||
return delta.get("content").getAsString();
|
||||
}
|
||||
}
|
||||
} catch (Exception e) {
|
||||
logger.warn("解析 Grok token 失败: {}", sseChunk);
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
```
|
||||
|
||||
### 4.6 分句逻辑
|
||||
|
||||
```java
|
||||
private boolean isSentenceEnd(String text) {
|
||||
// 中英文标点检测
|
||||
String trimmed = text.trim();
|
||||
if (trimmed.isEmpty()) return false;
|
||||
|
||||
char lastChar = trimmed.charAt(trimmed.length() - 1);
|
||||
|
||||
// 中文标点
|
||||
if (lastChar == '。' || lastChar == '!' || lastChar == '?' ||
|
||||
lastChar == ';' || lastChar == ',') {
|
||||
return true;
|
||||
}
|
||||
|
||||
// 英文标点
|
||||
if (lastChar == '.' || lastChar == '!' || lastChar == '?') {
|
||||
return true;
|
||||
}
|
||||
|
||||
// 防止句子过长(超过50字强制分句)
|
||||
if (trimmed.length() > 50) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
```
|
||||
|
||||
### 4.7 MiniMax TTS WebSocket Stream
|
||||
|
||||
```java
|
||||
private void streamTtsToFrontend(String text, Session frontendSession) {
|
||||
try {
|
||||
URI uri = new URI("wss://api.minimax.io/ws/v1/t2a_v2");
|
||||
|
||||
WebSocketClient ttsClient = new WebSocketClient(uri) {
|
||||
private boolean taskStarted = false;
|
||||
|
||||
@Override
|
||||
public void onOpen(ServerHandshake handshake) {
|
||||
logger.info("MiniMax TTS 连接成功");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onMessage(String message) {
|
||||
try {
|
||||
JsonObject msg = JsonParser.parseString(message).getAsJsonObject();
|
||||
String event = msg.get("event").getAsString();
|
||||
|
||||
if ("connected_success".equals(event)) {
|
||||
// 发送 task_start
|
||||
send(buildTaskStartJson());
|
||||
|
||||
} else if ("task_started".equals(event)) {
|
||||
// 发送 task_continue
|
||||
taskStarted = true;
|
||||
send(buildTaskContinueJson(text));
|
||||
|
||||
} else if (msg.has("data")) {
|
||||
JsonObject data = msg.getAsJsonObject("data");
|
||||
if (data.has("audio")) {
|
||||
String hexAudio = data.get("audio").getAsString();
|
||||
|
||||
// 转为字节数组
|
||||
byte[] audioBytes = hexToBytes(hexAudio);
|
||||
|
||||
// 立即转发给前端
|
||||
frontendSession.getAsyncRemote().sendBinary(
|
||||
ByteBuffer.wrap(audioBytes)
|
||||
);
|
||||
|
||||
logger.debug("转发音频数据: {} bytes", audioBytes.length);
|
||||
}
|
||||
|
||||
// 检查是否结束
|
||||
if (msg.has("is_final") && msg.get("is_final").getAsBoolean()) {
|
||||
send("{\"event\":\"task_finish\"}");
|
||||
close();
|
||||
}
|
||||
}
|
||||
|
||||
} catch (Exception e) {
|
||||
logger.error("处理 TTS 消息失败", e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onError(Exception ex) {
|
||||
logger.error("MiniMax TTS 错误", ex);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onClose(int code, String reason, boolean remote) {
|
||||
logger.info("MiniMax TTS 连接关闭");
|
||||
}
|
||||
};
|
||||
|
||||
// 添加认证头
|
||||
Map<String, String> headers = new HashMap<>();
|
||||
headers.put("Authorization", "Bearer " + minimaxApiKey);
|
||||
ttsClient.setHttpHeaders(headers);
|
||||
|
||||
// 连接(阻塞等待,最多5秒)
|
||||
boolean connected = ttsClient.connectBlocking(5, TimeUnit.SECONDS);
|
||||
if (!connected) {
|
||||
throw new RuntimeException("连接 MiniMax TTS 超时");
|
||||
}
|
||||
|
||||
} catch (Exception e) {
|
||||
logger.error("调用 MiniMax TTS 失败", e);
|
||||
sendJsonMessage(frontendSession, Map.of("event", "error", "message", "TTS 处理失败"));
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 4.8 工具方法
|
||||
|
||||
```java
|
||||
// 构建 task_start JSON
|
||||
private String buildTaskStartJson() {
|
||||
return """
|
||||
{
|
||||
"event": "task_start",
|
||||
"model": "speech-2.6-hd",
|
||||
"voice_setting": {
|
||||
"voice_id": "male-qn-qingse",
|
||||
"speed": 1.0,
|
||||
"vol": 1.0,
|
||||
"pitch": 0
|
||||
},
|
||||
"audio_setting": {
|
||||
"sample_rate": 32000,
|
||||
"bitrate": 128000,
|
||||
"format": "mp3",
|
||||
"channel": 1
|
||||
}
|
||||
}
|
||||
""";
|
||||
}
|
||||
|
||||
// 构建 task_continue JSON
|
||||
private String buildTaskContinueJson(String text) {
|
||||
JsonObject json = new JsonObject();
|
||||
json.addProperty("event", "task_continue");
|
||||
json.addProperty("text", text);
|
||||
return json.toString();
|
||||
}
|
||||
|
||||
// Hex 转字节数组
|
||||
private byte[] hexToBytes(String hex) {
|
||||
int len = hex.length();
|
||||
byte[] data = new byte[len / 2];
|
||||
for (int i = 0; i < len; i += 2) {
|
||||
data[i / 2] = (byte) ((Character.digit(hex.charAt(i), 16) << 4)
|
||||
+ Character.digit(hex.charAt(i+1), 16));
|
||||
}
|
||||
return data;
|
||||
}
|
||||
|
||||
// 发送 JSON 消息给前端
|
||||
private void sendJsonMessage(Session session, Map<String, Object> data) {
|
||||
try {
|
||||
String json = new Gson().toJson(data);
|
||||
session.getAsyncRemote().sendText(json);
|
||||
} catch (Exception e) {
|
||||
logger.error("发送消息失败", e);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 5. Maven 依赖配置
|
||||
|
||||
```xml
|
||||
<!-- Spring Boot WebSocket -->
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter-websocket</artifactId>
|
||||
</dependency>
|
||||
|
||||
<!-- WebClient (用于 Grok SSE) -->
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter-webflux</artifactId>
|
||||
</dependency>
|
||||
|
||||
<!-- Java WebSocket Client (用于 MiniMax) -->
|
||||
<dependency>
|
||||
<groupId>org.java-websocket</groupId>
|
||||
<artifactId>Java-WebSocket</artifactId>
|
||||
<version>1.5.3</version>
|
||||
</dependency>
|
||||
|
||||
<!-- JSON 处理 -->
|
||||
<dependency>
|
||||
<groupId>com.google.code.gson</groupId>
|
||||
<artifactId>gson</artifactId>
|
||||
</dependency>
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 6. 关键技术难点与解决方案
|
||||
|
||||
### 6.1 并发控制
|
||||
|
||||
**问题**:多个句子的 TTS 如何保证顺序?
|
||||
|
||||
**解决方案**:
|
||||
- 使用 **信号量 (Semaphore)** 或 **串行队列**
|
||||
- 一个句子的 TTS 完全结束后,才开始下一个句子
|
||||
|
||||
```java
|
||||
private final Semaphore ttsSemaphore = new Semaphore(1);
|
||||
|
||||
private void streamTtsToFrontend(String text, Session frontendSession) {
|
||||
try {
|
||||
ttsSemaphore.acquire(); // 获取锁
|
||||
// ... TTS 处理 ...
|
||||
} finally {
|
||||
ttsSemaphore.release(); // 释放锁
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 6.2 Grok SSE 连接稳定性
|
||||
|
||||
**问题**:SSE 长连接可能中断。
|
||||
|
||||
**解决方案**:
|
||||
- 设置超时和重试机制
|
||||
- 使用 `Flux.retry(3)` 自动重试
|
||||
|
||||
### 6.3 前端音频播放卡顿
|
||||
|
||||
**问题**:网络抖动导致音频断续。
|
||||
|
||||
**解决方案**:
|
||||
- 实现 **Jitter Buffer**(缓冲 3-5 个音频块再开始播放)
|
||||
- 检测队列长度,动态调整播放速率
|
||||
|
||||
### 6.4 音频格式兼容性
|
||||
|
||||
**问题**:小程序对 PCM 格式支持不佳。
|
||||
|
||||
**解决方案**:
|
||||
- MiniMax 配置输出 `mp3` 格式(压缩率高,兼容性好)
|
||||
- 前端直接播放 MP3 无需解码
|
||||
|
||||
---
|
||||
|
||||
## 7. 性能指标预估
|
||||
|
||||
| 阶段 | 预估耗时 | 说明 |
|
||||
|---|---|---|
|
||||
| **前端 VAD** | 实时 | 用户说话期间持续检测 |
|
||||
| **音频上传** | 50-200ms | 取决于网络和文件大小 |
|
||||
| **STT 处理** | 300-800ms | Vosk 本地处理较快 |
|
||||
| **LLM 首 Token** | 500-1000ms | Grok 响应速度 |
|
||||
| **TTS 首块音频** | 200-500ms | MiniMax WebSocket 延迟 |
|
||||
| **首次播放延迟** | **1-2秒** | 用户说完到听到回复 |
|
||||
|
||||
---
|
||||
|
||||
## 8. 测试建议
|
||||
|
||||
### 8.1 单元测试
|
||||
|
||||
- STT 服务独立测试
|
||||
- Grok API 调用测试(Mock SSE 流)
|
||||
- MiniMax TTS 调用测试(Mock WebSocket)
|
||||
|
||||
### 8.2 集成测试
|
||||
|
||||
- 完整语音流测试(录音 -> STT -> LLM -> TTS -> 播放)
|
||||
- 并发测试(多用户同时对话)
|
||||
- 异常场景(网络中断、API 超时)
|
||||
|
||||
### 8.3 压力测试
|
||||
|
||||
- 模拟 100 并发用户
|
||||
- 监控服务器 CPU、内存、网络 I/O
|
||||
- 检查 WebSocket 连接泄漏
|
||||
|
||||
---
|
||||
|
||||
## 9. 优化方向
|
||||
|
||||
### 短期优化
|
||||
1. **缓存 TTS 结果**:相同文本不重复合成
|
||||
2. **自适应分句**:根据网络状况动态调整句子长度
|
||||
3. **优雅降级**:API 失败时使用备用服务
|
||||
|
||||
### 长期优化
|
||||
1. **边缘计算**:STT 迁移到设备端(Whisper.cpp)
|
||||
2. **模型本地化**:部署私有 LLM 和 TTS
|
||||
3. **多模态融合**:支持图片、视频输入
|
||||
|
||||
---
|
||||
|
||||
## 10. 总结
|
||||
|
||||
本方案通过 **WebSocket 全双工通信** 实现了高效的语音流式交互:
|
||||
|
||||
✅ **前端简单**:保留现有 VAD,只需改造上传和播放逻辑
|
||||
✅ **后端高效**:LLM+TTS 流式串联,极低延迟
|
||||
✅ **用户体验**:说完话 1 秒内听到回复,接近真人对话
|
||||
✅ **技术成熟**:Grok、MiniMax 官方支持流式 API
|
||||
|
||||
最终效果:**从"录音-等待-播放"进化为"流式对话"**,用户感知延迟降低 **60-80%**。🚀
|
||||
|
||||
183
docs/import asyncio
Normal file
183
docs/import asyncio
Normal file
@@ -0,0 +1,183 @@
|
||||
import asyncio
|
||||
import websockets
|
||||
import json
|
||||
import ssl
|
||||
import subprocess
|
||||
import os
|
||||
|
||||
model = "speech-2.6-hd"
|
||||
file_format = "mp3"
|
||||
|
||||
class StreamAudioPlayer:
|
||||
def __init__(self):
|
||||
self.mpv_process = None
|
||||
|
||||
def start_mpv(self):
|
||||
"""Start MPV player process"""
|
||||
try:
|
||||
mpv_command = ["mpv", "--no-cache", "--no-terminal", "--", "fd://0"]
|
||||
self.mpv_process = subprocess.Popen(
|
||||
mpv_command,
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
)
|
||||
print("MPV player started")
|
||||
return True
|
||||
except FileNotFoundError:
|
||||
print("Error: mpv not found. Please install mpv")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"Failed to start mpv: {e}")
|
||||
return False
|
||||
|
||||
def play_audio_chunk(self, hex_audio):
|
||||
"""Play audio chunk"""
|
||||
try:
|
||||
if self.mpv_process and self.mpv_process.stdin:
|
||||
audio_bytes = bytes.fromhex(hex_audio)
|
||||
self.mpv_process.stdin.write(audio_bytes)
|
||||
self.mpv_process.stdin.flush()
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Play failed: {e}")
|
||||
return False
|
||||
return False
|
||||
|
||||
def stop(self):
|
||||
"""Stop player"""
|
||||
if self.mpv_process:
|
||||
if self.mpv_process.stdin and not self.mpv_process.stdin.closed:
|
||||
self.mpv_process.stdin.close()
|
||||
try:
|
||||
self.mpv_process.wait(timeout=20)
|
||||
except subprocess.TimeoutExpired:
|
||||
self.mpv_process.terminate()
|
||||
|
||||
async def establish_connection(api_key):
|
||||
"""Establish WebSocket connection"""
|
||||
url = "wss://api.minimax.io/ws/v1/t2a_v2"
|
||||
headers = {"Authorization": f"Bearer {api_key}"}
|
||||
|
||||
ssl_context = ssl.create_default_context()
|
||||
ssl_context.check_hostname = False
|
||||
ssl_context.verify_mode = ssl.CERT_NONE
|
||||
|
||||
try:
|
||||
ws = await websockets.connect(url, additional_headers=headers, ssl=ssl_context)
|
||||
connected = json.loads(await ws.recv())
|
||||
if connected.get("event") == "connected_success":
|
||||
print("Connection successful")
|
||||
return ws
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Connection failed: {e}")
|
||||
return None
|
||||
|
||||
async def start_task(websocket):
|
||||
"""Send task start request"""
|
||||
start_msg = {
|
||||
"event": "task_start",
|
||||
"model": model,
|
||||
"voice_setting": {
|
||||
"voice_id": "male-qn-qingse",
|
||||
"speed": 1,
|
||||
"vol": 1,
|
||||
"pitch": 0,
|
||||
"english_normalization": False
|
||||
},
|
||||
"audio_setting": {
|
||||
"sample_rate": 32000,
|
||||
"bitrate": 128000,
|
||||
"format": file_format,
|
||||
"channel": 1
|
||||
}
|
||||
}
|
||||
await websocket.send(json.dumps(start_msg))
|
||||
response = json.loads(await websocket.recv())
|
||||
return response.get("event") == "task_started"
|
||||
|
||||
async def continue_task_with_stream_play(websocket, text, player):
|
||||
"""Send continue request and stream play audio"""
|
||||
await websocket.send(json.dumps({
|
||||
"event": "task_continue",
|
||||
"text": text
|
||||
}))
|
||||
|
||||
chunk_counter = 1
|
||||
total_audio_size = 0
|
||||
audio_data = b""
|
||||
|
||||
while True:
|
||||
try:
|
||||
response = json.loads(await websocket.recv())
|
||||
|
||||
if "data" in response and "audio" in response["data"]:
|
||||
audio = response["data"]["audio"]
|
||||
if audio:
|
||||
print(f"Playing chunk #{chunk_counter}")
|
||||
audio_bytes = bytes.fromhex(audio)
|
||||
if player.play_audio_chunk(audio):
|
||||
total_audio_size += len(audio_bytes)
|
||||
audio_data += audio_bytes
|
||||
chunk_counter += 1
|
||||
|
||||
if response.get("is_final"):
|
||||
print(f"Audio synthesis completed: {chunk_counter-1} chunks")
|
||||
if player.mpv_process and player.mpv_process.stdin:
|
||||
player.mpv_process.stdin.close()
|
||||
|
||||
# Save audio to file
|
||||
with open(f"output.{file_format}", "wb") as f:
|
||||
f.write(audio_data)
|
||||
print(f"Audio saved as output.{file_format}")
|
||||
|
||||
estimated_duration = total_audio_size * 0.0625 / 1000
|
||||
wait_time = max(estimated_duration + 5, 10)
|
||||
return wait_time
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
break
|
||||
|
||||
return 10
|
||||
|
||||
async def close_connection(websocket):
|
||||
"""Close connection"""
|
||||
if websocket:
|
||||
try:
|
||||
await websocket.send(json.dumps({"event": "task_finish"}))
|
||||
await websocket.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def main():
|
||||
API_KEY = os.getenv("MINIMAX_API_KEY")
|
||||
TEXT = "The real danger is not that computers start thinking like people, but that people start thinking like computers. Computers can only help us with simple tasks."
|
||||
|
||||
player = StreamAudioPlayer()
|
||||
|
||||
try:
|
||||
if not player.start_mpv():
|
||||
return
|
||||
|
||||
ws = await establish_connection(API_KEY)
|
||||
if not ws:
|
||||
return
|
||||
|
||||
if not await start_task(ws):
|
||||
print("Task startup failed")
|
||||
return
|
||||
|
||||
wait_time = await continue_task_with_stream_play(ws, TEXT, player)
|
||||
await asyncio.sleep(wait_time)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
finally:
|
||||
player.stop()
|
||||
if 'ws' in locals():
|
||||
await close_connection(ws)
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
13
pom.xml
13
pom.xml
@@ -41,6 +41,19 @@
|
||||
<artifactId>spring-boot-starter-websocket</artifactId>
|
||||
</dependency>
|
||||
|
||||
<!-- Spring WebFlux for Grok SSE -->
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter-webflux</artifactId>
|
||||
</dependency>
|
||||
|
||||
<!-- Java-WebSocket for MiniMax TTS -->
|
||||
<dependency>
|
||||
<groupId>org.java-websocket</groupId>
|
||||
<artifactId>Java-WebSocket</artifactId>
|
||||
<version>1.5.7</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter-aop</artifactId>
|
||||
|
||||
@@ -0,0 +1,327 @@
|
||||
package com.xiaozhi.communication.server.websocket;
|
||||
|
||||
import com.google.gson.Gson;
|
||||
import com.xiaozhi.service.VoiceStreamService;
|
||||
import jakarta.annotation.Resource;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.springframework.web.socket.BinaryMessage;
|
||||
import org.springframework.web.socket.CloseStatus;
|
||||
import org.springframework.web.socket.TextMessage;
|
||||
import org.springframework.web.socket.WebSocketSession;
|
||||
import org.springframework.web.socket.handler.AbstractWebSocketHandler;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.net.URI;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
/**
|
||||
* 语音流式对话 WebSocket 处理器
|
||||
*/
|
||||
@Component
|
||||
public class VoiceStreamHandler extends AbstractWebSocketHandler {
|
||||
private static final Logger logger = LoggerFactory.getLogger(VoiceStreamHandler.class);
|
||||
private static final Gson gson = new Gson();
|
||||
|
||||
// 保存所有活跃的会话
|
||||
private final Map<String, WebSocketSession> sessions = new ConcurrentHashMap<>();
|
||||
|
||||
@Resource
|
||||
private VoiceStreamService voiceStreamService;
|
||||
|
||||
@Override
|
||||
public void afterConnectionEstablished(WebSocketSession session) {
|
||||
String sessionId = session.getId();
|
||||
|
||||
// 从请求头或查询参数获取用户认证信息
|
||||
Map<String, String> params = getParamsFromSession(session);
|
||||
String token = params.get("token");
|
||||
String userId = params.get("userId");
|
||||
|
||||
// 获取聊天会话相关参数(sessionId 用于历史记录查询和保存)
|
||||
String chatSessionId = params.get("sessionId");
|
||||
String templateId = params.get("templateId");
|
||||
|
||||
logger.info("语音流WebSocket连接建立 - SessionId: {}, UserId: {}, ChatSessionId: {}, TemplateId: {}",
|
||||
sessionId, userId, chatSessionId, templateId);
|
||||
|
||||
// 保存会话和相关参数到session attributes
|
||||
sessions.put(sessionId, session);
|
||||
if (chatSessionId != null) {
|
||||
session.getAttributes().put("chatSessionId", chatSessionId);
|
||||
}
|
||||
if (templateId != null) {
|
||||
session.getAttributes().put("templateId", Integer.parseInt(templateId));
|
||||
}
|
||||
|
||||
// 发送连接成功消息
|
||||
sendTextMessage(session, createMessage("connected", "连接成功", Map.of("sessionId", sessionId)));
|
||||
|
||||
// 预热TTS连接
|
||||
voiceStreamService.warmupTtsConnection(sessionId)
|
||||
.thenAccept(v -> {
|
||||
logger.info("TTS连接预热成功 - SessionId: {}", sessionId);
|
||||
})
|
||||
.exceptionally(ex -> {
|
||||
logger.error("TTS连接预热失败 - SessionId: {}", sessionId, ex);
|
||||
// 预热失败不影响主流程,降级为按需创建
|
||||
return null;
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void handleBinaryMessage(WebSocketSession session, BinaryMessage message) {
|
||||
String sessionId = session.getId();
|
||||
byte[] audioData = message.getPayload().array();
|
||||
|
||||
logger.debug("收到音频数据 - SessionId: {}, Size: {} bytes", sessionId, audioData.length);
|
||||
|
||||
// 从session attributes获取聊天会话相关参数
|
||||
String chatSessionId = (String) session.getAttributes().get("chatSessionId");
|
||||
Integer templateId = (Integer) session.getAttributes().get("templateId");
|
||||
|
||||
try {
|
||||
// 判断是否使用带历史记录的处理方法
|
||||
if (chatSessionId != null) {
|
||||
// 调用带历史记录的语音流服务
|
||||
voiceStreamService.processAudioStreamWithHistory(
|
||||
sessionId, audioData, chatSessionId, templateId,
|
||||
new VoiceStreamService.StreamCallback() {
|
||||
@Override
|
||||
public void onSttResult(String text) {
|
||||
// STT识别结果
|
||||
sendTextMessage(session, createMessage("stt_result", text, null));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onLlmToken(String token) {
|
||||
// LLM输出token(可选发送给前端显示)
|
||||
sendTextMessage(session, createMessage("llm_token", token, null));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onSentenceComplete(String sentence) {
|
||||
// 完整句子(发送给前端显示)
|
||||
sendTextMessage(session, createMessage("sentence", sentence, null));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onAudioChunk(byte[] audioChunk) {
|
||||
// TTS音频数据
|
||||
sendBinaryMessage(session, audioChunk);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onComplete() {
|
||||
// 所有处理完成
|
||||
sendTextMessage(session, createMessage("complete", "对话完成", null));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onError(String error) {
|
||||
// 错误处理
|
||||
sendTextMessage(session, createMessage("error", error, null));
|
||||
}
|
||||
});
|
||||
} else {
|
||||
// 调用原有的不带历史记录的方法(向后兼容)
|
||||
voiceStreamService.processAudioStream(sessionId, audioData, new VoiceStreamService.StreamCallback() {
|
||||
@Override
|
||||
public void onSttResult(String text) {
|
||||
// STT识别结果
|
||||
sendTextMessage(session, createMessage("stt_result", text, null));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onLlmToken(String token) {
|
||||
// LLM输出token(可选发送给前端显示)
|
||||
sendTextMessage(session, createMessage("llm_token", token, null));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onSentenceComplete(String sentence) {
|
||||
// 完整句子(发送给前端显示)
|
||||
sendTextMessage(session, createMessage("sentence", sentence, null));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onAudioChunk(byte[] audioChunk) {
|
||||
// TTS音频数据
|
||||
sendBinaryMessage(session, audioChunk);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onComplete() {
|
||||
// 所有处理完成
|
||||
sendTextMessage(session, createMessage("complete", "对话完成", null));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onError(String error) {
|
||||
// 错误处理
|
||||
sendTextMessage(session, createMessage("error", error, null));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} catch (Exception e) {
|
||||
logger.error("处理音频流失败 - SessionId: {}", sessionId, e);
|
||||
sendTextMessage(session, createMessage("error", "处理音频失败: " + e.getMessage(), null));
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void handleTextMessage(WebSocketSession session, TextMessage message) {
|
||||
String sessionId = session.getId();
|
||||
String payload = message.getPayload();
|
||||
|
||||
logger.debug("收到文本消息 - SessionId: {}, Message: {}", sessionId, payload);
|
||||
|
||||
try {
|
||||
@SuppressWarnings("unchecked")
|
||||
Map<String, Object> msgMap = gson.fromJson(payload, Map.class);
|
||||
String type = (String) msgMap.get("type");
|
||||
|
||||
if ("ping".equals(type)) {
|
||||
// 心跳响应
|
||||
sendTextMessage(session, createMessage("pong", "pong", null));
|
||||
} else if ("cancel".equals(type)) {
|
||||
// 取消当前对话(打断)
|
||||
voiceStreamService.cancelStream(sessionId);
|
||||
sendTextMessage(session, createMessage("cancelled", "已取消", null));
|
||||
}
|
||||
|
||||
} catch (Exception e) {
|
||||
logger.error("处理文本消息失败 - SessionId: {}", sessionId, e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterConnectionClosed(WebSocketSession session, CloseStatus status) {
|
||||
String sessionId = session.getId();
|
||||
sessions.remove(sessionId);
|
||||
|
||||
// 取消该会话的所有处理
|
||||
voiceStreamService.cancelStream(sessionId);
|
||||
|
||||
// 清理TTS连接
|
||||
voiceStreamService.closeTtsConnection(sessionId);
|
||||
|
||||
logger.info("语音流WebSocket连接关闭 - SessionId: {}, Status: {}", sessionId, status);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handleTransportError(WebSocketSession session, Throwable exception) {
|
||||
String sessionId = session.getId();
|
||||
|
||||
if (isClientCloseRequest(exception)) {
|
||||
logger.info("WebSocket连接被客户端主动关闭 - SessionId: {}", sessionId);
|
||||
} else {
|
||||
logger.error("WebSocket传输错误 - SessionId: {}", sessionId, exception);
|
||||
}
|
||||
|
||||
// 清理会话
|
||||
sessions.remove(sessionId);
|
||||
voiceStreamService.cancelStream(sessionId);
|
||||
}
|
||||
|
||||
/**
|
||||
* 判断异常是否由客户端主动关闭连接导致
|
||||
*/
|
||||
private boolean isClientCloseRequest(Throwable exception) {
|
||||
if (exception instanceof IOException) {
|
||||
String message = exception.getMessage();
|
||||
if (message != null) {
|
||||
return message.contains("Connection reset by peer") ||
|
||||
message.contains("Broken pipe") ||
|
||||
message.contains("Connection closed") ||
|
||||
message.contains("远程主机强迫关闭了一个现有的连接");
|
||||
}
|
||||
return exception instanceof java.io.EOFException;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* 从会话中获取参数(从query string或header)
|
||||
*/
|
||||
private Map<String, String> getParamsFromSession(WebSocketSession session) {
|
||||
Map<String, String> params = new HashMap<>();
|
||||
|
||||
// 从header获取
|
||||
String token = session.getHandshakeHeaders().getFirst("Authorization");
|
||||
if (token != null) {
|
||||
params.put("token", token.replace("Bearer ", ""));
|
||||
}
|
||||
|
||||
String userId = session.getHandshakeHeaders().getFirst("User-Id");
|
||||
if (userId != null) {
|
||||
params.put("userId", userId);
|
||||
}
|
||||
|
||||
// 从URI query参数获取
|
||||
URI uri = session.getUri();
|
||||
if (uri != null && uri.getQuery() != null) {
|
||||
String query = uri.getQuery();
|
||||
for (String param : query.split("&")) {
|
||||
String[] kv = param.split("=", 2);
|
||||
if (kv.length == 2) {
|
||||
params.put(kv[0], kv[1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建JSON格式的消息
|
||||
*/
|
||||
private String createMessage(String type, String message, Map<String, Object> data) {
|
||||
Map<String, Object> msg = new HashMap<>();
|
||||
msg.put("type", type);
|
||||
msg.put("message", message);
|
||||
msg.put("timestamp", System.currentTimeMillis());
|
||||
if (data != null) {
|
||||
msg.put("data", data);
|
||||
}
|
||||
return gson.toJson(msg);
|
||||
}
|
||||
|
||||
/**
|
||||
* 发送文本消息
|
||||
*/
|
||||
private void sendTextMessage(WebSocketSession session, String message) {
|
||||
try {
|
||||
if (session.isOpen()) {
|
||||
session.sendMessage(new TextMessage(message));
|
||||
}
|
||||
} catch (Exception e) {
|
||||
logger.error("发送文本消息失败 - SessionId: {}", session.getId(), e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 发送二进制消息
|
||||
*/
|
||||
private void sendBinaryMessage(WebSocketSession session, byte[] data) {
|
||||
try {
|
||||
if (session.isOpen()) {
|
||||
session.sendMessage(new BinaryMessage(data));
|
||||
}
|
||||
} catch (Exception e) {
|
||||
logger.error("发送二进制消息失败 - SessionId: {}", session.getId(), e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取指定会话
|
||||
*/
|
||||
public WebSocketSession getSession(String sessionId) {
|
||||
return sessions.get(sessionId);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -21,18 +21,27 @@ public class WebSocketConfig implements WebSocketConfigurer {
|
||||
|
||||
// 定义为public static以便其他类可以访问
|
||||
public static final String WS_PATH = "/ws/xiaozhi/v1/";
|
||||
public static final String VOICE_STREAM_PATH = "/ws/voice-stream";
|
||||
|
||||
@Resource
|
||||
private WebSocketHandler webSocketHandler;
|
||||
|
||||
@Resource
|
||||
private VoiceStreamHandler voiceStreamHandler;
|
||||
|
||||
@Resource
|
||||
private CmsUtils cmsUtils;
|
||||
|
||||
@Override
|
||||
public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
|
||||
// 设备WebSocket端点
|
||||
registry.addHandler(webSocketHandler, WS_PATH)
|
||||
.setAllowedOrigins("*");
|
||||
|
||||
// 语音流式对话WebSocket端点
|
||||
registry.addHandler(voiceStreamHandler, VOICE_STREAM_PATH)
|
||||
.setAllowedOrigins("*");
|
||||
|
||||
logger.info("📡 WebSocket服务地址: {}", cmsUtils.getWebsocketAddress());
|
||||
logger.info("==========================================================");
|
||||
}
|
||||
|
||||
89
src/main/java/com/xiaozhi/config/VoiceStreamConfig.java
Normal file
89
src/main/java/com/xiaozhi/config/VoiceStreamConfig.java
Normal file
@@ -0,0 +1,89 @@
|
||||
package com.xiaozhi.config;
|
||||
|
||||
import lombok.Data;
|
||||
import org.springframework.boot.context.properties.ConfigurationProperties;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
|
||||
/**
|
||||
* 语音流式对话配置类
|
||||
*/
|
||||
@Configuration
|
||||
@ConfigurationProperties(prefix = "xiaozhi.voice-stream")
|
||||
@Data
|
||||
public class VoiceStreamConfig {
|
||||
|
||||
private GrokConfig grok = new GrokConfig();
|
||||
private MinimaxConfig minimax = new MinimaxConfig();
|
||||
|
||||
@Data
|
||||
public static class GrokConfig {
|
||||
/**
|
||||
* Grok API Key
|
||||
*/
|
||||
private String apiKey;
|
||||
|
||||
/**
|
||||
* Grok API URL
|
||||
*/
|
||||
private String apiUrl;
|
||||
|
||||
/**
|
||||
* Grok 模型名称
|
||||
*/
|
||||
private String model;
|
||||
}
|
||||
|
||||
@Data
|
||||
public static class MinimaxConfig {
|
||||
/**
|
||||
* MiniMax API Key
|
||||
*/
|
||||
private String apiKey;
|
||||
|
||||
/**
|
||||
* MiniMax Group ID
|
||||
*/
|
||||
private String groupId;
|
||||
|
||||
/**
|
||||
* MiniMax WebSocket URL
|
||||
*/
|
||||
private String wsUrl;
|
||||
|
||||
/**
|
||||
* MiniMax TTS 模型名称
|
||||
*/
|
||||
private String model;
|
||||
|
||||
/**
|
||||
* 音色ID
|
||||
*/
|
||||
private String voiceId;
|
||||
|
||||
/**
|
||||
* 语速 (0.5-2.0)
|
||||
*/
|
||||
private Double speed;
|
||||
|
||||
/**
|
||||
* 音量 (0.1-10.0)
|
||||
*/
|
||||
private Double vol;
|
||||
|
||||
/**
|
||||
* 音调 (-12到12)
|
||||
*/
|
||||
private Integer pitch;
|
||||
|
||||
/**
|
||||
* 音频采样率
|
||||
*/
|
||||
private Integer audioSampleRate;
|
||||
|
||||
/**
|
||||
* 比特率
|
||||
*/
|
||||
private Integer bitrate;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
package com.xiaozhi.controller;
|
||||
|
||||
import com.xiaozhi.dialogue.stt.SttService;
|
||||
import com.xiaozhi.dto.ChatRequest;
|
||||
import com.xiaozhi.dto.TtsRequest;
|
||||
import com.xiaozhi.dto.VoiceChatRequest;
|
||||
import com.xiaozhi.dto.AudioUploadRequest;
|
||||
import com.xiaozhi.dto.ChatResponse;
|
||||
import com.xiaozhi.dto.ApiResponse;
|
||||
import com.xiaozhi.entity.SysConfig;
|
||||
import com.xiaozhi.service.ChatSessionService;
|
||||
import com.xiaozhi.service.TtsIntegrationService;
|
||||
import com.xiaozhi.config.ChatConstants;
|
||||
@@ -182,33 +184,6 @@ public class ChatController {
|
||||
ChatConstants.ERROR_INTERNAL + ": " + e.getMessage()));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 语音对话接口
|
||||
* @param request 语音对话请求
|
||||
* @return 包含识别文本、回复文本和音频Base64的响应
|
||||
*/
|
||||
@Operation(summary = "语音对话", description = "接收音频数据进行语音识别、对话生成和语音合成的完整流程。必填参数:audioData;可选参数:sessionId、useFunctionCall、modelId、templateId、sttConfigId、ttsConfigId")
|
||||
@PostMapping("/voice-chat")
|
||||
public ResponseEntity<ApiResponse<Map<String, Object>>> voiceChat(
|
||||
@Parameter(description = "语音对话请求参数", required = true,
|
||||
content = @Content(examples = @ExampleObject(value = "{\"audioData\":\"UklGRiQAAABXQVZFZm10IBAAAAABAAEA...\",\"sessionId\":\"session123\",\"useFunctionCall\":false,\"modelId\":7,\"templateId\":1,\"sttConfigId\":9,\"ttsConfigId\":8}")))
|
||||
@RequestBody VoiceChatRequest request) {
|
||||
try {
|
||||
logger.info("收到语音对话请求");
|
||||
|
||||
Map<String, Object> result = chatSessionService.voiceChat(request);
|
||||
|
||||
logger.info("语音对话响应成功");
|
||||
return ResponseEntity.ok(ApiResponse.success(ChatConstants.VOICE_CHAT_SUCCESS_MESSAGE, result));
|
||||
} catch (Exception e) {
|
||||
logger.error("语音对话处理失败", e);
|
||||
return ResponseEntity.internalServerError()
|
||||
.body(ApiResponse.error(ChatConstants.INTERNAL_ERROR_CODE,
|
||||
ChatConstants.ERROR_INTERNAL + ": " + e.getMessage()));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 语音对话接口(multipart/form-data 版本)
|
||||
* 为兼容客户端以表单方式上传音频数据,支持多个常见字段名。
|
||||
@@ -273,6 +248,10 @@ public class ChatController {
|
||||
}
|
||||
}
|
||||
|
||||
@Resource
|
||||
private com.xiaozhi.dialogue.stt.factory.SttServiceFactory sttServiceFactory;
|
||||
|
||||
|
||||
/**
|
||||
* 清除会话缓存
|
||||
* @param sessionId 会话ID
|
||||
|
||||
336
src/main/java/com/xiaozhi/dialogue/llm/GrokStreamService.java
Normal file
336
src/main/java/com/xiaozhi/dialogue/llm/GrokStreamService.java
Normal file
@@ -0,0 +1,336 @@
|
||||
package com.xiaozhi.dialogue.llm;
|
||||
|
||||
import com.google.gson.Gson;
|
||||
import com.google.gson.JsonObject;
|
||||
import com.google.gson.JsonParser;
|
||||
import com.xiaozhi.config.VoiceStreamConfig;
|
||||
import jakarta.annotation.Resource;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.web.reactive.function.client.WebClient;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* Grok LLM 流式服务
|
||||
*
|
||||
* 支持通过 Server-Sent Events (SSE) 方式流式调用 Grok API
|
||||
*
|
||||
* API文档: https://docs.x.ai/api
|
||||
*
|
||||
* 响应格式:
|
||||
* data: {"id":"xxx","object":"chat.completion.chunk","created":xxx,
|
||||
* "model":"grok-4-1-fast-non-reasoning",
|
||||
* "choices":[{"index":0,"delta":{"content":"token","role":"assistant"}}],
|
||||
* "system_fingerprint":"xxx"}
|
||||
* data: [DONE]
|
||||
*/
|
||||
@Service
|
||||
public class GrokStreamService {
|
||||
private static final Logger logger = LoggerFactory.getLogger(GrokStreamService.class);
|
||||
private static final Gson gson = new Gson();
|
||||
|
||||
// SSE数据前缀
|
||||
private static final String SSE_DATA_PREFIX = "data: ";
|
||||
// SSE流结束标记
|
||||
private static final String SSE_DONE_FLAG = "[DONE]";
|
||||
// 默认温度参数
|
||||
private static final double DEFAULT_TEMPERATURE = 0.7;
|
||||
// 默认最大token数(如果经常被截断,可以增加这个值)
|
||||
private static final int DEFAULT_MAX_TOKENS = 4000;
|
||||
|
||||
@Resource
|
||||
private VoiceStreamConfig voiceStreamConfig;
|
||||
|
||||
private WebClient webClient;
|
||||
|
||||
/**
|
||||
* 流式调用 Grok API
|
||||
*
|
||||
* @param userMessage 用户消息
|
||||
* @param systemPrompt 系统提示词(可选)
|
||||
* @param callback 接收token的回调
|
||||
*/
|
||||
public void streamChat(String userMessage, String systemPrompt, TokenCallback callback) {
|
||||
try {
|
||||
// 参数验证
|
||||
if (userMessage == null || userMessage.trim().isEmpty()) {
|
||||
logger.warn("用户消息为空,跳过API调用");
|
||||
callback.onError("用户消息不能为空");
|
||||
return;
|
||||
}
|
||||
|
||||
if (webClient == null) {
|
||||
initWebClient();
|
||||
}
|
||||
|
||||
// 构建消息列表
|
||||
List<Map<String, String>> messages = new ArrayList<>();
|
||||
if (systemPrompt != null && !systemPrompt.isEmpty()) {
|
||||
messages.add(Map.of("role", "system", "content", systemPrompt));
|
||||
}
|
||||
messages.add(Map.of("role", "user", "content", userMessage));
|
||||
|
||||
// 构建请求体
|
||||
Map<String, Object> requestBody = new HashMap<>();
|
||||
requestBody.put("model", voiceStreamConfig.getGrok().getModel());
|
||||
requestBody.put("messages", messages);
|
||||
requestBody.put("stream", true);
|
||||
requestBody.put("temperature", DEFAULT_TEMPERATURE);
|
||||
requestBody.put("max_tokens", DEFAULT_MAX_TOKENS); // 限制最大token数
|
||||
|
||||
String requestJson = gson.toJson(requestBody);
|
||||
logger.info("开始调用Grok API - Model: {}, UserMessage: {}",
|
||||
voiceStreamConfig.getGrok().getModel(),
|
||||
userMessage.length() > 50 ? userMessage.substring(0, 50) + "..." : userMessage);
|
||||
logger.debug("请求体: {}", requestJson);
|
||||
|
||||
// 发起流式请求
|
||||
webClient.post()
|
||||
.uri(voiceStreamConfig.getGrok().getApiUrl())
|
||||
.header("Authorization", "Bearer " + voiceStreamConfig.getGrok().getApiKey())
|
||||
.header("Content-Type", "application/json")
|
||||
.bodyValue(requestJson)
|
||||
.retrieve()
|
||||
.bodyToFlux(String.class)
|
||||
.flatMap(this::parseSSE)
|
||||
.doOnNext(token -> {
|
||||
if (token != null && !token.isEmpty()) {
|
||||
callback.onToken(token);
|
||||
}
|
||||
})
|
||||
.doOnComplete(() -> {
|
||||
logger.info("Grok API流式调用完成");
|
||||
callback.onComplete();
|
||||
})
|
||||
.doOnError(error -> {
|
||||
String errorMsg = "LLM调用失败: " + error.getMessage();
|
||||
logger.error(errorMsg, error);
|
||||
callback.onError(errorMsg);
|
||||
})
|
||||
.subscribe();
|
||||
|
||||
} catch (Exception e) {
|
||||
String errorMsg = "调用Grok API时发生异常: " + e.getMessage();
|
||||
logger.error(errorMsg, e);
|
||||
callback.onError(errorMsg);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 流式调用 Grok API(带历史记录)
|
||||
*
|
||||
* @param messages 完整的消息列表(包含系统提示、历史对话和当前用户消息)
|
||||
* @param callback 接收token的回调
|
||||
*/
|
||||
public void streamChatWithHistory(List<Map<String, String>> messages, TokenCallback callback) {
|
||||
try {
|
||||
// 参数验证
|
||||
if (messages == null || messages.isEmpty()) {
|
||||
logger.warn("消息列表为空,跳过API调用");
|
||||
callback.onError("消息列表不能为空");
|
||||
return;
|
||||
}
|
||||
|
||||
if (webClient == null) {
|
||||
initWebClient();
|
||||
}
|
||||
|
||||
// 构建请求体
|
||||
Map<String, Object> requestBody = new HashMap<>();
|
||||
requestBody.put("model", voiceStreamConfig.getGrok().getModel());
|
||||
requestBody.put("messages", messages);
|
||||
requestBody.put("stream", true);
|
||||
requestBody.put("temperature", DEFAULT_TEMPERATURE);
|
||||
requestBody.put("max_tokens", DEFAULT_MAX_TOKENS); // 限制最大token数
|
||||
|
||||
String requestJson = gson.toJson(requestBody);
|
||||
logger.info("开始调用Grok API(带历史记录) - Model: {}, 消息数量: {}",
|
||||
voiceStreamConfig.getGrok().getModel(), messages.size());
|
||||
logger.debug("请求体: {}", requestJson);
|
||||
|
||||
// 发起流式请求
|
||||
webClient.post()
|
||||
.uri(voiceStreamConfig.getGrok().getApiUrl())
|
||||
.header("Authorization", "Bearer " + voiceStreamConfig.getGrok().getApiKey())
|
||||
.header("Content-Type", "application/json")
|
||||
.bodyValue(requestJson)
|
||||
.retrieve()
|
||||
.bodyToFlux(String.class)
|
||||
.flatMap(this::parseSSE)
|
||||
.doOnNext(token -> {
|
||||
if (token != null && !token.isEmpty()) {
|
||||
callback.onToken(token);
|
||||
}
|
||||
})
|
||||
.doOnComplete(() -> {
|
||||
logger.info("Grok API流式调用完成(带历史记录)");
|
||||
callback.onComplete();
|
||||
})
|
||||
.doOnError(error -> {
|
||||
String errorMsg = "LLM调用失败: " + error.getMessage();
|
||||
logger.error(errorMsg, error);
|
||||
callback.onError(errorMsg);
|
||||
})
|
||||
.subscribe();
|
||||
|
||||
} catch (Exception e) {
|
||||
String errorMsg = "调用Grok API时发生异常: " + e.getMessage();
|
||||
logger.error(errorMsg, e);
|
||||
callback.onError(errorMsg);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析SSE格式的响应
|
||||
*
|
||||
* 数据格式示例(WebClient已自动去除 "data: " 前缀):
|
||||
* {"id":"xxx","object":"chat.completion.chunk","created":xxx,"model":"grok-4-1-fast-non-reasoning",
|
||||
* "choices":[{"index":0,"delta":{"content":"你","role":"assistant"}}],"system_fingerprint":"xxx"}
|
||||
*/
|
||||
private Flux<String> parseSSE(String line) {
|
||||
return Flux.create(sink -> {
|
||||
try {
|
||||
// 跳过空行
|
||||
if (line == null || line.trim().isEmpty()) {
|
||||
sink.complete();
|
||||
return;
|
||||
}
|
||||
|
||||
String data = line.trim();
|
||||
|
||||
// 记录原始数据(仅在trace级别)
|
||||
logger.trace("收到原始SSE数据: {}", data.length() > 200 ? data.substring(0, 200) + "..." : data);
|
||||
|
||||
// 跳过 [DONE] 标记(流结束标志)
|
||||
// 可能是 "data: [DONE]" 或直接 "[DONE]"
|
||||
if (SSE_DONE_FLAG.equals(data) || data.equals(SSE_DATA_PREFIX + SSE_DONE_FLAG)) {
|
||||
logger.debug("收到流结束标记");
|
||||
sink.complete();
|
||||
return;
|
||||
}
|
||||
|
||||
// 如果有 "data: " 前缀,去掉它
|
||||
if (data.startsWith(SSE_DATA_PREFIX)) {
|
||||
data = data.substring(SSE_DATA_PREFIX.length()).trim();
|
||||
}
|
||||
|
||||
// 再次检查是否是 [DONE]
|
||||
if (SSE_DONE_FLAG.equals(data)) {
|
||||
logger.debug("收到流结束标记");
|
||||
sink.complete();
|
||||
return;
|
||||
}
|
||||
|
||||
// 解析JSON(现在data直接就是JSON字符串)
|
||||
JsonObject json = JsonParser.parseString(data).getAsJsonObject();
|
||||
|
||||
// 验证基本字段
|
||||
if (!json.has("choices")) {
|
||||
logger.debug("JSON中缺少choices字段");
|
||||
sink.complete();
|
||||
return;
|
||||
}
|
||||
|
||||
// 获取choices数组
|
||||
var choices = json.getAsJsonArray("choices");
|
||||
if (choices.isEmpty()) {
|
||||
logger.debug("choices数组为空");
|
||||
sink.complete();
|
||||
return;
|
||||
}
|
||||
|
||||
// 获取第一个choice
|
||||
var choice = choices.get(0).getAsJsonObject();
|
||||
|
||||
// 检查是否有delta字段
|
||||
if (!choice.has("delta")) {
|
||||
logger.debug("choice中缺少delta字段");
|
||||
sink.complete();
|
||||
return;
|
||||
}
|
||||
|
||||
var delta = choice.getAsJsonObject("delta");
|
||||
|
||||
// 提取content内容
|
||||
if (delta.has("content")) {
|
||||
String content = delta.get("content").getAsString();
|
||||
|
||||
// 只有非空内容才发送
|
||||
if (content != null && !content.isEmpty()) {
|
||||
logger.trace("提取到token: {}", content);
|
||||
sink.next(content);
|
||||
}
|
||||
}
|
||||
|
||||
// 如果有role字段,可以记录日志(首次消息会包含role)
|
||||
if (delta.has("role")) {
|
||||
String role = delta.get("role").getAsString();
|
||||
logger.debug("收到角色信息: {}", role);
|
||||
}
|
||||
|
||||
// 检查finish_reason(流结束原因)
|
||||
if (choice.has("finish_reason") && !choice.get("finish_reason").isJsonNull()) {
|
||||
String finishReason = choice.get("finish_reason").getAsString();
|
||||
logger.info("流结束原因: {}", finishReason);
|
||||
// finish_reason 可能的值: "stop"(正常结束), "length"(达到最大长度), "content_filter"(内容过滤)
|
||||
}
|
||||
|
||||
sink.complete();
|
||||
|
||||
} catch (com.google.gson.JsonSyntaxException e) {
|
||||
logger.error("JSON解析失败 - 原始数据: {}", line, e);
|
||||
sink.complete();
|
||||
} catch (Exception e) {
|
||||
logger.error("解析SSE数据时发生未知错误 - 原始数据: {}", line, e);
|
||||
sink.complete();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* 初始化WebClient
|
||||
*/
|
||||
private void initWebClient() {
|
||||
String apiUrl = voiceStreamConfig.getGrok().getApiUrl();
|
||||
if (apiUrl == null || apiUrl.trim().isEmpty()) {
|
||||
throw new IllegalStateException("Grok API URL未配置");
|
||||
}
|
||||
|
||||
logger.info("初始化Grok WebClient - URL: {}", apiUrl);
|
||||
|
||||
this.webClient = WebClient.builder()
|
||||
.baseUrl(apiUrl)
|
||||
// 增加缓冲区大小,支持更大的响应
|
||||
.codecs(configurer -> configurer.defaultCodecs().maxInMemorySize(10 * 1024 * 1024)) // 10MB
|
||||
.build();
|
||||
|
||||
logger.info("Grok WebClient初始化完成");
|
||||
}
|
||||
|
||||
/**
|
||||
* Token回调接口
|
||||
*/
|
||||
public interface TokenCallback {
|
||||
/**
|
||||
* 接收到新的token
|
||||
*/
|
||||
void onToken(String token);
|
||||
|
||||
/**
|
||||
* 流式输出完成
|
||||
*/
|
||||
void onComplete();
|
||||
|
||||
/**
|
||||
* 发生错误
|
||||
*/
|
||||
void onError(String error);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,137 @@
|
||||
package com.xiaozhi.dialogue.llm;
|
||||
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.regex.Pattern;
|
||||
|
||||
/**
|
||||
* 分句缓冲服务
|
||||
* 将LLM流式输出的token缓冲并按句子分割
|
||||
*/
|
||||
@Service
|
||||
public class SentenceBufferService {
|
||||
private static final Logger logger = LoggerFactory.getLogger(SentenceBufferService.class);
|
||||
|
||||
// 中英文句子结束标点符号
|
||||
private static final Pattern SENTENCE_END_PATTERN = Pattern.compile("[。!?!?;;]");
|
||||
|
||||
// 强制分句的最大长度
|
||||
private static final int MAX_SENTENCE_LENGTH = 50;
|
||||
|
||||
/**
|
||||
* 创建一个新的句子缓冲器
|
||||
*/
|
||||
public SentenceBuffer createBuffer(SentenceCallback callback) {
|
||||
return new SentenceBuffer(callback);
|
||||
}
|
||||
|
||||
/**
|
||||
* 句子缓冲器
|
||||
*/
|
||||
public class SentenceBuffer {
|
||||
private final StringBuilder buffer = new StringBuilder();
|
||||
private final SentenceCallback callback;
|
||||
|
||||
public SentenceBuffer(SentenceCallback callback) {
|
||||
this.callback = callback;
|
||||
}
|
||||
|
||||
/**
|
||||
* 添加token到缓冲区
|
||||
*/
|
||||
public synchronized void addToken(String token) {
|
||||
buffer.append(token);
|
||||
|
||||
// 检查是否有完整的句子
|
||||
String currentBuffer = buffer.toString();
|
||||
int lastSentenceEnd = findLastSentenceEnd(currentBuffer);
|
||||
|
||||
if (lastSentenceEnd > 0) {
|
||||
// 找到句子结束符,提取完整句子
|
||||
String sentence = currentBuffer.substring(0, lastSentenceEnd + 1).trim();
|
||||
if (!sentence.isEmpty()) {
|
||||
logger.debug("检测到完整句子: {}", sentence);
|
||||
callback.onSentence(sentence);
|
||||
}
|
||||
|
||||
// 保留剩余部分
|
||||
buffer.setLength(0);
|
||||
if (lastSentenceEnd + 1 < currentBuffer.length()) {
|
||||
buffer.append(currentBuffer.substring(lastSentenceEnd + 1));
|
||||
}
|
||||
} else if (buffer.length() > MAX_SENTENCE_LENGTH) {
|
||||
// 超过最大长度,强制分句
|
||||
String sentence = buffer.toString().trim();
|
||||
if (!sentence.isEmpty()) {
|
||||
logger.debug("强制分句(长度超限): {}", sentence);
|
||||
callback.onSentence(sentence);
|
||||
}
|
||||
buffer.setLength(0);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 完成输入,输出剩余内容
|
||||
*/
|
||||
public synchronized void finish() {
|
||||
if (buffer.length() > 0) {
|
||||
String sentence = buffer.toString().trim();
|
||||
if (!sentence.isEmpty()) {
|
||||
logger.debug("输出最后剩余内容: {}", sentence);
|
||||
callback.onSentence(sentence);
|
||||
}
|
||||
buffer.setLength(0);
|
||||
}
|
||||
callback.onComplete();
|
||||
}
|
||||
|
||||
/**
|
||||
* 清空缓冲区
|
||||
*/
|
||||
public synchronized void clear() {
|
||||
buffer.setLength(0);
|
||||
}
|
||||
|
||||
/**
|
||||
* 查找最后一个句子结束符的位置
|
||||
*/
|
||||
private int findLastSentenceEnd(String text) {
|
||||
int lastPos = -1;
|
||||
for (int i = text.length() - 1; i >= 0; i--) {
|
||||
char ch = text.charAt(i);
|
||||
if (isSentenceEndChar(ch)) {
|
||||
lastPos = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
return lastPos;
|
||||
}
|
||||
|
||||
/**
|
||||
* 判断是否为句子结束符
|
||||
*/
|
||||
private boolean isSentenceEndChar(char ch) {
|
||||
return ch == '。' || ch == '!' || ch == '?' ||
|
||||
ch == '.' || ch == '!' || ch == '?' ||
|
||||
ch == ';' || ch == ';';
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 句子回调接口
|
||||
*/
|
||||
public interface SentenceCallback {
|
||||
/**
|
||||
* 检测到完整句子
|
||||
*/
|
||||
void onSentence(String sentence);
|
||||
|
||||
/**
|
||||
* 所有句子处理完成
|
||||
*/
|
||||
void onComplete();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -30,7 +30,7 @@ public class SttServiceFactory {
|
||||
private final Map<String, SttService> serviceCache = new ConcurrentHashMap<>();
|
||||
|
||||
// 默认服务提供商名称
|
||||
private static final String DEFAULT_PROVIDER = "vosk";
|
||||
private static final String DEFAULT_PROVIDER = "aliyun";
|
||||
|
||||
// 标记Vosk是否初始化成功
|
||||
private boolean voskInitialized = false;
|
||||
|
||||
@@ -40,7 +40,7 @@ public class AliyunSttService implements SttService {
|
||||
// private final String modelName;
|
||||
|
||||
public AliyunSttService(SysConfig config) {
|
||||
this.apiKey = config.getApiKey();
|
||||
this.apiKey = "sk-f4c91752a2604845b2956265863f941d";
|
||||
// // 从配置中获取模型名称,如果没有配置则使用默认值
|
||||
// this.modelName = (config.getConfigName() != null && !config.getConfigName().trim().isEmpty())
|
||||
// ? config.getConfigName().trim()
|
||||
|
||||
@@ -0,0 +1,548 @@
|
||||
package com.xiaozhi.dialogue.tts;
|
||||
|
||||
import com.google.gson.Gson;
|
||||
import com.xiaozhi.config.VoiceStreamConfig;
|
||||
import com.xiaozhi.dialogue.tts.TtsConnectionState.ConnectionStatus;
|
||||
import jakarta.annotation.Resource;
|
||||
import org.java_websocket.client.WebSocketClient;
|
||||
import org.java_websocket.handshake.ServerHandshake;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.net.URI;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.*;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
|
||||
/**
|
||||
* MiniMax TTS 流式服务
|
||||
* 基于WebSocket实现流式文本转语音,支持连接预热和复用
|
||||
*/
|
||||
@Service
|
||||
public class MinimaxTtsStreamService {
|
||||
private static final Logger logger = LoggerFactory.getLogger(MinimaxTtsStreamService.class);
|
||||
private static final Gson gson = new Gson();
|
||||
private static final int CONNECTION_TIMEOUT_SECONDS = 5;
|
||||
|
||||
@Resource
|
||||
private VoiceStreamConfig voiceStreamConfig;
|
||||
|
||||
// 管理活跃的WebSocket连接和状态
|
||||
private final Map<String, SessionTtsConnection> connections = new ConcurrentHashMap<>();
|
||||
|
||||
/**
|
||||
* 会话TTS连接封装
|
||||
*/
|
||||
private static class SessionTtsConnection {
|
||||
final MinimaxTtsClient client;
|
||||
final TtsConnectionState state;
|
||||
final Semaphore processingLock = new Semaphore(1); // 确保串行处理
|
||||
|
||||
SessionTtsConnection(MinimaxTtsClient client, TtsConnectionState state) {
|
||||
this.client = client;
|
||||
this.state = state;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 预热TTS连接
|
||||
* 建立WebSocket连接并发送task_start,等待连接就绪
|
||||
*
|
||||
* @param sessionId 会话ID
|
||||
* @return CompletableFuture 连接就绪后完成
|
||||
*/
|
||||
public CompletableFuture<Void> warmupConnection(String sessionId) {
|
||||
long startTime = System.currentTimeMillis();
|
||||
logger.info("开始预热TTS连接 - SessionId: {}", sessionId);
|
||||
|
||||
CompletableFuture<Void> warmupFuture = new CompletableFuture<>();
|
||||
|
||||
try {
|
||||
// 检查是否已有连接
|
||||
SessionTtsConnection existing = connections.get(sessionId);
|
||||
if (existing != null && existing.state.isConnected()) {
|
||||
logger.info("TTS连接已存在,跳过预热 - SessionId: {}, 状态: {}",
|
||||
sessionId, existing.state.getStatus());
|
||||
warmupFuture.complete(null);
|
||||
return warmupFuture;
|
||||
}
|
||||
|
||||
// 创建新连接
|
||||
String wsUrl = voiceStreamConfig.getMinimax().getWsUrl();
|
||||
TtsConnectionState state = new TtsConnectionState();
|
||||
state.setStatus(ConnectionStatus.CONNECTING);
|
||||
|
||||
MinimaxTtsClient client = new MinimaxTtsClient(
|
||||
new URI(wsUrl),
|
||||
sessionId,
|
||||
state,
|
||||
warmupFuture
|
||||
);
|
||||
|
||||
SessionTtsConnection connection = new SessionTtsConnection(client, state);
|
||||
connections.put(sessionId, connection);
|
||||
|
||||
// 异步连接
|
||||
CompletableFuture.runAsync(() -> {
|
||||
try {
|
||||
boolean connected = client.connectBlocking(CONNECTION_TIMEOUT_SECONDS, TimeUnit.SECONDS);
|
||||
if (!connected) {
|
||||
throw new TimeoutException("连接超时");
|
||||
}
|
||||
long duration = System.currentTimeMillis() - startTime;
|
||||
logger.info("TTS连接预热成功 - SessionId: {}, 耗时: {}ms", sessionId, duration);
|
||||
} catch (Exception e) {
|
||||
logger.error("TTS连接预热失败 - SessionId: {}", sessionId, e);
|
||||
state.setStatus(ConnectionStatus.DISCONNECTED);
|
||||
connections.remove(sessionId);
|
||||
warmupFuture.completeExceptionally(e);
|
||||
}
|
||||
});
|
||||
|
||||
} catch (Exception e) {
|
||||
logger.error("创建TTS连接失败 - SessionId: {}", sessionId, e);
|
||||
warmupFuture.completeExceptionally(e);
|
||||
}
|
||||
|
||||
return warmupFuture;
|
||||
}
|
||||
|
||||
/**
|
||||
* 流式合成语音(复用现有连接)
|
||||
*
|
||||
* @param sessionId 会话ID
|
||||
* @param text 要合成的文本
|
||||
* @param callback 音频数据回调
|
||||
*/
|
||||
public CompletableFuture<Void> streamTts(String sessionId, String text, AudioCallback callback) {
|
||||
long startTime = System.currentTimeMillis();
|
||||
CompletableFuture<Void> future = new CompletableFuture<>();
|
||||
|
||||
try {
|
||||
SessionTtsConnection connection = connections.get(sessionId);
|
||||
|
||||
// 检查连接是否存在且就绪
|
||||
if (connection == null || !connection.state.isReady()) {
|
||||
logger.warn("TTS连接不存在或未就绪 - SessionId: {}, 状态: {}",
|
||||
sessionId, connection != null ? connection.state.getStatus() : "NULL");
|
||||
|
||||
// 尝试建立新连接
|
||||
return warmupConnection(sessionId)
|
||||
.thenCompose(v -> streamTts(sessionId, text, callback));
|
||||
}
|
||||
|
||||
// 获取处理锁,确保串行处理
|
||||
connection.processingLock.acquire();
|
||||
|
||||
try {
|
||||
logger.info("使用已有TTS连接 - SessionId: {}, Text: {}", sessionId, text);
|
||||
connection.state.setStatus(ConnectionStatus.PROCESSING);
|
||||
|
||||
// 发送文本进行TTS
|
||||
connection.client.sendText(text, callback, future);
|
||||
|
||||
// 等待TTS完成后释放锁
|
||||
future.whenComplete((result, error) -> {
|
||||
connection.processingLock.release();
|
||||
if (error == null) {
|
||||
long duration = System.currentTimeMillis() - startTime;
|
||||
logger.info("TTS完成 - SessionId: {}, 耗时: {}ms(复用连接节省约1秒)",
|
||||
sessionId, duration);
|
||||
connection.state.setStatus(ConnectionStatus.IDLE);
|
||||
} else {
|
||||
logger.error("TTS失败 - SessionId: {}", sessionId, error);
|
||||
connection.state.setStatus(ConnectionStatus.IDLE);
|
||||
}
|
||||
});
|
||||
|
||||
} catch (Exception e) {
|
||||
connection.processingLock.release();
|
||||
throw e;
|
||||
}
|
||||
|
||||
} catch (Exception e) {
|
||||
logger.error("TTS处理失败 - SessionId: {}", sessionId, e);
|
||||
callback.onError("TTS处理失败: " + e.getMessage());
|
||||
future.completeExceptionally(e);
|
||||
}
|
||||
|
||||
return future;
|
||||
}
|
||||
|
||||
/**
|
||||
* 关闭指定会话的TTS连接
|
||||
*/
|
||||
public void closeConnection(String sessionId) {
|
||||
SessionTtsConnection connection = connections.remove(sessionId);
|
||||
if (connection != null) {
|
||||
try {
|
||||
connection.client.closeConnection();
|
||||
logger.info("关闭TTS连接 - SessionId: {}", sessionId);
|
||||
} catch (Exception e) {
|
||||
logger.error("关闭TTS连接失败 - SessionId: {}", sessionId, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 取消指定会话的TTS(兼容旧接口)
|
||||
*/
|
||||
public void cancelTts(String sessionId) {
|
||||
closeConnection(sessionId);
|
||||
}
|
||||
|
||||
/**
|
||||
* MiniMax TTS WebSocket 客户端(支持连接复用)
|
||||
*/
|
||||
private class MinimaxTtsClient extends WebSocketClient {
|
||||
private final String sessionId;
|
||||
private final TtsConnectionState state;
|
||||
private final CompletableFuture<Void> warmupFuture;
|
||||
|
||||
// 当前任务相关
|
||||
private volatile String currentText;
|
||||
private volatile AudioCallback currentCallback;
|
||||
private volatile CompletableFuture<Void> currentTaskFuture;
|
||||
private final AtomicBoolean taskProcessing = new AtomicBoolean(false);
|
||||
|
||||
// 音频缓冲器 - 缓冲整句音频
|
||||
private java.io.ByteArrayOutputStream audioBuffer;
|
||||
|
||||
public MinimaxTtsClient(URI serverUri, String sessionId, TtsConnectionState state,
|
||||
CompletableFuture<Void> warmupFuture) {
|
||||
super(serverUri);
|
||||
this.sessionId = sessionId;
|
||||
this.state = state;
|
||||
this.warmupFuture = warmupFuture;
|
||||
|
||||
// 添加Authorization header
|
||||
this.addHeader("Authorization", "Bearer " + voiceStreamConfig.getMinimax().getApiKey());
|
||||
}
|
||||
|
||||
/**
|
||||
* 发送文本进行TTS合成
|
||||
* 注意:调用此方法前,外层应该已经检查过状态并设置为PROCESSING
|
||||
*/
|
||||
public void sendText(String text, AudioCallback callback, CompletableFuture<Void> taskFuture) {
|
||||
this.currentText = text;
|
||||
this.currentCallback = callback;
|
||||
this.currentTaskFuture = taskFuture;
|
||||
this.audioBuffer = new java.io.ByteArrayOutputStream();
|
||||
this.taskProcessing.set(true);
|
||||
|
||||
// 直接发送task_continue
|
||||
sendTaskContinue(text);
|
||||
}
|
||||
|
||||
/**
|
||||
* 关闭连接
|
||||
*/
|
||||
public void closeConnection() {
|
||||
try {
|
||||
if (isOpen()) {
|
||||
sendTaskFinish();
|
||||
}
|
||||
close();
|
||||
state.setStatus(ConnectionStatus.DISCONNECTED);
|
||||
} catch (Exception e) {
|
||||
logger.error("关闭连接失败", e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onOpen(ServerHandshake handshake) {
|
||||
logger.debug("MiniMax TTS连接已建立 - SessionId: {}", sessionId);
|
||||
state.setStatus(ConnectionStatus.CONNECTED);
|
||||
// 连接建立后,等待收到 connected_success 事件再发送 task_start
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onMessage(String message) {
|
||||
try {
|
||||
logger.debug("收到消息 - SessionId: {}: {}", sessionId, message);
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
Map<String, Object> response = gson.fromJson(message, Map.class);
|
||||
String event = (String) response.get("event");
|
||||
|
||||
if ("connected_success".equals(event)) {
|
||||
// 连接成功,发送 task_start(不包含text)
|
||||
logger.info("收到connected_success,发送task_start - SessionId: {}", sessionId);
|
||||
sendTaskStart();
|
||||
|
||||
} else if ("task_started".equals(event)) {
|
||||
// 任务启动成功,连接就绪
|
||||
logger.info("收到task_started,连接就绪 - SessionId: {}", sessionId);
|
||||
state.setStatus(ConnectionStatus.TASK_STARTED);
|
||||
state.resetReconnectAttempts(); // 连接成功,重置重连计数
|
||||
|
||||
// 完成预热future
|
||||
if (warmupFuture != null && !warmupFuture.isDone()) {
|
||||
warmupFuture.complete(null);
|
||||
}
|
||||
|
||||
// 如果是复用连接后立即发送的task_continue,状态会自动设为IDLE
|
||||
if (!taskProcessing.get()) {
|
||||
state.setStatus(ConnectionStatus.IDLE);
|
||||
}
|
||||
|
||||
} else if ("task_failed".equals(event)) {
|
||||
// 任务失败
|
||||
@SuppressWarnings("unchecked")
|
||||
Map<String, Object> baseResp = (Map<String, Object>) response.get("base_resp");
|
||||
String errorMsg = baseResp != null ? (String) baseResp.get("status_msg") : "unknown error";
|
||||
logger.error("TTS任务失败 - SessionId: {}: {}", sessionId, errorMsg);
|
||||
|
||||
if (currentCallback != null) {
|
||||
currentCallback.onError("TTS任务失败: " + errorMsg);
|
||||
}
|
||||
if (currentTaskFuture != null && !currentTaskFuture.isDone()) {
|
||||
currentTaskFuture.completeExceptionally(new RuntimeException(errorMsg));
|
||||
}
|
||||
|
||||
taskProcessing.set(false);
|
||||
state.setStatus(ConnectionStatus.IDLE);
|
||||
|
||||
} else if (response.containsKey("data")) {
|
||||
// 音频数据
|
||||
@SuppressWarnings("unchecked")
|
||||
Map<String, Object> data = (Map<String, Object>) response.get("data");
|
||||
if (data != null && data.containsKey("audio")) {
|
||||
String hexAudio = (String) data.get("audio");
|
||||
if (hexAudio != null && !hexAudio.isEmpty() && audioBuffer != null) {
|
||||
// 将hex字符串转为bytes并写入缓冲区
|
||||
byte[] audioBytes = hexStringToByteArray(hexAudio);
|
||||
try {
|
||||
audioBuffer.write(audioBytes);
|
||||
logger.debug("缓冲音频块 - SessionId: {}: {} bytes, 总计: {} bytes",
|
||||
sessionId, audioBytes.length, audioBuffer.size());
|
||||
} catch (java.io.IOException e) {
|
||||
logger.error("写入音频缓冲区失败 - SessionId: {}", sessionId, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 检查是否完成
|
||||
Boolean isFinal = (Boolean) response.get("is_final");
|
||||
if (Boolean.TRUE.equals(isFinal) && taskProcessing.compareAndSet(true, false)) {
|
||||
logger.debug("TTS任务完成(is_final=true)- SessionId: {}", sessionId);
|
||||
|
||||
// 一次性发送完整的音频数据
|
||||
if (audioBuffer != null && currentCallback != null) {
|
||||
byte[] completeAudio = audioBuffer.toByteArray();
|
||||
logger.info("TTS完成 - SessionId: {}, Text: {}, 音频总大小: {} bytes",
|
||||
sessionId, currentText, completeAudio.length);
|
||||
|
||||
if (completeAudio.length > 0) {
|
||||
currentCallback.onAudioChunk(completeAudio);
|
||||
}
|
||||
currentCallback.onComplete();
|
||||
}
|
||||
|
||||
// 完成当前任务future
|
||||
if (currentTaskFuture != null && !currentTaskFuture.isDone()) {
|
||||
currentTaskFuture.complete(null);
|
||||
}
|
||||
|
||||
// 不关闭连接,设置为IDLE状态以便复用
|
||||
state.setStatus(ConnectionStatus.IDLE);
|
||||
|
||||
// 清理当前任务
|
||||
currentText = null;
|
||||
currentCallback = null;
|
||||
currentTaskFuture = null;
|
||||
audioBuffer = null;
|
||||
}
|
||||
}
|
||||
|
||||
} catch (Exception e) {
|
||||
logger.error("处理TTS消息失败 - SessionId: {}", sessionId, e);
|
||||
if (currentCallback != null) {
|
||||
currentCallback.onError("处理TTS响应失败: " + e.getMessage());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 发送 task_start 消息(不包含text)
|
||||
*/
|
||||
private void sendTaskStart() {
|
||||
try {
|
||||
Map<String, Object> message = new HashMap<>();
|
||||
message.put("event", "task_start");
|
||||
message.put("model", voiceStreamConfig.getMinimax().getModel());
|
||||
|
||||
// voice_setting
|
||||
Map<String, Object> voiceSetting = new HashMap<>();
|
||||
voiceSetting.put("voice_id", voiceStreamConfig.getMinimax().getVoiceId());
|
||||
voiceSetting.put("speed", voiceStreamConfig.getMinimax().getSpeed());
|
||||
voiceSetting.put("vol", voiceStreamConfig.getMinimax().getVol());
|
||||
voiceSetting.put("pitch", voiceStreamConfig.getMinimax().getPitch());
|
||||
message.put("voice_setting", voiceSetting);
|
||||
|
||||
// audio_setting
|
||||
Map<String, Object> audioSetting = new HashMap<>();
|
||||
audioSetting.put("sample_rate", voiceStreamConfig.getMinimax().getAudioSampleRate());
|
||||
audioSetting.put("bitrate", voiceStreamConfig.getMinimax().getBitrate());
|
||||
audioSetting.put("format", "mp3"); // 使用mp3格式,兼容性更好
|
||||
audioSetting.put("channel", 1);
|
||||
message.put("audio_setting", audioSetting);
|
||||
|
||||
String jsonMessage = gson.toJson(message);
|
||||
logger.debug("发送task_start - SessionId: {}: {}", sessionId, jsonMessage);
|
||||
send(jsonMessage);
|
||||
|
||||
} catch (Exception e) {
|
||||
logger.error("发送task_start失败 - SessionId: {}", sessionId, e);
|
||||
if (warmupFuture != null && !warmupFuture.isDone()) {
|
||||
warmupFuture.completeExceptionally(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 发送 task_continue 消息(包含text)
|
||||
*/
|
||||
private void sendTaskContinue(String text) {
|
||||
try {
|
||||
Map<String, Object> message = new HashMap<>();
|
||||
message.put("event", "task_continue");
|
||||
message.put("text", text);
|
||||
|
||||
String jsonMessage = gson.toJson(message);
|
||||
logger.debug("发送task_continue - SessionId: {}: {}", sessionId, jsonMessage);
|
||||
send(jsonMessage);
|
||||
|
||||
} catch (Exception e) {
|
||||
logger.error("发送task_continue失败 - SessionId: {}", sessionId, e);
|
||||
if (currentCallback != null) {
|
||||
currentCallback.onError("发送文本失败: " + e.getMessage());
|
||||
}
|
||||
if (currentTaskFuture != null && !currentTaskFuture.isDone()) {
|
||||
currentTaskFuture.completeExceptionally(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 发送 task_finish 消息
|
||||
*/
|
||||
private void sendTaskFinish() {
|
||||
try {
|
||||
Map<String, Object> message = new HashMap<>();
|
||||
message.put("event", "task_finish");
|
||||
|
||||
String jsonMessage = gson.toJson(message);
|
||||
logger.debug("发送task_finish - SessionId: {}: {}", sessionId, jsonMessage);
|
||||
send(jsonMessage);
|
||||
|
||||
} catch (Exception e) {
|
||||
logger.error("发送task_finish失败 - SessionId: {}", sessionId, e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onClose(int code, String reason, boolean remote) {
|
||||
logger.warn("MiniMax TTS连接关闭 - SessionId: {}, Code: {}, Reason: {}, Remote: {}",
|
||||
sessionId, code, reason, remote);
|
||||
|
||||
ConnectionStatus prevStatus = state.getStatus();
|
||||
state.setStatus(ConnectionStatus.DISCONNECTED);
|
||||
|
||||
// 判断是否需要重连
|
||||
boolean needReconnect = remote && state.canReconnect() &&
|
||||
(prevStatus == ConnectionStatus.CONNECTED ||
|
||||
prevStatus == ConnectionStatus.TASK_STARTED ||
|
||||
prevStatus == ConnectionStatus.IDLE);
|
||||
|
||||
if (needReconnect) {
|
||||
// 自动重连
|
||||
int attempts = state.incrementReconnectAttempts();
|
||||
logger.info("尝试自动重连TTS - SessionId: {}, 第{}次重连", sessionId, attempts);
|
||||
|
||||
CompletableFuture.runAsync(() -> {
|
||||
try {
|
||||
Thread.sleep(1000 * attempts); // 延迟递增
|
||||
warmupConnection(sessionId)
|
||||
.exceptionally(ex -> {
|
||||
logger.error("自动重连失败 - SessionId: {}", sessionId, ex);
|
||||
return null;
|
||||
});
|
||||
} catch (InterruptedException e) {
|
||||
Thread.currentThread().interrupt();
|
||||
}
|
||||
});
|
||||
} else {
|
||||
// 不重连,通知错误
|
||||
if (taskProcessing.get()) {
|
||||
if (currentCallback != null) {
|
||||
currentCallback.onComplete(); // 标记完成,避免阻塞
|
||||
}
|
||||
if (currentTaskFuture != null && !currentTaskFuture.isDone()) {
|
||||
currentTaskFuture.complete(null);
|
||||
}
|
||||
taskProcessing.set(false);
|
||||
}
|
||||
|
||||
// 如果是预热阶段失败
|
||||
if (warmupFuture != null && !warmupFuture.isDone()) {
|
||||
warmupFuture.completeExceptionally(new RuntimeException("连接关闭: " + reason));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onError(Exception ex) {
|
||||
logger.error("MiniMax TTS连接错误 - SessionId: {}", sessionId, ex);
|
||||
|
||||
state.setStatus(ConnectionStatus.DISCONNECTED);
|
||||
|
||||
if (currentCallback != null) {
|
||||
currentCallback.onError("TTS连接错误: " + ex.getMessage());
|
||||
}
|
||||
if (currentTaskFuture != null && !currentTaskFuture.isDone()) {
|
||||
currentTaskFuture.completeExceptionally(ex);
|
||||
}
|
||||
if (warmupFuture != null && !warmupFuture.isDone()) {
|
||||
warmupFuture.completeExceptionally(ex);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 将hex字符串转为byte数组
|
||||
*/
|
||||
private byte[] hexStringToByteArray(String hex) {
|
||||
int len = hex.length();
|
||||
byte[] data = new byte[len / 2];
|
||||
for (int i = 0; i < len; i += 2) {
|
||||
data[i / 2] = (byte) ((Character.digit(hex.charAt(i), 16) << 4)
|
||||
+ Character.digit(hex.charAt(i + 1), 16));
|
||||
}
|
||||
return data;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 音频数据回调接口
|
||||
*/
|
||||
public interface AudioCallback {
|
||||
/**
|
||||
* 接收到音频数据块
|
||||
*/
|
||||
void onAudioChunk(byte[] audioData);
|
||||
|
||||
/**
|
||||
* TTS完成
|
||||
*/
|
||||
void onComplete();
|
||||
|
||||
/**
|
||||
* 发生错误
|
||||
*/
|
||||
void onError(String error);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
136
src/main/java/com/xiaozhi/dialogue/tts/TtsConnectionState.java
Normal file
136
src/main/java/com/xiaozhi/dialogue/tts/TtsConnectionState.java
Normal file
@@ -0,0 +1,136 @@
|
||||
package com.xiaozhi.dialogue.tts;
|
||||
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
|
||||
/**
|
||||
* TTS连接状态管理
|
||||
* 管理连接状态、重连计数等
|
||||
*/
|
||||
public class TtsConnectionState {
|
||||
private static final int MAX_RECONNECT_ATTEMPTS = 3;
|
||||
|
||||
private final AtomicReference<ConnectionStatus> status;
|
||||
private final AtomicInteger reconnectAttempts;
|
||||
private volatile long lastActiveTime;
|
||||
|
||||
public TtsConnectionState() {
|
||||
this.status = new AtomicReference<>(ConnectionStatus.DISCONNECTED);
|
||||
this.reconnectAttempts = new AtomicInteger(0);
|
||||
this.lastActiveTime = System.currentTimeMillis();
|
||||
}
|
||||
|
||||
/**
|
||||
* 连接状态枚举
|
||||
*/
|
||||
public enum ConnectionStatus {
|
||||
DISCONNECTED, // 未连接
|
||||
CONNECTING, // 连接中
|
||||
CONNECTED, // 已连接(未完成task_start)
|
||||
TASK_STARTED, // task_start已完成,连接就绪
|
||||
PROCESSING, // 正在处理TTS任务
|
||||
IDLE // 空闲,可接受新任务
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取当前状态
|
||||
*/
|
||||
public ConnectionStatus getStatus() {
|
||||
return status.get();
|
||||
}
|
||||
|
||||
/**
|
||||
* 设置状态
|
||||
*/
|
||||
public void setStatus(ConnectionStatus newStatus) {
|
||||
ConnectionStatus oldStatus = status.getAndSet(newStatus);
|
||||
if (oldStatus != newStatus) {
|
||||
updateLastActiveTime();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* CAS方式更新状态
|
||||
*/
|
||||
public boolean compareAndSetStatus(ConnectionStatus expect, ConnectionStatus update) {
|
||||
boolean updated = status.compareAndSet(expect, update);
|
||||
if (updated) {
|
||||
updateLastActiveTime();
|
||||
}
|
||||
return updated;
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查是否可以处理新任务
|
||||
*/
|
||||
public boolean isReady() {
|
||||
ConnectionStatus current = status.get();
|
||||
return current == ConnectionStatus.IDLE || current == ConnectionStatus.TASK_STARTED;
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查是否已连接
|
||||
*/
|
||||
public boolean isConnected() {
|
||||
ConnectionStatus current = status.get();
|
||||
return current != ConnectionStatus.DISCONNECTED && current != ConnectionStatus.CONNECTING;
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取重连次数
|
||||
*/
|
||||
public int getReconnectAttempts() {
|
||||
return reconnectAttempts.get();
|
||||
}
|
||||
|
||||
/**
|
||||
* 增加重连次数
|
||||
*/
|
||||
public int incrementReconnectAttempts() {
|
||||
return reconnectAttempts.incrementAndGet();
|
||||
}
|
||||
|
||||
/**
|
||||
* 重置重连次数
|
||||
*/
|
||||
public void resetReconnectAttempts() {
|
||||
reconnectAttempts.set(0);
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查是否还可以重连
|
||||
*/
|
||||
public boolean canReconnect() {
|
||||
return reconnectAttempts.get() < MAX_RECONNECT_ATTEMPTS;
|
||||
}
|
||||
|
||||
/**
|
||||
* 更新最后活跃时间
|
||||
*/
|
||||
public void updateLastActiveTime() {
|
||||
this.lastActiveTime = System.currentTimeMillis();
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取最后活跃时间
|
||||
*/
|
||||
public long getLastActiveTime() {
|
||||
return lastActiveTime;
|
||||
}
|
||||
|
||||
/**
|
||||
* 重置状态
|
||||
*/
|
||||
public void reset() {
|
||||
status.set(ConnectionStatus.DISCONNECTED);
|
||||
reconnectAttempts.set(0);
|
||||
updateLastActiveTime();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return String.format("TtsConnectionState{status=%s, reconnectAttempts=%d, lastActiveTime=%d}",
|
||||
status.get(), reconnectAttempts.get(), lastActiveTime);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -86,7 +86,6 @@ public class TtsServiceFactory {
|
||||
case "volcengine" -> new VolcengineTtsService(config, voiceName, outputPath);
|
||||
case "xfyun" -> new XfyunTtsService(config, voiceName, outputPath);
|
||||
case "minimax" -> new MiniMaxTtsService(config, voiceName, outputPath);
|
||||
case "minimax-ws" -> new MiniMaxTtsWebSocketService(config, voiceName, outputPath); // WebSocket 流式版本
|
||||
default -> new EdgeTtsService(voiceName, outputPath);
|
||||
};
|
||||
}
|
||||
|
||||
649
src/main/java/com/xiaozhi/service/VoiceStreamService.java
Normal file
649
src/main/java/com/xiaozhi/service/VoiceStreamService.java
Normal file
@@ -0,0 +1,649 @@
|
||||
package com.xiaozhi.service;
|
||||
|
||||
import com.xiaozhi.dialogue.llm.GrokStreamService;
|
||||
import com.xiaozhi.dialogue.llm.SentenceBufferService;
|
||||
import com.xiaozhi.dialogue.llm.memory.DatabaseChatMemory;
|
||||
import com.xiaozhi.dialogue.stt.SttService;
|
||||
import com.xiaozhi.dialogue.stt.factory.SttServiceFactory;
|
||||
import com.xiaozhi.dialogue.tts.MinimaxTtsStreamService;
|
||||
import com.xiaozhi.entity.SysConfig;
|
||||
import com.xiaozhi.entity.SysMessage;
|
||||
import com.xiaozhi.utils.AudioUtils;
|
||||
import jakarta.annotation.Resource;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.*;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
|
||||
/**
|
||||
* 语音流式对话核心服务
|
||||
* 串联 STT -> LLM Stream -> 分句 -> TTS Stream 流程
|
||||
*/
|
||||
@Service
|
||||
public class VoiceStreamService {
|
||||
private static final Logger logger = LoggerFactory.getLogger(VoiceStreamService.class);
|
||||
|
||||
@Resource
|
||||
private SttServiceFactory sttServiceFactory;
|
||||
|
||||
@Resource
|
||||
private GrokStreamService grokStreamService;
|
||||
|
||||
@Resource
|
||||
private SentenceBufferService sentenceBufferService;
|
||||
|
||||
@Resource
|
||||
private MinimaxTtsStreamService minimaxTtsStreamService;
|
||||
|
||||
@Resource
|
||||
private SysConfigService configService;
|
||||
|
||||
@Resource
|
||||
private DatabaseChatMemory chatMemory;
|
||||
|
||||
@Resource
|
||||
private SysMessageService sysMessageService;
|
||||
|
||||
// 管理每个会话的状态
|
||||
private final Map<String, SessionState> sessions = new ConcurrentHashMap<>();
|
||||
|
||||
// 线程池用于异步处理
|
||||
private final ExecutorService executorService = Executors.newCachedThreadPool();
|
||||
|
||||
/**
|
||||
* 处理音频流
|
||||
*
|
||||
* @param sessionId 会话ID
|
||||
* @param audioData PCM音频数据
|
||||
* @param callback 回调接口
|
||||
*/
|
||||
public void processAudioStream(String sessionId, byte[] audioData, StreamCallback callback) {
|
||||
executorService.submit(() -> {
|
||||
try {
|
||||
// 获取或创建会话状态
|
||||
SessionState state = sessions.computeIfAbsent(sessionId, k -> new SessionState());
|
||||
|
||||
// 取消之前的处理(打断机制)
|
||||
state.cancel();
|
||||
state.reset();
|
||||
|
||||
logger.info("开始处理音频流 - SessionId: {}, AudioSize: {}", sessionId, audioData.length);
|
||||
|
||||
// 1. STT - 语音转文字
|
||||
String recognizedText = performStt(audioData);
|
||||
if (recognizedText == null || recognizedText.trim().isEmpty()) {
|
||||
callback.onError("语音识别失败或未识别到内容");
|
||||
return;
|
||||
}
|
||||
|
||||
logger.info("STT识别结果 - SessionId: {}, Text: {}", sessionId, recognizedText);
|
||||
callback.onSttResult(recognizedText);
|
||||
|
||||
// 2. 创建分句缓冲器
|
||||
SentenceBufferService.SentenceBuffer sentenceBuffer =
|
||||
sentenceBufferService.createBuffer(new SentenceBufferService.SentenceCallback() {
|
||||
|
||||
private final Queue<String> sentenceQueue = new LinkedList<>();
|
||||
private final Semaphore ttsPermit = new Semaphore(1); // 保证TTS顺序执行
|
||||
private final AtomicBoolean processing = new AtomicBoolean(false);
|
||||
|
||||
@Override
|
||||
public void onSentence(String sentence) {
|
||||
if (state.isCancelled()) {
|
||||
logger.info("会话已取消,跳过句子处理 - SessionId: {}", sessionId);
|
||||
return;
|
||||
}
|
||||
|
||||
logger.info("检测到完整句子 - SessionId: {}, Sentence: {}", sessionId, sentence);
|
||||
callback.onSentenceComplete(sentence);
|
||||
|
||||
// 将句子加入队列并异步处理TTS
|
||||
sentenceQueue.offer(sentence);
|
||||
processSentenceQueue();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onComplete() {
|
||||
logger.info("LLM输出完成 - SessionId: {}", sessionId);
|
||||
// 等待所有TTS完成
|
||||
waitForAllTtsComplete();
|
||||
callback.onComplete();
|
||||
}
|
||||
|
||||
private void processSentenceQueue() {
|
||||
if (processing.compareAndSet(false, true)) {
|
||||
executorService.submit(() -> {
|
||||
try {
|
||||
while (!sentenceQueue.isEmpty() && !state.isCancelled()) {
|
||||
String sentence = sentenceQueue.poll();
|
||||
if (sentence != null) {
|
||||
processSentenceTts(sentence);
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
processing.set(false);
|
||||
// 检查是否还有新的句子
|
||||
if (!sentenceQueue.isEmpty() && !state.isCancelled()) {
|
||||
processSentenceQueue();
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
private void processSentenceTts(String sentence) {
|
||||
if (state.isCancelled()) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
ttsPermit.acquire();
|
||||
|
||||
logger.info("开始TTS合成 - SessionId: {}, Sentence: {}", sessionId, sentence);
|
||||
|
||||
CompletableFuture<Void> ttsFuture = minimaxTtsStreamService.streamTts(
|
||||
sessionId,
|
||||
sentence,
|
||||
new MinimaxTtsStreamService.AudioCallback() {
|
||||
@Override
|
||||
public void onAudioChunk(byte[] audioData) {
|
||||
if (!state.isCancelled()) {
|
||||
callback.onAudioChunk(audioData);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onComplete() {
|
||||
logger.debug("句子TTS完成 - SessionId: {}", sessionId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onError(String error) {
|
||||
logger.error("TTS错误 - SessionId: {}, Error: {}", sessionId, error);
|
||||
callback.onError(error);
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
// 等待TTS完成
|
||||
ttsFuture.get(30, TimeUnit.SECONDS);
|
||||
|
||||
} catch (InterruptedException e) {
|
||||
Thread.currentThread().interrupt();
|
||||
logger.warn("TTS处理被中断 - SessionId: {}", sessionId);
|
||||
state.markTtsFailed(); // 标记 TTS 失败
|
||||
} catch (Exception e) {
|
||||
logger.error("TTS处理失败 - SessionId: {}", sessionId, e);
|
||||
state.markTtsFailed(); // 标记 TTS 失败
|
||||
callback.onError("TTS处理失败: " + e.getMessage());
|
||||
} finally {
|
||||
ttsPermit.release();
|
||||
}
|
||||
}
|
||||
|
||||
private void waitForAllTtsComplete() {
|
||||
try {
|
||||
// 等待所有TTS完成
|
||||
ttsPermit.acquire();
|
||||
ttsPermit.release();
|
||||
} catch (InterruptedException e) {
|
||||
Thread.currentThread().interrupt();
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// 3. LLM流式调用
|
||||
String systemPrompt = "你是一个友好的AI助手,请用简洁、自然的语气回答用户的问题。";
|
||||
|
||||
grokStreamService.streamChat(
|
||||
recognizedText,
|
||||
systemPrompt,
|
||||
new GrokStreamService.TokenCallback() {
|
||||
@Override
|
||||
public void onToken(String token) {
|
||||
if (state.isCancelled()) {
|
||||
return;
|
||||
}
|
||||
callback.onLlmToken(token);
|
||||
sentenceBuffer.addToken(token);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onComplete() {
|
||||
if (!state.isCancelled()) {
|
||||
sentenceBuffer.finish();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onError(String error) {
|
||||
logger.error("LLM调用失败 - SessionId: {}, Error: {}", sessionId, error);
|
||||
callback.onError(error);
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
} catch (Exception e) {
|
||||
logger.error("处理音频流失败 - SessionId: {}", sessionId, e);
|
||||
callback.onError("处理失败: " + e.getMessage());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理音频流(带历史记录)
|
||||
*
|
||||
* @param sessionId WebSocket会话ID
|
||||
* @param audioData PCM音频数据
|
||||
* @param chatSessionId 聊天会话ID(用于查询和保存历史记录)
|
||||
* @param templateId 模板ID(可选,用于获取角色信息)
|
||||
* @param callback 回调接口
|
||||
*/
|
||||
public void processAudioStreamWithHistory(String sessionId, byte[] audioData,
|
||||
String chatSessionId, Integer templateId,
|
||||
StreamCallback callback) {
|
||||
executorService.submit(() -> {
|
||||
try {
|
||||
// 获取或创建会话状态
|
||||
SessionState state = sessions.computeIfAbsent(sessionId, k -> new SessionState());
|
||||
|
||||
// 取消之前的处理(打断机制)
|
||||
state.cancel();
|
||||
state.reset();
|
||||
|
||||
logger.info("开始处理音频流(带历史记录) - SessionId: {}, ChatSessionId: {}, TemplateId: {}, AudioSize: {}",
|
||||
sessionId, chatSessionId, templateId, audioData.length);
|
||||
|
||||
// 使用 chatSessionId 作为 deviceId(保持与文字对话一致)
|
||||
String deviceId = chatSessionId;
|
||||
// roleId 可以为 null,数据库允许
|
||||
Integer roleId = templateId;
|
||||
|
||||
// 1. STT - 语音转文字
|
||||
String recognizedText = performStt(audioData);
|
||||
if (recognizedText == null || recognizedText.trim().isEmpty()) {
|
||||
callback.onError("语音识别失败或未识别到内容");
|
||||
return;
|
||||
}
|
||||
|
||||
logger.info("STT识别结果 - SessionId: {}, Text: {}", sessionId, recognizedText);
|
||||
callback.onSttResult(recognizedText);
|
||||
|
||||
// 2. 异步保存用户消息到数据库(不等待,保留 future 引用)
|
||||
CompletableFuture<Void> userMessageFuture = chatMemory.addMessage(
|
||||
deviceId, chatSessionId, "user", recognizedText,
|
||||
roleId, "NORMAL", System.currentTimeMillis());
|
||||
logger.info("用户消息开始异步保存 - ChatSessionId: {}", chatSessionId);
|
||||
|
||||
// 3. 查询历史记录(最近20条)
|
||||
// 等待用户消息保存完成,确保能查询到最新数据
|
||||
List<SysMessage> historyMessages = new ArrayList<>();
|
||||
try {
|
||||
userMessageFuture.join(); // 等待用户消息保存完成
|
||||
logger.info("用户消息保存完成 - ChatSessionId: {}", chatSessionId);
|
||||
|
||||
List<SysMessage> allMessages = sysMessageService.queryBySessionId(chatSessionId);
|
||||
// 获取最近20条消息(不包括刚刚保存的用户消息)
|
||||
int startIndex = Math.max(0, allMessages.size() - 21); // -21 因为包含了刚保存的用户消息
|
||||
historyMessages = allMessages.subList(startIndex, allMessages.size() - 1); // -1 排除最后一条(刚保存的)
|
||||
logger.info("查询到历史消息 {} 条 - ChatSessionId: {}", historyMessages.size(), chatSessionId);
|
||||
} catch (Exception e) {
|
||||
logger.error("保存或查询历史消息失败,将继续处理(无历史上下文) - ChatSessionId: {}", chatSessionId, e);
|
||||
}
|
||||
|
||||
// 4. 累积LLM完整响应(用于保存到数据库)
|
||||
AtomicReference<String> fullResponse = new AtomicReference<>("");
|
||||
|
||||
// 5. 创建分句缓冲器
|
||||
SentenceBufferService.SentenceBuffer sentenceBuffer =
|
||||
sentenceBufferService.createBuffer(new SentenceBufferService.SentenceCallback() {
|
||||
|
||||
private final Queue<String> sentenceQueue = new LinkedList<>();
|
||||
private final Semaphore ttsPermit = new Semaphore(1); // 保证TTS顺序执行
|
||||
private final AtomicBoolean processing = new AtomicBoolean(false);
|
||||
|
||||
@Override
|
||||
public void onSentence(String sentence) {
|
||||
if (state.isCancelled()) {
|
||||
logger.info("会话已取消,跳过句子处理 - SessionId: {}", sessionId);
|
||||
return;
|
||||
}
|
||||
|
||||
logger.info("检测到完整句子 - SessionId: {}, Sentence: {}", sessionId, sentence);
|
||||
callback.onSentenceComplete(sentence);
|
||||
|
||||
// 将句子加入队列并异步处理TTS
|
||||
sentenceQueue.offer(sentence);
|
||||
processSentenceQueue();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onComplete() {
|
||||
logger.info("LLM输出完成 - SessionId: {}", sessionId);
|
||||
|
||||
// 等待所有TTS完成
|
||||
waitForAllTtsComplete();
|
||||
|
||||
// TTS 全部成功后才保存 AI 回复(如果未取消且未失败)
|
||||
String aiResponse = fullResponse.get();
|
||||
if (aiResponse != null && !aiResponse.trim().isEmpty()
|
||||
&& !state.isCancelled() && !state.isTtsFailed()) {
|
||||
|
||||
// 使用链式调用确保在用户消息之后保存,并增加 1ms 保证顺序
|
||||
userMessageFuture.thenCompose(v -> {
|
||||
long timestamp = System.currentTimeMillis() + 1;
|
||||
return chatMemory.addMessage(deviceId, chatSessionId, "assistant",
|
||||
aiResponse, roleId, "NORMAL", timestamp);
|
||||
}).thenAccept(v -> {
|
||||
logger.info("AI响应已保存 - ChatSessionId: {}, Length: {}",
|
||||
chatSessionId, aiResponse.length());
|
||||
}).exceptionally(ex -> {
|
||||
logger.error("保存AI响应失败 - ChatSessionId: {}", chatSessionId, ex);
|
||||
return null;
|
||||
});
|
||||
} else if (state.isTtsFailed()) {
|
||||
logger.warn("TTS失败,不保存AI响应 - ChatSessionId: {}", chatSessionId);
|
||||
} else if (state.isCancelled()) {
|
||||
logger.info("会话已取消,不保存AI响应 - ChatSessionId: {}", chatSessionId);
|
||||
}
|
||||
|
||||
callback.onComplete();
|
||||
}
|
||||
|
||||
private void processSentenceQueue() {
|
||||
if (processing.compareAndSet(false, true)) {
|
||||
executorService.submit(() -> {
|
||||
try {
|
||||
while (!sentenceQueue.isEmpty() && !state.isCancelled()) {
|
||||
String sentence = sentenceQueue.poll();
|
||||
if (sentence != null) {
|
||||
processSentenceTts(sentence);
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
processing.set(false);
|
||||
// 检查是否还有新的句子
|
||||
if (!sentenceQueue.isEmpty() && !state.isCancelled()) {
|
||||
processSentenceQueue();
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
private void processSentenceTts(String sentence) {
|
||||
if (state.isCancelled()) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
ttsPermit.acquire();
|
||||
|
||||
logger.info("开始TTS合成 - SessionId: {}, Sentence: {}", sessionId, sentence);
|
||||
|
||||
CompletableFuture<Void> ttsFuture = minimaxTtsStreamService.streamTts(
|
||||
sessionId,
|
||||
sentence,
|
||||
new MinimaxTtsStreamService.AudioCallback() {
|
||||
@Override
|
||||
public void onAudioChunk(byte[] audioData) {
|
||||
if (!state.isCancelled()) {
|
||||
callback.onAudioChunk(audioData);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onComplete() {
|
||||
logger.debug("句子TTS完成 - SessionId: {}", sessionId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onError(String error) {
|
||||
logger.error("TTS错误 - SessionId: {}, Error: {}", sessionId, error);
|
||||
callback.onError(error);
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
// 等待TTS完成
|
||||
ttsFuture.get(30, TimeUnit.SECONDS);
|
||||
|
||||
} catch (InterruptedException e) {
|
||||
Thread.currentThread().interrupt();
|
||||
logger.warn("TTS处理被中断 - SessionId: {}", sessionId);
|
||||
state.markTtsFailed(); // 标记 TTS 失败
|
||||
} catch (Exception e) {
|
||||
logger.error("TTS处理失败 - SessionId: {}", sessionId, e);
|
||||
state.markTtsFailed(); // 标记 TTS 失败
|
||||
callback.onError("TTS处理失败: " + e.getMessage());
|
||||
} finally {
|
||||
ttsPermit.release();
|
||||
}
|
||||
}
|
||||
|
||||
private void waitForAllTtsComplete() {
|
||||
try {
|
||||
// 等待所有TTS完成
|
||||
ttsPermit.acquire();
|
||||
ttsPermit.release();
|
||||
} catch (InterruptedException e) {
|
||||
Thread.currentThread().interrupt();
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// 6. 构建包含历史记录的消息列表
|
||||
List<Map<String, String>> messages = buildMessagesWithHistory(historyMessages, recognizedText);
|
||||
|
||||
// 7. LLM流式调用(带历史记录)
|
||||
grokStreamService.streamChatWithHistory(
|
||||
messages,
|
||||
new GrokStreamService.TokenCallback() {
|
||||
@Override
|
||||
public void onToken(String token) {
|
||||
if (state.isCancelled()) {
|
||||
return;
|
||||
}
|
||||
// 累积完整响应
|
||||
fullResponse.updateAndGet(current -> current + token);
|
||||
|
||||
callback.onLlmToken(token);
|
||||
sentenceBuffer.addToken(token);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onComplete() {
|
||||
if (!state.isCancelled()) {
|
||||
sentenceBuffer.finish();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onError(String error) {
|
||||
logger.error("LLM调用失败 - SessionId: {}, Error: {}", sessionId, error);
|
||||
callback.onError(error);
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
} catch (Exception e) {
|
||||
logger.error("处理音频流失败(带历史记录) - SessionId: {}", sessionId, e);
|
||||
callback.onError("处理失败: " + e.getMessage());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建包含历史记录的消息列表
|
||||
*/
|
||||
private List<Map<String, String>> buildMessagesWithHistory(List<SysMessage> historyMessages,
|
||||
String currentUserMessage) {
|
||||
List<Map<String, String>> messages = new ArrayList<>();
|
||||
|
||||
// 添加系统提示
|
||||
Map<String, String> systemMessage = new HashMap<>();
|
||||
systemMessage.put("role", "system");
|
||||
systemMessage.put("content", "你是一个友好的AI助手,请用简洁、自然的语气回答用户的问题。");
|
||||
messages.add(systemMessage);
|
||||
|
||||
// 添加历史消息
|
||||
for (SysMessage msg : historyMessages) {
|
||||
Map<String, String> message = new HashMap<>();
|
||||
// 将数据库中的 sender 转换为 Grok API 需要的角色
|
||||
if ("user".equals(msg.getSender())) {
|
||||
message.put("role", "user");
|
||||
} else if ("assistant".equals(msg.getSender())) {
|
||||
message.put("role", "assistant");
|
||||
} else {
|
||||
// 跳过不认识的角色
|
||||
logger.warn("未知的消息发送方: {}", msg.getSender());
|
||||
continue;
|
||||
}
|
||||
message.put("content", msg.getMessage());
|
||||
messages.add(message);
|
||||
}
|
||||
|
||||
// 添加当前用户消息
|
||||
Map<String, String> currentMessage = new HashMap<>();
|
||||
currentMessage.put("role", "user");
|
||||
currentMessage.put("content", currentUserMessage);
|
||||
messages.add(currentMessage);
|
||||
|
||||
logger.debug("构建消息列表完成,共 {} 条消息(含系统提示和历史)", messages.size());
|
||||
return messages;
|
||||
}
|
||||
|
||||
/**
|
||||
* 取消指定会话的流处理
|
||||
*/
|
||||
public void cancelStream(String sessionId) {
|
||||
SessionState state = sessions.get(sessionId);
|
||||
if (state != null) {
|
||||
state.cancel();
|
||||
// 取消TTS
|
||||
minimaxTtsStreamService.cancelTts(sessionId);
|
||||
logger.info("取消流处理 - SessionId: {}", sessionId);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 预热TTS连接
|
||||
*
|
||||
* @param sessionId 会话ID
|
||||
* @return CompletableFuture 连接就绪后完成
|
||||
*/
|
||||
public CompletableFuture<Void> warmupTtsConnection(String sessionId) {
|
||||
logger.info("预热TTS连接 - SessionId: {}", sessionId);
|
||||
return minimaxTtsStreamService.warmupConnection(sessionId);
|
||||
}
|
||||
|
||||
/**
|
||||
* 关闭TTS连接
|
||||
*
|
||||
* @param sessionId 会话ID
|
||||
*/
|
||||
public void closeTtsConnection(String sessionId) {
|
||||
logger.info("关闭TTS连接 - SessionId: {}", sessionId);
|
||||
minimaxTtsStreamService.closeConnection(sessionId);
|
||||
}
|
||||
|
||||
/**
|
||||
* 执行STT(复用现有逻辑)
|
||||
*/
|
||||
private String performStt(byte[] audioData) {
|
||||
try {
|
||||
// 标准化输入音频为 PCM 16k 单声道
|
||||
byte[] pcmData = AudioUtils.bytesToPcm(audioData);
|
||||
|
||||
if (pcmData == null || pcmData.length == 0) {
|
||||
throw new RuntimeException("音频数据为空或转换后为空");
|
||||
}
|
||||
|
||||
// 获取STT服务(使用默认配置)
|
||||
SttService sttService = sttServiceFactory.getSttService(null);
|
||||
if (sttService == null) {
|
||||
throw new RuntimeException("无法获取STT服务");
|
||||
}
|
||||
|
||||
// 执行语音识别
|
||||
String recognizedText = sttService.recognition(pcmData);
|
||||
return recognizedText != null ? recognizedText.trim() : "";
|
||||
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException("音频标准化失败: " + e.getMessage(), e);
|
||||
} catch (Exception e) {
|
||||
logger.error("STT处理失败: {}", e.getMessage(), e);
|
||||
throw new RuntimeException("语音识别失败: " + e.getMessage(), e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 会话状态
|
||||
*/
|
||||
private static class SessionState {
|
||||
private final AtomicBoolean cancelled = new AtomicBoolean(false);
|
||||
private final AtomicBoolean ttsFailed = new AtomicBoolean(false);
|
||||
|
||||
public void cancel() {
|
||||
cancelled.set(true);
|
||||
}
|
||||
|
||||
public void reset() {
|
||||
cancelled.set(false);
|
||||
ttsFailed.set(false); // 重置时也清除 TTS 失败标记
|
||||
}
|
||||
|
||||
public boolean isCancelled() {
|
||||
return cancelled.get();
|
||||
}
|
||||
|
||||
public void markTtsFailed() {
|
||||
ttsFailed.set(true);
|
||||
}
|
||||
|
||||
public boolean isTtsFailed() {
|
||||
return ttsFailed.get();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 流处理回调接口
|
||||
*/
|
||||
public interface StreamCallback {
|
||||
/**
|
||||
* STT识别结果
|
||||
*/
|
||||
void onSttResult(String text);
|
||||
|
||||
/**
|
||||
* LLM输出token
|
||||
*/
|
||||
void onLlmToken(String token);
|
||||
|
||||
/**
|
||||
* 完整句子
|
||||
*/
|
||||
void onSentenceComplete(String sentence);
|
||||
|
||||
/**
|
||||
* TTS音频数据块
|
||||
*/
|
||||
void onAudioChunk(byte[] audioChunk);
|
||||
|
||||
/**
|
||||
* 所有处理完成
|
||||
*/
|
||||
void onComplete();
|
||||
|
||||
/**
|
||||
* 发生错误
|
||||
*/
|
||||
void onError(String error);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -182,7 +182,12 @@ public class ChatSessionServiceImpl implements ChatSessionService {
|
||||
ChatSession session = getOrCreateSession(sessionId, request.getModelId(), request.getTemplateId());
|
||||
|
||||
// 1. STT - 语音转文本
|
||||
long sttStartTime = System.currentTimeMillis();
|
||||
Map<String, Object> sttResult = performStt(request);
|
||||
long sttEndTime = System.currentTimeMillis();
|
||||
long sttDuration = sttEndTime - sttStartTime;
|
||||
logger.info("STT处理完成,sessionId={}, 耗时={}s", sessionId, sttDuration / 1000.0);
|
||||
|
||||
String recognizedText = (String) sttResult.get("text");
|
||||
|
||||
if (recognizedText == null || recognizedText.trim().isEmpty()) {
|
||||
@@ -191,6 +196,7 @@ public class ChatSessionServiceImpl implements ChatSessionService {
|
||||
result.put("sttResult", sttResult);
|
||||
result.put("sessionId", sessionId);
|
||||
result.put("timestamp", System.currentTimeMillis());
|
||||
result.put("sttDuration", sttDuration);
|
||||
return result;
|
||||
}
|
||||
|
||||
@@ -202,8 +208,12 @@ public class ChatSessionServiceImpl implements ChatSessionService {
|
||||
chatRequest.setTemplateId(request.getTemplateId());
|
||||
chatRequest.setUseFunctionCall(request.getUseFunctionCall());
|
||||
|
||||
long llmStartTime = System.currentTimeMillis();
|
||||
ChatResponse chatResponse = syncChat(chatRequest);
|
||||
|
||||
long llmEndTime = System.currentTimeMillis();
|
||||
long llmDuration = llmEndTime - llmStartTime;
|
||||
logger.info("LLM处理完成,sessionId={}, 耗时={}s", sessionId, llmDuration / 1000.0);
|
||||
|
||||
Map<String, Object> llmResult = new HashMap<>();
|
||||
llmResult.put("response", chatResponse.getResponse());
|
||||
llmResult.put("inputText", recognizedText);
|
||||
@@ -211,13 +221,17 @@ public class ChatSessionServiceImpl implements ChatSessionService {
|
||||
// 获取会话中角色的voiceName
|
||||
String voiceName = null;
|
||||
ChatSession chatSession = getOrCreateSession(sessionId, request.getModelId(), request.getTemplateId());
|
||||
if (chatSession != null && chatSession.getConversation() != null &&
|
||||
if (chatSession != null && chatSession.getConversation() != null &&
|
||||
chatSession.getConversation().role() != null) {
|
||||
voiceName = chatSession.getConversation().role().getVoiceName();
|
||||
}
|
||||
|
||||
// 3. TTS - 文本转语音
|
||||
long ttsStartTime = System.currentTimeMillis();
|
||||
Map<String, Object> ttsResult = performTts(chatResponse.getResponse(), request.getTtsConfigId(), voiceName);
|
||||
long ttsEndTime = System.currentTimeMillis();
|
||||
long ttsDuration = ttsEndTime - ttsStartTime;
|
||||
logger.info("TTS处理完成,sessionId={}, 耗时={}s", sessionId, ttsDuration / 1000.0);
|
||||
|
||||
// 组装完整响应
|
||||
Map<String, Object> result = new HashMap<>();
|
||||
@@ -226,6 +240,13 @@ public class ChatSessionServiceImpl implements ChatSessionService {
|
||||
result.put("ttsResult", ttsResult);
|
||||
result.put("sessionId", sessionId);
|
||||
result.put("timestamp", System.currentTimeMillis());
|
||||
result.put("sttDuration", sttDuration);
|
||||
result.put("llmDuration", llmDuration);
|
||||
result.put("ttsDuration", ttsDuration);
|
||||
|
||||
logger.info("语音对话完成,sessionId={}, STT耗时={}s, LLM耗时={}s, TTS耗时={}s, 总耗时={}s",
|
||||
sessionId, sttDuration / 1000.0, llmDuration / 1000.0, ttsDuration / 1000.0,
|
||||
(sttDuration + llmDuration + ttsDuration) / 1000.0);
|
||||
|
||||
return result;
|
||||
|
||||
|
||||
@@ -311,4 +311,23 @@ openapi:
|
||||
contact-email: xiaozhi@qq.com
|
||||
version: 1.0
|
||||
external-description: xiaozhi API Docs
|
||||
external-url: https://github.com/joey-zhou/xiaozhi-esp32-server-java
|
||||
external-url: https://github.com/joey-zhou/xiaozhi-esp32-server-java
|
||||
|
||||
# 语音流式对话配置
|
||||
xiaozhi:
|
||||
voice-stream:
|
||||
grok:
|
||||
api-key: xai-KKU4O5WumrQowiPc2qSF223L7Nw2XAwDOgrBxQRtrhKAbAri7JBDJuiv6HNwBcvNTnO026YPUeijwGqq
|
||||
api-url: https://caddy.liqupan.cn/v1/chat/completions
|
||||
model: grok-4-1-fast-non-reasoning
|
||||
minimax:
|
||||
api-key: eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJHcm91cE5hbWUiOiJ2YmlvZGJkcCIsIlVzZXJOYW1lIjoidHNldCIsIkFjY291bnQiOiIiLCJTdWJqZWN0SUQiOiIxOTkyOTAyNTAzMzg5MjA1NDY3IiwiUGhvbmUiOiIiLCJHcm91cElEIjoiMTk5MjkwMjUwMzM4MDgyMDk1NSIsIlBhZ2VOYW1lIjoiIiwiTWFpbCI6InZiaW9kYmRwQGdtYWlsLmNvbSIsIkNyZWF0ZVRpbWUiOiIyMDI1LTEyLTA2IDE1OjQzOjUxIiwiVG9rZW5UeXBlIjoxLCJpc3MiOiJtaW5pbWF4In0.hf1M4cPe27Sz_QeSyYODqM6yrN8aQ68nRwYB7iQ3uO5nu0NSN7qHQRVxAt2tVuoOf503SEx5F-PfYyC85OFJFhWNNhhDuFuxPIz97LVz1oQUlIejZ_BmCMj4iWwGXTUmEugGK1lzcsI6eJz8eRjQHsxOgJJmxPLXWHTPs1gDqtnckAgjOBRQJSadP58Xe9EdI6n-2_SL_ni3Tqm3LuWq9tUPJa5WgDMZX9IDK7XXyZy0i1GoSXmp8P1O1JmIecBVUoCzyYFwWW787BNdYiyEV3UrFjC_4onJ8Tzh-eGq84-rtxBR5FKO2MpNU_I0xI-W3YJxOEl_JPXXGgX5ASTKNw
|
||||
group-id: ${MINIMAX_GROUP_ID:your-group-id}
|
||||
ws-url: wss://api.minimax.io/ws/v1/t2a_v2
|
||||
model: speech-2.6-hd
|
||||
voice-id: Chinese (Mandarin)_BashfulGirl
|
||||
speed: 1.0
|
||||
vol: 1.0
|
||||
pitch: 0
|
||||
audio-sample-rate: 32000
|
||||
bitrate: 128000
|
||||
Reference in New Issue
Block a user