org.apache.flink.python.util.ProtoUtils Maven / Gradle / Ivy
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.python.util;
import org.apache.flink.annotation.Internal;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.common.state.StateTtlConfig;
import org.apache.flink.api.common.time.Time;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.fnexecution.v1.FlinkFnApi;
import org.apache.flink.streaming.api.functions.python.DataStreamPythonFunctionInfo;
import org.apache.flink.streaming.api.utils.PythonTypeUtils;
import org.apache.flink.table.functions.python.PythonAggregateFunctionInfo;
import org.apache.flink.table.functions.python.PythonFunctionInfo;
import org.apache.flink.table.functions.python.PythonFunctionKind;
import org.apache.flink.table.runtime.dataview.DataViewSpec;
import org.apache.flink.table.runtime.dataview.ListViewSpec;
import org.apache.flink.table.runtime.dataview.MapViewSpec;
import org.apache.flink.table.types.logical.RowType;
import org.apache.flink.util.Preconditions;
import com.google.protobuf.ByteString;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import static org.apache.flink.python.Constants.FLINK_CODER_URN;
import static org.apache.flink.table.runtime.typeutils.PythonTypeUtils.toProtoType;
/** Utilities used to construct protobuf objects or construct objects from protobuf objects. */
@Internal
public enum ProtoUtils {
;
public static RunnerApi.Coder createCoderProto(
FlinkFnApi.CoderInfoDescriptor coderInfoDescriptor) {
return RunnerApi.Coder.newBuilder()
.setSpec(
RunnerApi.FunctionSpec.newBuilder()
.setUrn(FLINK_CODER_URN)
.setPayload(
org.apache.beam.vendor.grpc.v1p48p1.com.google.protobuf
.ByteString.copyFrom(
coderInfoDescriptor.toByteArray()))
.build())
.build();
}
// ------------------------------------------------------------------------
// Table API related utilities
// ------------------------------------------------------------------------
// coder utilities
public static FlinkFnApi.CoderInfoDescriptor createRowTypeCoderInfoDescriptorProto(
RowType rowType,
FlinkFnApi.CoderInfoDescriptor.Mode mode,
boolean separatedWithEndMessage) {
return createCoderInfoDescriptorProto(
null,
FlinkFnApi.CoderInfoDescriptor.RowType.newBuilder()
.setSchema(toProtoType(rowType).getRowSchema())
.build(),
null,
null,
null,
mode,
separatedWithEndMessage);
}
public static FlinkFnApi.CoderInfoDescriptor createFlattenRowTypeCoderInfoDescriptorProto(
RowType rowType,
FlinkFnApi.CoderInfoDescriptor.Mode mode,
boolean separatedWithEndMessage) {
FlinkFnApi.CoderInfoDescriptor.FlattenRowType flattenRowType =
FlinkFnApi.CoderInfoDescriptor.FlattenRowType.newBuilder()
.setSchema(toProtoType(rowType).getRowSchema())
.build();
return createCoderInfoDescriptorProto(
flattenRowType, null, null, null, null, mode, separatedWithEndMessage);
}
public static FlinkFnApi.CoderInfoDescriptor createArrowTypeCoderInfoDescriptorProto(
RowType rowType,
FlinkFnApi.CoderInfoDescriptor.Mode mode,
boolean separatedWithEndMessage) {
return createCoderInfoDescriptorProto(
null,
null,
FlinkFnApi.CoderInfoDescriptor.ArrowType.newBuilder()
.setSchema(toProtoType(rowType).getRowSchema())
.build(),
null,
null,
mode,
separatedWithEndMessage);
}
public static FlinkFnApi.CoderInfoDescriptor createOverWindowArrowTypeCoderInfoDescriptorProto(
RowType rowType,
FlinkFnApi.CoderInfoDescriptor.Mode mode,
boolean separatedWithEndMessage) {
return createCoderInfoDescriptorProto(
null,
null,
null,
FlinkFnApi.CoderInfoDescriptor.OverWindowArrowType.newBuilder()
.setSchema(toProtoType(rowType).getRowSchema())
.build(),
null,
mode,
separatedWithEndMessage);
}
// function utilities
public static FlinkFnApi.UserDefinedFunctions createUserDefinedFunctionsProto(
RuntimeContext runtimeContext,
PythonFunctionInfo[] userDefinedFunctions,
boolean isMetricEnabled,
boolean isProfileEnabled) {
FlinkFnApi.UserDefinedFunctions.Builder builder =
FlinkFnApi.UserDefinedFunctions.newBuilder();
for (PythonFunctionInfo userDefinedFunction : userDefinedFunctions) {
builder.addUdfs(createUserDefinedFunctionProto(userDefinedFunction));
}
builder.setMetricEnabled(isMetricEnabled);
builder.setProfileEnabled(isProfileEnabled);
builder.addAllJobParameters(
runtimeContext.getGlobalJobParameters().entrySet().stream()
.map(
entry ->
FlinkFnApi.JobParameter.newBuilder()
.setKey(entry.getKey())
.setValue(entry.getValue())
.build())
.collect(Collectors.toList()));
return builder.build();
}
public static FlinkFnApi.UserDefinedFunction createUserDefinedFunctionProto(
PythonFunctionInfo pythonFunctionInfo) {
FlinkFnApi.UserDefinedFunction.Builder builder =
FlinkFnApi.UserDefinedFunction.newBuilder();
builder.setPayload(
ByteString.copyFrom(
pythonFunctionInfo.getPythonFunction().getSerializedPythonFunction()));
for (Object input : pythonFunctionInfo.getInputs()) {
FlinkFnApi.Input.Builder inputProto = FlinkFnApi.Input.newBuilder();
if (input instanceof PythonFunctionInfo) {
inputProto.setUdf(createUserDefinedFunctionProto((PythonFunctionInfo) input));
} else if (input instanceof Integer) {
inputProto.setInputOffset((Integer) input);
} else {
inputProto.setInputConstant(ByteString.copyFrom((byte[]) input));
}
builder.addInputs(inputProto);
}
builder.setTakesRowAsInput(pythonFunctionInfo.getPythonFunction().takesRowAsInput());
builder.setIsPandasUdf(
pythonFunctionInfo.getPythonFunction().getPythonFunctionKind()
== PythonFunctionKind.PANDAS);
return builder.build();
}
public static FlinkFnApi.UserDefinedAggregateFunction createUserDefinedAggregateFunctionProto(
PythonAggregateFunctionInfo pythonFunctionInfo, DataViewSpec[] dataViewSpecs) {
FlinkFnApi.UserDefinedAggregateFunction.Builder builder =
FlinkFnApi.UserDefinedAggregateFunction.newBuilder();
builder.setPayload(
ByteString.copyFrom(
pythonFunctionInfo.getPythonFunction().getSerializedPythonFunction()));
builder.setDistinct(pythonFunctionInfo.isDistinct());
builder.setFilterArg(pythonFunctionInfo.getFilterArg());
builder.setTakesRowAsInput(pythonFunctionInfo.getPythonFunction().takesRowAsInput());
for (Object input : pythonFunctionInfo.getInputs()) {
FlinkFnApi.Input.Builder inputProto = FlinkFnApi.Input.newBuilder();
if (input instanceof Integer) {
inputProto.setInputOffset((Integer) input);
} else {
inputProto.setInputConstant(ByteString.copyFrom((byte[]) input));
}
builder.addInputs(inputProto);
}
if (dataViewSpecs != null) {
for (DataViewSpec spec : dataViewSpecs) {
FlinkFnApi.UserDefinedAggregateFunction.DataViewSpec.Builder specBuilder =
FlinkFnApi.UserDefinedAggregateFunction.DataViewSpec.newBuilder();
specBuilder.setName(spec.getStateId());
if (spec instanceof ListViewSpec) {
ListViewSpec listViewSpec = (ListViewSpec) spec;
specBuilder.setListView(
FlinkFnApi.UserDefinedAggregateFunction.DataViewSpec.ListView
.newBuilder()
.setElementType(
toProtoType(
listViewSpec
.getElementDataType()
.getLogicalType())));
} else {
MapViewSpec mapViewSpec = (MapViewSpec) spec;
FlinkFnApi.UserDefinedAggregateFunction.DataViewSpec.MapView.Builder
mapViewBuilder =
FlinkFnApi.UserDefinedAggregateFunction.DataViewSpec.MapView
.newBuilder();
mapViewBuilder.setKeyType(
toProtoType(mapViewSpec.getKeyDataType().getLogicalType()));
mapViewBuilder.setValueType(
toProtoType(mapViewSpec.getValueDataType().getLogicalType()));
specBuilder.setMapView(mapViewBuilder.build());
}
specBuilder.setFieldIndex(spec.getFieldIndex());
builder.addSpecs(specBuilder.build());
}
}
return builder.build();
}
// ------------------------------------------------------------------------
// DataStream API related utilities
// ------------------------------------------------------------------------
public static FlinkFnApi.UserDefinedDataStreamFunction createUserDefinedDataStreamFunctionProto(
DataStreamPythonFunctionInfo dataStreamPythonFunctionInfo,
RuntimeContext runtimeContext,
Map internalParameters,
boolean inBatchExecutionMode,
boolean isMetricEnabled,
boolean isProfileEnabled,
boolean hasSideOutput,
int stateCacheSize,
int mapStateReadCacheSize,
int mapStateWriteCacheSize) {
FlinkFnApi.UserDefinedDataStreamFunction.Builder builder =
FlinkFnApi.UserDefinedDataStreamFunction.newBuilder();
builder.setFunctionType(
FlinkFnApi.UserDefinedDataStreamFunction.FunctionType.forNumber(
dataStreamPythonFunctionInfo.getFunctionType()));
builder.setRuntimeContext(
FlinkFnApi.UserDefinedDataStreamFunction.RuntimeContext.newBuilder()
.setTaskName(runtimeContext.getTaskInfo().getTaskName())
.setTaskNameWithSubtasks(
runtimeContext.getTaskInfo().getTaskNameWithSubtasks())
.setNumberOfParallelSubtasks(
runtimeContext.getTaskInfo().getNumberOfParallelSubtasks())
.setMaxNumberOfParallelSubtasks(
runtimeContext.getTaskInfo().getMaxNumberOfParallelSubtasks())
.setIndexOfThisSubtask(runtimeContext.getTaskInfo().getIndexOfThisSubtask())
.setAttemptNumber(runtimeContext.getTaskInfo().getAttemptNumber())
.addAllJobParameters(
runtimeContext.getGlobalJobParameters().entrySet().stream()
.map(
entry ->
FlinkFnApi.JobParameter.newBuilder()
.setKey(entry.getKey())
.setValue(entry.getValue())
.build())
.collect(Collectors.toList()))
.addAllJobParameters(
internalParameters.entrySet().stream()
.map(
entry ->
FlinkFnApi.JobParameter.newBuilder()
.setKey(entry.getKey())
.setValue(entry.getValue())
.build())
.collect(Collectors.toList()))
.setInBatchExecutionMode(inBatchExecutionMode)
.build());
builder.setPayload(
ByteString.copyFrom(
dataStreamPythonFunctionInfo
.getPythonFunction()
.getSerializedPythonFunction()));
builder.setMetricEnabled(isMetricEnabled);
builder.setProfileEnabled(isProfileEnabled);
builder.setHasSideOutput(hasSideOutput);
builder.setStateCacheSize(stateCacheSize);
builder.setMapStateReadCacheSize(mapStateReadCacheSize);
builder.setMapStateWriteCacheSize(mapStateWriteCacheSize);
return builder.build();
}
public static FlinkFnApi.UserDefinedDataStreamFunction
createReviseOutputDataStreamFunctionProto() {
return FlinkFnApi.UserDefinedDataStreamFunction.newBuilder()
.setFunctionType(
FlinkFnApi.UserDefinedDataStreamFunction.FunctionType.REVISE_OUTPUT)
.build();
}
public static List
createUserDefinedDataStreamFunctionProtos(
DataStreamPythonFunctionInfo dataStreamPythonFunctionInfo,
RuntimeContext runtimeContext,
Map internalParameters,
boolean inBatchExecutionMode,
boolean isMetricEnabled,
boolean isProfileEnabled,
boolean hasSideOutput,
int stateCacheSize,
int mapStateReadCacheSize,
int mapStateWriteCacheSize) {
List results = new ArrayList<>();
Object[] inputs = dataStreamPythonFunctionInfo.getInputs();
if (inputs != null && inputs.length > 0) {
Preconditions.checkArgument(inputs.length == 1);
results.addAll(
createUserDefinedDataStreamFunctionProtos(
(DataStreamPythonFunctionInfo) inputs[0],
runtimeContext,
internalParameters,
inBatchExecutionMode,
isMetricEnabled,
isProfileEnabled,
false,
stateCacheSize,
mapStateReadCacheSize,
mapStateWriteCacheSize));
}
results.add(
createUserDefinedDataStreamFunctionProto(
dataStreamPythonFunctionInfo,
runtimeContext,
internalParameters,
inBatchExecutionMode,
isMetricEnabled,
isProfileEnabled,
hasSideOutput,
stateCacheSize,
mapStateReadCacheSize,
mapStateWriteCacheSize));
return results;
}
public static List
createUserDefinedDataStreamStatefulFunctionProtos(
DataStreamPythonFunctionInfo dataStreamPythonFunctionInfo,
RuntimeContext runtimeContext,
Map internalParameters,
TypeInformation> keyTypeInfo,
boolean inBatchExecutionMode,
boolean isMetricEnabled,
boolean isProfileEnabled,
boolean hasSideOutput,
int stateCacheSize,
int mapStateReadCacheSize,
int mapStateWriteCacheSize) {
List results =
createUserDefinedDataStreamFunctionProtos(
dataStreamPythonFunctionInfo,
runtimeContext,
internalParameters,
inBatchExecutionMode,
isMetricEnabled,
isProfileEnabled,
hasSideOutput,
stateCacheSize,
mapStateReadCacheSize,
mapStateWriteCacheSize);
// set the key typeinfo for the head operator
FlinkFnApi.TypeInfo builtKeyTypeInfo =
PythonTypeUtils.TypeInfoToProtoConverter.toTypeInfoProto(
keyTypeInfo, runtimeContext.getUserCodeClassLoader());
results.set(0, results.get(0).toBuilder().setKeyTypeInfo(builtKeyTypeInfo).build());
return results;
}
public static FlinkFnApi.CoderInfoDescriptor createRawTypeCoderInfoDescriptorProto(
TypeInformation> typeInformation,
FlinkFnApi.CoderInfoDescriptor.Mode mode,
boolean separatedWithEndMessage,
ClassLoader userCodeClassLoader) {
FlinkFnApi.TypeInfo typeinfo =
PythonTypeUtils.TypeInfoToProtoConverter.toTypeInfoProto(
typeInformation, userCodeClassLoader);
return createCoderInfoDescriptorProto(
null,
null,
null,
null,
FlinkFnApi.CoderInfoDescriptor.RawType.newBuilder().setTypeInfo(typeinfo).build(),
mode,
separatedWithEndMessage);
}
private static FlinkFnApi.CoderInfoDescriptor createCoderInfoDescriptorProto(
FlinkFnApi.CoderInfoDescriptor.FlattenRowType flattenRowType,
FlinkFnApi.CoderInfoDescriptor.RowType rowType,
FlinkFnApi.CoderInfoDescriptor.ArrowType arrowType,
FlinkFnApi.CoderInfoDescriptor.OverWindowArrowType overWindowArrowType,
FlinkFnApi.CoderInfoDescriptor.RawType rawType,
FlinkFnApi.CoderInfoDescriptor.Mode mode,
boolean separatedWithEndMessage) {
FlinkFnApi.CoderInfoDescriptor.Builder builder =
FlinkFnApi.CoderInfoDescriptor.newBuilder();
if (flattenRowType != null) {
builder.setFlattenRowType(flattenRowType);
} else if (rowType != null) {
builder.setRowType(rowType);
} else if (arrowType != null) {
builder.setArrowType(arrowType);
} else if (overWindowArrowType != null) {
builder.setOverWindowArrowType(overWindowArrowType);
} else if (rawType != null) {
builder.setRawType(rawType);
}
builder.setMode(mode);
builder.setSeparatedWithEndMessage(separatedWithEndMessage);
return builder.build();
}
// ------------------------------------------------------------------------
// State related utilities
// ------------------------------------------------------------------------
public static StateTtlConfig parseStateTtlConfigFromProto(
FlinkFnApi.StateDescriptor.StateTTLConfig stateTTLConfigProto) {
StateTtlConfig.Builder builder =
StateTtlConfig.newBuilder(Time.milliseconds(stateTTLConfigProto.getTtl()))
.setUpdateType(
parseUpdateTypeFromProto(stateTTLConfigProto.getUpdateType()))
.setStateVisibility(
parseStateVisibilityFromProto(
stateTTLConfigProto.getStateVisibility()))
.setTtlTimeCharacteristic(
parseTtlTimeCharacteristicFromProto(
stateTTLConfigProto.getTtlTimeCharacteristic()));
FlinkFnApi.StateDescriptor.StateTTLConfig.CleanupStrategies cleanupStrategiesProto =
stateTTLConfigProto.getCleanupStrategies();
if (!cleanupStrategiesProto.getIsCleanupInBackground()) {
builder.disableCleanupInBackground();
}
for (FlinkFnApi.StateDescriptor.StateTTLConfig.CleanupStrategies.MapStrategiesEntry
mapStrategiesEntry : cleanupStrategiesProto.getStrategiesList()) {
FlinkFnApi.StateDescriptor.StateTTLConfig.CleanupStrategies.Strategies strategyProto =
mapStrategiesEntry.getStrategy();
if (strategyProto
== FlinkFnApi.StateDescriptor.StateTTLConfig.CleanupStrategies.Strategies
.FULL_STATE_SCAN_SNAPSHOT) {
builder.cleanupFullSnapshot();
} else if (strategyProto
== FlinkFnApi.StateDescriptor.StateTTLConfig.CleanupStrategies.Strategies
.INCREMENTAL_CLEANUP) {
FlinkFnApi.StateDescriptor.StateTTLConfig.CleanupStrategies
.IncrementalCleanupStrategy
incrementalCleanupStrategyProto =
mapStrategiesEntry.getIncrementalCleanupStrategy();
builder.cleanupIncrementally(
incrementalCleanupStrategyProto.getCleanupSize(),
incrementalCleanupStrategyProto.getRunCleanupForEveryRecord());
} else if (strategyProto
== FlinkFnApi.StateDescriptor.StateTTLConfig.CleanupStrategies.Strategies
.ROCKSDB_COMPACTION_FILTER) {
FlinkFnApi.StateDescriptor.StateTTLConfig.CleanupStrategies
.RocksdbCompactFilterCleanupStrategy
rocksdbCompactFilterCleanupStrategyProto =
mapStrategiesEntry.getRocksdbCompactFilterCleanupStrategy();
builder.cleanupInRocksdbCompactFilter(
rocksdbCompactFilterCleanupStrategyProto.getQueryTimeAfterNumEntries());
}
}
return builder.build();
}
private static StateTtlConfig.UpdateType parseUpdateTypeFromProto(
FlinkFnApi.StateDescriptor.StateTTLConfig.UpdateType updateType) {
if (updateType == FlinkFnApi.StateDescriptor.StateTTLConfig.UpdateType.Disabled) {
return StateTtlConfig.UpdateType.Disabled;
} else if (updateType
== FlinkFnApi.StateDescriptor.StateTTLConfig.UpdateType.OnCreateAndWrite) {
return StateTtlConfig.UpdateType.OnCreateAndWrite;
} else if (updateType
== FlinkFnApi.StateDescriptor.StateTTLConfig.UpdateType.OnReadAndWrite) {
return StateTtlConfig.UpdateType.OnReadAndWrite;
}
throw new RuntimeException("Unknown UpdateType " + updateType);
}
private static StateTtlConfig.StateVisibility parseStateVisibilityFromProto(
FlinkFnApi.StateDescriptor.StateTTLConfig.StateVisibility stateVisibility) {
if (stateVisibility
== FlinkFnApi.StateDescriptor.StateTTLConfig.StateVisibility
.ReturnExpiredIfNotCleanedUp) {
return StateTtlConfig.StateVisibility.ReturnExpiredIfNotCleanedUp;
} else if (stateVisibility
== FlinkFnApi.StateDescriptor.StateTTLConfig.StateVisibility.NeverReturnExpired) {
return StateTtlConfig.StateVisibility.NeverReturnExpired;
}
throw new RuntimeException("Unknown StateVisibility " + stateVisibility);
}
private static StateTtlConfig.TtlTimeCharacteristic parseTtlTimeCharacteristicFromProto(
FlinkFnApi.StateDescriptor.StateTTLConfig.TtlTimeCharacteristic ttlTimeCharacteristic) {
if (ttlTimeCharacteristic
== FlinkFnApi.StateDescriptor.StateTTLConfig.TtlTimeCharacteristic.ProcessingTime) {
return StateTtlConfig.TtlTimeCharacteristic.ProcessingTime;
}
throw new RuntimeException("Unknown TtlTimeCharacteristic " + ttlTimeCharacteristic);
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy