All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.jyuzawa.onnxruntime.OnnxTensorStringImpl Maven / Gradle / Ivy

The newest version!
/*
 * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/)
 * SPDX-License-Identifier: MIT
 */
package com.jyuzawa.onnxruntime;

import static com.jyuzawa.onnxruntime_extern.onnxruntime_all_h.C_CHAR;
import static com.jyuzawa.onnxruntime_extern.onnxruntime_all_h.C_POINTER;

import java.lang.foreign.MemoryAddress;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.SegmentAllocator;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.stream.Stream;

final class OnnxTensorStringImpl extends OnnxTensorImpl {

    private final String[] buffer;

    OnnxTensorStringImpl(TensorInfoImpl tensorInfo, ValueContext valueContext, MemoryAddress ortValueAddress) {
        super(tensorInfo, valueContext);
        this.buffer = new String[Math.toIntExact(tensorInfo.getElementCount())];
        if (ortValueAddress != null) {
            ApiImpl api = valueContext.api();
            SegmentAllocator segmentAllocator = valueContext.segmentAllocator();
            int numOutputs = buffer.length;
            for (int i = 0; i < numOutputs; i++) {
                final long index = i;
                long length = api.extractLong(
                        segmentAllocator, out -> api.GetStringTensorElementLength.apply(ortValueAddress, index, out));
                // add a byte for the null termination
                MemorySegment output = segmentAllocator.allocateArray(C_CHAR, length + 1);
                api.checkStatus(api.GetStringTensorElement.apply(ortValueAddress, length, index, output.address()));
                buffer[i] = output.getUtf8String(0);
            }
        }
    }

    @Override
    public String toString() {
        return "{OnnxTensor: info=" + tensorInfo + ", buffer=" + Arrays.toString(buffer) + "}";
    }

    @Override
    public String[] getStringBuffer() {
        return buffer;
    }

    @Override
    public MemoryAddress toNative() {
        int numOutputs = buffer.length;
        ApiImpl api = valueContext.api();
        SegmentAllocator allocator = valueContext.segmentAllocator();
        MemorySegment stringArray = allocator.allocateArray(C_POINTER, numOutputs);
        for (int i = 0; i < numOutputs; i++) {
            stringArray.setAtIndex(C_POINTER, i, allocator.allocateUtf8String(buffer[i]));
        }
        MemoryAddress tensor = api.create(
                allocator,
                out -> api.CreateTensorAsOrtValue.apply(
                        valueContext.ortAllocatorAddress(),
                        tensorInfo.shapeData.address(),
                        tensorInfo.getShape().size(),
                        tensorInfo.getType().getNumber(),
                        out));
        api.checkStatus(api.FillStringTensor.apply(tensor, stringArray.address(), numOutputs));
        return tensor;
    }

    @Override
    public void putScalars(Collection scalars) {
        int i = 0;
        for (OnnxTensorImpl scalar : scalars) {
            buffer[i++] = scalar.getStringBuffer()[0];
        }
    }

    @Override
    public void getScalars(Stream scalars) {
        int i = 0;
        Iterator iter = scalars.iterator();
        while (iter.hasNext()) {
            OnnxTensorImpl scalar = iter.next();
            scalar.getStringBuffer()[0] = buffer[i++];
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy