All Downloads are FREE. Search and download functionalities are using the official Maven repository.

apoc.load.LoadArrow Maven / Gradle / Ivy

There is a newer version: 5.24.0
Show newest version
/*
 * Copyright (c) "Neo4j"
 * Neo4j Sweden AB [http://neo4j.com]
 *
 * This file is part of Neo4j.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package apoc.load;

import apoc.result.MapResult;
import apoc.util.FileUtils;
import apoc.util.JsonUtil;
import apoc.util.Util;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.net.URISyntaxException;
import java.nio.channels.SeekableByteChannel;
import java.time.Instant;
import java.time.ZoneOffset;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import java.util.Spliterator;
import java.util.Spliterators;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.BitVector;
import org.apache.arrow.vector.DateMilliVector;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.ipc.ArrowFileReader;
import org.apache.arrow.vector.ipc.ArrowReader;
import org.apache.arrow.vector.ipc.ArrowStreamReader;
import org.apache.arrow.vector.util.Text;
import org.neo4j.graphdb.security.URLAccessChecker;
import org.neo4j.graphdb.security.URLAccessValidationError;
import org.neo4j.procedure.Context;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Name;
import org.neo4j.procedure.Procedure;
import org.neo4j.values.storable.Values;

public class LoadArrow {

    @Context
    public URLAccessChecker urlAccessChecker;

    private static class ArrowSpliterator extends Spliterators.AbstractSpliterator {

        private final ArrowReader reader;
        private final VectorSchemaRoot schemaRoot;
        private final AtomicInteger counter;

        public ArrowSpliterator(ArrowReader reader, VectorSchemaRoot schemaRoot) throws IOException {
            super(Long.MAX_VALUE, Spliterator.ORDERED);
            this.reader = reader;
            this.schemaRoot = schemaRoot;
            this.counter = new AtomicInteger();
            this.reader.loadNextBatch();
        }

        @Override
        public synchronized boolean tryAdvance(Consumer action) {
            try {
                if (counter.get() >= schemaRoot.getRowCount()) {
                    if (reader.loadNextBatch()) {
                        counter.set(0);
                    } else {
                        return false;
                    }
                }
                final Map row = schemaRoot.getFieldVectors().stream()
                        .collect(
                                HashMap::new,
                                (map, fieldVector) -> map.put(fieldVector.getName(), read(fieldVector, counter.get())),
                                HashMap::putAll); // please look at https://bugs.openjdk.java.net/browse/JDK-8148463
                counter.incrementAndGet();
                action.accept(new MapResult(row));
                return true;
            } catch (Exception e) {
                return false;
            }
        }
    }

    @Procedure(name = "apoc.load.arrow.stream")
    @Description("Imports `NODE` and `RELATIONSHIP` values from the provided arrow byte array.")
    public Stream stream(
            @Name("source") byte[] source, @Name(value = "config", defaultValue = "{}") Map config)
            throws IOException {
        RootAllocator allocator = new RootAllocator();
        ByteArrayInputStream inputStream = new ByteArrayInputStream(source);
        ArrowStreamReader streamReader = new ArrowStreamReader(inputStream, allocator);
        VectorSchemaRoot schemaRoot = streamReader.getVectorSchemaRoot();
        return StreamSupport.stream(new ArrowSpliterator(streamReader, schemaRoot), false)
                .onClose(() -> {
                    Util.close(allocator);
                    Util.close(streamReader);
                    Util.close(schemaRoot);
                    Util.close(inputStream);
                });
    }

    @Procedure(name = "apoc.load.arrow")
    @Description("Imports `NODE` and `RELATIONSHIP` values from the provided arrow file.")
    public Stream file(
            @Name("file") String fileName, @Name(value = "config", defaultValue = "{}") Map config)
            throws IOException, URISyntaxException, URLAccessValidationError {
        final SeekableByteChannel channel = FileUtils.inputStreamFor(fileName, null, null, null, urlAccessChecker)
                .asChannel();
        RootAllocator allocator = new RootAllocator();
        ArrowFileReader streamReader = new ArrowFileReader(channel, allocator);
        VectorSchemaRoot schemaRoot = streamReader.getVectorSchemaRoot();
        return StreamSupport.stream(new ArrowSpliterator(streamReader, schemaRoot), false)
                .onClose(() -> {
                    Util.close(allocator);
                    Util.close(streamReader);
                    Util.close(schemaRoot);
                    Util.close(channel);
                });
    }

    private static Object read(FieldVector fieldVector, int index) {
        if (fieldVector.isNull(index)) {
            return null;
        } else if (fieldVector instanceof DateMilliVector) {
            DateMilliVector fe = (DateMilliVector) fieldVector;
            return Instant.ofEpochMilli(fe.get(index)).atOffset(ZoneOffset.UTC);
        } else if (fieldVector instanceof BitVector) {
            BitVector fe = (BitVector) fieldVector;
            return fe.get(index) == 1;
        } else {
            Object object = fieldVector.getObject(index);
            return getObject(object);
        }
    }

    private static Object getObject(Object object) {
        if (object instanceof Collection) {
            return ((Collection) object).stream().map(LoadArrow::getObject).collect(Collectors.toList());
        }
        if (object instanceof Map) {
            return ((Map) object)
                    .entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> getObject(e.getValue())));
        }
        if (object instanceof Text) {
            return object.toString();
        }
        try {
            // we test if is a valid Neo4j type
            return Values.of(object);
        } catch (Exception e) {
            // otherwise we try coerce it
            return valueToString(object);
        }
    }

    private static String valueToString(Object value) {
        return JsonUtil.writeValueAsString(value);
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy