All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.druid.query.monomorphicprocessing.SpecializationService Maven / Gradle / Ivy

There is a newer version: 31.0.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.druid.query.monomorphicprocessing;

import com.google.common.collect.ImmutableMap;
import com.google.common.io.ByteStreams;
import org.apache.druid.java.util.common.DefineClassUtils;
import org.apache.druid.java.util.common.concurrent.Execs;
import org.apache.druid.java.util.common.logger.Logger;
import org.objectweb.asm.ClassReader;
import org.objectweb.asm.ClassVisitor;
import org.objectweb.asm.ClassWriter;
import org.objectweb.asm.commons.ClassRemapper;
import org.objectweb.asm.commons.SimpleRemapper;

import javax.annotation.Nullable;
import java.io.IOException;
import java.io.InputStream;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;

/**
 * Manages class specialization during query processing.
 * Usage:
 *
 * String runtimeShape = stringRuntimeShape.of(bufferAggregator);
 * SpecializationState specializationState = SpecializationService.getSpecializationState(
 *   ProcessingAlgorithmImpl.class,
 *   runtimeShape
 * );
 * ProcessingAlgorithm algorithm = specializationState.getSpecializedOrDefault(new ProcessingAlgorithmImpl());
 * long loopIterations = new ProcessingAlgorithmImpl().run(bufferAggregator, ...);
 * specializationState.accountLoopIterations(loopIterations);
 *
 * ProcessingAlgorithmImpl.class, passed as prototypeClass to {@link #getSpecializationState} methods must have public
 * no-arg constructor and must be stateless (no fields).
 *
 * @see SpecializationState
 */
public final class SpecializationService
{
  private static final Logger LOG = new Logger(SpecializationService.class);

  /**
   * If true, specialization is not actually done, an instance of prototypeClass is used as a "specialized" instance.
   * Useful for analysis of generated assembly with JITWatch (https://github.com/AdoptOpenJDK/jitwatch), because
   * JITWatch shows only classes present in the loaded JAR (prototypeClass should be), not classes generated during
   * runtime.
   */
  private static final boolean FAKE_SPECIALIZE = Boolean.getBoolean("fakeSpecialize");

  /**
   * Number of loop iterations, accounted via {@link SpecializationState#accountLoopIterations(long)} in
   * {@link WindowedLoopIterationCounter} during the last hour window, after which WindowedLoopIterationCounter decides
   * to specialize class for the specific runtimeShape. The default value is chosen to be so that the specialized
   * class will likely be compiled with C2 HotSpot compiler with the default values of *BackEdgeThreshold options.
   */
  private static final int TRIGGER_SPECIALIZATION_ITERATIONS_THRESHOLD =
      Integer.getInteger("triggerSpecializationIterationsThreshold", 10_000);

  /**
   * The maximum number of specializations, that this service is allowed to make. It's not unlimited because each
   * specialization takes some JVM memory (machine code cache, byte code, etc.)
   */
  private static final int MAX_SPECIALIZATIONS = Integer.getInteger("maxSpecializations", 1000);
  private static final AtomicBoolean MAX_SPECIALIZATIONS_WARNING_EMITTED = new AtomicBoolean(false);

  private static final ExecutorService CLASS_SPECIALIZATION_EXECUTOR = Execs.singleThreaded("class-specialization-%d");

  private static final AtomicLong SPECIALIZED_CLASS_COUNTER = new AtomicLong();

  private static final ClassValue PER_PROTOTYPE_CLASS_STATE =
      new ClassValue()
      {
        @Override
        protected PerPrototypeClassState computeValue(Class type)
        {
          return new PerPrototypeClassState<>(type);
        }
      };

  /**
   * @param  type of query processing algorithm
   * @see SpecializationService class-level javadoc for details
   */
  public static  SpecializationState getSpecializationState(
      Class prototypeClass,
      String runtimeShape
  )
  {
    return getSpecializationState(prototypeClass, runtimeShape, ImmutableMap.of());
  }

  /**
   * @param classRemapping classes, that should be replaced in the bytecode of the given prototypeClass when specialized
   * @see #getSpecializationState(Class, String)
   */
  @SuppressWarnings("unchecked")
  public static  SpecializationState getSpecializationState(
      Class prototypeClass,
      String runtimeShape,
      ImmutableMap, Class> classRemapping
  )
  {
    return PER_PROTOTYPE_CLASS_STATE.get(prototypeClass).getSpecializationState(runtimeShape, classRemapping);
  }

  static class PerPrototypeClassState
  {
    private final Class prototypeClass;
    private final ConcurrentHashMap> specializationStates =
        new ConcurrentHashMap<>();
    private final String prototypeClassBytecodeName;
    private final String specializedClassNamePrefix;

