Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
io.trino.verifier.Verifier Maven / Gradle / Ivy
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.verifier;
import com.google.common.base.Joiner;
import com.google.common.collect.ImmutableSet;
import io.airlift.event.client.EventClient;
import io.airlift.log.Logger;
import io.airlift.units.Duration;
import io.trino.spi.ErrorCode;
import io.trino.spi.TrinoException;
import javax.annotation.Nullable;
import java.io.Closeable;
import java.io.IOException;
import java.io.PrintStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.OptionalDouble;
import java.util.Set;
import java.util.concurrent.CompletionService;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorCompletionService;
import java.util.concurrent.ExecutorService;
import java.util.function.Function;
import java.util.regex.Pattern;
import java.util.stream.Stream;
import static com.google.common.base.Throwables.getStackTraceAsString;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.Streams.concat;
import static io.trino.spi.StandardErrorCode.PAGE_TRANSPORT_TIMEOUT;
import static io.trino.spi.StandardErrorCode.REMOTE_TASK_MISMATCH;
import static io.trino.spi.StandardErrorCode.TOO_MANY_REQUESTS_FAILED;
import static io.trino.verifier.QueryResult.State.SUCCESS;
import static java.lang.String.format;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.nio.file.Files.createDirectories;
import static java.util.Locale.ENGLISH;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.Executors.newFixedThreadPool;
import static java.util.concurrent.TimeUnit.SECONDS;
public class Verifier
{
private static final Logger log = Logger.get(Verifier.class);
private static final Set EXPECTED_ERRORS = ImmutableSet.builder()
.add(REMOTE_TASK_MISMATCH.toErrorCode())
.add(TOO_MANY_REQUESTS_FAILED.toErrorCode())
.add(PAGE_TRANSPORT_TIMEOUT.toErrorCode())
.build();
@Deprecated // TODO do not keep mutable config instance on a field
private final VerifierConfig config;
private final Set eventClients;
private final int threadCount;
private final Set allowedQueries;
private final Set bannedQueries;
private final int precision;
public Verifier(PrintStream out, VerifierConfig config, Set eventClients)
{
requireNonNull(out, "out is null");
this.config = requireNonNull(config, "config is null");
this.eventClients = requireNonNull(eventClients, "eventClients is null");
this.allowedQueries = requireNonNull(config.getAllowedQueries(), "allowedQueries is null");
this.bannedQueries = requireNonNull(config.getBannedQueries(), "bannedQueries is null");
this.threadCount = config.getThreadCount();
this.precision = config.getDoublePrecision();
}
// Returns number of failed queries
public int run(List queries)
throws InterruptedException
{
ExecutorService executor = newFixedThreadPool(threadCount);
CompletionService completionService = new ExecutorCompletionService<>(executor);
int totalQueries = queries.size() * config.getSuiteRepetitions() * config.getQueryRepetitions();
log.info("Total Queries: %d", totalQueries);
log.info("Allowed Queries: %s", Joiner.on(',').join(allowedQueries));
int queriesSubmitted = 0;
for (int i = 0; i < config.getSuiteRepetitions(); i++) {
for (QueryPair query : queries) {
for (int j = 0; j < config.getQueryRepetitions(); j++) {
// If we have allowed queries, only run the tests on those
if (!allowedQueries.isEmpty() && !allowedQueries.contains(query.getName())) {
log.debug("Query %s is not allowed", query.getName());
continue;
}
if (bannedQueries.contains(query.getName())) {
log.debug("Query %s is banned", query.getName());
continue;
}
Validator validator = new Validator(
config.getControlGateway(),
config.getTestGateway(),
config.getControlTimeout(),
config.getTestTimeout(),
config.getMaxRowCount(),
config.isExplainOnly(),
config.getDoublePrecision(),
isCheckCorrectness(query),
config.isCheckDeterminismEnabled(),
config.isVerboseResultsComparison(),
config.getControlTeardownRetries(),
config.getTestTeardownRetries(),
config.getRunTearDownOnResultMismatch(),
config.isSkipControl(),
query);
completionService.submit(validator::valid, validator);
queriesSubmitted++;
}
}
}
log.info("Allowed Queries: %d", queriesSubmitted);
log.info("Skipped Queries: %d", (totalQueries - queriesSubmitted));
log.info("---------------------");
executor.shutdown();
int total = 0;
int valid = 0;
int failed = 0;
int skipped = 0;
double lastProgress = 0;
while (total < queriesSubmitted) {
total++;
Validator validator = takeUnchecked(completionService);
if (validator.isSkipped()) {
if (!config.isQuiet()) {
log.warn("%s", validator.getSkippedMessage());
}
skipped++;
continue;
}
QueryResult controlResult = validator.getControlResult();
if (config.isSimplifiedControlQueriesGenerationEnabled() && controlResult.getState() == SUCCESS) {
QueryPair queryPair = validator.getQueryPair();
Path path = Paths.get(format(
"%s/%s/%s/%s.sql",
config.getSimplifiedControlQueriesOutputDirectory(),
config.getRunId(),
queryPair.getSuite(),
queryPair.getName()));
try {
String content = generateCorrespondingSelect(controlResult.getColumnTypes(), controlResult.getResults());
createDirectories(path.getParent());
Files.write(path, content.getBytes(UTF_8));
}
catch (IOException | RuntimeException e) {
log.error(e, "Failed generating corresponding select statement for expected results for query %s", queryPair.getName());
}
}
if (validator.valid()) {
valid++;
}
else {
failed++;
}
for (EventClient eventClient : eventClients) {
eventClient.post(buildEvent(validator));
}
double progress = (((double) total) / totalQueries) * 100;
if (!config.isQuiet() || (progress - lastProgress) > 1) {
log.info("Progress: %s valid, %s failed, %s skipped, %.2f%% done", valid, failed, skipped, progress);
lastProgress = progress;
}
}
log.info("Results: %s / %s (%s skipped)", valid, failed, skipped);
log.info("");
for (EventClient eventClient : eventClients) {
if (eventClient instanceof Closeable) {
try {
((Closeable) eventClient).close();
}
catch (IOException ignored) {
}
log.info("");
}
}
return failed;
}
private boolean isCheckCorrectness(QueryPair query)
{
// Check if either the control query or the test query matches the regex
if (Pattern.matches(config.getSkipCorrectnessRegex(), query.getTest().getQuery()) ||
Pattern.matches(config.getSkipCorrectnessRegex(), query.getControl().getQuery())) {
// If so disable correctness checking
return false;
}
return config.isCheckCorrectnessEnabled();
}
private VerifierQueryEvent buildEvent(Validator validator)
{
String errorMessage = null;
QueryPair queryPair = validator.getQueryPair();
QueryResult control = validator.getControlResult();
QueryResult test = validator.getTestResult();
if (!validator.valid()) {
errorMessage = format("Test state %s, Control state %s\n", test.getState(), control.getState());
Exception e = test.getException();
if (e != null && shouldAddStackTrace(e)) {
errorMessage += getStackTraceAsString(e);
}
if (control.getState() == SUCCESS && test.getState() == SUCCESS) {
errorMessage += validator.getResultsComparison(precision).trim();
}
}
return new VerifierQueryEvent(
queryPair.getSuite(),
config.getRunId(),
config.getSource(),
queryPair.getName(),
!validator.valid(),
queryPair.getTest().getCatalog(),
queryPair.getTest().getSchema(),
queryPair.getTest().getPreQueries(),
queryPair.getTest().getQuery(),
queryPair.getTest().getPostQueries(),
validator.getTestPreQueryResults().stream()
.map(QueryResult::getQueryId)
.filter(Objects::nonNull)
.collect(toImmutableList()),
test.getQueryId(),
validator.getTestPostQueryResults().stream()
.map(QueryResult::getQueryId)
.filter(Objects::nonNull)
.collect(toImmutableList()),
getTotalDurationInSeconds(validator.getTestPreQueryResults(), validator.getTestResult(), validator.getTestPostQueryResults(), QueryResult::getCpuTime),
getTotalDurationInSeconds(validator.getTestPreQueryResults(), validator.getTestResult(), validator.getTestPostQueryResults(), QueryResult::getWallTime),
queryPair.getControl().getCatalog(),
queryPair.getControl().getSchema(),
queryPair.getControl().getPreQueries(),
queryPair.getControl().getQuery(),
queryPair.getControl().getPostQueries(),
validator.getControlPreQueryResults().stream()
.map(QueryResult::getQueryId)
.filter(Objects::nonNull)
.collect(toImmutableList()),
control.getQueryId(),
validator.getControlPostQueryResults().stream()
.map(QueryResult::getQueryId)
.filter(Objects::nonNull)
.collect(toImmutableList()),
getTotalDurationInSeconds(validator.getControlPreQueryResults(), validator.getControlResult(), validator.getControlPostQueryResults(), QueryResult::getCpuTime),
getTotalDurationInSeconds(validator.getControlPreQueryResults(), validator.getControlResult(), validator.getControlPostQueryResults(), QueryResult::getWallTime),
errorMessage);
}
@Nullable
private static Double getTotalDurationInSeconds(List preQueries, QueryResult query, List postQueries, Function metric)
{
OptionalDouble result = concat(preQueries.stream(), Stream.of(query), postQueries.stream())
.map(metric)
.filter(Objects::nonNull)
.mapToDouble(duration -> duration.getValue(SECONDS))
.reduce(Double::sum);
if (result.isEmpty()) {
return null;
}
return result.getAsDouble();
}
private static T takeUnchecked(CompletionService completionService)
throws InterruptedException
{
try {
return completionService.take().get();
}
catch (ExecutionException e) {
throw new RuntimeException(e);
}
}
private static boolean shouldAddStackTrace(Exception e)
{
if (e instanceof TrinoException) {
ErrorCode errorCode = ((TrinoException) e).getErrorCode();
if (EXPECTED_ERRORS.contains(errorCode)) {
return false;
}
}
return true;
}
private static String generateCorrespondingSelect(List columnTypes, List> rows)
{
StringBuilder sb = new StringBuilder("SELECT *\nFROM\n(\n VALUES\n");
for (int rowIndex = 0; rowIndex < rows.size(); rowIndex++) {
List row = rows.get(rowIndex);
sb.append(" (");
for (int columnIndex = 0; columnIndex < columnTypes.size(); columnIndex++) {
String type = columnTypes.get(columnIndex);
Optional value = Optional.ofNullable(row.get(columnIndex)).map(Object::toString);
String literal = getLiteral(type, value);
sb.append(literal);
if (columnIndex < columnTypes.size() - 1) {
sb.append(", ");
}
}
sb.append(")");
if (rowIndex < rows.size() - 1) {
sb.append(",");
}
sb.append("\n");
}
if (rows.isEmpty()) {
sb.append(" (");
for (int columnIndex = 0; columnIndex < columnTypes.size(); columnIndex++) {
sb.append("NULL");
if (columnIndex < columnTypes.size() - 1) {
sb.append(", ");
}
}
sb.append(")\n");
}
sb.append(")\n");
if (rows.isEmpty()) {
sb.append("WHERE 1=0\n");
}
return sb.toString();
}
private static String getLiteral(String type, Optional value)
{
String baseType = getBaseType(type);
switch (baseType) {
case "TINYINT":
case "SMALLINT":
case "INTEGER":
case "BIGINT":
case "DECIMAL":
case "DATE":
case "TIME":
case "REAL":
case "DOUBLE":
return value.map(v -> baseType + " '" + v + "'").orElse("NULL");
case "CHAR":
case "VARCHAR":
return value.map(v -> baseType + " '" + v.replaceAll("'", "''") + "'").orElse("NULL");
case "VARBINARY":
return value.map(v -> "X'" + v + "'").orElse("NULL");
case "UNKNOWN":
return "NULL";
default:
throw new IllegalArgumentException(format("Unexpected type: %s", type));
}
}
private static String getBaseType(String type)
{
String baseType = type.toUpperCase(ENGLISH);
int index = baseType.indexOf('(');
if (index != -1) {
baseType = baseType.substring(0, index);
}
return baseType;
}
}