diff --git a/docs/GROK_API_DEBUG_GUIDE.md b/docs/GROK_API_DEBUG_GUIDE.md new file mode 100644 index 0000000..672afc9 --- /dev/null +++ b/docs/GROK_API_DEBUG_GUIDE.md @@ -0,0 +1,311 @@ +# 🔍 Grok API 调试指南 + +## 数据格式说明 + +### WebClient 处理后的数据格式 + +Spring WebFlux 的 `WebClient.bodyToFlux(String.class)` 会自动处理 SSE 流,每个元素直接就是一个完整的 JSON 字符串: + +```json +{ + "id": "8670de6a-75b3-97e2-fa5a-c2e3d7f1d0f2", + "object": "chat.completion.chunk", + "created": 1764858564, + "model": "grok-4-1-fast-non-reasoning", + "choices": [{ + "index": 0, + "delta": { + "content": "你", + "role": "assistant" + }, + "finish_reason": null + }], + "system_fingerprint": "fp_174298dd8e" +} +``` + +### 字段说明 + +| 字段 | 说明 | 示例值 | +|------|------|--------| +| `id` | 请求唯一ID | `"8670de6a-75b3-97e2-fa5a-c2e3d7f1d0f2"` | +| `object` | 对象类型 | `"chat.completion.chunk"` | +| `created` | 创建时间戳 | `1764858564` | +| `model` | 使用的模型 | `"grok-4-1-fast-non-reasoning"` | +| `choices[0].index` | 选择索引 | `0` | +| `choices[0].delta.content` | 当前token内容 | `"你"` | +| `choices[0].delta.role` | 角色(首次出现) | `"assistant"` | +| `choices[0].finish_reason` | 结束原因 | `null`, `"stop"`, `"length"`, `"content_filter"` | +| `system_fingerprint` | 系统指纹 | `"fp_174298dd8e"` | + +## finish_reason 说明 + +| 值 | 含义 | 处理建议 | +|----|------|----------| +| `null` | 流还在继续 | 继续接收token | +| `"stop"` | 正常结束 | ✅ 回复完整 | +| `"length"` | 达到最大token限制 | ⚠️ 回复可能被截断,建议增加 `max_tokens` | +| `"content_filter"` | 内容被过滤 | ⚠️ 回复包含敏感内容 | + +## 解析流程 + +``` +收到数据 → 跳过空行 → 检查 [DONE] → 去除 "data: " 前缀(如果有) + ↓ +解析 JSON → 验证 choices 字段 → 获取 choices[0] + ↓ +检查 delta 字段 → 提取 content → 回调 onToken(content) + ↓ +检查 finish_reason → 记录日志 → 完成 +``` + +## 日志级别配置 + +在 `application.yml` 中配置: + +```yaml +logging: + level: + com.xiaozhi.dialogue.llm.GrokStreamService: DEBUG # 或 TRACE +``` + +### 各级别输出内容 + +#### TRACE(最详细) +``` +TRACE - 收到原始SSE数据: {"id":"8670de6a... +TRACE - 提取到token: 你 +TRACE - 提取到token: 好 +``` + +#### DEBUG +``` +DEBUG - 请求体: {"model":"grok-4-1-fast-non-reasoning",...} +DEBUG - 收到角色信息: assistant +DEBUG - JSON中缺少choices字段(如果有问题) +``` + +#### INFO +``` +INFO - 开始调用Grok API - Model: grok-4-1-fast-non-reasoning, UserMessage: 你好... +INFO - 流结束原因: length +INFO - Grok API流式调用完成 +``` + +#### ERROR +``` +ERROR - JSON解析失败 - 原始数据: xxx +ERROR - LLM调用失败: Connection refused +``` + +## 常见问题排查 + +### 问题1: 没有收到任何token + +**可能原因**: +- API Key 错误 +- 网络连接问题 +- API URL 配置错误 + +**检查步骤**: +1. 查看日志是否有 `INFO - 开始调用Grok API` +2. 检查是否有错误日志 +3. 验证 `application.yml` 中的配置: + ```yaml + xiaozhi: + voice-stream: + grok: + api-key: ${GROK_API_KEY} + api-url: https://api.x.ai/v1/chat/completions + ``` + +### 问题2: token乱码或格式错误 + +**检查点**: +1. 启用 TRACE 日志查看原始数据 +2. 确认 JSON 格式是否正确 +3. 查看是否有 `ERROR - JSON解析失败` 日志 + +### 问题3: 回复被截断 + +**日志特征**: +``` +INFO - 流结束原因: length +``` + +**解决方案**: +增加 `DEFAULT_MAX_TOKENS` 常量: +```java +private static final int DEFAULT_MAX_TOKENS = 4000; // 或更大 +``` + +### 问题4: WebClient 连接超时 + +**错误日志**: +``` +ERROR - LLM调用失败: Connection timeout +``` + +**解决方案**: +在 `initWebClient()` 中增加超时配置: +```java +this.webClient = WebClient.builder() + .baseUrl(apiUrl) + .codecs(configurer -> configurer.defaultCodecs().maxInMemorySize(10 * 1024 * 1024)) + .defaultHeader(HttpHeaders.ACCEPT, "text/event-stream") + .build(); +``` + +## 测试建议 + +### 1. 使用 curl 测试 API + +```bash +curl -X POST https://api.x.ai/v1/chat/completions \ + -H "Authorization: Bearer YOUR_API_KEY" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "grok-4-1-fast-non-reasoning", + "messages": [{"role": "user", "content": "你好"}], + "stream": true + }' +``` + +预期输出: +``` +data: {"id":"xxx","object":"chat.completion.chunk",...} +data: {"id":"xxx","object":"chat.completion.chunk",...} +data: [DONE] +``` + +### 2. 检查配置 + +```java +// 在 GrokStreamService 中添加测试方法 +@PostConstruct +public void testConfig() { + logger.info("Grok配置:"); + logger.info(" API URL: {}", voiceStreamConfig.getGrok().getApiUrl()); + logger.info(" Model: {}", voiceStreamConfig.getGrok().getModel()); + logger.info(" API Key配置: {}", voiceStreamConfig.getGrok().getApiKey() != null ? "已配置" : "未配置"); +} +``` + +### 3. 单元测试 + +```java +@Test +public void testParseSSE() { + String testJson = """ + { + "id": "test", + "choices": [{ + "delta": { + "content": "测试" + } + }] + } + """; + + // 测试解析逻辑 +} +``` + +## 性能监控 + +### 关键指标 + +1. **首 token 延迟**:从发送请求到收到第一个 token 的时间 +2. **token 吞吐率**:每秒收到的 token 数量 +3. **总响应时间**:从开始到收到 [DONE] 的时间 + +### 添加监控日志 + +```java +private long startTime; +private int tokenCount = 0; + +public void streamChat(...) { + startTime = System.currentTimeMillis(); + // ... + .doOnNext(token -> { + tokenCount++; + if (tokenCount == 1) { + long firstTokenLatency = System.currentTimeMillis() - startTime; + logger.info("首token延迟: {}ms", firstTokenLatency); + } + }) + .doOnComplete(() -> { + long totalTime = System.currentTimeMillis() - startTime; + double tps = tokenCount / (totalTime / 1000.0); + logger.info("总计: {}个token, 耗时: {}ms, 吞吐率: {:.2f} tokens/s", + tokenCount, totalTime, tps); + }) +} +``` + +## 调试技巧 + +### 1. 保存原始响应 + +```java +// 在parseSSE中 +logger.trace("原始响应: {}", line); +``` + +### 2. 使用断点调试 + +在以下位置设置断点: +- `parseSSE()` 方法入口 +- `sink.next(content)` 之前 +- `callback.onToken()` 调用处 + +### 3. 模拟测试 + +创建测试类: +```java +@Test +public void testGrokStreamService() { + GrokStreamService service = new GrokStreamService(); + service.streamChat("你好", null, new TokenCallback() { + @Override + public void onToken(String token) { + System.out.println("Token: " + token); + } + + @Override + public void onComplete() { + System.out.println("完成"); + } + + @Override + public void onError(String error) { + System.err.println("错误: " + error); + } + }); + + // 等待完成 + Thread.sleep(10000); +} +``` + +## 总结 + +正常工作流程: +1. ✅ 看到 `INFO - 开始调用Grok API` +2. ✅ 看到多个 `TRACE - 提取到token: xxx` +3. ✅ 看到 `INFO - 流结束原因: stop`(或 length) +4. ✅ 看到 `INFO - Grok API流式调用完成` + +如果出现问题: +1. 检查日志级别是否足够详细 +2. 查找 ERROR 级别的日志 +3. 验证配置是否正确 +4. 使用 curl 直接测试 API +5. 检查网络连接 + +--- + +**需要更多帮助?** 提供完整的错误日志和配置信息。 + + diff --git a/docs/TTS_CONNECTION_WARMUP_TEST.md b/docs/TTS_CONNECTION_WARMUP_TEST.md new file mode 100644 index 0000000..d9c7f58 --- /dev/null +++ b/docs/TTS_CONNECTION_WARMUP_TEST.md @@ -0,0 +1,251 @@ +# TTS连接预热优化测试指南 + +## 测试目标 + +验证TTS WebSocket连接预热和复用功能,确保: +1. 用户连接时成功预热TTS连接 +2. 多句话复用同一个TTS连接 +3. 连接断开后自动重连 +4. 用户断开时正确清理连接 +5. 性能提升符合预期 + +## 测试环境准备 + +1. 启动后端服务 +2. 确保MiniMax TTS API配置正确 +3. 使用微信小程序或WebSocket客户端 + +## 测试用例 + +### 测试1:连接预热验证 + +**目的**:验证用户连接时是否成功预热TTS连接 + +**步骤**: +1. 前端连接到语音流WebSocket:`ws://your-server:8091/ws/voice-stream` +2. 观察后端日志 + +**预期日志**: +``` +[VoiceStreamHandler] 语音流WebSocket连接建立 - SessionId: xxx, UserId: null +[VoiceStreamService] 预热TTS连接 - SessionId: xxx +[MinimaxTtsStreamService] 开始预热TTS连接 - SessionId: xxx +[MinimaxTtsStreamService] MiniMax TTS连接已建立 - SessionId: xxx +[MinimaxTtsStreamService] 收到connected_success,发送task_start - SessionId: xxx +[MinimaxTtsStreamService] 收到task_started,连接就绪 - SessionId: xxx +[MinimaxTtsStreamService] TTS连接预热成功 - SessionId: xxx, 耗时: XXXms +[VoiceStreamHandler] TTS连接预热成功 - SessionId: xxx +``` + +**成功标准**: +- ✅ 连接建立后立即开始预热 +- ✅ 预热在1-2秒内完成 +- ✅ 连接状态变为IDLE + +--- + +### 测试2:连接复用验证 + +**目的**:验证多句话是否复用同一个TTS连接 + +**步骤**: +1. 确保已建立连接并预热成功 +2. 发送一段音频(用户说:"你好,今天天气怎么样?") +3. 等待AI回复多句话(例如:"你好!今天天气很好。阳光明媚。") +4. 观察后端日志 + +**预期日志**: +``` +[VoiceStreamService] STT识别结果 - SessionId: xxx, Text: 你好,今天天气怎么样? +[VoiceStreamService] 检测到完整句子 - SessionId: xxx, Sentence: 你好! +[MinimaxTtsStreamService] 使用已有TTS连接 - SessionId: xxx, Text: 你好! +[MinimaxTtsStreamService] 缓冲音频块 - SessionId: xxx: 2048 bytes, 总计: 2048 bytes +... +[MinimaxTtsStreamService] TTS完成 - SessionId: xxx, Text: 你好!, 音频总大小: 24576 bytes +[MinimaxTtsStreamService] TTS完成 - SessionId: xxx, 耗时: 300ms(复用连接节省约1秒) +[VoiceStreamService] 检测到完整句子 - SessionId: xxx, Sentence: 今天天气很好。 +[MinimaxTtsStreamService] 使用已有TTS连接 - SessionId: xxx, Text: 今天天气很好。 +[MinimaxTtsStreamService] TTS完成 - SessionId: xxx, 耗时: 250ms(复用连接节省约1秒) +``` + +**成功标准**: +- ✅ 日志显示"使用已有TTS连接"而不是"创建MiniMax TTS连接" +- ✅ 每句TTS完成时间在200-500ms(而非1-1.5s) +- ✅ 日志显示"复用连接节省约1秒" +- ✅ 没有重复的连接建立日志 + +--- + +### 测试3:自动重连验证 + +**目的**:验证连接断开后是否自动重连 + +**步骤**: +1. 正常建立连接 +2. 手动关闭MiniMax TTS WebSocket连接(模拟网络断开) +3. 发送一段音频触发TTS +4. 观察后端日志 + +**预期日志**: +``` +[MinimaxTtsStreamService] MiniMax TTS连接关闭 - SessionId: xxx, Code: 1006, Reason: ..., Remote: true +[MinimaxTtsStreamService] 尝试自动重连TTS - SessionId: xxx, 第1次重连 +[MinimaxTtsStreamService] 开始预热TTS连接 - SessionId: xxx +[MinimaxTtsStreamService] TTS连接预热成功 - SessionId: xxx, 耗时: XXXms +[MinimaxTtsStreamService] 使用已有TTS连接 - SessionId: xxx, Text: ... +``` + +**成功标准**: +- ✅ 检测到连接断开 +- ✅ 自动触发重连(最多3次) +- ✅ 重连成功后可正常使用 +- ✅ 重连失败3次后停止尝试 + +--- + +### 测试4:连接清理验证 + +**目的**:验证用户断开时是否正确清理TTS连接 + +**步骤**: +1. 建立连接并预热成功 +2. 前端主动断开WebSocket连接 +3. 观察后端日志 + +**预期日志**: +``` +[VoiceStreamHandler] 语音流WebSocket连接关闭 - SessionId: xxx, Status: CloseStatus[code=1000, reason=null] +[VoiceStreamService] 关闭TTS连接 - SessionId: xxx +[MinimaxTtsStreamService] 关闭TTS连接 - SessionId: xxx +[MinimaxTtsStreamService] 发送task_finish - SessionId: xxx +[MinimaxTtsStreamService] MiniMax TTS连接关闭 - SessionId: xxx, Code: 1000, Reason: ... +``` + +**成功标准**: +- ✅ 正确发送task_finish +- ✅ 正常关闭TTS连接 +- ✅ 从连接池中移除 +- ✅ 无内存泄漏 + +--- + +### 测试5:性能对比验证 + +**目的**:验证性能提升是否符合预期 + +**测试场景**:AI回复3句话 + +**修改前性能**: +- 首句延迟:1-1.5秒 +- 第二句延迟:1-1.5秒 +- 第三句延迟:1-1.5秒 +- **总延迟**:3-4.5秒 + +**修改后预期性能**: +- 首句延迟:0.5秒(仅TTS合成) +- 第二句延迟:0.2-0.3秒 +- 第三句延迟:0.2-0.3秒 +- **总延迟**:0.9-1.1秒 + +**性能提升**:约 **65-70%** + +**验证方法**: +1. 在日志中记录每句TTS的开始和结束时间 +2. 计算总耗时 +3. 与预期值对比 + +--- + +### 测试6:并发测试 + +**目的**:验证多用户同时使用时的连接管理 + +**步骤**: +1. 创建3-5个WebSocket连接(模拟多用户) +2. 每个连接同时发送音频 +3. 观察后端日志和性能 + +**成功标准**: +- ✅ 每个用户有独立的TTS连接 +- ✅ 连接之间互不影响 +- ✅ 无连接混乱 +- ✅ 性能稳定 + +--- + +## 性能监控指标 + +在日志中关注以下关键指标: + +1. **连接建立时间**:`TTS连接预热成功 - SessionId: xxx, 耗时: XXXms` +2. **连接复用次数**:统计"使用已有TTS连接"出现次数 +3. **TTS合成时间**:`TTS完成 - SessionId: xxx, 耗时: XXXms` +4. **重连次数**:`尝试自动重连TTS - SessionId: xxx, 第X次重连` +5. **连接状态变化**:观察状态从CONNECTING → CONNECTED → TASK_STARTED → IDLE的转换 + +--- + +## 常见问题排查 + +### 问题1:预热失败 + +**现象**:`TTS连接预热失败 - SessionId: xxx` + +**可能原因**: +- MiniMax API配置错误 +- 网络连接问题 +- API密钥过期 + +**解决方法**: +- 检查application.yml中的MiniMax配置 +- 验证API密钥是否有效 +- 检查网络连接 + +### 问题2:连接未复用 + +**现象**:每次TTS都显示"创建MiniMax TTS连接" + +**可能原因**: +- 连接状态管理错误 +- 连接在使用前被关闭 + +**解决方法**: +- 检查连接状态日志 +- 确认is_final后状态设为IDLE而非关闭 + +### 问题3:自动重连失败 + +**现象**:重连3次后仍无法恢复 + +**可能原因**: +- MiniMax服务问题 +- 网络持续不稳定 + +**解决方法**: +- 检查MiniMax服务状态 +- 降级为按需创建连接 + +--- + +## 测试清单 + +- [ ] 测试1:连接预热验证 +- [ ] 测试2:连接复用验证 +- [ ] 测试3:自动重连验证 +- [ ] 测试4:连接清理验证 +- [ ] 测试5:性能对比验证 +- [ ] 测试6:并发测试 +- [ ] 检查无内存泄漏 +- [ ] 检查无连接泄漏 +- [ ] 压力测试(10+并发用户) + +--- + +## 测试结论 + +完成所有测试后,记录: +1. 实际性能提升百分比 +2. 连接复用成功率 +3. 自动重连成功率 +4. 发现的问题和改进建议 + diff --git a/docs/VOICE_CHAT_INTERACTION.md b/docs/VOICE_CHAT_INTERACTION.md new file mode 100644 index 0000000..3fc376e --- /dev/null +++ b/docs/VOICE_CHAT_INTERACTION.md @@ -0,0 +1,158 @@ +# 语音对话接口交互文档 + +本文档详细描述了前端 `webUI` 与后端 `server` 之间关于语音对话功能(`/voice-chat`)的交互流程、参数传递及返回值结构。 + +## 1. 概述 + +语音对话功能允许用户录制一段语音,前端将其上传至后端。后端依次执行以下操作: +1. **STT (Speech-to-Text)**:将语音转换为文本。 +2. **LLM (Large Language Model)**:将识别出的文本作为输入,获取 AI 的文本回复。 +3. **TTS (Text-to-Speech)**:将 AI 的文本回复转换为语音。 + +最终,后端将用户识别文本、AI 回复文本及合成的语音数据一次性返回给前端。 + +## 2. 前端调用 (WebUI) + +前端主要涉及的文件为 `webUI/src/components/ChatBox.vue` 和 `webUI/src/utils/api.js`。 + +### 2.1 调用方式 + +在 `ChatBox.vue` 中,当用户完成录音后,调用 `handleVoiceModeMessage` 方法,进而调用 `voiceAPI.voiceChat`。 + +底层通过 `uni.uploadFile` 发起 `multipart/form-data` 类型的 POST 请求。 + +### 2.2 请求参数 + +前端向后端发送的请求包含 **文件** 和 **表单数据 (FormData)**。 + +* **URL**: `/api/chat/voice-chat` (由 `config.js` 中的 `VOICE_CHAT` 常量定义) +* **Method**: `POST` +* **Header**: `Authorization: Bearer ` +* **文件部分**: + * `name`: `"audio"` + * `filePath`: 录音文件的本地路径 (e.g., `.aac` 或 `.wav`) + +* **表单数据 (FormData)**: + +| 参数名 | 类型 | 必填 | 说明 | 来源 | +| :--- | :--- | :--- | :--- | :--- | +| `sessionId` | String | 否 | 会话 ID,用于保持上下文 | `conversationId.value` | +| `modelId` | Integer | 否 | 模型 ID | `characterConfig.modelId` | +| `templateId` | Integer | 否 | 模板 ID | `characterConfig.templateId` | +| `voiceStyle` | String | 否 | 语音风格 (用于 TTS) | `options.voiceStyle` | +| `ttsConfigId` | Integer | 否 | TTS 配置 ID | `aiConfig.ttsId` | +| `sttConfigId` | Integer | 否 | STT 配置 ID | `aiConfig.sttId` | +| `useFunctionCall` | Boolean | 否 | 是否使用函数调用 | 默认为 `false` | + +### 2.3 响应处理 + +前端接收到后端返回的 JSON 数据后,进行如下解析: + +1. **用户文本**: 从 `sttResult.text` 获取,显示在聊天界面右侧。 +2. **AI 回复**: 从 `llmResult.response` 获取,显示在聊天界面左侧。 +3. **语音播放**: 优先使用 `ttsResult.audioBase64` (Base64 编码音频),如果没有则使用 `ttsResult.audioPath` (音频 URL) 进行播放。 + +--- + +## 3. 后端处理 (Server) + +后端入口为 `server/src/main/java/com/xiaozhi/controller/ChatController.java`,核心逻辑在 `ChatSessionServiceImpl.java`。 + +### 3.1 接口定义 + +* **Controller**: `ChatController` +* **Path**: `/api/chat/voice-chat` +* **Consumes**: `multipart/form-data` + +### 3.2 处理流程 + +1. **接收文件**: 后端支持字段名为 `audioFile`, `file`, 或 `audio` 的文件上传 (前端使用 `audio`)。 +2. **参数解析**: 解析 `sessionId`, `modelId` 等参数。 +3. **文件验证**: 检查文件大小 (1KB - 50MB)、格式 (支持 mp3, wav, m4a, aac 等) 和 MIME 类型。 +4. **音频处理**: + * 将上传的音频文件保存为临时文件。 + * 使用 `AudioUtils` 将音频转换为 **PCM 16k 单声道** 格式(适配 STT 引擎)。 +5. **业务逻辑 (`ChatSessionService.voiceChat`)**: + * **STT**: 调用配置的 STT 服务识别语音,得到 `recognizedText`。 + * **LLM**: 如果识别到文本,调用 `syncChat` 获取 AI 回复 `chatResponse`。 + * **TTS**: 调用 TTS 服务将 AI 回复转换为语音,生成音频文件并读取为 Base64。 +6. **结果封装**: 将 STT、LLM、TTS 的结果封装到 Map 中返回。 + +--- + +## 4. 接口规范 + +### 4.1 请求结构 + +**POST** `/api/chat/voice-chat` +**Content-Type**: `multipart/form-data` + +**Body**: +* `audio`: [二进制文件数据] +* `sessionId`: "session_12345" +* `modelId`: 10 +* `templateId`: 6 +* ... + +### 4.2 响应结构 + +**Content-Type**: `application/json` + +```json +{ + "code": 200, + "message": "语音对话成功", + "data": { + "sessionId": "session_12345", + "timestamp": 1717660000000, + + // 1. STT 结果 + "sttResult": { + "text": "你好,请介绍一下你自己。", // 用户语音识别结果 + "audioSize": 32000, + "sttProvider": "vosk" + }, + + // 2. LLM 结果 + "llmResult": { + "response": "你好!我是蔚AI,很高兴为你服务。", // AI 回复文本 + "inputText": "你好,请介绍一下你自己。" + }, + + // 3. TTS 结果 + "ttsResult": { + "audioBase64": "UklGRi...", // Base64 编码的音频数据 (用于直接播放) + "audioPath": "audio/output/...", // 服务器音频文件路径 + "timestamp": 1717660005000 + }, + + // 性能统计 (耗时: ms) + "sttDuration": 500, + "llmDuration": 1200, + "ttsDuration": 800, + + // 文件元数据 + "originalFileName": "temp_audio.aac", + "fileSize": 15000, + "contentType": "audio/aac", + "description": null + } +} +``` + +### 4.3 错误响应 + +```json +{ + "code": 400, // 或 500 + "message": "请求参数错误: 音频文件不能为空", + "data": null +} +``` + +## 5. 总结 + +* **交互模式**: 同步一次性交互。前端发送音频,等待后端完成所有处理(识别+对话+合成)后,一次性接收所有数据。 +* **音频格式**: 前端通常录制 `aac` 或 `wav`,后端统一转码为 `pcm` 进行处理。 +* **回退机制**: 如果后端处理失败,前端 `ChatBox.vue` 会捕获异常并提示用户,或使用本地模拟回复(在未登录等特定情况下)。 + diff --git a/docs/WEBSOCKET_VOICE_STREAM_DESIGN.md b/docs/WEBSOCKET_VOICE_STREAM_DESIGN.md new file mode 100644 index 0000000..0ec0a88 --- /dev/null +++ b/docs/WEBSOCKET_VOICE_STREAM_DESIGN.md @@ -0,0 +1,744 @@ +# WebSocket 实时语音流式对话架构设计 + +## 1. 方案概述 + +本方案设计了一套**"伪实时"语音交互系统**,结合前端 VAD(语音活动检测)和后端流式处理,实现低延迟的语音对话体验。 + +### 核心特点 + +- **前端保持简单**:复用现有 VAD 逻辑,用户说完话才上传完整音频 +- **后端流式处理**:STT 同步处理,LLM+TTS 流式串联 +- **极低感知延迟**:用户说完话后 1 秒内听到第一句回复 + +### 技术选型 + +- **前端通信**:WebSocket (双向实时通信) +- **STT**:复用现有 Vosk/其他 STT 服务(同步处理完整音频) +- **LLM**:Grok API (`https://api.x.ai/v1/chat/completions`) + SSE Stream +- **TTS**:MiniMax WebSocket TTS (`wss://api.minimax.io/ws/v1/t2a_v2`) + Stream + +--- + +## 2. 整体架构流程 + +``` +┌─────────────┐ +│ 前端 UI │ +└──────┬──────┘ + │ 1. 用户说话 + ↓ +┌─────────────────┐ +│ RecorderManager │ (VAD 检测说完) +│ (PCM/AAC) │ +└────────┬────────┘ + │ 2. WebSocket 发送完整音频 (二进制) + ↓ + ┌────────────────────┐ + │ 后端 WebSocket │ + │ Handler │ + └────────┬───────────┘ + │ + ↓ 3. 调用 STT (同步) + ┌────────────────┐ + │ STT Service │ → "你好,今天天气怎么样?" + │ (复用现有) │ + └────────┬───────┘ + │ + ↓ 4. 流式调用 Grok LLM (HTTP SSE) + ┌─────────────────────────┐ + │ Grok API Stream │ + │ https://api.x.ai/v1/... │ + └────────┬────────────────┘ + │ Token 流: "今", "天", "天", "气", "很", "好", "。", ... + ↓ + ┌────────────────────┐ + │ 分句缓冲器 │ (检测标点) + └────────┬───────────┘ + │ 完整句子: "今天天气很好。" + ↓ 5. 为每句调用 MiniMax TTS (WebSocket Stream) + ┌─────────────────────────────┐ + │ MiniMax TTS WebSocket │ + │ wss://api.minimax.io/ws/... │ + └────────┬────────────────────┘ + │ 音频流 (Hex → Bytes) + ↓ 6. 实时转发给前端 (WebSocket 二进制) + ┌────────────────────┐ + │ 前端 WebSocket │ + │ onMessage │ + └────────┬───────────┘ + │ 7. WebAudioContext 实时播放 + ↓ + ┌────────────────────┐ + │ 音频播放队列 │ + │ (无缝连续播放) │ + └────────────────────┘ +``` + +--- + +## 3. 前端改造方案 (ChatBox.vue) + +### 3.1 保留的部分 + +✅ **现有 VAD 逻辑**: +- `recorderManager` 配置 +- `onStart`、`onFrameRecorded`(用于波形可视化) +- `onStop`(核心触发点) +- `calculateVolumeRMS`、`vadConfig`(静音检测) + +### 3.2 需要修改的部分 + +#### A. 建立 WebSocket 连接 + +```javascript +// 新增:WebSocket 连接管理 +const voiceWebSocket = ref(null); + +const connectVoiceWebSocket = () => { + const wsUrl = `ws://192.168.3.13:8091/ws/voice-stream`; + + voiceWebSocket.value = uni.connectSocket({ + url: wsUrl, + success: () => { + console.log('WebSocket 连接成功'); + } + }); + + // 监听消息(接收音频流) + voiceWebSocket.value.onMessage((res) => { + if (res.data instanceof ArrayBuffer) { + // 收到音频数据,加入播放队列 + playAudioChunk(res.data); + } else { + // JSON 控制消息 + const msg = JSON.parse(res.data); + if (msg.event === 'stt_done') { + addMessage('user', msg.text); + voiceState.value = 'thinking'; + } else if (msg.event === 'llm_start') { + voiceState.value = 'speaking'; + } + } + }); + + voiceWebSocket.value.onError((err) => { + console.error('WebSocket 错误:', err); + }); +}; + +// 进入语音模式时连接 +const toggleVoiceMode = () => { + if (isVoiceMode.value) { + connectVoiceWebSocket(); + } else { + voiceWebSocket.value?.close(); + } +}; +``` + +#### B. 修改音频上传逻辑 + +```javascript +// 修改 handleVoiceModeMessage +const handleVoiceModeMessage = async (filePath) => { + if (!isVoiceMode.value) return; + + voiceState.value = 'thinking'; + + try { + // 读取录音文件为 ArrayBuffer + const fs = uni.getFileSystemManager(); + const audioData = await new Promise((resolve, reject) => { + fs.readFile({ + filePath: filePath, + success: (res) => resolve(res.data), + fail: reject + }); + }); + + // 通过 WebSocket 发送音频 + voiceWebSocket.value.send({ + data: audioData, + success: () => { + console.log('音频已发送'); + }, + fail: (err) => { + console.error('发送失败:', err); + } + }); + + } catch (error) { + console.error('处理音频失败:', error); + voiceState.value = 'idle'; + } +}; +``` + +#### C. 实现流式音频播放 + +```javascript +// 音频播放队列 +const audioQueue = ref([]); +const isPlayingAudio = ref(false); + +const playAudioChunk = (arrayBuffer) => { + audioQueue.value.push(arrayBuffer); + if (!isPlayingAudio.value) { + processAudioQueue(); + } +}; + +const processAudioQueue = async () => { + if (audioQueue.value.length === 0) { + isPlayingAudio.value = false; + // 播放完成,回到待机状态 + if (isVoiceMode.value && isAutoVoiceMode.value) { + startVoiceRecording(); + } else { + voiceState.value = 'idle'; + } + return; + } + + isPlayingAudio.value = true; + const chunk = audioQueue.value.shift(); + + // 使用 InnerAudioContext (需先转为临时文件) + const fs = uni.getFileSystemManager(); + const tempPath = `${wx.env.USER_DATA_PATH}/temp_audio_${Date.now()}.mp3`; + + fs.writeFile({ + filePath: tempPath, + data: chunk, + encoding: 'binary', + success: () => { + const audio = uni.createInnerAudioContext(); + audio.src = tempPath; + audio.onEnded(() => { + processAudioQueue(); // 播放下一块 + }); + audio.onError(() => { + processAudioQueue(); // 出错也继续 + }); + audio.play(); + } + }); +}; +``` + +--- + +## 4. 后端实现方案 (Java) + +### 4.1 WebSocket Handler + +```java +@ServerEndpoint(value = "/ws/voice-stream") +@Component +public class VoiceStreamHandler { + + private static final Logger logger = LoggerFactory.getLogger(VoiceStreamHandler.class); + + @Autowired + private SttService sttService; + + @OnOpen + public void onOpen(Session session) { + logger.info("WebSocket 连接建立: {}", session.getId()); + } + + @OnMessage + public void onBinaryMessage(ByteBuffer audioBuffer, Session session) { + logger.info("收到音频数据: {} bytes", audioBuffer.remaining()); + + // 转为 byte[] + byte[] audioData = new byte[audioBuffer.remaining()]; + audioBuffer.get(audioData); + + // 异步处理(避免阻塞 WebSocket 线程) + CompletableFuture.runAsync(() -> { + processVoiceStream(audioData, session); + }); + } + + @OnClose + public void onClose(Session session) { + logger.info("WebSocket 连接关闭: {}", session.getId()); + } + + @OnError + public void onError(Session session, Throwable error) { + logger.error("WebSocket 错误: {}", session.getId(), error); + } +} +``` + +### 4.2 核心处理流程 + +```java +private void processVoiceStream(byte[] audioData, Session session) { + try { + // ========== 1. STT (同步处理) ========== + long sttStart = System.currentTimeMillis(); + String recognizedText = performStt(audioData); + long sttDuration = System.currentTimeMillis() - sttStart; + + if (recognizedText == null || recognizedText.trim().isEmpty()) { + sendJsonMessage(session, Map.of("event", "error", "message", "未识别到语音")); + return; + } + + logger.info("STT 完成 ({}ms): {}", sttDuration, recognizedText); + + // 通知前端识别结果 + sendJsonMessage(session, Map.of( + "event", "stt_done", + "text", recognizedText, + "duration", sttDuration + )); + + // ========== 2. LLM Stream + TTS Stream (串联) ========== + streamLlmAndTts(recognizedText, session); + + } catch (Exception e) { + logger.error("处理语音流失败", e); + sendJsonMessage(session, Map.of("event", "error", "message", e.getMessage())); + } +} +``` + +### 4.3 STT 处理(复用现有) + +```java +private String performStt(byte[] audioData) { + try { + // 转换为 PCM 16k 单声道 + byte[] pcmData = AudioUtils.bytesToPcm(audioData); + + // 调用 STT 服务(复用现有逻辑) + String text = sttService.recognition(pcmData); + + return text != null ? text.trim() : ""; + + } catch (Exception e) { + logger.error("STT 处理失败", e); + return ""; + } +} +``` + +### 4.4 Grok LLM Stream 调用 + +```java +private void streamLlmAndTts(String userText, Session frontendSession) { + // 分句缓冲器 + StringBuilder sentenceBuffer = new StringBuilder(); + + // 使用 WebClient 调用 Grok SSE API + WebClient client = WebClient.builder() + .baseUrl("https://api.x.ai") + .defaultHeader("Authorization", "Bearer " + grokApiKey) + .build(); + + sendJsonMessage(frontendSession, Map.of("event", "llm_start")); + + client.post() + .uri("/v1/chat/completions") + .contentType(MediaType.APPLICATION_JSON) + .bodyValue(Map.of( + "model", "grok-beta", + "messages", List.of( + Map.of("role", "system", "content", "你是一个友好的AI助手"), + Map.of("role", "user", "content", userText) + ), + "stream", true, + "temperature", 0.7 + )) + .retrieve() + .bodyToFlux(String.class) // SSE 流 + .subscribe( + // onNext: 处理每个 SSE chunk + sseChunk -> { + String token = parseGrokToken(sseChunk); + if (token != null && !token.isEmpty()) { + sentenceBuffer.append(token); + + // 检测句子结束 + if (isSentenceEnd(sentenceBuffer.toString())) { + String sentence = sentenceBuffer.toString().trim(); + sentenceBuffer.setLength(0); // 清空 + + logger.info("检测到完整句子: {}", sentence); + + // 调用 TTS 并流式发送 + streamTtsToFrontend(sentence, frontendSession); + } + } + }, + // onError + error -> { + logger.error("Grok LLM Stream 错误", error); + sendJsonMessage(frontendSession, Map.of("event", "error", "message", "LLM 处理失败")); + }, + // onComplete + () -> { + // 如果还有剩余内容,也发送 + if (sentenceBuffer.length() > 0) { + String lastSentence = sentenceBuffer.toString().trim(); + if (!lastSentence.isEmpty()) { + streamTtsToFrontend(lastSentence, frontendSession); + } + } + + sendJsonMessage(frontendSession, Map.of("event", "llm_complete")); + logger.info("LLM 处理完成"); + } + ); +} +``` + +### 4.5 Grok SSE 解析 + +```java +private String parseGrokToken(String sseChunk) { + // SSE 格式: data: {"choices":[{"delta":{"content":"你好"}}]}\n\n + if (!sseChunk.startsWith("data: ")) { + return null; + } + + String jsonData = sseChunk.substring(6).trim(); + if (jsonData.equals("[DONE]")) { + return null; + } + + try { + JsonObject json = JsonParser.parseString(jsonData).getAsJsonObject(); + JsonArray choices = json.getAsJsonArray("choices"); + if (choices != null && choices.size() > 0) { + JsonObject delta = choices.get(0).getAsJsonObject().getAsJsonObject("delta"); + if (delta != null && delta.has("content")) { + return delta.get("content").getAsString(); + } + } + } catch (Exception e) { + logger.warn("解析 Grok token 失败: {}", sseChunk); + } + + return null; +} +``` + +### 4.6 分句逻辑 + +```java +private boolean isSentenceEnd(String text) { + // 中英文标点检测 + String trimmed = text.trim(); + if (trimmed.isEmpty()) return false; + + char lastChar = trimmed.charAt(trimmed.length() - 1); + + // 中文标点 + if (lastChar == '。' || lastChar == '!' || lastChar == '?' || + lastChar == ';' || lastChar == ',') { + return true; + } + + // 英文标点 + if (lastChar == '.' || lastChar == '!' || lastChar == '?') { + return true; + } + + // 防止句子过长(超过50字强制分句) + if (trimmed.length() > 50) { + return true; + } + + return false; +} +``` + +### 4.7 MiniMax TTS WebSocket Stream + +```java +private void streamTtsToFrontend(String text, Session frontendSession) { + try { + URI uri = new URI("wss://api.minimax.io/ws/v1/t2a_v2"); + + WebSocketClient ttsClient = new WebSocketClient(uri) { + private boolean taskStarted = false; + + @Override + public void onOpen(ServerHandshake handshake) { + logger.info("MiniMax TTS 连接成功"); + } + + @Override + public void onMessage(String message) { + try { + JsonObject msg = JsonParser.parseString(message).getAsJsonObject(); + String event = msg.get("event").getAsString(); + + if ("connected_success".equals(event)) { + // 发送 task_start + send(buildTaskStartJson()); + + } else if ("task_started".equals(event)) { + // 发送 task_continue + taskStarted = true; + send(buildTaskContinueJson(text)); + + } else if (msg.has("data")) { + JsonObject data = msg.getAsJsonObject("data"); + if (data.has("audio")) { + String hexAudio = data.get("audio").getAsString(); + + // 转为字节数组 + byte[] audioBytes = hexToBytes(hexAudio); + + // 立即转发给前端 + frontendSession.getAsyncRemote().sendBinary( + ByteBuffer.wrap(audioBytes) + ); + + logger.debug("转发音频数据: {} bytes", audioBytes.length); + } + + // 检查是否结束 + if (msg.has("is_final") && msg.get("is_final").getAsBoolean()) { + send("{\"event\":\"task_finish\"}"); + close(); + } + } + + } catch (Exception e) { + logger.error("处理 TTS 消息失败", e); + } + } + + @Override + public void onError(Exception ex) { + logger.error("MiniMax TTS 错误", ex); + } + + @Override + public void onClose(int code, String reason, boolean remote) { + logger.info("MiniMax TTS 连接关闭"); + } + }; + + // 添加认证头 + Map headers = new HashMap<>(); + headers.put("Authorization", "Bearer " + minimaxApiKey); + ttsClient.setHttpHeaders(headers); + + // 连接(阻塞等待,最多5秒) + boolean connected = ttsClient.connectBlocking(5, TimeUnit.SECONDS); + if (!connected) { + throw new RuntimeException("连接 MiniMax TTS 超时"); + } + + } catch (Exception e) { + logger.error("调用 MiniMax TTS 失败", e); + sendJsonMessage(frontendSession, Map.of("event", "error", "message", "TTS 处理失败")); + } +} +``` + +### 4.8 工具方法 + +```java +// 构建 task_start JSON +private String buildTaskStartJson() { + return """ + { + "event": "task_start", + "model": "speech-2.6-hd", + "voice_setting": { + "voice_id": "male-qn-qingse", + "speed": 1.0, + "vol": 1.0, + "pitch": 0 + }, + "audio_setting": { + "sample_rate": 32000, + "bitrate": 128000, + "format": "mp3", + "channel": 1 + } + } + """; +} + +// 构建 task_continue JSON +private String buildTaskContinueJson(String text) { + JsonObject json = new JsonObject(); + json.addProperty("event", "task_continue"); + json.addProperty("text", text); + return json.toString(); +} + +// Hex 转字节数组 +private byte[] hexToBytes(String hex) { + int len = hex.length(); + byte[] data = new byte[len / 2]; + for (int i = 0; i < len; i += 2) { + data[i / 2] = (byte) ((Character.digit(hex.charAt(i), 16) << 4) + + Character.digit(hex.charAt(i+1), 16)); + } + return data; +} + +// 发送 JSON 消息给前端 +private void sendJsonMessage(Session session, Map data) { + try { + String json = new Gson().toJson(data); + session.getAsyncRemote().sendText(json); + } catch (Exception e) { + logger.error("发送消息失败", e); + } +} +``` + +--- + +## 5. Maven 依赖配置 + +```xml + + + org.springframework.boot + spring-boot-starter-websocket + + + + + org.springframework.boot + spring-boot-starter-webflux + + + + + org.java-websocket + Java-WebSocket + 1.5.3 + + + + + com.google.code.gson + gson + +``` + +--- + +## 6. 关键技术难点与解决方案 + +### 6.1 并发控制 + +**问题**:多个句子的 TTS 如何保证顺序? + +**解决方案**: +- 使用 **信号量 (Semaphore)** 或 **串行队列** +- 一个句子的 TTS 完全结束后,才开始下一个句子 + +```java +private final Semaphore ttsSemaphore = new Semaphore(1); + +private void streamTtsToFrontend(String text, Session frontendSession) { + try { + ttsSemaphore.acquire(); // 获取锁 + // ... TTS 处理 ... + } finally { + ttsSemaphore.release(); // 释放锁 + } +} +``` + +### 6.2 Grok SSE 连接稳定性 + +**问题**:SSE 长连接可能中断。 + +**解决方案**: +- 设置超时和重试机制 +- 使用 `Flux.retry(3)` 自动重试 + +### 6.3 前端音频播放卡顿 + +**问题**:网络抖动导致音频断续。 + +**解决方案**: +- 实现 **Jitter Buffer**(缓冲 3-5 个音频块再开始播放) +- 检测队列长度,动态调整播放速率 + +### 6.4 音频格式兼容性 + +**问题**:小程序对 PCM 格式支持不佳。 + +**解决方案**: +- MiniMax 配置输出 `mp3` 格式(压缩率高,兼容性好) +- 前端直接播放 MP3 无需解码 + +--- + +## 7. 性能指标预估 + +| 阶段 | 预估耗时 | 说明 | +|---|---|---| +| **前端 VAD** | 实时 | 用户说话期间持续检测 | +| **音频上传** | 50-200ms | 取决于网络和文件大小 | +| **STT 处理** | 300-800ms | Vosk 本地处理较快 | +| **LLM 首 Token** | 500-1000ms | Grok 响应速度 | +| **TTS 首块音频** | 200-500ms | MiniMax WebSocket 延迟 | +| **首次播放延迟** | **1-2秒** | 用户说完到听到回复 | + +--- + +## 8. 测试建议 + +### 8.1 单元测试 + +- STT 服务独立测试 +- Grok API 调用测试(Mock SSE 流) +- MiniMax TTS 调用测试(Mock WebSocket) + +### 8.2 集成测试 + +- 完整语音流测试(录音 -> STT -> LLM -> TTS -> 播放) +- 并发测试(多用户同时对话) +- 异常场景(网络中断、API 超时) + +### 8.3 压力测试 + +- 模拟 100 并发用户 +- 监控服务器 CPU、内存、网络 I/O +- 检查 WebSocket 连接泄漏 + +--- + +## 9. 优化方向 + +### 短期优化 +1. **缓存 TTS 结果**:相同文本不重复合成 +2. **自适应分句**:根据网络状况动态调整句子长度 +3. **优雅降级**:API 失败时使用备用服务 + +### 长期优化 +1. **边缘计算**:STT 迁移到设备端(Whisper.cpp) +2. **模型本地化**:部署私有 LLM 和 TTS +3. **多模态融合**:支持图片、视频输入 + +--- + +## 10. 总结 + +本方案通过 **WebSocket 全双工通信** 实现了高效的语音流式交互: + +✅ **前端简单**:保留现有 VAD,只需改造上传和播放逻辑 +✅ **后端高效**:LLM+TTS 流式串联,极低延迟 +✅ **用户体验**:说完话 1 秒内听到回复,接近真人对话 +✅ **技术成熟**:Grok、MiniMax 官方支持流式 API + +最终效果:**从"录音-等待-播放"进化为"流式对话"**,用户感知延迟降低 **60-80%**。🚀 + diff --git a/docs/import asyncio b/docs/import asyncio new file mode 100644 index 0000000..556345b --- /dev/null +++ b/docs/import asyncio @@ -0,0 +1,183 @@ +import asyncio +import websockets +import json +import ssl +import subprocess +import os + +model = "speech-2.6-hd" +file_format = "mp3" + +class StreamAudioPlayer: + def __init__(self): + self.mpv_process = None + + def start_mpv(self): + """Start MPV player process""" + try: + mpv_command = ["mpv", "--no-cache", "--no-terminal", "--", "fd://0"] + self.mpv_process = subprocess.Popen( + mpv_command, + stdin=subprocess.PIPE, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + print("MPV player started") + return True + except FileNotFoundError: + print("Error: mpv not found. Please install mpv") + return False + except Exception as e: + print(f"Failed to start mpv: {e}") + return False + + def play_audio_chunk(self, hex_audio): + """Play audio chunk""" + try: + if self.mpv_process and self.mpv_process.stdin: + audio_bytes = bytes.fromhex(hex_audio) + self.mpv_process.stdin.write(audio_bytes) + self.mpv_process.stdin.flush() + return True + except Exception as e: + print(f"Play failed: {e}") + return False + return False + + def stop(self): + """Stop player""" + if self.mpv_process: + if self.mpv_process.stdin and not self.mpv_process.stdin.closed: + self.mpv_process.stdin.close() + try: + self.mpv_process.wait(timeout=20) + except subprocess.TimeoutExpired: + self.mpv_process.terminate() + +async def establish_connection(api_key): + """Establish WebSocket connection""" + url = "wss://api.minimax.io/ws/v1/t2a_v2" + headers = {"Authorization": f"Bearer {api_key}"} + + ssl_context = ssl.create_default_context() + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE + + try: + ws = await websockets.connect(url, additional_headers=headers, ssl=ssl_context) + connected = json.loads(await ws.recv()) + if connected.get("event") == "connected_success": + print("Connection successful") + return ws + return None + except Exception as e: + print(f"Connection failed: {e}") + return None + +async def start_task(websocket): + """Send task start request""" + start_msg = { + "event": "task_start", + "model": model, + "voice_setting": { + "voice_id": "male-qn-qingse", + "speed": 1, + "vol": 1, + "pitch": 0, + "english_normalization": False + }, + "audio_setting": { + "sample_rate": 32000, + "bitrate": 128000, + "format": file_format, + "channel": 1 + } + } + await websocket.send(json.dumps(start_msg)) + response = json.loads(await websocket.recv()) + return response.get("event") == "task_started" + +async def continue_task_with_stream_play(websocket, text, player): + """Send continue request and stream play audio""" + await websocket.send(json.dumps({ + "event": "task_continue", + "text": text + })) + + chunk_counter = 1 + total_audio_size = 0 + audio_data = b"" + + while True: + try: + response = json.loads(await websocket.recv()) + + if "data" in response and "audio" in response["data"]: + audio = response["data"]["audio"] + if audio: + print(f"Playing chunk #{chunk_counter}") + audio_bytes = bytes.fromhex(audio) + if player.play_audio_chunk(audio): + total_audio_size += len(audio_bytes) + audio_data += audio_bytes + chunk_counter += 1 + + if response.get("is_final"): + print(f"Audio synthesis completed: {chunk_counter-1} chunks") + if player.mpv_process and player.mpv_process.stdin: + player.mpv_process.stdin.close() + + # Save audio to file + with open(f"output.{file_format}", "wb") as f: + f.write(audio_data) + print(f"Audio saved as output.{file_format}") + + estimated_duration = total_audio_size * 0.0625 / 1000 + wait_time = max(estimated_duration + 5, 10) + return wait_time + + except Exception as e: + print(f"Error: {e}") + break + + return 10 + +async def close_connection(websocket): + """Close connection""" + if websocket: + try: + await websocket.send(json.dumps({"event": "task_finish"})) + await websocket.close() + except Exception: + pass + +async def main(): + API_KEY = os.getenv("MINIMAX_API_KEY") + TEXT = "The real danger is not that computers start thinking like people, but that people start thinking like computers. Computers can only help us with simple tasks." + + player = StreamAudioPlayer() + + try: + if not player.start_mpv(): + return + + ws = await establish_connection(API_KEY) + if not ws: + return + + if not await start_task(ws): + print("Task startup failed") + return + + wait_time = await continue_task_with_stream_play(ws, TEXT, player) + await asyncio.sleep(wait_time) + + except Exception as e: + print(f"Error: {e}") + finally: + player.stop() + if 'ws' in locals(): + await close_connection(ws) + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/pom.xml b/pom.xml index 2408e1a..49c4245 100644 --- a/pom.xml +++ b/pom.xml @@ -41,6 +41,19 @@ spring-boot-starter-websocket + + + org.springframework.boot + spring-boot-starter-webflux + + + + + org.java-websocket + Java-WebSocket + 1.5.7 + + org.springframework.boot spring-boot-starter-aop diff --git a/src/main/java/com/xiaozhi/communication/server/websocket/VoiceStreamHandler.java b/src/main/java/com/xiaozhi/communication/server/websocket/VoiceStreamHandler.java new file mode 100644 index 0000000..2397232 --- /dev/null +++ b/src/main/java/com/xiaozhi/communication/server/websocket/VoiceStreamHandler.java @@ -0,0 +1,327 @@ +package com.xiaozhi.communication.server.websocket; + +import com.google.gson.Gson; +import com.xiaozhi.service.VoiceStreamService; +import jakarta.annotation.Resource; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.stereotype.Component; +import org.springframework.web.socket.BinaryMessage; +import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.handler.AbstractWebSocketHandler; + +import java.io.IOException; +import java.net.URI; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +/** + * 语音流式对话 WebSocket 处理器 + */ +@Component +public class VoiceStreamHandler extends AbstractWebSocketHandler { + private static final Logger logger = LoggerFactory.getLogger(VoiceStreamHandler.class); + private static final Gson gson = new Gson(); + + // 保存所有活跃的会话 + private final Map sessions = new ConcurrentHashMap<>(); + + @Resource + private VoiceStreamService voiceStreamService; + + @Override + public void afterConnectionEstablished(WebSocketSession session) { + String sessionId = session.getId(); + + // 从请求头或查询参数获取用户认证信息 + Map params = getParamsFromSession(session); + String token = params.get("token"); + String userId = params.get("userId"); + + // 获取聊天会话相关参数(sessionId 用于历史记录查询和保存) + String chatSessionId = params.get("sessionId"); + String templateId = params.get("templateId"); + + logger.info("语音流WebSocket连接建立 - SessionId: {}, UserId: {}, ChatSessionId: {}, TemplateId: {}", + sessionId, userId, chatSessionId, templateId); + + // 保存会话和相关参数到session attributes + sessions.put(sessionId, session); + if (chatSessionId != null) { + session.getAttributes().put("chatSessionId", chatSessionId); + } + if (templateId != null) { + session.getAttributes().put("templateId", Integer.parseInt(templateId)); + } + + // 发送连接成功消息 + sendTextMessage(session, createMessage("connected", "连接成功", Map.of("sessionId", sessionId))); + + // 预热TTS连接 + voiceStreamService.warmupTtsConnection(sessionId) + .thenAccept(v -> { + logger.info("TTS连接预热成功 - SessionId: {}", sessionId); + }) + .exceptionally(ex -> { + logger.error("TTS连接预热失败 - SessionId: {}", sessionId, ex); + // 预热失败不影响主流程,降级为按需创建 + return null; + }); + } + + @Override + protected void handleBinaryMessage(WebSocketSession session, BinaryMessage message) { + String sessionId = session.getId(); + byte[] audioData = message.getPayload().array(); + + logger.debug("收到音频数据 - SessionId: {}, Size: {} bytes", sessionId, audioData.length); + + // 从session attributes获取聊天会话相关参数 + String chatSessionId = (String) session.getAttributes().get("chatSessionId"); + Integer templateId = (Integer) session.getAttributes().get("templateId"); + + try { + // 判断是否使用带历史记录的处理方法 + if (chatSessionId != null) { + // 调用带历史记录的语音流服务 + voiceStreamService.processAudioStreamWithHistory( + sessionId, audioData, chatSessionId, templateId, + new VoiceStreamService.StreamCallback() { + @Override + public void onSttResult(String text) { + // STT识别结果 + sendTextMessage(session, createMessage("stt_result", text, null)); + } + + @Override + public void onLlmToken(String token) { + // LLM输出token(可选发送给前端显示) + sendTextMessage(session, createMessage("llm_token", token, null)); + } + + @Override + public void onSentenceComplete(String sentence) { + // 完整句子(发送给前端显示) + sendTextMessage(session, createMessage("sentence", sentence, null)); + } + + @Override + public void onAudioChunk(byte[] audioChunk) { + // TTS音频数据 + sendBinaryMessage(session, audioChunk); + } + + @Override + public void onComplete() { + // 所有处理完成 + sendTextMessage(session, createMessage("complete", "对话完成", null)); + } + + @Override + public void onError(String error) { + // 错误处理 + sendTextMessage(session, createMessage("error", error, null)); + } + }); + } else { + // 调用原有的不带历史记录的方法(向后兼容) + voiceStreamService.processAudioStream(sessionId, audioData, new VoiceStreamService.StreamCallback() { + @Override + public void onSttResult(String text) { + // STT识别结果 + sendTextMessage(session, createMessage("stt_result", text, null)); + } + + @Override + public void onLlmToken(String token) { + // LLM输出token(可选发送给前端显示) + sendTextMessage(session, createMessage("llm_token", token, null)); + } + + @Override + public void onSentenceComplete(String sentence) { + // 完整句子(发送给前端显示) + sendTextMessage(session, createMessage("sentence", sentence, null)); + } + + @Override + public void onAudioChunk(byte[] audioChunk) { + // TTS音频数据 + sendBinaryMessage(session, audioChunk); + } + + @Override + public void onComplete() { + // 所有处理完成 + sendTextMessage(session, createMessage("complete", "对话完成", null)); + } + + @Override + public void onError(String error) { + // 错误处理 + sendTextMessage(session, createMessage("error", error, null)); + } + }); + } + + } catch (Exception e) { + logger.error("处理音频流失败 - SessionId: {}", sessionId, e); + sendTextMessage(session, createMessage("error", "处理音频失败: " + e.getMessage(), null)); + } + } + + @Override + protected void handleTextMessage(WebSocketSession session, TextMessage message) { + String sessionId = session.getId(); + String payload = message.getPayload(); + + logger.debug("收到文本消息 - SessionId: {}, Message: {}", sessionId, payload); + + try { + @SuppressWarnings("unchecked") + Map msgMap = gson.fromJson(payload, Map.class); + String type = (String) msgMap.get("type"); + + if ("ping".equals(type)) { + // 心跳响应 + sendTextMessage(session, createMessage("pong", "pong", null)); + } else if ("cancel".equals(type)) { + // 取消当前对话(打断) + voiceStreamService.cancelStream(sessionId); + sendTextMessage(session, createMessage("cancelled", "已取消", null)); + } + + } catch (Exception e) { + logger.error("处理文本消息失败 - SessionId: {}", sessionId, e); + } + } + + @Override + public void afterConnectionClosed(WebSocketSession session, CloseStatus status) { + String sessionId = session.getId(); + sessions.remove(sessionId); + + // 取消该会话的所有处理 + voiceStreamService.cancelStream(sessionId); + + // 清理TTS连接 + voiceStreamService.closeTtsConnection(sessionId); + + logger.info("语音流WebSocket连接关闭 - SessionId: {}, Status: {}", sessionId, status); + } + + @Override + public void handleTransportError(WebSocketSession session, Throwable exception) { + String sessionId = session.getId(); + + if (isClientCloseRequest(exception)) { + logger.info("WebSocket连接被客户端主动关闭 - SessionId: {}", sessionId); + } else { + logger.error("WebSocket传输错误 - SessionId: {}", sessionId, exception); + } + + // 清理会话 + sessions.remove(sessionId); + voiceStreamService.cancelStream(sessionId); + } + + /** + * 判断异常是否由客户端主动关闭连接导致 + */ + private boolean isClientCloseRequest(Throwable exception) { + if (exception instanceof IOException) { + String message = exception.getMessage(); + if (message != null) { + return message.contains("Connection reset by peer") || + message.contains("Broken pipe") || + message.contains("Connection closed") || + message.contains("远程主机强迫关闭了一个现有的连接"); + } + return exception instanceof java.io.EOFException; + } + return false; + } + + /** + * 从会话中获取参数(从query string或header) + */ + private Map getParamsFromSession(WebSocketSession session) { + Map params = new HashMap<>(); + + // 从header获取 + String token = session.getHandshakeHeaders().getFirst("Authorization"); + if (token != null) { + params.put("token", token.replace("Bearer ", "")); + } + + String userId = session.getHandshakeHeaders().getFirst("User-Id"); + if (userId != null) { + params.put("userId", userId); + } + + // 从URI query参数获取 + URI uri = session.getUri(); + if (uri != null && uri.getQuery() != null) { + String query = uri.getQuery(); + for (String param : query.split("&")) { + String[] kv = param.split("=", 2); + if (kv.length == 2) { + params.put(kv[0], kv[1]); + } + } + } + + return params; + } + + /** + * 创建JSON格式的消息 + */ + private String createMessage(String type, String message, Map data) { + Map msg = new HashMap<>(); + msg.put("type", type); + msg.put("message", message); + msg.put("timestamp", System.currentTimeMillis()); + if (data != null) { + msg.put("data", data); + } + return gson.toJson(msg); + } + + /** + * 发送文本消息 + */ + private void sendTextMessage(WebSocketSession session, String message) { + try { + if (session.isOpen()) { + session.sendMessage(new TextMessage(message)); + } + } catch (Exception e) { + logger.error("发送文本消息失败 - SessionId: {}", session.getId(), e); + } + } + + /** + * 发送二进制消息 + */ + private void sendBinaryMessage(WebSocketSession session, byte[] data) { + try { + if (session.isOpen()) { + session.sendMessage(new BinaryMessage(data)); + } + } catch (Exception e) { + logger.error("发送二进制消息失败 - SessionId: {}", session.getId(), e); + } + } + + /** + * 获取指定会话 + */ + public WebSocketSession getSession(String sessionId) { + return sessions.get(sessionId); + } +} + diff --git a/src/main/java/com/xiaozhi/communication/server/websocket/WebSocketConfig.java b/src/main/java/com/xiaozhi/communication/server/websocket/WebSocketConfig.java index 9827ab2..0547998 100644 --- a/src/main/java/com/xiaozhi/communication/server/websocket/WebSocketConfig.java +++ b/src/main/java/com/xiaozhi/communication/server/websocket/WebSocketConfig.java @@ -21,18 +21,27 @@ public class WebSocketConfig implements WebSocketConfigurer { // 定义为public static以便其他类可以访问 public static final String WS_PATH = "/ws/xiaozhi/v1/"; + public static final String VOICE_STREAM_PATH = "/ws/voice-stream"; @Resource private WebSocketHandler webSocketHandler; + @Resource + private VoiceStreamHandler voiceStreamHandler; + @Resource private CmsUtils cmsUtils; @Override public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) { + // 设备WebSocket端点 registry.addHandler(webSocketHandler, WS_PATH) .setAllowedOrigins("*"); + // 语音流式对话WebSocket端点 + registry.addHandler(voiceStreamHandler, VOICE_STREAM_PATH) + .setAllowedOrigins("*"); + logger.info("📡 WebSocket服务地址: {}", cmsUtils.getWebsocketAddress()); logger.info("=========================================================="); } diff --git a/src/main/java/com/xiaozhi/config/VoiceStreamConfig.java b/src/main/java/com/xiaozhi/config/VoiceStreamConfig.java new file mode 100644 index 0000000..c0effd7 --- /dev/null +++ b/src/main/java/com/xiaozhi/config/VoiceStreamConfig.java @@ -0,0 +1,89 @@ +package com.xiaozhi.config; + +import lombok.Data; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.context.annotation.Configuration; + +/** + * 语音流式对话配置类 + */ +@Configuration +@ConfigurationProperties(prefix = "xiaozhi.voice-stream") +@Data +public class VoiceStreamConfig { + + private GrokConfig grok = new GrokConfig(); + private MinimaxConfig minimax = new MinimaxConfig(); + + @Data + public static class GrokConfig { + /** + * Grok API Key + */ + private String apiKey; + + /** + * Grok API URL + */ + private String apiUrl; + + /** + * Grok 模型名称 + */ + private String model; + } + + @Data + public static class MinimaxConfig { + /** + * MiniMax API Key + */ + private String apiKey; + + /** + * MiniMax Group ID + */ + private String groupId; + + /** + * MiniMax WebSocket URL + */ + private String wsUrl; + + /** + * MiniMax TTS 模型名称 + */ + private String model; + + /** + * 音色ID + */ + private String voiceId; + + /** + * 语速 (0.5-2.0) + */ + private Double speed; + + /** + * 音量 (0.1-10.0) + */ + private Double vol; + + /** + * 音调 (-12到12) + */ + private Integer pitch; + + /** + * 音频采样率 + */ + private Integer audioSampleRate; + + /** + * 比特率 + */ + private Integer bitrate; + } +} + diff --git a/src/main/java/com/xiaozhi/controller/ChatController.java b/src/main/java/com/xiaozhi/controller/ChatController.java index 06a8110..d175e32 100644 --- a/src/main/java/com/xiaozhi/controller/ChatController.java +++ b/src/main/java/com/xiaozhi/controller/ChatController.java @@ -1,11 +1,13 @@ package com.xiaozhi.controller; +import com.xiaozhi.dialogue.stt.SttService; import com.xiaozhi.dto.ChatRequest; import com.xiaozhi.dto.TtsRequest; import com.xiaozhi.dto.VoiceChatRequest; import com.xiaozhi.dto.AudioUploadRequest; import com.xiaozhi.dto.ChatResponse; import com.xiaozhi.dto.ApiResponse; +import com.xiaozhi.entity.SysConfig; import com.xiaozhi.service.ChatSessionService; import com.xiaozhi.service.TtsIntegrationService; import com.xiaozhi.config.ChatConstants; @@ -182,33 +184,6 @@ public class ChatController { ChatConstants.ERROR_INTERNAL + ": " + e.getMessage())); } } - - /** - * 语音对话接口 - * @param request 语音对话请求 - * @return 包含识别文本、回复文本和音频Base64的响应 - */ - @Operation(summary = "语音对话", description = "接收音频数据进行语音识别、对话生成和语音合成的完整流程。必填参数:audioData;可选参数:sessionId、useFunctionCall、modelId、templateId、sttConfigId、ttsConfigId") - @PostMapping("/voice-chat") - public ResponseEntity>> voiceChat( - @Parameter(description = "语音对话请求参数", required = true, - content = @Content(examples = @ExampleObject(value = "{\"audioData\":\"UklGRiQAAABXQVZFZm10IBAAAAABAAEA...\",\"sessionId\":\"session123\",\"useFunctionCall\":false,\"modelId\":7,\"templateId\":1,\"sttConfigId\":9,\"ttsConfigId\":8}"))) - @RequestBody VoiceChatRequest request) { - try { - logger.info("收到语音对话请求"); - - Map result = chatSessionService.voiceChat(request); - - logger.info("语音对话响应成功"); - return ResponseEntity.ok(ApiResponse.success(ChatConstants.VOICE_CHAT_SUCCESS_MESSAGE, result)); - } catch (Exception e) { - logger.error("语音对话处理失败", e); - return ResponseEntity.internalServerError() - .body(ApiResponse.error(ChatConstants.INTERNAL_ERROR_CODE, - ChatConstants.ERROR_INTERNAL + ": " + e.getMessage())); - } - } - /** * 语音对话接口(multipart/form-data 版本) * 为兼容客户端以表单方式上传音频数据,支持多个常见字段名。 @@ -273,6 +248,10 @@ public class ChatController { } } + @Resource + private com.xiaozhi.dialogue.stt.factory.SttServiceFactory sttServiceFactory; + + /** * 清除会话缓存 * @param sessionId 会话ID diff --git a/src/main/java/com/xiaozhi/dialogue/llm/GrokStreamService.java b/src/main/java/com/xiaozhi/dialogue/llm/GrokStreamService.java new file mode 100644 index 0000000..9bde670 --- /dev/null +++ b/src/main/java/com/xiaozhi/dialogue/llm/GrokStreamService.java @@ -0,0 +1,336 @@ +package com.xiaozhi.dialogue.llm; + +import com.google.gson.Gson; +import com.google.gson.JsonObject; +import com.google.gson.JsonParser; +import com.xiaozhi.config.VoiceStreamConfig; +import jakarta.annotation.Resource; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.stereotype.Service; +import org.springframework.web.reactive.function.client.WebClient; +import reactor.core.publisher.Flux; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Grok LLM 流式服务 + * + * 支持通过 Server-Sent Events (SSE) 方式流式调用 Grok API + * + * API文档: https://docs.x.ai/api + * + * 响应格式: + * data: {"id":"xxx","object":"chat.completion.chunk","created":xxx, + * "model":"grok-4-1-fast-non-reasoning", + * "choices":[{"index":0,"delta":{"content":"token","role":"assistant"}}], + * "system_fingerprint":"xxx"} + * data: [DONE] + */ +@Service +public class GrokStreamService { + private static final Logger logger = LoggerFactory.getLogger(GrokStreamService.class); + private static final Gson gson = new Gson(); + + // SSE数据前缀 + private static final String SSE_DATA_PREFIX = "data: "; + // SSE流结束标记 + private static final String SSE_DONE_FLAG = "[DONE]"; + // 默认温度参数 + private static final double DEFAULT_TEMPERATURE = 0.7; + // 默认最大token数(如果经常被截断,可以增加这个值) + private static final int DEFAULT_MAX_TOKENS = 4000; + + @Resource + private VoiceStreamConfig voiceStreamConfig; + + private WebClient webClient; + + /** + * 流式调用 Grok API + * + * @param userMessage 用户消息 + * @param systemPrompt 系统提示词(可选) + * @param callback 接收token的回调 + */ + public void streamChat(String userMessage, String systemPrompt, TokenCallback callback) { + try { + // 参数验证 + if (userMessage == null || userMessage.trim().isEmpty()) { + logger.warn("用户消息为空,跳过API调用"); + callback.onError("用户消息不能为空"); + return; + } + + if (webClient == null) { + initWebClient(); + } + + // 构建消息列表 + List> messages = new ArrayList<>(); + if (systemPrompt != null && !systemPrompt.isEmpty()) { + messages.add(Map.of("role", "system", "content", systemPrompt)); + } + messages.add(Map.of("role", "user", "content", userMessage)); + + // 构建请求体 + Map requestBody = new HashMap<>(); + requestBody.put("model", voiceStreamConfig.getGrok().getModel()); + requestBody.put("messages", messages); + requestBody.put("stream", true); + requestBody.put("temperature", DEFAULT_TEMPERATURE); + requestBody.put("max_tokens", DEFAULT_MAX_TOKENS); // 限制最大token数 + + String requestJson = gson.toJson(requestBody); + logger.info("开始调用Grok API - Model: {}, UserMessage: {}", + voiceStreamConfig.getGrok().getModel(), + userMessage.length() > 50 ? userMessage.substring(0, 50) + "..." : userMessage); + logger.debug("请求体: {}", requestJson); + + // 发起流式请求 + webClient.post() + .uri(voiceStreamConfig.getGrok().getApiUrl()) + .header("Authorization", "Bearer " + voiceStreamConfig.getGrok().getApiKey()) + .header("Content-Type", "application/json") + .bodyValue(requestJson) + .retrieve() + .bodyToFlux(String.class) + .flatMap(this::parseSSE) + .doOnNext(token -> { + if (token != null && !token.isEmpty()) { + callback.onToken(token); + } + }) + .doOnComplete(() -> { + logger.info("Grok API流式调用完成"); + callback.onComplete(); + }) + .doOnError(error -> { + String errorMsg = "LLM调用失败: " + error.getMessage(); + logger.error(errorMsg, error); + callback.onError(errorMsg); + }) + .subscribe(); + + } catch (Exception e) { + String errorMsg = "调用Grok API时发生异常: " + e.getMessage(); + logger.error(errorMsg, e); + callback.onError(errorMsg); + } + } + + /** + * 流式调用 Grok API(带历史记录) + * + * @param messages 完整的消息列表(包含系统提示、历史对话和当前用户消息) + * @param callback 接收token的回调 + */ + public void streamChatWithHistory(List> messages, TokenCallback callback) { + try { + // 参数验证 + if (messages == null || messages.isEmpty()) { + logger.warn("消息列表为空,跳过API调用"); + callback.onError("消息列表不能为空"); + return; + } + + if (webClient == null) { + initWebClient(); + } + + // 构建请求体 + Map requestBody = new HashMap<>(); + requestBody.put("model", voiceStreamConfig.getGrok().getModel()); + requestBody.put("messages", messages); + requestBody.put("stream", true); + requestBody.put("temperature", DEFAULT_TEMPERATURE); + requestBody.put("max_tokens", DEFAULT_MAX_TOKENS); // 限制最大token数 + + String requestJson = gson.toJson(requestBody); + logger.info("开始调用Grok API(带历史记录) - Model: {}, 消息数量: {}", + voiceStreamConfig.getGrok().getModel(), messages.size()); + logger.debug("请求体: {}", requestJson); + + // 发起流式请求 + webClient.post() + .uri(voiceStreamConfig.getGrok().getApiUrl()) + .header("Authorization", "Bearer " + voiceStreamConfig.getGrok().getApiKey()) + .header("Content-Type", "application/json") + .bodyValue(requestJson) + .retrieve() + .bodyToFlux(String.class) + .flatMap(this::parseSSE) + .doOnNext(token -> { + if (token != null && !token.isEmpty()) { + callback.onToken(token); + } + }) + .doOnComplete(() -> { + logger.info("Grok API流式调用完成(带历史记录)"); + callback.onComplete(); + }) + .doOnError(error -> { + String errorMsg = "LLM调用失败: " + error.getMessage(); + logger.error(errorMsg, error); + callback.onError(errorMsg); + }) + .subscribe(); + + } catch (Exception e) { + String errorMsg = "调用Grok API时发生异常: " + e.getMessage(); + logger.error(errorMsg, e); + callback.onError(errorMsg); + } + } + + /** + * 解析SSE格式的响应 + * + * 数据格式示例(WebClient已自动去除 "data: " 前缀): + * {"id":"xxx","object":"chat.completion.chunk","created":xxx,"model":"grok-4-1-fast-non-reasoning", + * "choices":[{"index":0,"delta":{"content":"你","role":"assistant"}}],"system_fingerprint":"xxx"} + */ + private Flux parseSSE(String line) { + return Flux.create(sink -> { + try { + // 跳过空行 + if (line == null || line.trim().isEmpty()) { + sink.complete(); + return; + } + + String data = line.trim(); + + // 记录原始数据(仅在trace级别) + logger.trace("收到原始SSE数据: {}", data.length() > 200 ? data.substring(0, 200) + "..." : data); + + // 跳过 [DONE] 标记(流结束标志) + // 可能是 "data: [DONE]" 或直接 "[DONE]" + if (SSE_DONE_FLAG.equals(data) || data.equals(SSE_DATA_PREFIX + SSE_DONE_FLAG)) { + logger.debug("收到流结束标记"); + sink.complete(); + return; + } + + // 如果有 "data: " 前缀,去掉它 + if (data.startsWith(SSE_DATA_PREFIX)) { + data = data.substring(SSE_DATA_PREFIX.length()).trim(); + } + + // 再次检查是否是 [DONE] + if (SSE_DONE_FLAG.equals(data)) { + logger.debug("收到流结束标记"); + sink.complete(); + return; + } + + // 解析JSON(现在data直接就是JSON字符串) + JsonObject json = JsonParser.parseString(data).getAsJsonObject(); + + // 验证基本字段 + if (!json.has("choices")) { + logger.debug("JSON中缺少choices字段"); + sink.complete(); + return; + } + + // 获取choices数组 + var choices = json.getAsJsonArray("choices"); + if (choices.isEmpty()) { + logger.debug("choices数组为空"); + sink.complete(); + return; + } + + // 获取第一个choice + var choice = choices.get(0).getAsJsonObject(); + + // 检查是否有delta字段 + if (!choice.has("delta")) { + logger.debug("choice中缺少delta字段"); + sink.complete(); + return; + } + + var delta = choice.getAsJsonObject("delta"); + + // 提取content内容 + if (delta.has("content")) { + String content = delta.get("content").getAsString(); + + // 只有非空内容才发送 + if (content != null && !content.isEmpty()) { + logger.trace("提取到token: {}", content); + sink.next(content); + } + } + + // 如果有role字段,可以记录日志(首次消息会包含role) + if (delta.has("role")) { + String role = delta.get("role").getAsString(); + logger.debug("收到角色信息: {}", role); + } + + // 检查finish_reason(流结束原因) + if (choice.has("finish_reason") && !choice.get("finish_reason").isJsonNull()) { + String finishReason = choice.get("finish_reason").getAsString(); + logger.info("流结束原因: {}", finishReason); + // finish_reason 可能的值: "stop"(正常结束), "length"(达到最大长度), "content_filter"(内容过滤) + } + + sink.complete(); + + } catch (com.google.gson.JsonSyntaxException e) { + logger.error("JSON解析失败 - 原始数据: {}", line, e); + sink.complete(); + } catch (Exception e) { + logger.error("解析SSE数据时发生未知错误 - 原始数据: {}", line, e); + sink.complete(); + } + }); + } + + /** + * 初始化WebClient + */ + private void initWebClient() { + String apiUrl = voiceStreamConfig.getGrok().getApiUrl(); + if (apiUrl == null || apiUrl.trim().isEmpty()) { + throw new IllegalStateException("Grok API URL未配置"); + } + + logger.info("初始化Grok WebClient - URL: {}", apiUrl); + + this.webClient = WebClient.builder() + .baseUrl(apiUrl) + // 增加缓冲区大小,支持更大的响应 + .codecs(configurer -> configurer.defaultCodecs().maxInMemorySize(10 * 1024 * 1024)) // 10MB + .build(); + + logger.info("Grok WebClient初始化完成"); + } + + /** + * Token回调接口 + */ + public interface TokenCallback { + /** + * 接收到新的token + */ + void onToken(String token); + + /** + * 流式输出完成 + */ + void onComplete(); + + /** + * 发生错误 + */ + void onError(String error); + } +} + diff --git a/src/main/java/com/xiaozhi/dialogue/llm/SentenceBufferService.java b/src/main/java/com/xiaozhi/dialogue/llm/SentenceBufferService.java new file mode 100644 index 0000000..401275d --- /dev/null +++ b/src/main/java/com/xiaozhi/dialogue/llm/SentenceBufferService.java @@ -0,0 +1,137 @@ +package com.xiaozhi.dialogue.llm; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.stereotype.Service; + +import java.util.regex.Pattern; + +/** + * 分句缓冲服务 + * 将LLM流式输出的token缓冲并按句子分割 + */ +@Service +public class SentenceBufferService { + private static final Logger logger = LoggerFactory.getLogger(SentenceBufferService.class); + + // 中英文句子结束标点符号 + private static final Pattern SENTENCE_END_PATTERN = Pattern.compile("[。!?!?;;]"); + + // 强制分句的最大长度 + private static final int MAX_SENTENCE_LENGTH = 50; + + /** + * 创建一个新的句子缓冲器 + */ + public SentenceBuffer createBuffer(SentenceCallback callback) { + return new SentenceBuffer(callback); + } + + /** + * 句子缓冲器 + */ + public class SentenceBuffer { + private final StringBuilder buffer = new StringBuilder(); + private final SentenceCallback callback; + + public SentenceBuffer(SentenceCallback callback) { + this.callback = callback; + } + + /** + * 添加token到缓冲区 + */ + public synchronized void addToken(String token) { + buffer.append(token); + + // 检查是否有完整的句子 + String currentBuffer = buffer.toString(); + int lastSentenceEnd = findLastSentenceEnd(currentBuffer); + + if (lastSentenceEnd > 0) { + // 找到句子结束符,提取完整句子 + String sentence = currentBuffer.substring(0, lastSentenceEnd + 1).trim(); + if (!sentence.isEmpty()) { + logger.debug("检测到完整句子: {}", sentence); + callback.onSentence(sentence); + } + + // 保留剩余部分 + buffer.setLength(0); + if (lastSentenceEnd + 1 < currentBuffer.length()) { + buffer.append(currentBuffer.substring(lastSentenceEnd + 1)); + } + } else if (buffer.length() > MAX_SENTENCE_LENGTH) { + // 超过最大长度,强制分句 + String sentence = buffer.toString().trim(); + if (!sentence.isEmpty()) { + logger.debug("强制分句(长度超限): {}", sentence); + callback.onSentence(sentence); + } + buffer.setLength(0); + } + } + + /** + * 完成输入,输出剩余内容 + */ + public synchronized void finish() { + if (buffer.length() > 0) { + String sentence = buffer.toString().trim(); + if (!sentence.isEmpty()) { + logger.debug("输出最后剩余内容: {}", sentence); + callback.onSentence(sentence); + } + buffer.setLength(0); + } + callback.onComplete(); + } + + /** + * 清空缓冲区 + */ + public synchronized void clear() { + buffer.setLength(0); + } + + /** + * 查找最后一个句子结束符的位置 + */ + private int findLastSentenceEnd(String text) { + int lastPos = -1; + for (int i = text.length() - 1; i >= 0; i--) { + char ch = text.charAt(i); + if (isSentenceEndChar(ch)) { + lastPos = i; + break; + } + } + return lastPos; + } + + /** + * 判断是否为句子结束符 + */ + private boolean isSentenceEndChar(char ch) { + return ch == '。' || ch == '!' || ch == '?' || + ch == '.' || ch == '!' || ch == '?' || + ch == ';' || ch == ';'; + } + } + + /** + * 句子回调接口 + */ + public interface SentenceCallback { + /** + * 检测到完整句子 + */ + void onSentence(String sentence); + + /** + * 所有句子处理完成 + */ + void onComplete(); + } +} + diff --git a/src/main/java/com/xiaozhi/dialogue/stt/factory/SttServiceFactory.java b/src/main/java/com/xiaozhi/dialogue/stt/factory/SttServiceFactory.java index f994ca3..4edc476 100644 --- a/src/main/java/com/xiaozhi/dialogue/stt/factory/SttServiceFactory.java +++ b/src/main/java/com/xiaozhi/dialogue/stt/factory/SttServiceFactory.java @@ -30,7 +30,7 @@ public class SttServiceFactory { private final Map serviceCache = new ConcurrentHashMap<>(); // 默认服务提供商名称 - private static final String DEFAULT_PROVIDER = "vosk"; + private static final String DEFAULT_PROVIDER = "aliyun"; // 标记Vosk是否初始化成功 private boolean voskInitialized = false; diff --git a/src/main/java/com/xiaozhi/dialogue/stt/providers/AliyunSttService.java b/src/main/java/com/xiaozhi/dialogue/stt/providers/AliyunSttService.java index 9544c5f..21b537c 100644 --- a/src/main/java/com/xiaozhi/dialogue/stt/providers/AliyunSttService.java +++ b/src/main/java/com/xiaozhi/dialogue/stt/providers/AliyunSttService.java @@ -40,7 +40,7 @@ public class AliyunSttService implements SttService { // private final String modelName; public AliyunSttService(SysConfig config) { - this.apiKey = config.getApiKey(); + this.apiKey = "sk-f4c91752a2604845b2956265863f941d"; // // 从配置中获取模型名称,如果没有配置则使用默认值 // this.modelName = (config.getConfigName() != null && !config.getConfigName().trim().isEmpty()) // ? config.getConfigName().trim() diff --git a/src/main/java/com/xiaozhi/dialogue/tts/MinimaxTtsStreamService.java b/src/main/java/com/xiaozhi/dialogue/tts/MinimaxTtsStreamService.java new file mode 100644 index 0000000..7f8e41b --- /dev/null +++ b/src/main/java/com/xiaozhi/dialogue/tts/MinimaxTtsStreamService.java @@ -0,0 +1,548 @@ +package com.xiaozhi.dialogue.tts; + +import com.google.gson.Gson; +import com.xiaozhi.config.VoiceStreamConfig; +import com.xiaozhi.dialogue.tts.TtsConnectionState.ConnectionStatus; +import jakarta.annotation.Resource; +import org.java_websocket.client.WebSocketClient; +import org.java_websocket.handshake.ServerHandshake; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.stereotype.Service; + +import java.net.URI; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.*; +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * MiniMax TTS 流式服务 + * 基于WebSocket实现流式文本转语音,支持连接预热和复用 + */ +@Service +public class MinimaxTtsStreamService { + private static final Logger logger = LoggerFactory.getLogger(MinimaxTtsStreamService.class); + private static final Gson gson = new Gson(); + private static final int CONNECTION_TIMEOUT_SECONDS = 5; + + @Resource + private VoiceStreamConfig voiceStreamConfig; + + // 管理活跃的WebSocket连接和状态 + private final Map connections = new ConcurrentHashMap<>(); + + /** + * 会话TTS连接封装 + */ + private static class SessionTtsConnection { + final MinimaxTtsClient client; + final TtsConnectionState state; + final Semaphore processingLock = new Semaphore(1); // 确保串行处理 + + SessionTtsConnection(MinimaxTtsClient client, TtsConnectionState state) { + this.client = client; + this.state = state; + } + } + + /** + * 预热TTS连接 + * 建立WebSocket连接并发送task_start,等待连接就绪 + * + * @param sessionId 会话ID + * @return CompletableFuture 连接就绪后完成 + */ + public CompletableFuture warmupConnection(String sessionId) { + long startTime = System.currentTimeMillis(); + logger.info("开始预热TTS连接 - SessionId: {}", sessionId); + + CompletableFuture warmupFuture = new CompletableFuture<>(); + + try { + // 检查是否已有连接 + SessionTtsConnection existing = connections.get(sessionId); + if (existing != null && existing.state.isConnected()) { + logger.info("TTS连接已存在,跳过预热 - SessionId: {}, 状态: {}", + sessionId, existing.state.getStatus()); + warmupFuture.complete(null); + return warmupFuture; + } + + // 创建新连接 + String wsUrl = voiceStreamConfig.getMinimax().getWsUrl(); + TtsConnectionState state = new TtsConnectionState(); + state.setStatus(ConnectionStatus.CONNECTING); + + MinimaxTtsClient client = new MinimaxTtsClient( + new URI(wsUrl), + sessionId, + state, + warmupFuture + ); + + SessionTtsConnection connection = new SessionTtsConnection(client, state); + connections.put(sessionId, connection); + + // 异步连接 + CompletableFuture.runAsync(() -> { + try { + boolean connected = client.connectBlocking(CONNECTION_TIMEOUT_SECONDS, TimeUnit.SECONDS); + if (!connected) { + throw new TimeoutException("连接超时"); + } + long duration = System.currentTimeMillis() - startTime; + logger.info("TTS连接预热成功 - SessionId: {}, 耗时: {}ms", sessionId, duration); + } catch (Exception e) { + logger.error("TTS连接预热失败 - SessionId: {}", sessionId, e); + state.setStatus(ConnectionStatus.DISCONNECTED); + connections.remove(sessionId); + warmupFuture.completeExceptionally(e); + } + }); + + } catch (Exception e) { + logger.error("创建TTS连接失败 - SessionId: {}", sessionId, e); + warmupFuture.completeExceptionally(e); + } + + return warmupFuture; + } + + /** + * 流式合成语音(复用现有连接) + * + * @param sessionId 会话ID + * @param text 要合成的文本 + * @param callback 音频数据回调 + */ + public CompletableFuture streamTts(String sessionId, String text, AudioCallback callback) { + long startTime = System.currentTimeMillis(); + CompletableFuture future = new CompletableFuture<>(); + + try { + SessionTtsConnection connection = connections.get(sessionId); + + // 检查连接是否存在且就绪 + if (connection == null || !connection.state.isReady()) { + logger.warn("TTS连接不存在或未就绪 - SessionId: {}, 状态: {}", + sessionId, connection != null ? connection.state.getStatus() : "NULL"); + + // 尝试建立新连接 + return warmupConnection(sessionId) + .thenCompose(v -> streamTts(sessionId, text, callback)); + } + + // 获取处理锁,确保串行处理 + connection.processingLock.acquire(); + + try { + logger.info("使用已有TTS连接 - SessionId: {}, Text: {}", sessionId, text); + connection.state.setStatus(ConnectionStatus.PROCESSING); + + // 发送文本进行TTS + connection.client.sendText(text, callback, future); + + // 等待TTS完成后释放锁 + future.whenComplete((result, error) -> { + connection.processingLock.release(); + if (error == null) { + long duration = System.currentTimeMillis() - startTime; + logger.info("TTS完成 - SessionId: {}, 耗时: {}ms(复用连接节省约1秒)", + sessionId, duration); + connection.state.setStatus(ConnectionStatus.IDLE); + } else { + logger.error("TTS失败 - SessionId: {}", sessionId, error); + connection.state.setStatus(ConnectionStatus.IDLE); + } + }); + + } catch (Exception e) { + connection.processingLock.release(); + throw e; + } + + } catch (Exception e) { + logger.error("TTS处理失败 - SessionId: {}", sessionId, e); + callback.onError("TTS处理失败: " + e.getMessage()); + future.completeExceptionally(e); + } + + return future; + } + + /** + * 关闭指定会话的TTS连接 + */ + public void closeConnection(String sessionId) { + SessionTtsConnection connection = connections.remove(sessionId); + if (connection != null) { + try { + connection.client.closeConnection(); + logger.info("关闭TTS连接 - SessionId: {}", sessionId); + } catch (Exception e) { + logger.error("关闭TTS连接失败 - SessionId: {}", sessionId, e); + } + } + } + + /** + * 取消指定会话的TTS(兼容旧接口) + */ + public void cancelTts(String sessionId) { + closeConnection(sessionId); + } + + /** + * MiniMax TTS WebSocket 客户端(支持连接复用) + */ + private class MinimaxTtsClient extends WebSocketClient { + private final String sessionId; + private final TtsConnectionState state; + private final CompletableFuture warmupFuture; + + // 当前任务相关 + private volatile String currentText; + private volatile AudioCallback currentCallback; + private volatile CompletableFuture currentTaskFuture; + private final AtomicBoolean taskProcessing = new AtomicBoolean(false); + + // 音频缓冲器 - 缓冲整句音频 + private java.io.ByteArrayOutputStream audioBuffer; + + public MinimaxTtsClient(URI serverUri, String sessionId, TtsConnectionState state, + CompletableFuture warmupFuture) { + super(serverUri); + this.sessionId = sessionId; + this.state = state; + this.warmupFuture = warmupFuture; + + // 添加Authorization header + this.addHeader("Authorization", "Bearer " + voiceStreamConfig.getMinimax().getApiKey()); + } + + /** + * 发送文本进行TTS合成 + * 注意:调用此方法前,外层应该已经检查过状态并设置为PROCESSING + */ + public void sendText(String text, AudioCallback callback, CompletableFuture taskFuture) { + this.currentText = text; + this.currentCallback = callback; + this.currentTaskFuture = taskFuture; + this.audioBuffer = new java.io.ByteArrayOutputStream(); + this.taskProcessing.set(true); + + // 直接发送task_continue + sendTaskContinue(text); + } + + /** + * 关闭连接 + */ + public void closeConnection() { + try { + if (isOpen()) { + sendTaskFinish(); + } + close(); + state.setStatus(ConnectionStatus.DISCONNECTED); + } catch (Exception e) { + logger.error("关闭连接失败", e); + } + } + + @Override + public void onOpen(ServerHandshake handshake) { + logger.debug("MiniMax TTS连接已建立 - SessionId: {}", sessionId); + state.setStatus(ConnectionStatus.CONNECTED); + // 连接建立后,等待收到 connected_success 事件再发送 task_start + } + + @Override + public void onMessage(String message) { + try { + logger.debug("收到消息 - SessionId: {}: {}", sessionId, message); + + @SuppressWarnings("unchecked") + Map response = gson.fromJson(message, Map.class); + String event = (String) response.get("event"); + + if ("connected_success".equals(event)) { + // 连接成功,发送 task_start(不包含text) + logger.info("收到connected_success,发送task_start - SessionId: {}", sessionId); + sendTaskStart(); + + } else if ("task_started".equals(event)) { + // 任务启动成功,连接就绪 + logger.info("收到task_started,连接就绪 - SessionId: {}", sessionId); + state.setStatus(ConnectionStatus.TASK_STARTED); + state.resetReconnectAttempts(); // 连接成功,重置重连计数 + + // 完成预热future + if (warmupFuture != null && !warmupFuture.isDone()) { + warmupFuture.complete(null); + } + + // 如果是复用连接后立即发送的task_continue,状态会自动设为IDLE + if (!taskProcessing.get()) { + state.setStatus(ConnectionStatus.IDLE); + } + + } else if ("task_failed".equals(event)) { + // 任务失败 + @SuppressWarnings("unchecked") + Map baseResp = (Map) response.get("base_resp"); + String errorMsg = baseResp != null ? (String) baseResp.get("status_msg") : "unknown error"; + logger.error("TTS任务失败 - SessionId: {}: {}", sessionId, errorMsg); + + if (currentCallback != null) { + currentCallback.onError("TTS任务失败: " + errorMsg); + } + if (currentTaskFuture != null && !currentTaskFuture.isDone()) { + currentTaskFuture.completeExceptionally(new RuntimeException(errorMsg)); + } + + taskProcessing.set(false); + state.setStatus(ConnectionStatus.IDLE); + + } else if (response.containsKey("data")) { + // 音频数据 + @SuppressWarnings("unchecked") + Map data = (Map) response.get("data"); + if (data != null && data.containsKey("audio")) { + String hexAudio = (String) data.get("audio"); + if (hexAudio != null && !hexAudio.isEmpty() && audioBuffer != null) { + // 将hex字符串转为bytes并写入缓冲区 + byte[] audioBytes = hexStringToByteArray(hexAudio); + try { + audioBuffer.write(audioBytes); + logger.debug("缓冲音频块 - SessionId: {}: {} bytes, 总计: {} bytes", + sessionId, audioBytes.length, audioBuffer.size()); + } catch (java.io.IOException e) { + logger.error("写入音频缓冲区失败 - SessionId: {}", sessionId, e); + } + } + } + + // 检查是否完成 + Boolean isFinal = (Boolean) response.get("is_final"); + if (Boolean.TRUE.equals(isFinal) && taskProcessing.compareAndSet(true, false)) { + logger.debug("TTS任务完成(is_final=true)- SessionId: {}", sessionId); + + // 一次性发送完整的音频数据 + if (audioBuffer != null && currentCallback != null) { + byte[] completeAudio = audioBuffer.toByteArray(); + logger.info("TTS完成 - SessionId: {}, Text: {}, 音频总大小: {} bytes", + sessionId, currentText, completeAudio.length); + + if (completeAudio.length > 0) { + currentCallback.onAudioChunk(completeAudio); + } + currentCallback.onComplete(); + } + + // 完成当前任务future + if (currentTaskFuture != null && !currentTaskFuture.isDone()) { + currentTaskFuture.complete(null); + } + + // 不关闭连接,设置为IDLE状态以便复用 + state.setStatus(ConnectionStatus.IDLE); + + // 清理当前任务 + currentText = null; + currentCallback = null; + currentTaskFuture = null; + audioBuffer = null; + } + } + + } catch (Exception e) { + logger.error("处理TTS消息失败 - SessionId: {}", sessionId, e); + if (currentCallback != null) { + currentCallback.onError("处理TTS响应失败: " + e.getMessage()); + } + } + } + + /** + * 发送 task_start 消息(不包含text) + */ + private void sendTaskStart() { + try { + Map message = new HashMap<>(); + message.put("event", "task_start"); + message.put("model", voiceStreamConfig.getMinimax().getModel()); + + // voice_setting + Map voiceSetting = new HashMap<>(); + voiceSetting.put("voice_id", voiceStreamConfig.getMinimax().getVoiceId()); + voiceSetting.put("speed", voiceStreamConfig.getMinimax().getSpeed()); + voiceSetting.put("vol", voiceStreamConfig.getMinimax().getVol()); + voiceSetting.put("pitch", voiceStreamConfig.getMinimax().getPitch()); + message.put("voice_setting", voiceSetting); + + // audio_setting + Map audioSetting = new HashMap<>(); + audioSetting.put("sample_rate", voiceStreamConfig.getMinimax().getAudioSampleRate()); + audioSetting.put("bitrate", voiceStreamConfig.getMinimax().getBitrate()); + audioSetting.put("format", "mp3"); // 使用mp3格式,兼容性更好 + audioSetting.put("channel", 1); + message.put("audio_setting", audioSetting); + + String jsonMessage = gson.toJson(message); + logger.debug("发送task_start - SessionId: {}: {}", sessionId, jsonMessage); + send(jsonMessage); + + } catch (Exception e) { + logger.error("发送task_start失败 - SessionId: {}", sessionId, e); + if (warmupFuture != null && !warmupFuture.isDone()) { + warmupFuture.completeExceptionally(e); + } + } + } + + /** + * 发送 task_continue 消息(包含text) + */ + private void sendTaskContinue(String text) { + try { + Map message = new HashMap<>(); + message.put("event", "task_continue"); + message.put("text", text); + + String jsonMessage = gson.toJson(message); + logger.debug("发送task_continue - SessionId: {}: {}", sessionId, jsonMessage); + send(jsonMessage); + + } catch (Exception e) { + logger.error("发送task_continue失败 - SessionId: {}", sessionId, e); + if (currentCallback != null) { + currentCallback.onError("发送文本失败: " + e.getMessage()); + } + if (currentTaskFuture != null && !currentTaskFuture.isDone()) { + currentTaskFuture.completeExceptionally(e); + } + } + } + + /** + * 发送 task_finish 消息 + */ + private void sendTaskFinish() { + try { + Map message = new HashMap<>(); + message.put("event", "task_finish"); + + String jsonMessage = gson.toJson(message); + logger.debug("发送task_finish - SessionId: {}: {}", sessionId, jsonMessage); + send(jsonMessage); + + } catch (Exception e) { + logger.error("发送task_finish失败 - SessionId: {}", sessionId, e); + } + } + + @Override + public void onClose(int code, String reason, boolean remote) { + logger.warn("MiniMax TTS连接关闭 - SessionId: {}, Code: {}, Reason: {}, Remote: {}", + sessionId, code, reason, remote); + + ConnectionStatus prevStatus = state.getStatus(); + state.setStatus(ConnectionStatus.DISCONNECTED); + + // 判断是否需要重连 + boolean needReconnect = remote && state.canReconnect() && + (prevStatus == ConnectionStatus.CONNECTED || + prevStatus == ConnectionStatus.TASK_STARTED || + prevStatus == ConnectionStatus.IDLE); + + if (needReconnect) { + // 自动重连 + int attempts = state.incrementReconnectAttempts(); + logger.info("尝试自动重连TTS - SessionId: {}, 第{}次重连", sessionId, attempts); + + CompletableFuture.runAsync(() -> { + try { + Thread.sleep(1000 * attempts); // 延迟递增 + warmupConnection(sessionId) + .exceptionally(ex -> { + logger.error("自动重连失败 - SessionId: {}", sessionId, ex); + return null; + }); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + }); + } else { + // 不重连,通知错误 + if (taskProcessing.get()) { + if (currentCallback != null) { + currentCallback.onComplete(); // 标记完成,避免阻塞 + } + if (currentTaskFuture != null && !currentTaskFuture.isDone()) { + currentTaskFuture.complete(null); + } + taskProcessing.set(false); + } + + // 如果是预热阶段失败 + if (warmupFuture != null && !warmupFuture.isDone()) { + warmupFuture.completeExceptionally(new RuntimeException("连接关闭: " + reason)); + } + } + } + + @Override + public void onError(Exception ex) { + logger.error("MiniMax TTS连接错误 - SessionId: {}", sessionId, ex); + + state.setStatus(ConnectionStatus.DISCONNECTED); + + if (currentCallback != null) { + currentCallback.onError("TTS连接错误: " + ex.getMessage()); + } + if (currentTaskFuture != null && !currentTaskFuture.isDone()) { + currentTaskFuture.completeExceptionally(ex); + } + if (warmupFuture != null && !warmupFuture.isDone()) { + warmupFuture.completeExceptionally(ex); + } + } + + /** + * 将hex字符串转为byte数组 + */ + private byte[] hexStringToByteArray(String hex) { + int len = hex.length(); + byte[] data = new byte[len / 2]; + for (int i = 0; i < len; i += 2) { + data[i / 2] = (byte) ((Character.digit(hex.charAt(i), 16) << 4) + + Character.digit(hex.charAt(i + 1), 16)); + } + return data; + } + } + + /** + * 音频数据回调接口 + */ + public interface AudioCallback { + /** + * 接收到音频数据块 + */ + void onAudioChunk(byte[] audioData); + + /** + * TTS完成 + */ + void onComplete(); + + /** + * 发生错误 + */ + void onError(String error); + } +} + + diff --git a/src/main/java/com/xiaozhi/dialogue/tts/TtsConnectionState.java b/src/main/java/com/xiaozhi/dialogue/tts/TtsConnectionState.java new file mode 100644 index 0000000..96dcd2b --- /dev/null +++ b/src/main/java/com/xiaozhi/dialogue/tts/TtsConnectionState.java @@ -0,0 +1,136 @@ +package com.xiaozhi.dialogue.tts; + +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; + +/** + * TTS连接状态管理 + * 管理连接状态、重连计数等 + */ +public class TtsConnectionState { + private static final int MAX_RECONNECT_ATTEMPTS = 3; + + private final AtomicReference status; + private final AtomicInteger reconnectAttempts; + private volatile long lastActiveTime; + + public TtsConnectionState() { + this.status = new AtomicReference<>(ConnectionStatus.DISCONNECTED); + this.reconnectAttempts = new AtomicInteger(0); + this.lastActiveTime = System.currentTimeMillis(); + } + + /** + * 连接状态枚举 + */ + public enum ConnectionStatus { + DISCONNECTED, // 未连接 + CONNECTING, // 连接中 + CONNECTED, // 已连接(未完成task_start) + TASK_STARTED, // task_start已完成,连接就绪 + PROCESSING, // 正在处理TTS任务 + IDLE // 空闲,可接受新任务 + } + + /** + * 获取当前状态 + */ + public ConnectionStatus getStatus() { + return status.get(); + } + + /** + * 设置状态 + */ + public void setStatus(ConnectionStatus newStatus) { + ConnectionStatus oldStatus = status.getAndSet(newStatus); + if (oldStatus != newStatus) { + updateLastActiveTime(); + } + } + + /** + * CAS方式更新状态 + */ + public boolean compareAndSetStatus(ConnectionStatus expect, ConnectionStatus update) { + boolean updated = status.compareAndSet(expect, update); + if (updated) { + updateLastActiveTime(); + } + return updated; + } + + /** + * 检查是否可以处理新任务 + */ + public boolean isReady() { + ConnectionStatus current = status.get(); + return current == ConnectionStatus.IDLE || current == ConnectionStatus.TASK_STARTED; + } + + /** + * 检查是否已连接 + */ + public boolean isConnected() { + ConnectionStatus current = status.get(); + return current != ConnectionStatus.DISCONNECTED && current != ConnectionStatus.CONNECTING; + } + + /** + * 获取重连次数 + */ + public int getReconnectAttempts() { + return reconnectAttempts.get(); + } + + /** + * 增加重连次数 + */ + public int incrementReconnectAttempts() { + return reconnectAttempts.incrementAndGet(); + } + + /** + * 重置重连次数 + */ + public void resetReconnectAttempts() { + reconnectAttempts.set(0); + } + + /** + * 检查是否还可以重连 + */ + public boolean canReconnect() { + return reconnectAttempts.get() < MAX_RECONNECT_ATTEMPTS; + } + + /** + * 更新最后活跃时间 + */ + public void updateLastActiveTime() { + this.lastActiveTime = System.currentTimeMillis(); + } + + /** + * 获取最后活跃时间 + */ + public long getLastActiveTime() { + return lastActiveTime; + } + + /** + * 重置状态 + */ + public void reset() { + status.set(ConnectionStatus.DISCONNECTED); + reconnectAttempts.set(0); + updateLastActiveTime(); + } + + @Override + public String toString() { + return String.format("TtsConnectionState{status=%s, reconnectAttempts=%d, lastActiveTime=%d}", + status.get(), reconnectAttempts.get(), lastActiveTime); + } +} + diff --git a/src/main/java/com/xiaozhi/dialogue/tts/factory/TtsServiceFactory.java b/src/main/java/com/xiaozhi/dialogue/tts/factory/TtsServiceFactory.java index acd1678..359a8e8 100644 --- a/src/main/java/com/xiaozhi/dialogue/tts/factory/TtsServiceFactory.java +++ b/src/main/java/com/xiaozhi/dialogue/tts/factory/TtsServiceFactory.java @@ -86,7 +86,6 @@ public class TtsServiceFactory { case "volcengine" -> new VolcengineTtsService(config, voiceName, outputPath); case "xfyun" -> new XfyunTtsService(config, voiceName, outputPath); case "minimax" -> new MiniMaxTtsService(config, voiceName, outputPath); - case "minimax-ws" -> new MiniMaxTtsWebSocketService(config, voiceName, outputPath); // WebSocket 流式版本 default -> new EdgeTtsService(voiceName, outputPath); }; } diff --git a/src/main/java/com/xiaozhi/service/VoiceStreamService.java b/src/main/java/com/xiaozhi/service/VoiceStreamService.java new file mode 100644 index 0000000..742a713 --- /dev/null +++ b/src/main/java/com/xiaozhi/service/VoiceStreamService.java @@ -0,0 +1,649 @@ +package com.xiaozhi.service; + +import com.xiaozhi.dialogue.llm.GrokStreamService; +import com.xiaozhi.dialogue.llm.SentenceBufferService; +import com.xiaozhi.dialogue.llm.memory.DatabaseChatMemory; +import com.xiaozhi.dialogue.stt.SttService; +import com.xiaozhi.dialogue.stt.factory.SttServiceFactory; +import com.xiaozhi.dialogue.tts.MinimaxTtsStreamService; +import com.xiaozhi.entity.SysConfig; +import com.xiaozhi.entity.SysMessage; +import com.xiaozhi.utils.AudioUtils; +import jakarta.annotation.Resource; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.stereotype.Service; + +import java.io.IOException; +import java.util.*; +import java.util.concurrent.*; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; + +/** + * 语音流式对话核心服务 + * 串联 STT -> LLM Stream -> 分句 -> TTS Stream 流程 + */ +@Service +public class VoiceStreamService { + private static final Logger logger = LoggerFactory.getLogger(VoiceStreamService.class); + + @Resource + private SttServiceFactory sttServiceFactory; + + @Resource + private GrokStreamService grokStreamService; + + @Resource + private SentenceBufferService sentenceBufferService; + + @Resource + private MinimaxTtsStreamService minimaxTtsStreamService; + + @Resource + private SysConfigService configService; + + @Resource + private DatabaseChatMemory chatMemory; + + @Resource + private SysMessageService sysMessageService; + + // 管理每个会话的状态 + private final Map sessions = new ConcurrentHashMap<>(); + + // 线程池用于异步处理 + private final ExecutorService executorService = Executors.newCachedThreadPool(); + + /** + * 处理音频流 + * + * @param sessionId 会话ID + * @param audioData PCM音频数据 + * @param callback 回调接口 + */ + public void processAudioStream(String sessionId, byte[] audioData, StreamCallback callback) { + executorService.submit(() -> { + try { + // 获取或创建会话状态 + SessionState state = sessions.computeIfAbsent(sessionId, k -> new SessionState()); + + // 取消之前的处理(打断机制) + state.cancel(); + state.reset(); + + logger.info("开始处理音频流 - SessionId: {}, AudioSize: {}", sessionId, audioData.length); + + // 1. STT - 语音转文字 + String recognizedText = performStt(audioData); + if (recognizedText == null || recognizedText.trim().isEmpty()) { + callback.onError("语音识别失败或未识别到内容"); + return; + } + + logger.info("STT识别结果 - SessionId: {}, Text: {}", sessionId, recognizedText); + callback.onSttResult(recognizedText); + + // 2. 创建分句缓冲器 + SentenceBufferService.SentenceBuffer sentenceBuffer = + sentenceBufferService.createBuffer(new SentenceBufferService.SentenceCallback() { + + private final Queue sentenceQueue = new LinkedList<>(); + private final Semaphore ttsPermit = new Semaphore(1); // 保证TTS顺序执行 + private final AtomicBoolean processing = new AtomicBoolean(false); + + @Override + public void onSentence(String sentence) { + if (state.isCancelled()) { + logger.info("会话已取消,跳过句子处理 - SessionId: {}", sessionId); + return; + } + + logger.info("检测到完整句子 - SessionId: {}, Sentence: {}", sessionId, sentence); + callback.onSentenceComplete(sentence); + + // 将句子加入队列并异步处理TTS + sentenceQueue.offer(sentence); + processSentenceQueue(); + } + + @Override + public void onComplete() { + logger.info("LLM输出完成 - SessionId: {}", sessionId); + // 等待所有TTS完成 + waitForAllTtsComplete(); + callback.onComplete(); + } + + private void processSentenceQueue() { + if (processing.compareAndSet(false, true)) { + executorService.submit(() -> { + try { + while (!sentenceQueue.isEmpty() && !state.isCancelled()) { + String sentence = sentenceQueue.poll(); + if (sentence != null) { + processSentenceTts(sentence); + } + } + } finally { + processing.set(false); + // 检查是否还有新的句子 + if (!sentenceQueue.isEmpty() && !state.isCancelled()) { + processSentenceQueue(); + } + } + }); + } + } + + private void processSentenceTts(String sentence) { + if (state.isCancelled()) { + return; + } + + try { + ttsPermit.acquire(); + + logger.info("开始TTS合成 - SessionId: {}, Sentence: {}", sessionId, sentence); + + CompletableFuture ttsFuture = minimaxTtsStreamService.streamTts( + sessionId, + sentence, + new MinimaxTtsStreamService.AudioCallback() { + @Override + public void onAudioChunk(byte[] audioData) { + if (!state.isCancelled()) { + callback.onAudioChunk(audioData); + } + } + + @Override + public void onComplete() { + logger.debug("句子TTS完成 - SessionId: {}", sessionId); + } + + @Override + public void onError(String error) { + logger.error("TTS错误 - SessionId: {}, Error: {}", sessionId, error); + callback.onError(error); + } + } + ); + + // 等待TTS完成 + ttsFuture.get(30, TimeUnit.SECONDS); + + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + logger.warn("TTS处理被中断 - SessionId: {}", sessionId); + state.markTtsFailed(); // 标记 TTS 失败 + } catch (Exception e) { + logger.error("TTS处理失败 - SessionId: {}", sessionId, e); + state.markTtsFailed(); // 标记 TTS 失败 + callback.onError("TTS处理失败: " + e.getMessage()); + } finally { + ttsPermit.release(); + } + } + + private void waitForAllTtsComplete() { + try { + // 等待所有TTS完成 + ttsPermit.acquire(); + ttsPermit.release(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + }); + + // 3. LLM流式调用 + String systemPrompt = "你是一个友好的AI助手,请用简洁、自然的语气回答用户的问题。"; + + grokStreamService.streamChat( + recognizedText, + systemPrompt, + new GrokStreamService.TokenCallback() { + @Override + public void onToken(String token) { + if (state.isCancelled()) { + return; + } + callback.onLlmToken(token); + sentenceBuffer.addToken(token); + } + + @Override + public void onComplete() { + if (!state.isCancelled()) { + sentenceBuffer.finish(); + } + } + + @Override + public void onError(String error) { + logger.error("LLM调用失败 - SessionId: {}, Error: {}", sessionId, error); + callback.onError(error); + } + } + ); + + } catch (Exception e) { + logger.error("处理音频流失败 - SessionId: {}", sessionId, e); + callback.onError("处理失败: " + e.getMessage()); + } + }); + } + + /** + * 处理音频流(带历史记录) + * + * @param sessionId WebSocket会话ID + * @param audioData PCM音频数据 + * @param chatSessionId 聊天会话ID(用于查询和保存历史记录) + * @param templateId 模板ID(可选,用于获取角色信息) + * @param callback 回调接口 + */ + public void processAudioStreamWithHistory(String sessionId, byte[] audioData, + String chatSessionId, Integer templateId, + StreamCallback callback) { + executorService.submit(() -> { + try { + // 获取或创建会话状态 + SessionState state = sessions.computeIfAbsent(sessionId, k -> new SessionState()); + + // 取消之前的处理(打断机制) + state.cancel(); + state.reset(); + + logger.info("开始处理音频流(带历史记录) - SessionId: {}, ChatSessionId: {}, TemplateId: {}, AudioSize: {}", + sessionId, chatSessionId, templateId, audioData.length); + + // 使用 chatSessionId 作为 deviceId(保持与文字对话一致) + String deviceId = chatSessionId; + // roleId 可以为 null,数据库允许 + Integer roleId = templateId; + + // 1. STT - 语音转文字 + String recognizedText = performStt(audioData); + if (recognizedText == null || recognizedText.trim().isEmpty()) { + callback.onError("语音识别失败或未识别到内容"); + return; + } + + logger.info("STT识别结果 - SessionId: {}, Text: {}", sessionId, recognizedText); + callback.onSttResult(recognizedText); + + // 2. 异步保存用户消息到数据库(不等待,保留 future 引用) + CompletableFuture userMessageFuture = chatMemory.addMessage( + deviceId, chatSessionId, "user", recognizedText, + roleId, "NORMAL", System.currentTimeMillis()); + logger.info("用户消息开始异步保存 - ChatSessionId: {}", chatSessionId); + + // 3. 查询历史记录(最近20条) + // 等待用户消息保存完成,确保能查询到最新数据 + List historyMessages = new ArrayList<>(); + try { + userMessageFuture.join(); // 等待用户消息保存完成 + logger.info("用户消息保存完成 - ChatSessionId: {}", chatSessionId); + + List allMessages = sysMessageService.queryBySessionId(chatSessionId); + // 获取最近20条消息(不包括刚刚保存的用户消息) + int startIndex = Math.max(0, allMessages.size() - 21); // -21 因为包含了刚保存的用户消息 + historyMessages = allMessages.subList(startIndex, allMessages.size() - 1); // -1 排除最后一条(刚保存的) + logger.info("查询到历史消息 {} 条 - ChatSessionId: {}", historyMessages.size(), chatSessionId); + } catch (Exception e) { + logger.error("保存或查询历史消息失败,将继续处理(无历史上下文) - ChatSessionId: {}", chatSessionId, e); + } + + // 4. 累积LLM完整响应(用于保存到数据库) + AtomicReference fullResponse = new AtomicReference<>(""); + + // 5. 创建分句缓冲器 + SentenceBufferService.SentenceBuffer sentenceBuffer = + sentenceBufferService.createBuffer(new SentenceBufferService.SentenceCallback() { + + private final Queue sentenceQueue = new LinkedList<>(); + private final Semaphore ttsPermit = new Semaphore(1); // 保证TTS顺序执行 + private final AtomicBoolean processing = new AtomicBoolean(false); + + @Override + public void onSentence(String sentence) { + if (state.isCancelled()) { + logger.info("会话已取消,跳过句子处理 - SessionId: {}", sessionId); + return; + } + + logger.info("检测到完整句子 - SessionId: {}, Sentence: {}", sessionId, sentence); + callback.onSentenceComplete(sentence); + + // 将句子加入队列并异步处理TTS + sentenceQueue.offer(sentence); + processSentenceQueue(); + } + + @Override + public void onComplete() { + logger.info("LLM输出完成 - SessionId: {}", sessionId); + + // 等待所有TTS完成 + waitForAllTtsComplete(); + + // TTS 全部成功后才保存 AI 回复(如果未取消且未失败) + String aiResponse = fullResponse.get(); + if (aiResponse != null && !aiResponse.trim().isEmpty() + && !state.isCancelled() && !state.isTtsFailed()) { + + // 使用链式调用确保在用户消息之后保存,并增加 1ms 保证顺序 + userMessageFuture.thenCompose(v -> { + long timestamp = System.currentTimeMillis() + 1; + return chatMemory.addMessage(deviceId, chatSessionId, "assistant", + aiResponse, roleId, "NORMAL", timestamp); + }).thenAccept(v -> { + logger.info("AI响应已保存 - ChatSessionId: {}, Length: {}", + chatSessionId, aiResponse.length()); + }).exceptionally(ex -> { + logger.error("保存AI响应失败 - ChatSessionId: {}", chatSessionId, ex); + return null; + }); + } else if (state.isTtsFailed()) { + logger.warn("TTS失败,不保存AI响应 - ChatSessionId: {}", chatSessionId); + } else if (state.isCancelled()) { + logger.info("会话已取消,不保存AI响应 - ChatSessionId: {}", chatSessionId); + } + + callback.onComplete(); + } + + private void processSentenceQueue() { + if (processing.compareAndSet(false, true)) { + executorService.submit(() -> { + try { + while (!sentenceQueue.isEmpty() && !state.isCancelled()) { + String sentence = sentenceQueue.poll(); + if (sentence != null) { + processSentenceTts(sentence); + } + } + } finally { + processing.set(false); + // 检查是否还有新的句子 + if (!sentenceQueue.isEmpty() && !state.isCancelled()) { + processSentenceQueue(); + } + } + }); + } + } + + private void processSentenceTts(String sentence) { + if (state.isCancelled()) { + return; + } + + try { + ttsPermit.acquire(); + + logger.info("开始TTS合成 - SessionId: {}, Sentence: {}", sessionId, sentence); + + CompletableFuture ttsFuture = minimaxTtsStreamService.streamTts( + sessionId, + sentence, + new MinimaxTtsStreamService.AudioCallback() { + @Override + public void onAudioChunk(byte[] audioData) { + if (!state.isCancelled()) { + callback.onAudioChunk(audioData); + } + } + + @Override + public void onComplete() { + logger.debug("句子TTS完成 - SessionId: {}", sessionId); + } + + @Override + public void onError(String error) { + logger.error("TTS错误 - SessionId: {}, Error: {}", sessionId, error); + callback.onError(error); + } + } + ); + + // 等待TTS完成 + ttsFuture.get(30, TimeUnit.SECONDS); + + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + logger.warn("TTS处理被中断 - SessionId: {}", sessionId); + state.markTtsFailed(); // 标记 TTS 失败 + } catch (Exception e) { + logger.error("TTS处理失败 - SessionId: {}", sessionId, e); + state.markTtsFailed(); // 标记 TTS 失败 + callback.onError("TTS处理失败: " + e.getMessage()); + } finally { + ttsPermit.release(); + } + } + + private void waitForAllTtsComplete() { + try { + // 等待所有TTS完成 + ttsPermit.acquire(); + ttsPermit.release(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + }); + + // 6. 构建包含历史记录的消息列表 + List> messages = buildMessagesWithHistory(historyMessages, recognizedText); + + // 7. LLM流式调用(带历史记录) + grokStreamService.streamChatWithHistory( + messages, + new GrokStreamService.TokenCallback() { + @Override + public void onToken(String token) { + if (state.isCancelled()) { + return; + } + // 累积完整响应 + fullResponse.updateAndGet(current -> current + token); + + callback.onLlmToken(token); + sentenceBuffer.addToken(token); + } + + @Override + public void onComplete() { + if (!state.isCancelled()) { + sentenceBuffer.finish(); + } + } + + @Override + public void onError(String error) { + logger.error("LLM调用失败 - SessionId: {}, Error: {}", sessionId, error); + callback.onError(error); + } + } + ); + + } catch (Exception e) { + logger.error("处理音频流失败(带历史记录) - SessionId: {}", sessionId, e); + callback.onError("处理失败: " + e.getMessage()); + } + }); + } + + /** + * 构建包含历史记录的消息列表 + */ + private List> buildMessagesWithHistory(List historyMessages, + String currentUserMessage) { + List> messages = new ArrayList<>(); + + // 添加系统提示 + Map systemMessage = new HashMap<>(); + systemMessage.put("role", "system"); + systemMessage.put("content", "你是一个友好的AI助手,请用简洁、自然的语气回答用户的问题。"); + messages.add(systemMessage); + + // 添加历史消息 + for (SysMessage msg : historyMessages) { + Map message = new HashMap<>(); + // 将数据库中的 sender 转换为 Grok API 需要的角色 + if ("user".equals(msg.getSender())) { + message.put("role", "user"); + } else if ("assistant".equals(msg.getSender())) { + message.put("role", "assistant"); + } else { + // 跳过不认识的角色 + logger.warn("未知的消息发送方: {}", msg.getSender()); + continue; + } + message.put("content", msg.getMessage()); + messages.add(message); + } + + // 添加当前用户消息 + Map currentMessage = new HashMap<>(); + currentMessage.put("role", "user"); + currentMessage.put("content", currentUserMessage); + messages.add(currentMessage); + + logger.debug("构建消息列表完成,共 {} 条消息(含系统提示和历史)", messages.size()); + return messages; + } + + /** + * 取消指定会话的流处理 + */ + public void cancelStream(String sessionId) { + SessionState state = sessions.get(sessionId); + if (state != null) { + state.cancel(); + // 取消TTS + minimaxTtsStreamService.cancelTts(sessionId); + logger.info("取消流处理 - SessionId: {}", sessionId); + } + } + + /** + * 预热TTS连接 + * + * @param sessionId 会话ID + * @return CompletableFuture 连接就绪后完成 + */ + public CompletableFuture warmupTtsConnection(String sessionId) { + logger.info("预热TTS连接 - SessionId: {}", sessionId); + return minimaxTtsStreamService.warmupConnection(sessionId); + } + + /** + * 关闭TTS连接 + * + * @param sessionId 会话ID + */ + public void closeTtsConnection(String sessionId) { + logger.info("关闭TTS连接 - SessionId: {}", sessionId); + minimaxTtsStreamService.closeConnection(sessionId); + } + + /** + * 执行STT(复用现有逻辑) + */ + private String performStt(byte[] audioData) { + try { + // 标准化输入音频为 PCM 16k 单声道 + byte[] pcmData = AudioUtils.bytesToPcm(audioData); + + if (pcmData == null || pcmData.length == 0) { + throw new RuntimeException("音频数据为空或转换后为空"); + } + + // 获取STT服务(使用默认配置) + SttService sttService = sttServiceFactory.getSttService(null); + if (sttService == null) { + throw new RuntimeException("无法获取STT服务"); + } + + // 执行语音识别 + String recognizedText = sttService.recognition(pcmData); + return recognizedText != null ? recognizedText.trim() : ""; + + } catch (IOException e) { + throw new RuntimeException("音频标准化失败: " + e.getMessage(), e); + } catch (Exception e) { + logger.error("STT处理失败: {}", e.getMessage(), e); + throw new RuntimeException("语音识别失败: " + e.getMessage(), e); + } + } + + /** + * 会话状态 + */ + private static class SessionState { + private final AtomicBoolean cancelled = new AtomicBoolean(false); + private final AtomicBoolean ttsFailed = new AtomicBoolean(false); + + public void cancel() { + cancelled.set(true); + } + + public void reset() { + cancelled.set(false); + ttsFailed.set(false); // 重置时也清除 TTS 失败标记 + } + + public boolean isCancelled() { + return cancelled.get(); + } + + public void markTtsFailed() { + ttsFailed.set(true); + } + + public boolean isTtsFailed() { + return ttsFailed.get(); + } + } + + /** + * 流处理回调接口 + */ + public interface StreamCallback { + /** + * STT识别结果 + */ + void onSttResult(String text); + + /** + * LLM输出token + */ + void onLlmToken(String token); + + /** + * 完整句子 + */ + void onSentenceComplete(String sentence); + + /** + * TTS音频数据块 + */ + void onAudioChunk(byte[] audioChunk); + + /** + * 所有处理完成 + */ + void onComplete(); + + /** + * 发生错误 + */ + void onError(String error); + } +} + diff --git a/src/main/java/com/xiaozhi/service/impl/ChatSessionServiceImpl.java b/src/main/java/com/xiaozhi/service/impl/ChatSessionServiceImpl.java index 8f26ada..17caa12 100644 --- a/src/main/java/com/xiaozhi/service/impl/ChatSessionServiceImpl.java +++ b/src/main/java/com/xiaozhi/service/impl/ChatSessionServiceImpl.java @@ -182,7 +182,12 @@ public class ChatSessionServiceImpl implements ChatSessionService { ChatSession session = getOrCreateSession(sessionId, request.getModelId(), request.getTemplateId()); // 1. STT - 语音转文本 + long sttStartTime = System.currentTimeMillis(); Map sttResult = performStt(request); + long sttEndTime = System.currentTimeMillis(); + long sttDuration = sttEndTime - sttStartTime; + logger.info("STT处理完成,sessionId={}, 耗时={}s", sessionId, sttDuration / 1000.0); + String recognizedText = (String) sttResult.get("text"); if (recognizedText == null || recognizedText.trim().isEmpty()) { @@ -191,6 +196,7 @@ public class ChatSessionServiceImpl implements ChatSessionService { result.put("sttResult", sttResult); result.put("sessionId", sessionId); result.put("timestamp", System.currentTimeMillis()); + result.put("sttDuration", sttDuration); return result; } @@ -202,8 +208,12 @@ public class ChatSessionServiceImpl implements ChatSessionService { chatRequest.setTemplateId(request.getTemplateId()); chatRequest.setUseFunctionCall(request.getUseFunctionCall()); + long llmStartTime = System.currentTimeMillis(); ChatResponse chatResponse = syncChat(chatRequest); - + long llmEndTime = System.currentTimeMillis(); + long llmDuration = llmEndTime - llmStartTime; + logger.info("LLM处理完成,sessionId={}, 耗时={}s", sessionId, llmDuration / 1000.0); + Map llmResult = new HashMap<>(); llmResult.put("response", chatResponse.getResponse()); llmResult.put("inputText", recognizedText); @@ -211,13 +221,17 @@ public class ChatSessionServiceImpl implements ChatSessionService { // 获取会话中角色的voiceName String voiceName = null; ChatSession chatSession = getOrCreateSession(sessionId, request.getModelId(), request.getTemplateId()); - if (chatSession != null && chatSession.getConversation() != null && + if (chatSession != null && chatSession.getConversation() != null && chatSession.getConversation().role() != null) { voiceName = chatSession.getConversation().role().getVoiceName(); } // 3. TTS - 文本转语音 + long ttsStartTime = System.currentTimeMillis(); Map ttsResult = performTts(chatResponse.getResponse(), request.getTtsConfigId(), voiceName); + long ttsEndTime = System.currentTimeMillis(); + long ttsDuration = ttsEndTime - ttsStartTime; + logger.info("TTS处理完成,sessionId={}, 耗时={}s", sessionId, ttsDuration / 1000.0); // 组装完整响应 Map result = new HashMap<>(); @@ -226,6 +240,13 @@ public class ChatSessionServiceImpl implements ChatSessionService { result.put("ttsResult", ttsResult); result.put("sessionId", sessionId); result.put("timestamp", System.currentTimeMillis()); + result.put("sttDuration", sttDuration); + result.put("llmDuration", llmDuration); + result.put("ttsDuration", ttsDuration); + + logger.info("语音对话完成,sessionId={}, STT耗时={}s, LLM耗时={}s, TTS耗时={}s, 总耗时={}s", + sessionId, sttDuration / 1000.0, llmDuration / 1000.0, ttsDuration / 1000.0, + (sttDuration + llmDuration + ttsDuration) / 1000.0); return result; diff --git a/src/main/resources/application.yml b/src/main/resources/application.yml index 40eaf3f..fb79b9f 100644 --- a/src/main/resources/application.yml +++ b/src/main/resources/application.yml @@ -311,4 +311,23 @@ openapi: contact-email: xiaozhi@qq.com version: 1.0 external-description: xiaozhi API Docs - external-url: https://github.com/joey-zhou/xiaozhi-esp32-server-java \ No newline at end of file + external-url: https://github.com/joey-zhou/xiaozhi-esp32-server-java + +# 语音流式对话配置 +xiaozhi: + voice-stream: + grok: + api-key: xai-KKU4O5WumrQowiPc2qSF223L7Nw2XAwDOgrBxQRtrhKAbAri7JBDJuiv6HNwBcvNTnO026YPUeijwGqq + api-url: https://caddy.liqupan.cn/v1/chat/completions + model: grok-4-1-fast-non-reasoning + minimax: + api-key: eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJHcm91cE5hbWUiOiJ2YmlvZGJkcCIsIlVzZXJOYW1lIjoidHNldCIsIkFjY291bnQiOiIiLCJTdWJqZWN0SUQiOiIxOTkyOTAyNTAzMzg5MjA1NDY3IiwiUGhvbmUiOiIiLCJHcm91cElEIjoiMTk5MjkwMjUwMzM4MDgyMDk1NSIsIlBhZ2VOYW1lIjoiIiwiTWFpbCI6InZiaW9kYmRwQGdtYWlsLmNvbSIsIkNyZWF0ZVRpbWUiOiIyMDI1LTEyLTA2IDE1OjQzOjUxIiwiVG9rZW5UeXBlIjoxLCJpc3MiOiJtaW5pbWF4In0.hf1M4cPe27Sz_QeSyYODqM6yrN8aQ68nRwYB7iQ3uO5nu0NSN7qHQRVxAt2tVuoOf503SEx5F-PfYyC85OFJFhWNNhhDuFuxPIz97LVz1oQUlIejZ_BmCMj4iWwGXTUmEugGK1lzcsI6eJz8eRjQHsxOgJJmxPLXWHTPs1gDqtnckAgjOBRQJSadP58Xe9EdI6n-2_SL_ni3Tqm3LuWq9tUPJa5WgDMZX9IDK7XXyZy0i1GoSXmp8P1O1JmIecBVUoCzyYFwWW787BNdYiyEV3UrFjC_4onJ8Tzh-eGq84-rtxBR5FKO2MpNU_I0xI-W3YJxOEl_JPXXGgX5ASTKNw + group-id: ${MINIMAX_GROUP_ID:your-group-id} + ws-url: wss://api.minimax.io/ws/v1/t2a_v2 + model: speech-2.6-hd + voice-id: Chinese (Mandarin)_BashfulGirl + speed: 1.0 + vol: 1.0 + pitch: 0 + audio-sample-rate: 32000 + bitrate: 128000 \ No newline at end of file