    private byte[] prototypeClassBytecode;

    PerPrototypeClassState(Class prototypeClass)
    {
      this.prototypeClass = prototypeClass;
      String prototypeClassName = prototypeClass.getName();
      prototypeClassBytecodeName = classBytecodeName(prototypeClassName);
      specializedClassNamePrefix = prototypeClassName + "$Copy";
    }

    SpecializationState getSpecializationState(String runtimeShape, ImmutableMap, Class> classRemapping)
    {
      SpecializationId specializationId = new SpecializationId(runtimeShape, classRemapping);
      // get() before computeIfAbsent() is an optimization to avoid locking in computeIfAbsent() if not needed.
      // See https://github.com/apache/druid/pull/6898#discussion_r251384586.
      SpecializationState alreadyExistingState = specializationStates.get(specializationId);
      if (alreadyExistingState != null) {
        return alreadyExistingState;
      }
      return specializationStates.computeIfAbsent(specializationId, id -> new WindowedLoopIterationCounter<>(this, id));
    }

    T specialize(ImmutableMap, Class> classRemapping)
    {
      String specializedClassName = specializedClassNamePrefix + SPECIALIZED_CLASS_COUNTER.get();
      ClassWriter specializedClassWriter = new ClassWriter(0);
      SimpleRemapper remapper = new SimpleRemapper(createRemapping(classRemapping, specializedClassName));
      ClassVisitor classTransformer = new ClassRemapper(specializedClassWriter, remapper);
      try {
        ClassReader prototypeClassReader = new ClassReader(getPrototypeClassBytecode());
        prototypeClassReader.accept(classTransformer, 0);
        byte[] specializedClassBytecode = specializedClassWriter.toByteArray();
        @SuppressWarnings("unchecked")
        Class specializedClass = (Class) DefineClassUtils.defineClass(
            prototypeClass,
            specializedClassBytecode,
            specializedClassName
        );
        SPECIALIZED_CLASS_COUNTER.incrementAndGet();
        return specializedClass.newInstance();
      }
      catch (InstantiationException | IllegalAccessException | IOException e) {
        throw new RuntimeException(e);
      }
    }

    private HashMap createRemapping(
        ImmutableMap, Class> classRemapping,
        String specializedClassName
    )
    {
      HashMap remapping = new HashMap<>();
      remapping.put(prototypeClassBytecodeName, classBytecodeName(specializedClassName));
      for (Map.Entry, Class> classRemappingEntry : classRemapping.entrySet()) {
        Class sourceClass = classRemappingEntry.getKey();
        Class remappingClass = classRemappingEntry.getValue();
        remapping.put(classBytecodeName(sourceClass.getName()), classBytecodeName(remappingClass.getName()));
      }
      return remapping;
    }

    /**
     * No synchronization, because {@link #specialize} is called only from {@link #CLASS_SPECIALIZATION_EXECUTOR}, i. e.
     * from a single thread.
     */
    byte[] getPrototypeClassBytecode() throws IOException
    {
      if (prototypeClassBytecode == null) {
        ClassLoader cl = prototypeClass.getClassLoader();
        try (InputStream prototypeClassBytecodeStream =
                 cl.getResourceAsStream(prototypeClassBytecodeName + ".class")) {
          prototypeClassBytecode = ByteStreams.toByteArray(prototypeClassBytecodeStream);
        }
      }
      return prototypeClassBytecode;
    }

    private static String classBytecodeName(String className)
    {
      return className.replace('.', '/');
    }
  }

  private static class SpecializationId
  {
    private final String runtimeShape;
    private final ImmutableMap, Class> classRemapping;
    private final int hashCode;

    private SpecializationId(String runtimeShape, ImmutableMap, Class> classRemapping)
    {
      this.runtimeShape = runtimeShape;
      this.classRemapping = classRemapping;
      this.hashCode = runtimeShape.hashCode() * 1000003 + classRemapping.hashCode();
    }

    @Override
    public boolean equals(Object obj)
    {
      if (!(obj instanceof SpecializationId)) {
        return false;
      }
      SpecializationId other = (SpecializationId) obj;
      return runtimeShape.equals(other.runtimeShape) && classRemapping.equals(other.classRemapping);
    }

    @Override
    public int hashCode()
    {
      return hashCode;
    }
  }

  /**
   * Accumulates the number of iterations during the last hour. (Window size = 1 hour)
   */
  static class WindowedLoopIterationCounter extends SpecializationState implements Runnable
  {
    private final PerPrototypeClassState perPrototypeClassState;
    private final SpecializationId specializationId;
    /** A map with the number of iterations per each minute during the last hour */
    private final ConcurrentHashMap perMinuteIterations = new ConcurrentHashMap<>();
    private final AtomicBoolean specializationScheduled = new AtomicBoolean(false);

