All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.arrow.flight.integration.tests.IntegrationProducer Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.arrow.flight.integration.tests;

import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.stream.Collectors;

import org.apache.arrow.flight.CallStatus;
import org.apache.arrow.flight.FlightDescriptor;
import org.apache.arrow.flight.FlightEndpoint;
import org.apache.arrow.flight.FlightInfo;
import org.apache.arrow.flight.FlightStream;
import org.apache.arrow.flight.Location;
import org.apache.arrow.flight.NoOpFlightProducer;
import org.apache.arrow.flight.PutResult;
import org.apache.arrow.flight.Ticket;
import org.apache.arrow.memory.ArrowBuf;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.util.AutoCloseables;
import org.apache.arrow.vector.VectorLoader;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.VectorUnloader;
import org.apache.arrow.vector.dictionary.DictionaryProvider;
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.arrow.vector.util.DictionaryUtility;

/**
 * A FlightProducer that hosts an in memory store of Arrow buffers. Used for integration testing.
 */
public class IntegrationProducer extends NoOpFlightProducer implements AutoCloseable {
  private final ConcurrentMap datasets = new ConcurrentHashMap<>();
  private final BufferAllocator allocator;
  private Location location;

  /**
   * Constructs a new instance.
   *
   * @param allocator The allocator for creating new Arrow buffers.
   * @param location The location of the storage.
   */
  public IntegrationProducer(BufferAllocator allocator, Location location) {
    super();
    this.allocator = allocator;
    this.location = location;
  }

  /**
   * Update the location after server start.
   *
   * 

Useful for binding to port 0 to get a free port. */ public void setLocation(Location location) { this.location = location; } @Override public void getStream(CallContext context, Ticket ticket, ServerStreamListener listener) { try { FlightDescriptor descriptor = FlightDescriptor.deserialize(ByteBuffer.wrap(ticket.getBytes())); Dataset dataset = datasets.get(descriptor); if (dataset == null) { listener.error(CallStatus.NOT_FOUND.withDescription("Unknown ticket: " + descriptor).toRuntimeException()); return; } dataset.streamTo(allocator, listener); } catch (Exception ex) { listener.error(IntegrationAssertions.toFlightRuntimeException(ex)); } } @Override public FlightInfo getFlightInfo(CallContext context, FlightDescriptor descriptor) { Dataset h = datasets.get(descriptor); if (h == null) { throw CallStatus.NOT_FOUND.withDescription("Unknown descriptor: " + descriptor).toRuntimeException(); } return h.getFlightInfo(location); } @Override public Runnable acceptPut(CallContext context, final FlightStream flightStream, final StreamListener ackStream) { return () -> { List batches = new ArrayList<>(); try { try (VectorSchemaRoot root = flightStream.getRoot()) { VectorUnloader unloader = new VectorUnloader(root); while (flightStream.next()) { ackStream.onNext(PutResult.metadata(flightStream.getLatestMetadata())); batches.add(unloader.getRecordBatch()); } // Closing the stream will release the dictionaries, take ownership final Dataset dataset = new Dataset( flightStream.getDescriptor(), flightStream.getSchema(), flightStream.takeDictionaryOwnership(), batches); batches.clear(); datasets.put(flightStream.getDescriptor(), dataset); } finally { AutoCloseables.close(batches); } } catch (Exception ex) { ackStream.onError(IntegrationAssertions.toFlightRuntimeException(ex)); } }; } @Override public void close() throws Exception { AutoCloseables.close(datasets.values()); datasets.clear(); } private static final class Dataset implements AutoCloseable { private final FlightDescriptor descriptor; private final Schema schema; private final DictionaryProvider dictionaryProvider; private final List batches; private Dataset(FlightDescriptor descriptor, Schema schema, DictionaryProvider dictionaryProvider, List batches) { this.descriptor = descriptor; this.schema = schema; this.dictionaryProvider = dictionaryProvider; this.batches = new ArrayList<>(batches); } public FlightInfo getFlightInfo(Location location) { ByteBuffer serializedDescriptor = descriptor.serialize(); byte[] descriptorBytes = new byte[serializedDescriptor.remaining()]; serializedDescriptor.get(descriptorBytes); final List endpoints = Collections.singletonList( new FlightEndpoint(new Ticket(descriptorBytes), location)); return new FlightInfo( messageFormatSchema(), descriptor, endpoints, batches.stream().mapToLong(ArrowRecordBatch::computeBodyLength).sum(), batches.stream().mapToInt(ArrowRecordBatch::getLength).sum()); } private Schema messageFormatSchema() { final Set dictionaryIdsUsed = new HashSet<>(); final List messageFormatFields = schema.getFields() .stream() .map(f -> DictionaryUtility.toMessageFormat(f, dictionaryProvider, dictionaryIdsUsed)) .collect(Collectors.toList()); return new Schema(messageFormatFields, schema.getCustomMetadata()); } @Override public void close() throws Exception { AutoCloseables.close(batches); } public void streamTo(BufferAllocator allocator, ServerStreamListener listener) { try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { listener.start(root, dictionaryProvider); final VectorLoader loader = new VectorLoader(root); int counter = 0; for (ArrowRecordBatch batch : batches) { final byte[] rawMetadata = Integer.toString(counter).getBytes(StandardCharsets.UTF_8); final ArrowBuf metadata = allocator.buffer(rawMetadata.length); metadata.writeBytes(rawMetadata); loader.load(batch); // Transfers ownership of the buffer - do not free buffer ourselves listener.putNext(metadata); counter++; } listener.completed(); } } } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy