org.daisy.pipeline.tts.calabash.impl.TextToPcmThread Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of tts-common Show documentation
Show all versions of tts-common Show documentation
Common API for TTS functionality
The newest version!
package org.daisy.pipeline.tts.calabash.impl;
import java.math.BigDecimal;
import java.math.MathContext;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ConcurrentLinkedQueue;
import javax.sound.sampled.AudioFormat;
import javax.sound.sampled.AudioInputStream;
import net.sf.saxon.s9api.Axis;
import net.sf.saxon.s9api.QName;
import net.sf.saxon.s9api.XdmNode;
import net.sf.saxon.s9api.XdmNodeKind;
import net.sf.saxon.s9api.XdmSequenceIterator;
import org.daisy.common.messaging.MessageAppender;
import org.daisy.common.messaging.MessageBuilder;
import org.daisy.pipeline.audio.AudioUtils;
import org.daisy.pipeline.tts.AudioFootprintMonitor;
import org.daisy.pipeline.tts.AudioFootprintMonitor.MemoryException;
import org.daisy.pipeline.tts.SSMLMarkSplitter;
import org.daisy.pipeline.tts.Sentence;
import org.daisy.pipeline.tts.SSMLMarkSplitter.Chunk;
import org.daisy.pipeline.tts.TTSEngine;
import org.daisy.pipeline.tts.TTSLog;
import org.daisy.pipeline.tts.TTSEngine.SynthesisResult;
import org.daisy.pipeline.tts.TTSLog.ErrorCode;
import org.daisy.pipeline.tts.TTSRegistry;
import org.daisy.pipeline.tts.TTSRegistry.TTSResource;
import org.daisy.pipeline.tts.TTSService.SynthesisException;
import org.daisy.pipeline.tts.TTSTimeout;
import org.daisy.pipeline.tts.TimedTTSExecutor;
import org.daisy.pipeline.tts.TTSTimeout.ThreadFreeInterrupter;
import org.daisy.pipeline.tts.TimedTTSExecutor.TimeoutException;
import org.daisy.pipeline.tts.Voice;
import org.daisy.pipeline.tts.Voice.MarkSupport;
import org.daisy.pipeline.tts.VoiceManager;
import org.slf4j.Logger;
import com.google.common.collect.Iterables;
/**
* TextToPcmThread consumes text from a shared queue. It produces PCM data as
* output, which in turn are pushed to another shared queue consumed by the
* EncodingThreads. PCM is produced by calling TTS processors.
*
* TTS processors may fail for some reasons, e.g. after a timeout, or if marks
* are missing. In such cases, TextToPcmThread will clean up the resources and
* attempt to synthesize the current sentence with another TTS processor chosen
* by the TTSRegistry, unless the error is a MemoryException, in which case the
* thread gives up on the guilty sentence.
*
* The resources of the TTS processors (e.g. sockets) are allocated on-the-fly
* and are all released at the end of the thread execution.
*
*/
public class TextToPcmThread implements FormatSpecifications {
private Logger mLogger;
private Map mResources = new HashMap();
private int mFileNrInSection; //usually = 0, but incremented when a flush occurs within a section
private List mSoundFileLinks; //result provided back to the SynthesizeStep caller
private List mLinksOfCurrentFile; //links under construction
private Iterable mAudioOfCurrentFile; // audio file under construction
private int mOffsetInFile; // reset after every flush
private int mMemFootprint; //reset after every flush
private Thread mThread;
private TimedTTSExecutor mExecutor;
private TTSRegistry mTTSRegistry;
private AudioFormat mLastFormat; //used for knowing if a flush is necessary
private AudioFootprintMonitor mAudioFootprintMonitor;
private SSMLMarkSplitter mSSMLSplitter;
private VoiceManager mVoiceManager;
private TTSLog mTTSLog;
private int mErrorCounter;
private Throwable uncaughtException;
/**
* Java counterpart of SSML's marks
*/
private class Mark {
public Mark(String name, int offset) {
this.offsetInAudio = offset;
this.name = name;
}
/**
* Name
*/
public String name;
/**
* Offset in bytes
*/
public int offsetInAudio;
}
/**
* @param totalTextSize Total size of all text (not only the text contained in input
)
* @param portion Estimated portion of the text that this thread will process.
*/
void start(final ConcurrentLinkedQueue input,
final BlockingQueue pcmOutput, TimedTTSExecutor executor,
TTSRegistry ttsregistry, VoiceManager voiceManager, SSMLMarkSplitter ssmlSplitter,
Logger logger, AudioFootprintMonitor audioFootprintMonitor, final int maxQueueEltSize,
TTSLog ttsLog, MessageAppender messageAppender,
long totalTextSize, BigDecimal portion) {
mSSMLSplitter = ssmlSplitter;
mSoundFileLinks = new ArrayList();
mExecutor = executor;
mTTSRegistry = ttsregistry;
mLogger = logger;
mAudioFootprintMonitor = audioFootprintMonitor;
mVoiceManager = voiceManager;
mTTSLog = ttsLog;
mErrorCounter = 0;
flush(null, pcmOutput);
mThread = new Thread() {
@Override
public void run() {
// wrap the messages from this thread in a (empty) block so that there is always an
// active block for this thread, so that SLF4J log messages always have a destination
MessageAppender messageThread = messageAppender != null
? messageAppender.append(new MessageBuilder().withProgress(portion))
: null;
TTSTimeout timeout = new TTSTimeout();
try {
/* Main loop */
while (true) {
ContiguousText section = input.poll();
if (section == null) { //queue is empty
break;
}
mFileNrInSection = 0;
boolean breakloop = false;
for (Sentence sentence : section.sentences) {
if (breakloop) {
mErrorCounter++;
continue;
}
try {
if (!speak(section, sentence, pcmOutput, timeout, maxQueueEltSize)) mErrorCounter++;
} catch (Throwable t) {
mErrorCounter++;
mTTSLog.getWritableEntry(sentence.getID()).addError(
new TTSLog.Error(
TTSLog.ErrorCode.CRITICAL_ERROR, "the current thread is stopping because of an error", t));
breakloop = true;
}
}
flush(section, pcmOutput);
// update progress
if (messageThread != null && portion.compareTo(BigDecimal.ZERO) > 0) {
MessageBuilder m = new MessageBuilder()
.withProgress(
new BigDecimal(section.getStringSize()).divide(new BigDecimal(totalTextSize), MathContext.DECIMAL128)
.divide(portion, MathContext.DECIMAL128)
.min(BigDecimal.ONE));
messageThread.append(m).close();
}
}
//release the TTS resources
for (Map.Entry e : mResources.entrySet()) {
timeout.enableForCurrentThread(2);
try {
releaseResource(e.getKey(), e.getValue());
} catch (Exception ex) {
mTTSLog.addGeneralError(
ErrorCode.WARNING,
"Error while releasing resource of " + e.getKey().getProvider().getName() + ex.getMessage(),
ex);
} finally {
timeout.disable();
}
}
} finally {
timeout.close();
if (messageThread != null)
messageThread.close(); // sets progress to 100% if not already 100%
}
}
};
mThread.setUncaughtExceptionHandler(
(thread, throwable) -> { uncaughtException = throwable; }
);
mThread.start();
}
Collection getSoundFragments() {
if (mThread != null) {
try {
mThread.join();
} catch (InterruptedException e) {
//should not happen
mLogger.warn("TextToPCMThread interruption");
}
mThread = null;
if (uncaughtException != null) {
throw new RuntimeException("unexpected error", uncaughtException); // should not happen
}
}
return mSoundFileLinks;
}
int getErrorCount() {
return mErrorCounter;
}
private void releaseResource(TTSEngine tts, TTSResource r) {
if (r == null) {
return;
}
synchronized (r) {
try {
tts.releaseThreadResources(r);
} catch (Throwable t) {
mTTSLog.addGeneralError(
ErrorCode.WARNING, "error while releasing resources of " + tts.getProvider().getName(), t);
}
}
}
private void flush(ContiguousText section, BlockingQueue pcmOutput) {
if (section != null && mLinksOfCurrentFile.size() > 0) {
if (mLastFormat == null) {
throw new RuntimeException("coding error"); // should not happen
} else {
String filePrefix = String.format("part%04d_%02d_%03d", section
.getDocumentPosition(), section.getDocumentSplitPosition(),
mFileNrInSection);
ContiguousPCM pcm = new ContiguousPCM(
AudioUtils.concat(mAudioOfCurrentFile), section.getAudioOutputDir(), filePrefix);
for (SoundFileLink link : mLinksOfCurrentFile) {
link.clipBase = pcm.getDestinationFile();
}
try {
mAudioFootprintMonitor.transferToEncoding(mMemFootprint, pcm.sizeInBytes());
} catch (InterruptedException e) {
// Should never happen since interruptions only occur during calls to TTS processors.
mLogger.warn("interruption of memory transfer");
}
pcmOutput.add(pcm);
pcm = null;
mSoundFileLinks.addAll(mLinksOfCurrentFile);
++mFileNrInSection;
}
}
mLinksOfCurrentFile = new ArrayList();
mAudioOfCurrentFile = new ArrayList();
mOffsetInFile = 0;
mMemFootprint = 0;
mLastFormat = null;
}
/**
* Wrapper around {@link TimedTTSExecutor#synthesizeWithTimeout()} to handle
* marks. Returns a sequence of {@link AudioInputStream} with the same audio
* format.
*/
private Iterable synthesize(
TTSTimeout timeout, TTSTimeout.ThreadFreeInterrupter interrupter, TTSLog.Entry logEntry,
Sentence sentence, TTSEngine tts, Voice voice, TTSResource threadResources,
List marks, List markNames
) throws SynthesisException, TimeoutException, MemoryException {
if (tts.handlesMarks()
&& voice.getMarkSupport() != MarkSupport.MARK_NOT_SUPPORTED){
logEntry.setActualVoice(voice);
SynthesisResult result = mExecutor.synthesizeWithTimeout(
timeout, interrupter, logEntry, sentence.getText(), sentence.getSize(), tts, voice,
threadResources);
if (!AudioUtils.isPCM(result.audio.getFormat()))
throw new IllegalArgumentException();
List markOffsets = result.marks;
if (markNames.size() != markOffsets.size()) {
mTTSLog.getWritableEntry(sentence.getID()).addError(
new TTSLog.Error(ErrorCode.WARNING, "wrong number of marks with "
+ tts.getProvider().getName()
+ ". Number of marks received: " + markOffsets.size()
+ ", expected number: " + markNames.size()));
return null;
}
for (int i = 0; i < markNames.size(); i++) {
marks.add(new Mark(markNames.get(i), markOffsets.get(i)));
}
mAudioFootprintMonitor.acquireTTSMemory(result.audio);
return Collections.singletonList(result.audio);
} else {
Collection chunks = mSSMLSplitter.split(sentence.getText());
List result = new ArrayList<>();
int offset = 0;
AudioFormat format = null;
for (Chunk chunk : chunks) {
logEntry.setActualVoice(voice);
try {
AudioInputStream stream = mExecutor.synthesizeWithTimeout(
timeout, interrupter, logEntry, chunk.ssml(), Sentence.computeSize(chunk.ssml()),
tts, voice, threadResources).audio;
if (format == null) {
format = stream.getFormat();
if (!AudioUtils.isPCM(format))
throw new IllegalArgumentException("AudioInputStream must be PCM encoded, but got: "+ format);
} else if (!format.matches(stream.getFormat()))
throw new IllegalArgumentException("All AudioInputStream must have the same format");
if (chunk.leftMark() != null) {
marks.add(new Mark(chunk.leftMark(), offset));
}
int size = AudioFootprintMonitor.getFootprint(stream);
offset += size;
mAudioFootprintMonitor.acquireTTSMemory(size);
result.add(stream);
} catch (MemoryException | SynthesisException | TimeoutException e) {
// TODO: flush here
for (AudioInputStream s : result)
mAudioFootprintMonitor.releaseTTSMemory(s);
throw e;
} catch (Throwable t) {
// TODO: flush here
for (AudioInputStream s : result)
mAudioFootprintMonitor.releaseTTSMemory(s);
throw new SynthesisException(t);
}
}
if (markNames.size() != marks.size()) {
throw new RuntimeException(); // should not happen
}
return result;
}
}
/**
* @return null if something went wrong
*/
private Iterable speakWithVoice(final Sentence sentence, Voice v,
final TTSEngine tts, List marks, List markNames, TTSTimeout timeout)
throws MemoryException {
//allocate a TTS resource if necessary
TTSResource resource = mResources.get(tts);
if (resource == null) {
timeout.enableForCurrentThread(3);
try {
resource = mTTSRegistry.allocateResourceFor(tts);
} catch (SynthesisException e) {
mTTSLog.getWritableEntry(sentence.getID()).addError(
new TTSLog.Error(ErrorCode.WARNING,
"Error while allocating resources for "
+ tts.getProvider().getName() + ": "
+ e));
return null;
} catch (InterruptedException e) {
mTTSLog.getWritableEntry(sentence.getID()).addError(
new TTSLog.Error(ErrorCode.WARNING,
"Timeout while trying to allocate resources for "
+ tts.getProvider().getName()));
return null;
} finally {
timeout.disable();
}
if (resource == null) {
//TTS not working anymore?
mTTSLog.getWritableEntry(sentence.getID()).addError(
new TTSLog.Error(ErrorCode.WARNING, "Could not allocate resource for "
+ tts.getProvider().getName()
+ " (it has probably been stopped)."));
return null; //it will try with another TTS
}
mResources.put(tts, resource);
}
//convert the input sentence into PCM using the TTS processor
final TTSResource fresource = resource;
TTSTimeout.ThreadFreeInterrupter interrupter = new ThreadFreeInterrupter() {
@Override
public void threadFreeInterrupt() {
mTTSLog.getWritableEntry(sentence.getID()).addError(
new TTSLog.Error(
ErrorCode.WARNING,
"Forcing interruption of the current work of " + tts.getProvider().getName() + "..."));
tts.interruptCurrentWork(fresource);
}
};
TTSLog.Entry logEntry = mTTSLog.getWritableEntry(sentence.getID());
try {
synchronized (resource) {
if (resource.invalid) {
mTTSLog.getWritableEntry(sentence.getID()).addError(
new TTSLog.Error(
ErrorCode.WARNING,
"Resource of " + tts.getProvider().getName()
+ " is no longer valid. The corresponding service has probably been stopped."));
return null;
}
return synthesize(timeout, interrupter, logEntry,
sentence, tts, v, resource, marks, markNames);
}
} catch (TimeoutException e) {
logEntry.addError(
new TTSLog.Error(ErrorCode.WARNING, "timeout (" + e.getSeconds()
+ " seconds) fired while speaking with "
+ tts.getProvider().getName()));
return null;
} catch (SynthesisException e) {
logEntry.addError(
new TTSLog.Error(
ErrorCode.WARNING, "error while speaking with " + tts.getProvider().getName() + ": " + e, e));
return null;
}
}
static List getMarkNames(XdmNode ssml) {
XdmSequenceIterator iter = ssml.axisIterator(Axis.DESCENDANT);
ArrayList markNames = new ArrayList();
while (iter.hasNext()){
XdmNode elt = (XdmNode) iter.next();
if (elt.getNodeKind() == XdmNodeKind.ELEMENT && "mark".equals(elt.getNodeName().getLocalName())){
markNames.add(elt.getAttributeValue(new QName("name")));
}
}
return markNames;
}
/**
* @return true when the sentence was successfully converted to speech, false when there was an error
*/
private boolean speak(ContiguousText section, Sentence sentence,
BlockingQueue pcmOutput, TTSTimeout timeout, int maxQueueEltSize) {
List markNames = getMarkNames(sentence.getText());
TTSEngine tts = sentence.getTTSproc();
Voice originalVoice = sentence.getPreferredVoice();
List marks = new ArrayList();
Iterable pcm;
try {
pcm = speakWithVoice(sentence, originalVoice, tts, marks, markNames, timeout);
} catch (MemoryException e) {
flush(section, pcmOutput);
printMemError(sentence, e);
return false;
}
if (pcm == null) {
//release the resource to make it more likely for the next try to succeed
releaseResource(tts, mResources.get(tts));
mResources.remove(tts);
//Find another voice for this sentence
Iterator fallbackVoices = sentence.getVoices().iterator();
fallbackVoices.next(); // this returns the original voice
if (!fallbackVoices.hasNext()) {
mTTSLog.getWritableEntry(sentence.getID()).addError(
new TTSLog.Error(
TTSLog.ErrorCode.AUDIO_MISSING,
"something went wrong but no fallback voice can be found for " + originalVoice));
return false;
}
Voice fallbackVoice = fallbackVoices.next();
tts = mVoiceManager.getTTS(fallbackVoice); //cannot return null in this case
//Try with the new engine
marks.clear();
try {
pcm = speakWithVoice(sentence, fallbackVoice, tts, marks, markNames, timeout);
} catch (MemoryException e) {
flush(section, pcmOutput);
printMemError(sentence, e);
return false;
}
if (pcm == null) {
mTTSLog.getWritableEntry(sentence.getID()).addError(
new TTSLog.Error(
TTSLog.ErrorCode.AUDIO_MISSING,
"something went wrong with " + originalVoice
+ " and fallback voice " + fallbackVoice + " didn't work either"));
return false;
}
mLogger.info("something went wrong with " + originalVoice + ". Voice " + fallbackVoice
+ " used instead to synthesize sentence");
if (mLastFormat != null && !pcm.iterator().next().getFormat().matches(mLastFormat))
flush(section, pcmOutput);
}
mLastFormat = pcm.iterator().next().getFormat();
int begin = mOffsetInFile;
addAudio(pcm);
// keep track of where the sound begins and where it ends within the audio file
if (marks.size() == 0) {
mLinksOfCurrentFile.add(
new SoundFileLink(
sentence.getBaseURI().resolve("#" + sentence.getID()),
AudioUtils.getDuration(mLastFormat, begin),
AudioUtils.getDuration(mLastFormat, mOffsetInFile)));
} else {
Map starts = new HashMap();
Map ends = new HashMap();
Set all = new HashSet();
for (Mark m : marks) {
String[] mark = m.name.split(FormatSpecifications.MarkDelimiter, -1);
if (!mark[0].isEmpty()) {
ends.put(mark[0], m.offsetInAudio);
all.add(mark[0]);
}
if (!mark[1].isEmpty()) {
starts.put(mark[1], m.offsetInAudio);
all.add(mark[1]);
}
}
for (String id : all) {
mLinksOfCurrentFile.add(
new SoundFileLink(
sentence.getBaseURI().resolve("#" + id),
AudioUtils.getDuration(mLastFormat,
starts.containsKey(id)
? begin + starts.get(id)
: begin),
AudioUtils.getDuration(mLastFormat,
ends.containsKey(id)
? begin + ends.get(id)
: mOffsetInFile)));
}
/*
* note: if marks.size() > 0 but all.size() == 0, it means that no
* marks refer to no ID. It should imply that the sentence contains
* skippable elements but no text. In such a case, it is important
* to let the script NOT add any fragment, not even the sentence's
* parent.
*/
}
if (mMemFootprint > maxQueueEltSize) {
/*
* This flush prevents the TTS processors from raising too many
* out-of-memory errors and smoothes the transfers of PCM data to
* the encoders.
*/
flush(section, pcmOutput);
}
return true;
}
private void printMemError(Sentence sentence, MemoryException e) {
String msg = "out of memory";
mTTSLog.getWritableEntry(sentence.getID()).addError(
new TTSLog.Error(ErrorCode.AUDIO_MISSING, msg));
}
private void addAudio(Iterable toadd) {
for (AudioInputStream b : toadd) {
int size = AudioFootprintMonitor.getFootprint(b);
mOffsetInFile += size;
mMemFootprint += size;
}
mAudioOfCurrentFile = Iterables.concat(mAudioOfCurrentFile, toadd);
}
}