io.cdap.plugin.gcp.speech.SpeechToTextTransform Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of google-cloud Show documentation
Show all versions of google-cloud Show documentation
Plugins for Google Big Query
/*
* Copyright © 2015 Cask Data, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
package io.cdap.plugin.gcp.speech;
import com.google.cloud.speech.v1.RecognitionAudio;
import com.google.cloud.speech.v1.RecognitionConfig;
import com.google.cloud.speech.v1.RecognizeResponse;
import com.google.cloud.speech.v1.SpeechClient;
import com.google.cloud.speech.v1.SpeechRecognitionAlternative;
import com.google.cloud.speech.v1.SpeechRecognitionResult;
import com.google.cloud.speech.v1.SpeechSettings;
import com.google.common.base.Strings;
import com.google.protobuf.ByteString;
import io.cdap.cdap.api.annotation.Description;
import io.cdap.cdap.api.annotation.Macro;
import io.cdap.cdap.api.annotation.Name;
import io.cdap.cdap.api.annotation.Plugin;
import io.cdap.cdap.api.data.format.StructuredRecord;
import io.cdap.cdap.api.data.schema.Schema;
import io.cdap.cdap.etl.api.Emitter;
import io.cdap.cdap.etl.api.FailureCollector;
import io.cdap.cdap.etl.api.PipelineConfigurer;
import io.cdap.cdap.etl.api.StageSubmitterContext;
import io.cdap.cdap.etl.api.Transform;
import io.cdap.cdap.etl.api.TransformContext;
import io.cdap.plugin.gcp.common.GCPConfig;
import io.cdap.plugin.gcp.common.GCPUtils;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
/**
* Google Cloud Speech-To-Text Transform which transcripts audio
*/
@Plugin(type = Transform.PLUGIN_TYPE)
@Name(SpeechToTextTransform.NAME)
@Description(SpeechToTextTransform.DESCRIPTION)
public class SpeechToTextTransform extends Transform {
public static final String NAME = "SpeechToText";
public static final String DESCRIPTION = "Converts audio files to text by applying powerful neural network models.";
private SpeechTransformConfig config;
private Schema outputSchema = null;
private static final Schema SPEECH =
Schema.recordOf("speech",
Schema.Field.of("transcript", Schema.nullableOf(Schema.of(Schema.Type.STRING))),
Schema.Field.of("confidence", Schema.nullableOf(Schema.of(Schema.Type.FLOAT)))
);
private RecognitionConfig recognitionConfig;
private SpeechClient speech;
@Override
public void configurePipeline(PipelineConfigurer configurer) throws IllegalArgumentException {
super.configurePipeline(configurer);
Schema inputSchema = configurer.getStageConfigurer().getInputSchema();
FailureCollector collector = configurer.getStageConfigurer().getFailureCollector();
config.validate(inputSchema, collector);
configurer.getStageConfigurer().setOutputSchema(getSchema(inputSchema));
}
@Override
public void prepareRun(StageSubmitterContext context) throws Exception {
super.prepareRun(context);
config.validate(context.getInputSchema(), context.getFailureCollector());
}
@Override
public void initialize(TransformContext context) throws Exception {
super.initialize(context);
outputSchema = context.getOutputSchema();
setRecognitionConfig();
if (config.isServiceAccountFilePath() == null) {
context.getFailureCollector().addFailure("Service account type is undefined.",
"Must be `filePath` or `JSON`");
context.getFailureCollector().getOrThrowException();
}
speech = SpeechClient.create(getSettings());
}
@Override
public void transform(StructuredRecord input, Emitter emitter) {
ByteString audioBytes = ByteString.copyFrom((byte[]) input.get(config.audioField));
RecognitionAudio audio = RecognitionAudio.newBuilder()
.setContent(audioBytes)
.build();
RecognizeResponse response = speech.recognize(recognitionConfig, audio);
List results = response.getResultsList();
// if an output schema is available then use it or else use the schema for the given input record
Schema currentSchema;
if (outputSchema != null) {
currentSchema = outputSchema;
} else {
currentSchema = getSchema(input.getSchema());
}
StructuredRecord.Builder outputBuilder = StructuredRecord.builder(currentSchema);
List transcriptParts = new ArrayList<>();
StringBuilder completeTranscript = new StringBuilder();
for (SpeechRecognitionResult result : results) {
if (config.getPartsField() != null) {
addTranscriptWithConfidence(transcriptParts, result.getAlternativesList());
}
if (config.getTextField() != null) {
completeTranscript.append(result.getAlternativesList().get(0).getTranscript());
}
}
if (config.getPartsField() != null) {
outputBuilder.set(config.getPartsField(), transcriptParts);
}
if (config.getTextField() != null) {
outputBuilder.set(config.getTextField(), completeTranscript.toString());
}
copyFields(input, outputBuilder);
emitter.emit(outputBuilder.build());
}
@Override
public void destroy() {
super.destroy();
try {
speech.close();
} catch (Exception e) {
// no-op
}
}
@Nullable
private Schema getSchema(@Nullable Schema inputSchema) {
if (inputSchema == null) {
return null;
}
List fields = new ArrayList<>();
if (inputSchema.getFields() != null) {
fields.addAll(inputSchema.getFields());
}
boolean hasTranscriptField = false;
// one of the output field in which the transcript will be specified must be specified
if (config.transcriptionPartsField != null) {
fields.add(Schema.Field.of(config.transcriptionPartsField, Schema.arrayOf(SPEECH)));
hasTranscriptField = true;
}
if (config.transcriptionTextField != null) {
fields.add(Schema.Field.of(config.transcriptionTextField, Schema.nullableOf(Schema.of(Schema.Type.STRING))));
hasTranscriptField = true;
}
if (!hasTranscriptField && !config.containsMacro(SpeechTransformConfig.NAME_TRANS_PART)
&& !config.containsMacro(SpeechTransformConfig.NAME_TRANS_TEXT)) {
throw new IllegalArgumentException("Either 'Transcript Parts Field' or 'Transcript Text Field' or both must be " +
"specified.");
}
return Schema.recordOf("record", fields);
}
private SpeechSettings getSettings() throws IOException {
SpeechSettings.Builder builder = SpeechSettings.newBuilder();
if (!Strings.isNullOrEmpty(config.getServiceAccount())) {
builder.setCredentialsProvider(() -> GCPUtils.loadServiceAccountCredentials(config.getServiceAccount(),
config.isServiceAccountFilePath()));
}
return builder.build();
}
private void setRecognitionConfig() {
recognitionConfig = RecognitionConfig.newBuilder()
.setEncoding(getAudioEncoding())
.setSampleRateHertz(config.getSampleRate())
.setProfanityFilter(config.shouldFilterProfanity())
.setLanguageCode(config.getLanguage())
.build();
}
private RecognitionConfig.AudioEncoding getAudioEncoding() {
RecognitionConfig.AudioEncoding encoding = RecognitionConfig.AudioEncoding.LINEAR16;
if (config.encoding.equalsIgnoreCase("amr")) {
encoding = RecognitionConfig.AudioEncoding.AMR;
} else if (config.encoding.equalsIgnoreCase("amr_wb")) {
encoding = RecognitionConfig.AudioEncoding.AMR_WB;
} else if (config.encoding.equalsIgnoreCase("flac")) {
encoding = RecognitionConfig.AudioEncoding.FLAC;
} else if (config.encoding.equalsIgnoreCase("mulaw")) {
encoding = RecognitionConfig.AudioEncoding.MULAW;
} else if (config.encoding.equalsIgnoreCase("ogg_opus")) {
encoding = RecognitionConfig.AudioEncoding.OGG_OPUS;
}
return encoding;
}
private void copyFields(StructuredRecord input, StructuredRecord.Builder outputBuilder) {
// copy other schema field values
List fields = input.getSchema().getFields();
if (fields != null) {
for (Schema.Field field : fields) {
outputBuilder.set(field.getName(), input.get(field.getName()));
}
}
}
private void addTranscriptWithConfidence(List speechArray,
List alternatives) {
for (SpeechRecognitionAlternative alternative : alternatives) {
float confidence = alternative.getConfidence();
String transcript = alternative.getTranscript();
StructuredRecord.Builder speech = StructuredRecord.builder(SPEECH);
speech.set("confidence", confidence);
speech.set("transcript", transcript);
speechArray.add(speech.build());
}
}
private static class SpeechTransformConfig extends GCPConfig {
private static final String NAME_AUDIOFIELD = "audiofield";
private static final String NAME_TRANS_PART = "transcriptionPartsField";
private static final String NAME_TRANS_TEXT = "transcriptionTextField";
private static final String NAME_RATE = "samplerate";
@Macro
@Name(NAME_AUDIOFIELD)
@Description("Name of field containing binary audio file data")
private String audioField;
@Macro
@Description("Audio encoding of the data sent in the audio message. All encodings support\n" +
"only 1 channel (mono) audio. Only `FLAC` and `WAV` include a header that\n" +
"describes the bytes of audio that follow the header. The other encodings\n" +
"are raw audio bytes with no header.")
private String encoding;
@Macro
@Name(NAME_RATE)
@Description("Sample rate in Hertz of the audio data sent in all `RecognitionAudio` messages. Valid values are: " +
"8000-48000. 16000 is optimal. For best results, set the sampling rate of the audio source to 16000 Hz. If " +
"that's not possible, use the native sample rate of the audio source (instead of re-sampling).")
private String sampleRate;
@Macro
@Description("Whether to mask profanity, replacing all but the initial character in each masked word with " +
"asterisks. For example, 'f***'.")
private String profanity;
@Macro
@Description("The language of the supplied audio as a [BCP-47](https://www.rfc-editor.org/rfc/bcp/bcp47.txt) " +
"language tag. Example: \"en-US\". See [Language Support](https://cloud.google.com/speech/docs/languages) for a" +
" list of the currently supported language codes.")
private String language;
@Macro
@Nullable
@Description("The name of the field to store all the different chunks of transcription with all the " +
"different possibility and their confidence score. Defaults to 'parts'")
private String transcriptionPartsField;
@Macro
@Nullable
@Description("If a field name is specified then the transcription with highest confidence score will be stored " +
"as text.")
private String transcriptionTextField;
@Nullable
public String getPartsField() {
if (containsMacro("transcriptionPartsField") || transcriptionPartsField == null ||
transcriptionPartsField.isEmpty()) {
return null;
}
return transcriptionPartsField;
}
@Nullable
public String getTextField() {
if (containsMacro("transcriptionTextField") || transcriptionTextField == null ||
transcriptionTextField.isEmpty()) {
return null;
}
return transcriptionTextField;
}
@Nullable
public String getAudioField() {
if (containsMacro("audiofield")) {
return null;
}
return audioField;
}
public boolean shouldFilterProfanity() {
return profanity.equalsIgnoreCase("true");
}
public String getLanguage() {
return language == null ? "en-US" : language;
}
/**
* Returns sample rate as an integer.
*
* @throws IllegalArgumentException if sample rate is an invalid number
*/
@Nullable
public Integer getSampleRate() {
if (containsMacro(NAME_RATE)) {
return null;
}
try {
Integer.parseInt(sampleRate);
} catch (NumberFormatException e) {
throw new IllegalArgumentException("Sample rate should be a valid number");
}
return Integer.parseInt(sampleRate);
}
private void validate(@Nullable Schema inputSchema, FailureCollector collector) {
if (inputSchema != null) {
// If a inputSchema and audioField is available then verify that the schema does contain the given audioField
// and that the type is byte array. This will allow to fail fast i.e. during deployment time.
String audioFieldName = getAudioField();
if (audioFieldName != null) {
Schema.Field field = inputSchema.getField(audioFieldName);
if (field == null) {
collector.addFailure(String.format("Field '%s' does not exist in the input schema.", audioFieldName),
"Change audio field to be one of the schema fields.")
.withConfigProperty(NAME_AUDIOFIELD);
} else {
Schema fieldSchema = field.getSchema();
fieldSchema = fieldSchema.isNullable() ? fieldSchema.getNonNullable() : fieldSchema;
if (fieldSchema.getLogicalType() != null || fieldSchema.getType() != Schema.Type.BYTES) {
collector.addFailure(
String.format("Field '%s' is of unsupported type '%s'.",
audioFieldName, fieldSchema.getDisplayName()), "Ensure it is of type 'bytes'.")
.withConfigProperty(NAME_AUDIOFIELD).withInputSchemaField(audioFieldName);
}
}
}
if ((!containsMacro(NAME_TRANS_TEXT) && getTextField() == null) &&
(!containsMacro(NAME_TRANS_PART) && getPartsField() == null)) {
collector.addFailure(
"'Transcript Parts Field' or 'Transcript Text Field' are not provided.",
"Provide atleast one of them.")
.withConfigProperty(NAME_TRANS_PART).withConfigProperty(NAME_TRANS_TEXT);
}
Set fields = inputSchema.getFields().stream().map(Schema.Field::getName).collect(Collectors.toSet());
if (getTextField() != null && fields.contains(getTextField())) {
collector.addFailure(
String.format("Transcript text field '%s' already exists in the input schema.", getTextField()),
"Change the field name.")
.withConfigProperty(NAME_TRANS_TEXT).withInputSchemaField(getTextField());
}
if (getPartsField() != null && fields.contains(getPartsField())) {
collector.addFailure(
String.format("Transcript parts field '%s' already exists in the input schema.", getPartsField()),
"Change the field name.")
.withConfigProperty(NAME_TRANS_PART).withInputSchemaField(getPartsField());
}
}
try {
Integer sampleRate = getSampleRate();
if (sampleRate != null && (sampleRate < 8000 || sampleRate > 48000)) {
collector.addFailure("Invalid sample rate.", "Ensure the value is between 8000 and 48000.")
.withConfigProperty(NAME_RATE);
}
} catch (IllegalArgumentException e) {
collector.addFailure("Invalid sample rate.", "Ensure the value is between 8000 and 48000.")
.withConfigProperty(NAME_RATE);
}
collector.getOrThrowException();
}
}
}