org.apache.iceberg.spark.source.SparkBatchQueryScan Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of iceberg-spark-3.5_2.13 Show documentation
Show all versions of iceberg-spark-3.5_2.13 Show documentation
A table format for huge analytic datasets
The newest version!
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.iceberg.spark.source;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import org.apache.iceberg.PartitionField;
import org.apache.iceberg.PartitionScanTask;
import org.apache.iceberg.PartitionSpec;
import org.apache.iceberg.Scan;
import org.apache.iceberg.ScanTask;
import org.apache.iceberg.ScanTaskGroup;
import org.apache.iceberg.Schema;
import org.apache.iceberg.Snapshot;
import org.apache.iceberg.Table;
import org.apache.iceberg.exceptions.ValidationException;
import org.apache.iceberg.expressions.Binder;
import org.apache.iceberg.expressions.Evaluator;
import org.apache.iceberg.expressions.Expression;
import org.apache.iceberg.expressions.ExpressionUtil;
import org.apache.iceberg.expressions.Expressions;
import org.apache.iceberg.expressions.Projections;
import org.apache.iceberg.metrics.ScanReport;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
import org.apache.iceberg.relocated.com.google.common.collect.Maps;
import org.apache.iceberg.relocated.com.google.common.collect.Sets;
import org.apache.iceberg.spark.Spark3Util;
import org.apache.iceberg.spark.SparkReadConf;
import org.apache.iceberg.spark.SparkSchemaUtil;
import org.apache.iceberg.spark.SparkV2Filters;
import org.apache.iceberg.util.SnapshotUtil;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.connector.expressions.NamedReference;
import org.apache.spark.sql.connector.expressions.filter.Predicate;
import org.apache.spark.sql.connector.read.Statistics;
import org.apache.spark.sql.connector.read.SupportsRuntimeV2Filtering;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
class SparkBatchQueryScan extends SparkPartitioningAwareScan
implements SupportsRuntimeV2Filtering {
private static final Logger LOG = LoggerFactory.getLogger(SparkBatchQueryScan.class);
private final Long snapshotId;
private final Long startSnapshotId;
private final Long endSnapshotId;
private final Long asOfTimestamp;
private final String tag;
private final List runtimeFilterExpressions;
SparkBatchQueryScan(
SparkSession spark,
Table table,
Scan, ? extends ScanTask, ? extends ScanTaskGroup>> scan,
SparkReadConf readConf,
Schema expectedSchema,
List filters,
Supplier scanReportSupplier) {
super(spark, table, scan, readConf, expectedSchema, filters, scanReportSupplier);
this.snapshotId = readConf.snapshotId();
this.startSnapshotId = readConf.startSnapshotId();
this.endSnapshotId = readConf.endSnapshotId();
this.asOfTimestamp = readConf.asOfTimestamp();
this.tag = readConf.tag();
this.runtimeFilterExpressions = Lists.newArrayList();
}
Long snapshotId() {
return snapshotId;
}
@Override
protected Class taskJavaClass() {
return PartitionScanTask.class;
}
@Override
public NamedReference[] filterAttributes() {
Set partitionFieldSourceIds = Sets.newHashSet();
for (PartitionSpec spec : specs()) {
for (PartitionField field : spec.fields()) {
partitionFieldSourceIds.add(field.sourceId());
}
}
Map quotedNameById = SparkSchemaUtil.indexQuotedNameById(expectedSchema());
// the optimizer will look for an equality condition with filter attributes in a join
// as the scan has been already planned, filtering can only be done on projected attributes
// that's why only partition source fields that are part of the read schema can be reported
return partitionFieldSourceIds.stream()
.filter(fieldId -> expectedSchema().findField(fieldId) != null)
.map(fieldId -> Spark3Util.toNamedReference(quotedNameById.get(fieldId)))
.toArray(NamedReference[]::new);
}
@Override
public void filter(Predicate[] predicates) {
Expression runtimeFilterExpr = convertRuntimeFilters(predicates);
if (runtimeFilterExpr != Expressions.alwaysTrue()) {
Map evaluatorsBySpecId = Maps.newHashMap();
for (PartitionSpec spec : specs()) {
Expression inclusiveExpr =
Projections.inclusive(spec, caseSensitive()).project(runtimeFilterExpr);
Evaluator inclusive = new Evaluator(spec.partitionType(), inclusiveExpr);
evaluatorsBySpecId.put(spec.specId(), inclusive);
}
List filteredTasks =
tasks().stream()
.filter(
task -> {
Evaluator evaluator = evaluatorsBySpecId.get(task.spec().specId());
return evaluator.eval(task.partition());
})
.collect(Collectors.toList());
LOG.info(
"{} of {} task(s) for table {} matched runtime filter {}",
filteredTasks.size(),
tasks().size(),
table().name(),
ExpressionUtil.toSanitizedString(runtimeFilterExpr));
// don't invalidate tasks if the runtime filter had no effect to avoid planning splits again
if (filteredTasks.size() < tasks().size()) {
resetTasks(filteredTasks);
}
// save the evaluated filter for equals/hashCode
runtimeFilterExpressions.add(runtimeFilterExpr);
}
}
// at this moment, Spark can only pass IN filters for a single attribute
// if there are multiple filter attributes, Spark will pass two separate IN filters
private Expression convertRuntimeFilters(Predicate[] predicates) {
Expression runtimeFilterExpr = Expressions.alwaysTrue();
for (Predicate predicate : predicates) {
Expression expr = SparkV2Filters.convert(predicate);
if (expr != null) {
try {
Binder.bind(expectedSchema().asStruct(), expr, caseSensitive());
runtimeFilterExpr = Expressions.and(runtimeFilterExpr, expr);
} catch (ValidationException e) {
LOG.warn("Failed to bind {} to expected schema, skipping runtime filter", expr, e);
}
} else {
LOG.warn("Unsupported runtime filter {}", predicate);
}
}
return runtimeFilterExpr;
}
@Override
public Statistics estimateStatistics() {
if (scan() == null) {
return estimateStatistics(null);
} else if (snapshotId != null) {
Snapshot snapshot = table().snapshot(snapshotId);
return estimateStatistics(snapshot);
} else if (asOfTimestamp != null) {
long snapshotIdAsOfTime = SnapshotUtil.snapshotIdAsOfTime(table(), asOfTimestamp);
Snapshot snapshot = table().snapshot(snapshotIdAsOfTime);
return estimateStatistics(snapshot);
} else if (branch() != null) {
Snapshot snapshot = table().snapshot(branch());
return estimateStatistics(snapshot);
} else if (tag != null) {
Snapshot snapshot = table().snapshot(tag);
return estimateStatistics(snapshot);
} else {
Snapshot snapshot = table().currentSnapshot();
return estimateStatistics(snapshot);
}
}
@Override
@SuppressWarnings("checkstyle:CyclomaticComplexity")
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
SparkBatchQueryScan that = (SparkBatchQueryScan) o;
return table().name().equals(that.table().name())
&& Objects.equals(branch(), that.branch())
&& readSchema().equals(that.readSchema()) // compare Spark schemas to ignore field ids
&& filterExpressions().toString().equals(that.filterExpressions().toString())
&& runtimeFilterExpressions.toString().equals(that.runtimeFilterExpressions.toString())
&& Objects.equals(snapshotId, that.snapshotId)
&& Objects.equals(startSnapshotId, that.startSnapshotId)
&& Objects.equals(endSnapshotId, that.endSnapshotId)
&& Objects.equals(asOfTimestamp, that.asOfTimestamp)
&& Objects.equals(tag, that.tag);
}
@Override
public int hashCode() {
return Objects.hash(
table().name(),
branch(),
readSchema(),
filterExpressions().toString(),
runtimeFilterExpressions.toString(),
snapshotId,
startSnapshotId,
endSnapshotId,
asOfTimestamp,
tag);
}
@Override
public String toString() {
return String.format(
"IcebergScan(table=%s, branch=%s, type=%s, filters=%s, runtimeFilters=%s, caseSensitive=%s)",
table(),
branch(),
expectedSchema().asStruct(),
filterExpressions(),
runtimeFilterExpressions,
caseSensitive());
}
}