    WindowedLoopIterationCounter(
        PerPrototypeClassState perPrototypeClassState,
        SpecializationId specializationId
    )
    {
      this.perPrototypeClassState = perPrototypeClassState;
      this.specializationId = specializationId;
    }

    @Nullable
    @Override
    public T getSpecialized()
    {
      // Returns null because the class is not yet specialized. The purpose of WindowedLoopIterationCounter is to decide
      // whether specialization should be done, or not.
      return null;
    }

    @Override
    public void accountLoopIterations(long loopIterations)
    {
      if (specializationScheduled.get()) {
        return;
      }
      if (loopIterations > TRIGGER_SPECIALIZATION_ITERATIONS_THRESHOLD ||
          addAndGetTotalIterationsOverTheLastHour(loopIterations) > TRIGGER_SPECIALIZATION_ITERATIONS_THRESHOLD) {
        if (specializationScheduled.compareAndSet(false, true)) {
          CLASS_SPECIALIZATION_EXECUTOR.submit(this);
        }
      }
    }

    private long addAndGetTotalIterationsOverTheLastHour(long newIterations)
    {
      long currentMillis = System.currentTimeMillis();
      long currentMinute = TimeUnit.MILLISECONDS.toMinutes(currentMillis);
      long minuteOneHourAgo = currentMinute - TimeUnit.HOURS.toMinutes(1);
      long totalIterations = 0;
      boolean currentMinutePresent = false;
      for (Iterator> it = perMinuteIterations.entrySet().iterator(); it.hasNext(); ) {
        Map.Entry minuteStats = it.next();
        long minute = minuteStats.getKey();
        if (minute < minuteOneHourAgo) {
          it.remove();
        } else if (minute == currentMinute) {
          totalIterations += minuteStats.getValue().addAndGet(newIterations);
          currentMinutePresent = true;
        } else {
          totalIterations += minuteStats.getValue().get();
        }
      }
      if (!currentMinutePresent) {
        perMinuteIterations.computeIfAbsent(currentMinute, m -> new AtomicLong()).addAndGet(newIterations);
        totalIterations += newIterations;
      }
      return totalIterations;
    }

    @Override
    public void run()
    {
      try {
        T specialized;
        if (SPECIALIZED_CLASS_COUNTER.get() > MAX_SPECIALIZATIONS) {
          // Don't specialize, just instantiate the prototype class and emit a warning.
          // The "better" approach is probably to implement some kind of cache eviction from
          // PerPrototypeClassState.specializationStates. But it might be that nobody ever hits even the current
          // maxSpecializations limit, so implementing cache eviction is an unnecessary complexity.
          specialized = perPrototypeClassState.prototypeClass.newInstance();
          if (!MAX_SPECIALIZATIONS_WARNING_EMITTED.get() && MAX_SPECIALIZATIONS_WARNING_EMITTED.compareAndSet(false, true)) {
            LOG.warn(
                "SpecializationService couldn't make more than [%d] specializations. " +
                "Not doing specialization for runtime shape[%s] and class remapping[%s], using the prototype class[%s]",
                MAX_SPECIALIZATIONS,
                specializationId.runtimeShape,
                specializationId.classRemapping,
                perPrototypeClassState.prototypeClass
            );
          }
        } else if (FAKE_SPECIALIZE) {
          specialized = perPrototypeClassState.prototypeClass.newInstance();
          LOG.info(
              "Not specializing prototype class[%s] for runtime shape[%s] and class remapping[%s] because "
              + "fakeSpecialize=true, using the prototype class instead",
              perPrototypeClassState.prototypeClass,
              specializationId.runtimeShape,
              specializationId.classRemapping
          );
        } else {
          specialized = perPrototypeClassState.specialize(specializationId.classRemapping);
          LOG.info(
              "Specializing prototype class[%s] for runtime shape[%s] and class remapping[%s]",
              perPrototypeClassState.prototypeClass,
              specializationId.runtimeShape,
              specializationId.classRemapping
          );
        }
        perPrototypeClassState.specializationStates.put(specializationId, new Specialized<>(specialized));
      }
      catch (Exception e) {
        LOG.error(
            e,
            "Error specializing prototype class[%s] for runtime shape[%s] and class remapping[%s]",
            perPrototypeClassState.prototypeClass,
            specializationId.runtimeShape,
            specializationId.classRemapping
        );
      }
    }
  }

  static class Specialized extends SpecializationState
  {
    private final T specialized;

    Specialized(T specialized)
    {
      this.specialized = specialized;
    }

    @Override
    public T getSpecialized()
    {
      return specialized;
    }

    @Override
    public void accountLoopIterations(long loopIterations)
    {
      // do nothing
    }
  }

  private SpecializationService()
  {
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy