com.alibaba.nls.client.protocol.tts.StreamInputTts Maven / Gradle / Ivy
/*
* Copyright 2015 Alibaba Group Holding Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.nls.client.protocol.tts;
import com.alibaba.nls.client.protocol.*;
import com.alibaba.nls.client.transport.Connection;
import com.alibaba.nls.client.util.IdGen;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.HashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import static com.alibaba.nls.client.protocol.SpeechReqProtocol.State.*;
/**
* @author lengjiayi.ljy
* @date 2024/04/22
*
* 实时语音合成器,用于设置及发送合成请求,处理合成结果回调
* 非线程安全
*/
public class StreamInputTts extends SpeechReqProtocol {
private static final Integer DEFAULT_SAMPLE_RATE = 24000;
private static final Integer DEFAULT_VOICE_VOLUME = 50;
private static final String DEFAULT_FORMAT = "pcm";
static Logger logger = LoggerFactory.getLogger(StreamInputTts.class);
protected String currentSessionId;
protected StreamInputTtsListener StreamTTSListener;
protected long lastSendTime = -1;
protected long minSendIntervalMS = 100;
private CountDownLatch completeLatch;
private CountDownLatch readyLatch;
public StreamInputTts(NlsClient client, StreamInputTtsListener listener) throws Exception {
this.conn = client.connect(listener);
afterConnection(listener);
}
public StreamInputTts(NlsClient client, String token, StreamInputTtsListener listener) throws Exception {
Connection conn = client.connect(token, listener);
this.conn = conn;
afterConnection(listener);
}
public StreamInputTtsListener getStreamTTSListener() {
return this.StreamTTSListener;
}
public String getCurrentSessionId() {
return this.currentSessionId;
}
protected void afterConnection(StreamInputTtsListener listener) {
payload = new HashMap();
header.put(Constant.PROP_NAMESPACE, TTSConstant.VALUE_NAMESPACE_STREAM_REALTIME_TTS);
header.put(Constant.PROP_NAME, TTSConstant.VALUE_NAME_TTS_START);
payload.put(TTSConstant.PROP_TTS_FORMAT, DEFAULT_FORMAT);
payload.put(TTSConstant.PROP_TTS_SAMPLE_RATE, DEFAULT_SAMPLE_RATE);
payload.put(TTSConstant.PROP_TTS_VOLUME, DEFAULT_VOICE_VOLUME);
listener.setStreamInputTts(this);
StreamTTSListener = listener;
state = STATE_CONNECTED;
}
/**
* 发音人
*
* @param voice
*/
public void setVoice(String voice) {
payload.put(TTSConstant.PROP_TTS_VOICE, voice);
}
/**
* 合成语音的编码格式,支持的格式:pcm,wav,mp3 默认是pcm
*
* @param format
*/
public void setFormat(OutputFormatEnum format) {
payload.put(Constant.PROP_ASR_FORMAT, format.getName());
}
/**
* 合成语音的采样率,可选默认16000
*
* @param sampleRate
*/
public void setSampleRate(SampleRateEnum sampleRate) {
payload.put(TTSConstant.PROP_TTS_SAMPLE_RATE, sampleRate.value);
}
/**
* 个别场景需要设置除'SampleRateEnum'之外的采样率,如32000,采用此方法进行设置
*
* @param sampleRate
*/
public void setSampleRate(int sampleRate) {
payload.put(TTSConstant.PROP_TTS_SAMPLE_RATE, sampleRate);
}
/**
* 音量,范围是0~100,可选,默认50
*
* @param volume
*/
public void setVolume(int volume) {
payload.put(TTSConstant.PROP_TTS_VOLUME, volume);
}
/**
* 语速,范围是-500~500,可选,默认是0
*
* @param speechRate
*/
public void setSpeechRate(int speechRate) {
payload.put(TTSConstant.PROP_TTS_SPEECH_RATE, speechRate);
}
/**
* 语调,范围是-500~500,可选,默认是0
*
* @param pitchRate
*/
public void setPitchRate(int pitchRate) {
payload.put(TTSConstant.PROP_TTS_PITCH_RATE, pitchRate);
}
/**
* 设置连续两次发送文本(调用send函数)的最小时间间隔(毫秒)
*
* @param sendInterval 单位毫秒
*/
public void setMinSendIntervalMS(long sendInterval) {
this.minSendIntervalMS = sendInterval;
}
/**
* 发送文本。
* 如果当前调用send时距离上次调用时间小于minSendIntervalMS,则会阻塞并等待直到满足条件再发送文本
*
* @param text 需要合成的文本
*/
public void sendStreamInputTts(String text) {
long msNeedToSleep = minSendIntervalMS - (System.currentTimeMillis() - lastSendTime);
if (lastSendTime != -1 && msNeedToSleep > 0) {
logger.info("too short send interval, sleep {} million second", msNeedToSleep);
try {
Thread.sleep(msNeedToSleep);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
state.checkSend();
try {
SpeechReqProtocol req = new SpeechReqProtocol();
req.header.put(Constant.PROP_TASK_ID, currentTaskId);
req.header.put(Constant.PROP_NAMESPACE, TTSConstant.VALUE_NAMESPACE_STREAM_REALTIME_TTS);
req.header.put(Constant.PROP_NAME, TTSConstant.VALUE_NAME_RUN_SYNTHESIS);
req.setAppKey(getAppKey());
req.payload = new HashMap();
req.payload.put(TTSConstant.PROP_TTS_TEXT, text);
conn.sendText(req.serialize());
lastSendTime = System.currentTimeMillis();
} catch (Exception e) {
logger.error("fail to send text, current_task_id:{},state:{}", currentTaskId, state, e);
throw new RuntimeException(e);
}
}
/**
* 服务端准备好了进行语音合成
*/
void markSynthesisReady() {
state = STATE_REQUEST_CONFIRMED;
if (readyLatch != null) {
readyLatch.countDown();
}
}
/**
* 服务端停止了语音转写
*/
void markSynthesisComplete() {
state = STATE_COMPLETE;
if (completeLatch != null) {
completeLatch.countDown();
}
}
/**
* 服务端返回错误
*/
void markFail() {
state = STATE_FAIL;
if (readyLatch != null) {
readyLatch.countDown();
}
if (completeLatch != null) {
completeLatch.countDown();
}
}
/**
* 内部调用方法
*/
void markClosed() {
state = STATE_CLOSED;
if (readyLatch != null) {
readyLatch.countDown();
}
if (completeLatch != null) {
completeLatch.countDown();
}
}
/**
* 开始语音转写:发送语音转写请求,同步接收服务端确认
*
* @throws Exception
*/
public void startStreamInputTts() throws Exception {
startStreamInputTts(Constant.DEFAULT_START_TIMEOUT_MILLISECONDS);
}
/**
* 开始语音合成:发送语音合成请求,同步接收服务端确认, 超时未返回则抛出异常
*
* @throws Exception
*/
public void startStreamInputTts(long milliSeconds) throws Exception {
String sessionId = IdGen.genId();
logger.info("start, gen session id: {}", sessionId);
currentSessionId = sessionId;
header.put(TTSConstant.PROP_SESSION_ID, currentSessionId);
super.start();
completeLatch = new CountDownLatch(1);
readyLatch = new CountDownLatch(1);
boolean result = readyLatch.await(milliSeconds, TimeUnit.MILLISECONDS);
if (!result) {
String msg = String.format("timeout after %d ms waiting for start confirmation.task_id:%s,state:%s",
milliSeconds, currentTaskId, state);
logger.error(msg);
throw new Exception(msg);
}
}
/**
* 结束语音合成: 发送结束识别通知,接收服务端确认
*
* @throws Exception
*/
/**
* 结束语音合成: 发送结束识别通知,接收服务端确认, 超时未返回则抛出异常
*
* @throws Exception
*/
public void stopStreamInputTts() throws Exception {
state.checkStop();
SpeechReqProtocol req = new SpeechReqProtocol();
req.header.put(Constant.PROP_TASK_ID, currentTaskId);
req.header.put(Constant.PROP_NAMESPACE, TTSConstant.VALUE_NAMESPACE_STREAM_REALTIME_TTS);
req.header.put(Constant.PROP_NAME, TTSConstant.VALUE_NAME_TTS_STOP);
req.setAppKey(getAppKey());
conn.sendText(req.serialize());
state = STATE_STOP_SENT;
completeLatch.await();
if (state == STATE_FAIL) {
String msg = String.format("timeout after %d ms waiting for complete confirmation.task_id:%s,state:%s",
10000, currentTaskId, state);
logger.error(msg);
throw new Exception(msg);
}
}
/**
* 关闭连接
*/
public void close() {
conn.close();
}
}