org.apache.druid.query.monomorphicprocessing.SpecializationService Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of druid-processing Show documentation
Show all versions of druid-processing Show documentation
A module that is everything required to understands Druid Segments
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.monomorphicprocessing;
import com.google.common.collect.ImmutableMap;
import com.google.common.io.ByteStreams;
import org.apache.druid.java.util.common.DefineClassUtils;
import org.apache.druid.java.util.common.concurrent.Execs;
import org.apache.druid.java.util.common.logger.Logger;
import org.objectweb.asm.ClassReader;
import org.objectweb.asm.ClassVisitor;
import org.objectweb.asm.ClassWriter;
import org.objectweb.asm.commons.ClassRemapper;
import org.objectweb.asm.commons.SimpleRemapper;
import javax.annotation.Nullable;
import java.io.IOException;
import java.io.InputStream;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
/**
* Manages class specialization during query processing.
* Usage:
*
* String runtimeShape = stringRuntimeShape.of(bufferAggregator);
* SpecializationState specializationState = SpecializationService.getSpecializationState(
* ProcessingAlgorithmImpl.class,
* runtimeShape
* );
* ProcessingAlgorithm algorithm = specializationState.getSpecializedOrDefault(new ProcessingAlgorithmImpl());
* long loopIterations = new ProcessingAlgorithmImpl().run(bufferAggregator, ...);
* specializationState.accountLoopIterations(loopIterations);
*
* ProcessingAlgorithmImpl.class, passed as prototypeClass to {@link #getSpecializationState} methods must have public
* no-arg constructor and must be stateless (no fields).
*
* @see SpecializationState
*/
public final class SpecializationService
{
private static final Logger LOG = new Logger(SpecializationService.class);
/**
* If true, specialization is not actually done, an instance of prototypeClass is used as a "specialized" instance.
* Useful for analysis of generated assembly with JITWatch (https://github.com/AdoptOpenJDK/jitwatch), because
* JITWatch shows only classes present in the loaded JAR (prototypeClass should be), not classes generated during
* runtime.
*/
private static final boolean FAKE_SPECIALIZE = Boolean.getBoolean("fakeSpecialize");
/**
* Number of loop iterations, accounted via {@link SpecializationState#accountLoopIterations(long)} in
* {@link WindowedLoopIterationCounter} during the last hour window, after which WindowedLoopIterationCounter decides
* to specialize class for the specific runtimeShape. The default value is chosen to be so that the specialized
* class will likely be compiled with C2 HotSpot compiler with the default values of *BackEdgeThreshold options.
*/
private static final int TRIGGER_SPECIALIZATION_ITERATIONS_THRESHOLD =
Integer.getInteger("triggerSpecializationIterationsThreshold", 10_000);
/**
* The maximum number of specializations, that this service is allowed to make. It's not unlimited because each
* specialization takes some JVM memory (machine code cache, byte code, etc.)
*/
private static final int MAX_SPECIALIZATIONS = Integer.getInteger("maxSpecializations", 1000);
private static final AtomicBoolean MAX_SPECIALIZATIONS_WARNING_EMITTED = new AtomicBoolean(false);
private static final ExecutorService CLASS_SPECIALIZATION_EXECUTOR = Execs.singleThreaded("class-specialization-%d");
private static final AtomicLong SPECIALIZED_CLASS_COUNTER = new AtomicLong();
private static final ClassValue PER_PROTOTYPE_CLASS_STATE =
new ClassValue()
{
@Override
protected PerPrototypeClassState computeValue(Class> type)
{
return new PerPrototypeClassState<>(type);
}
};
/**
* @param type of query processing algorithm
* @see SpecializationService class-level javadoc for details
*/
public static SpecializationState getSpecializationState(
Class extends T> prototypeClass,
String runtimeShape
)
{
return getSpecializationState(prototypeClass, runtimeShape, ImmutableMap.of());
}
/**
* @param classRemapping classes, that should be replaced in the bytecode of the given prototypeClass when specialized
* @see #getSpecializationState(Class, String)
*/
@SuppressWarnings("unchecked")
public static SpecializationState getSpecializationState(
Class extends T> prototypeClass,
String runtimeShape,
ImmutableMap, Class>> classRemapping
)
{
return PER_PROTOTYPE_CLASS_STATE.get(prototypeClass).getSpecializationState(runtimeShape, classRemapping);
}
static class PerPrototypeClassState
{
private final Class prototypeClass;
private final ConcurrentHashMap> specializationStates =
new ConcurrentHashMap<>();
private final String prototypeClassBytecodeName;
private final String specializedClassNamePrefix;
private byte[] prototypeClassBytecode;
PerPrototypeClassState(Class prototypeClass)
{
this.prototypeClass = prototypeClass;
String prototypeClassName = prototypeClass.getName();
prototypeClassBytecodeName = classBytecodeName(prototypeClassName);
specializedClassNamePrefix = prototypeClassName + "$Copy";
}
SpecializationState getSpecializationState(String runtimeShape, ImmutableMap, Class>> classRemapping)
{
SpecializationId specializationId = new SpecializationId(runtimeShape, classRemapping);
// get() before computeIfAbsent() is an optimization to avoid locking in computeIfAbsent() if not needed.
// See https://github.com/apache/druid/pull/6898#discussion_r251384586.
SpecializationState alreadyExistingState = specializationStates.get(specializationId);
if (alreadyExistingState != null) {
return alreadyExistingState;
}
return specializationStates.computeIfAbsent(specializationId, id -> new WindowedLoopIterationCounter<>(this, id));
}
T specialize(ImmutableMap, Class>> classRemapping)
{
String specializedClassName = specializedClassNamePrefix + SPECIALIZED_CLASS_COUNTER.get();
ClassWriter specializedClassWriter = new ClassWriter(0);
SimpleRemapper remapper = new SimpleRemapper(createRemapping(classRemapping, specializedClassName));
ClassVisitor classTransformer = new ClassRemapper(specializedClassWriter, remapper);
try {
ClassReader prototypeClassReader = new ClassReader(getPrototypeClassBytecode());
prototypeClassReader.accept(classTransformer, 0);
byte[] specializedClassBytecode = specializedClassWriter.toByteArray();
@SuppressWarnings("unchecked")
Class specializedClass = (Class) DefineClassUtils.defineClass(
prototypeClass,
specializedClassBytecode,
specializedClassName
);
SPECIALIZED_CLASS_COUNTER.incrementAndGet();
return specializedClass.newInstance();
}
catch (InstantiationException | IllegalAccessException | IOException e) {
throw new RuntimeException(e);
}
}
private HashMap createRemapping(
ImmutableMap, Class>> classRemapping,
String specializedClassName
)
{
HashMap remapping = new HashMap<>();
remapping.put(prototypeClassBytecodeName, classBytecodeName(specializedClassName));
for (Map.Entry, Class>> classRemappingEntry : classRemapping.entrySet()) {
Class> sourceClass = classRemappingEntry.getKey();
Class> remappingClass = classRemappingEntry.getValue();
remapping.put(classBytecodeName(sourceClass.getName()), classBytecodeName(remappingClass.getName()));
}
return remapping;
}
/**
* No synchronization, because {@link #specialize} is called only from {@link #CLASS_SPECIALIZATION_EXECUTOR}, i. e.
* from a single thread.
*/
byte[] getPrototypeClassBytecode() throws IOException
{
if (prototypeClassBytecode == null) {
ClassLoader cl = prototypeClass.getClassLoader();
try (InputStream prototypeClassBytecodeStream =
cl.getResourceAsStream(prototypeClassBytecodeName + ".class")) {
prototypeClassBytecode = ByteStreams.toByteArray(prototypeClassBytecodeStream);
}
}
return prototypeClassBytecode;
}
private static String classBytecodeName(String className)
{
return className.replace('.', '/');
}
}
private static class SpecializationId
{
private final String runtimeShape;
private final ImmutableMap, Class>> classRemapping;
private final int hashCode;
private SpecializationId(String runtimeShape, ImmutableMap, Class>> classRemapping)
{
this.runtimeShape = runtimeShape;
this.classRemapping = classRemapping;
this.hashCode = runtimeShape.hashCode() * 1000003 + classRemapping.hashCode();
}
@Override
public boolean equals(Object obj)
{
if (!(obj instanceof SpecializationId)) {
return false;
}
SpecializationId other = (SpecializationId) obj;
return runtimeShape.equals(other.runtimeShape) && classRemapping.equals(other.classRemapping);
}
@Override
public int hashCode()
{
return hashCode;
}
}
/**
* Accumulates the number of iterations during the last hour. (Window size = 1 hour)
*/
static class WindowedLoopIterationCounter extends SpecializationState implements Runnable
{
private final PerPrototypeClassState perPrototypeClassState;
private final SpecializationId specializationId;
/** A map with the number of iterations per each minute during the last hour */
private final ConcurrentHashMap perMinuteIterations = new ConcurrentHashMap<>();
private final AtomicBoolean specializationScheduled = new AtomicBoolean(false);
WindowedLoopIterationCounter(
PerPrototypeClassState perPrototypeClassState,
SpecializationId specializationId
)
{
this.perPrototypeClassState = perPrototypeClassState;
this.specializationId = specializationId;
}
@Nullable
@Override
public T getSpecialized()
{
// Returns null because the class is not yet specialized. The purpose of WindowedLoopIterationCounter is to decide
// whether specialization should be done, or not.
return null;
}
@Override
public void accountLoopIterations(long loopIterations)
{
if (specializationScheduled.get()) {
return;
}
if (loopIterations > TRIGGER_SPECIALIZATION_ITERATIONS_THRESHOLD ||
addAndGetTotalIterationsOverTheLastHour(loopIterations) > TRIGGER_SPECIALIZATION_ITERATIONS_THRESHOLD) {
if (specializationScheduled.compareAndSet(false, true)) {
CLASS_SPECIALIZATION_EXECUTOR.submit(this);
}
}
}
private long addAndGetTotalIterationsOverTheLastHour(long newIterations)
{
long currentMillis = System.currentTimeMillis();
long currentMinute = TimeUnit.MILLISECONDS.toMinutes(currentMillis);
long minuteOneHourAgo = currentMinute - TimeUnit.HOURS.toMinutes(1);
long totalIterations = 0;
boolean currentMinutePresent = false;
for (Iterator> it = perMinuteIterations.entrySet().iterator(); it.hasNext(); ) {
Map.Entry minuteStats = it.next();
long minute = minuteStats.getKey();
if (minute < minuteOneHourAgo) {
it.remove();
} else if (minute == currentMinute) {
totalIterations += minuteStats.getValue().addAndGet(newIterations);
currentMinutePresent = true;
} else {
totalIterations += minuteStats.getValue().get();
}
}
if (!currentMinutePresent) {
perMinuteIterations.computeIfAbsent(currentMinute, m -> new AtomicLong()).addAndGet(newIterations);
totalIterations += newIterations;
}
return totalIterations;
}
@Override
public void run()
{
try {
T specialized;
if (SPECIALIZED_CLASS_COUNTER.get() > MAX_SPECIALIZATIONS) {
// Don't specialize, just instantiate the prototype class and emit a warning.
// The "better" approach is probably to implement some kind of cache eviction from
// PerPrototypeClassState.specializationStates. But it might be that nobody ever hits even the current
// maxSpecializations limit, so implementing cache eviction is an unnecessary complexity.
specialized = perPrototypeClassState.prototypeClass.newInstance();
if (!MAX_SPECIALIZATIONS_WARNING_EMITTED.get() && MAX_SPECIALIZATIONS_WARNING_EMITTED.compareAndSet(false, true)) {
LOG.warn(
"SpecializationService couldn't make more than [%d] specializations. " +
"Not doing specialization for runtime shape[%s] and class remapping[%s], using the prototype class[%s]",
MAX_SPECIALIZATIONS,
specializationId.runtimeShape,
specializationId.classRemapping,
perPrototypeClassState.prototypeClass
);
}
} else if (FAKE_SPECIALIZE) {
specialized = perPrototypeClassState.prototypeClass.newInstance();
LOG.info(
"Not specializing prototype class[%s] for runtime shape[%s] and class remapping[%s] because "
+ "fakeSpecialize=true, using the prototype class instead",
perPrototypeClassState.prototypeClass,
specializationId.runtimeShape,
specializationId.classRemapping
);
} else {
specialized = perPrototypeClassState.specialize(specializationId.classRemapping);
LOG.info(
"Specializing prototype class[%s] for runtime shape[%s] and class remapping[%s]",
perPrototypeClassState.prototypeClass,
specializationId.runtimeShape,
specializationId.classRemapping
);
}
perPrototypeClassState.specializationStates.put(specializationId, new Specialized<>(specialized));
}
catch (Exception e) {
LOG.error(
e,
"Error specializing prototype class[%s] for runtime shape[%s] and class remapping[%s]",
perPrototypeClassState.prototypeClass,
specializationId.runtimeShape,
specializationId.classRemapping
);
}
}
}
static class Specialized extends SpecializationState
{
private final T specialized;
Specialized(T specialized)
{
this.specialized = specialized;
}
@Override
public T getSpecialized()
{
return specialized;
}
@Override
public void accountLoopIterations(long loopIterations)
{
// do nothing
}
}
private SpecializationService()
{
}
}