All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.djl.modality.audio.SampledAudioFactory Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
package ai.djl.modality.audio;

import ai.djl.modality.cv.ImageFactory;

import java.io.BufferedInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.ShortBuffer;
import java.nio.file.Path;

import javax.sound.sampled.AudioFormat;
import javax.sound.sampled.AudioInputStream;
import javax.sound.sampled.AudioSystem;
import javax.sound.sampled.UnsupportedAudioFileException;

/**
 * {@code SampledAudioFactory} is an implementation of {@link ImageFactory} using the Java Sampled
 * Package.
 *
 * @see https://docs.oracle.com/javase/tutorial/sound/sampled-overview.html
 */
public class SampledAudioFactory extends AudioFactory {

    /** {@inheritDoc} */
    @Override
    public AudioFactory setChannels(int channel) {
        if (channel != 0) {
            throw new UnsupportedOperationException("SampledAudioFactory only support 1 channel.");
        }
        return this;
    }

    /** {@inheritDoc} */
    @Override
    public AudioFactory setSampleRate(int sampleRate) {
        throw new UnsupportedOperationException("Not supported.");
    }

    /** {@inheritDoc} */
    @Override
    public AudioFactory setSampleFormat(int sampleFormat) {
        throw new UnsupportedOperationException("Not supported.");
    }

    /** {@inheritDoc} */
    @Override
    public Audio fromFile(Path path) throws IOException {
        try (AudioInputStream ais = AudioSystem.getAudioInputStream(path.toFile())) {
            AudioFormat format = ais.getFormat();
            byte[] bytes = read(ais);
            float[] floats = bytesToFloats(bytes, format.isBigEndian());
            return new Audio(floats, format.getSampleRate(), format.getChannels());
        } catch (UnsupportedAudioFileException e) {
            throw new IOException("Unsupported Audio file", e);
        }
    }

    /** {@inheritDoc} */
    @Override
    public Audio fromInputStream(InputStream is) throws IOException {
        try (AudioInputStream ais = AudioSystem.getAudioInputStream(new BufferedInputStream(is))) {
            AudioFormat format = ais.getFormat();
            byte[] bytes = read(ais);
            float[] floats = bytesToFloats(bytes, format.isBigEndian());
            return new Audio(floats, format.getSampleRate(), format.getChannels());
        } catch (UnsupportedAudioFileException e) {
            throw new IOException("Unsupported Audio file", e);
        }
    }

    private byte[] read(AudioInputStream ais) throws IOException {
        AudioFormat format = ais.getFormat();
        int frameSize = format.getFrameSize();

        // Some audio formats may have unspecified frame size
        if (frameSize == AudioSystem.NOT_SPECIFIED) {
            frameSize = 1;
        }

        int size = (int) ais.getFrameLength() * frameSize;
        byte[] ret = new byte[size];
        byte[] buf = new byte[1024];
        int offset = 0;
        int read;
        while ((read = ais.read(buf)) != -1) {
            System.arraycopy(buf, 0, ret, offset, read);
            offset += read;
        }
        return ret;
    }

    private float[] bytesToFloats(byte[] bytes, boolean isBigEndian) {
        ByteOrder order = isBigEndian ? ByteOrder.BIG_ENDIAN : ByteOrder.LITTLE_ENDIAN;
        ShortBuffer buffer = ByteBuffer.wrap(bytes).order(order).asShortBuffer();
        short[] shorts = new short[buffer.capacity()];
        buffer.get(shorts);

        // Feed in float values between -1.0f and 1.0f.
        float[] floats = new float[shorts.length];
        for (int i = 0; i < shorts.length; i++) {
            floats[i] = ((float) shorts[i]) / (float) Short.MAX_VALUE;
        }
        return floats;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy