All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.facebook.presto.client.StatementClientV1 Maven / Gradle / Ivy

There is a newer version: 0.290
Show newest version
/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.client;

import com.facebook.airlift.json.JsonCodec;
import com.facebook.presto.client.OkHttpUtil.NullCallback;
import com.facebook.presto.common.type.TimeZoneKey;
import com.facebook.presto.spi.security.SelectedRole;
import com.google.common.base.Joiner;
import com.google.common.base.Splitter;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.airlift.units.Duration;
import okhttp3.Headers;
import okhttp3.HttpUrl;
import okhttp3.MediaType;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.RequestBody;

import javax.annotation.Nullable;
import javax.annotation.concurrent.ThreadSafe;

import java.io.UnsupportedEncodingException;
import java.net.URI;
import java.net.URLDecoder;
import java.net.URLEncoder;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;

import static com.facebook.airlift.json.JsonCodec.jsonCodec;
import static com.facebook.presto.client.PrestoHeaders.PRESTO_ADDED_PREPARE;
import static com.facebook.presto.client.PrestoHeaders.PRESTO_ADDED_SESSION_FUNCTION;
import static com.facebook.presto.client.PrestoHeaders.PRESTO_CATALOG;
import static com.facebook.presto.client.PrestoHeaders.PRESTO_CLEAR_SESSION;
import static com.facebook.presto.client.PrestoHeaders.PRESTO_CLEAR_TRANSACTION_ID;
import static com.facebook.presto.client.PrestoHeaders.PRESTO_CLIENT_INFO;
import static com.facebook.presto.client.PrestoHeaders.PRESTO_CLIENT_TAGS;
import static com.facebook.presto.client.PrestoHeaders.PRESTO_DEALLOCATED_PREPARE;
import static com.facebook.presto.client.PrestoHeaders.PRESTO_EXTRA_CREDENTIAL;
import static com.facebook.presto.client.PrestoHeaders.PRESTO_LANGUAGE;
import static com.facebook.presto.client.PrestoHeaders.PRESTO_PREPARED_STATEMENT;
import static com.facebook.presto.client.PrestoHeaders.PRESTO_REMOVED_SESSION_FUNCTION;
import static com.facebook.presto.client.PrestoHeaders.PRESTO_RESOURCE_ESTIMATE;
import static com.facebook.presto.client.PrestoHeaders.PRESTO_SCHEMA;
import static com.facebook.presto.client.PrestoHeaders.PRESTO_SESSION;
import static com.facebook.presto.client.PrestoHeaders.PRESTO_SESSION_FUNCTION;
import static com.facebook.presto.client.PrestoHeaders.PRESTO_SET_CATALOG;
import static com.facebook.presto.client.PrestoHeaders.PRESTO_SET_ROLE;
import static com.facebook.presto.client.PrestoHeaders.PRESTO_SET_SCHEMA;
import static com.facebook.presto.client.PrestoHeaders.PRESTO_SET_SESSION;
import static com.facebook.presto.client.PrestoHeaders.PRESTO_SOURCE;
import static com.facebook.presto.client.PrestoHeaders.PRESTO_STARTED_TRANSACTION_ID;
import static com.facebook.presto.client.PrestoHeaders.PRESTO_TIME_ZONE;
import static com.facebook.presto.client.PrestoHeaders.PRESTO_TRACE_TOKEN;
import static com.facebook.presto.client.PrestoHeaders.PRESTO_TRANSACTION_ID;
import static com.facebook.presto.client.PrestoHeaders.PRESTO_USER;
import static com.google.common.base.MoreObjects.firstNonNull;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.Sets.newConcurrentHashSet;
import static com.google.common.net.HttpHeaders.ACCEPT_ENCODING;
import static com.google.common.net.HttpHeaders.USER_AGENT;
import static java.lang.String.format;
import static java.net.HttpURLConnection.HTTP_OK;
import static java.net.HttpURLConnection.HTTP_UNAUTHORIZED;
import static java.net.HttpURLConnection.HTTP_UNAVAILABLE;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.TimeUnit.MILLISECONDS;

@ThreadSafe
class StatementClientV1
        implements StatementClient
{
    private static final MediaType MEDIA_TYPE_TEXT = MediaType.parse("text/plain; charset=utf-8");
    private static final JsonCodec QUERY_RESULTS_CODEC = jsonCodec(QueryResults.class);

    private static final Splitter SESSION_HEADER_SPLITTER = Splitter.on('=').limit(2).trimResults();
    private static final String USER_AGENT_VALUE = StatementClientV1.class.getSimpleName() +
            "/" +
            firstNonNull(StatementClientV1.class.getPackage().getImplementationVersion(), "unknown");

    private final OkHttpClient httpClient;
    private final String query;
    private final AtomicReference currentResults = new AtomicReference<>();
    private final AtomicReference setCatalog = new AtomicReference<>();
    private final AtomicReference setSchema = new AtomicReference<>();
    private final Map setSessionProperties = new ConcurrentHashMap<>();
    private final Set resetSessionProperties = newConcurrentHashSet();
    private final Map setRoles = new ConcurrentHashMap<>();
    private final Map addedPreparedStatements = new ConcurrentHashMap<>();
    private final Set deallocatedPreparedStatements = newConcurrentHashSet();
    private final AtomicReference startedTransactionId = new AtomicReference<>();
    private final AtomicBoolean clearTransactionId = new AtomicBoolean();
    private final TimeZoneKey timeZone;
    private final Duration requestTimeoutNanos;
    private final String user;
    private final boolean compressionDisabled;
    private final Map addedSessionFunctions = new ConcurrentHashMap<>();
    private final Set removedSessionFunctions = newConcurrentHashSet();

    private final AtomicReference state = new AtomicReference<>(State.RUNNING);

    public StatementClientV1(OkHttpClient httpClient, ClientSession session, String query)
    {
        requireNonNull(httpClient, "httpClient is null");
        requireNonNull(session, "session is null");
        requireNonNull(query, "query is null");

        this.httpClient = httpClient;
        this.timeZone = session.getTimeZone();
        this.query = query;
        this.requestTimeoutNanos = session.getClientRequestTimeout();
        this.user = session.getUser();
        this.compressionDisabled = session.isCompressionDisabled();

        Request request = buildQueryRequest(session, query);

        JsonResponse response = JsonResponse.execute(QUERY_RESULTS_CODEC, httpClient, request);
        if ((response.getStatusCode() != HTTP_OK) || !response.hasValue()) {
            state.compareAndSet(State.RUNNING, State.CLIENT_ERROR);
            throw requestFailedException("starting query", request, response);
        }

        processResponse(response.getHeaders(), response.getValue());
    }

    private Request buildQueryRequest(ClientSession session, String query)
    {
        HttpUrl url = HttpUrl.get(session.getServer());
        if (url == null) {
            throw new ClientException("Invalid server URL: " + session.getServer());
        }
        url = url.newBuilder().encodedPath("/v1/statement").build();

        Request.Builder builder = prepareRequest(url)
                .post(RequestBody.create(MEDIA_TYPE_TEXT, query));

        Map customHeaders = session.getCustomHeaders();
        for (Entry entry : customHeaders.entrySet()) {
            builder.addHeader(entry.getKey(), entry.getValue());
        }

        if (session.getSource() != null) {
            builder.addHeader(PRESTO_SOURCE, session.getSource());
        }

        session.getTraceToken().ifPresent(token -> builder.addHeader(PRESTO_TRACE_TOKEN, token));

        if (session.getClientTags() != null && !session.getClientTags().isEmpty()) {
            builder.addHeader(PRESTO_CLIENT_TAGS, Joiner.on(",").join(session.getClientTags()));
        }
        if (session.getClientInfo() != null) {
            builder.addHeader(PRESTO_CLIENT_INFO, session.getClientInfo());
        }
        if (session.getCatalog() != null) {
            builder.addHeader(PRESTO_CATALOG, session.getCatalog());
        }
        if (session.getSchema() != null) {
            builder.addHeader(PRESTO_SCHEMA, session.getSchema());
        }
        builder.addHeader(PRESTO_TIME_ZONE, session.getTimeZone().getId());
        if (session.getLocale() != null) {
            builder.addHeader(PRESTO_LANGUAGE, session.getLocale().toLanguageTag());
        }

        Map property = session.getProperties();
        for (Entry entry : property.entrySet()) {
            builder.addHeader(PRESTO_SESSION, entry.getKey() + "=" + urlEncode(entry.getValue()));
        }

        Map resourceEstimates = session.getResourceEstimates();
        for (Entry entry : resourceEstimates.entrySet()) {
            builder.addHeader(PRESTO_RESOURCE_ESTIMATE, entry.getKey() + "=" + entry.getValue());
        }

        Map roles = session.getRoles();
        for (Entry entry : roles.entrySet()) {
            builder.addHeader(PrestoHeaders.PRESTO_ROLE, entry.getKey() + '=' + urlEncode(entry.getValue().toString()));
        }

        Map extraCredentials = session.getExtraCredentials();
        for (Entry entry : extraCredentials.entrySet()) {
            builder.addHeader(PRESTO_EXTRA_CREDENTIAL, entry.getKey() + "=" + entry.getValue());
        }

        Map statements = session.getPreparedStatements();
        for (Entry entry : statements.entrySet()) {
            builder.addHeader(PRESTO_PREPARED_STATEMENT, urlEncode(entry.getKey()) + "=" + urlEncode(entry.getValue()));
        }

        builder.addHeader(PRESTO_TRANSACTION_ID, session.getTransactionId() == null ? "NONE" : session.getTransactionId());

        Map sessionFunctions = session.getSessionFunctions();
        for (Entry entry : sessionFunctions.entrySet()) {
            builder.addHeader(PRESTO_SESSION_FUNCTION, urlEncode(entry.getKey()) + "=" + urlEncode(entry.getValue()));
        }

        return builder.build();
    }

    @Override
    public String getQuery()
    {
        return query;
    }

    @Override
    public TimeZoneKey getTimeZone()
    {
        return timeZone;
    }

    public boolean isRunning()
    {
        return state.get() == State.RUNNING;
    }

    public boolean isClientAborted()
    {
        return state.get() == State.CLIENT_ABORTED;
    }

    public boolean isClientError()
    {
        return state.get() == State.CLIENT_ERROR;
    }

    public boolean isFinished()
    {
        return state.get() == State.FINISHED;
    }

    @Override
    public StatementStats getStats()
    {
        return currentResults.get().getStats();
    }

    @Override
    public QueryStatusInfo currentStatusInfo()
    {
        checkState(isRunning(), "current position is not valid (cursor past end)");
        return currentResults.get();
    }

    @Override
    public QueryData currentData()
    {
        checkState(isRunning(), "current position is not valid (cursor past end)");
        return currentResults.get();
    }

    @Override
    public QueryStatusInfo finalStatusInfo()
    {
        checkState(!isRunning(), "current position is still valid");
        return currentResults.get();
    }

    @Override
    public Optional getSetCatalog()
    {
        return Optional.ofNullable(setCatalog.get());
    }

    @Override
    public Optional getSetSchema()
    {
        return Optional.ofNullable(setSchema.get());
    }

    @Override
    public Map getSetSessionProperties()
    {
        return ImmutableMap.copyOf(setSessionProperties);
    }

    @Override
    public Set getResetSessionProperties()
    {
        return ImmutableSet.copyOf(resetSessionProperties);
    }

    @Override
    public Map getSetRoles()
    {
        return ImmutableMap.copyOf(setRoles);
    }

    @Override
    public Map getAddedPreparedStatements()
    {
        return ImmutableMap.copyOf(addedPreparedStatements);
    }

    @Override
    public Set getDeallocatedPreparedStatements()
    {
        return ImmutableSet.copyOf(deallocatedPreparedStatements);
    }

    @Override
    @Nullable
    public String getStartedTransactionId()
    {
        return startedTransactionId.get();
    }

    @Override
    public boolean isClearTransactionId()
    {
        return clearTransactionId.get();
    }

    @Override
    public Map getAddedSessionFunctions()
    {
        return ImmutableMap.copyOf(addedSessionFunctions);
    }

    @Override
    public Set getRemovedSessionFunctions()
    {
        return ImmutableSet.copyOf(removedSessionFunctions);
    }

    private Request.Builder prepareRequest(HttpUrl url)
    {
        Request.Builder builder = new Request.Builder()
                .addHeader(PRESTO_USER, user)
                .addHeader(USER_AGENT, USER_AGENT_VALUE)
                .url(url);
        if (compressionDisabled) {
            builder.header(ACCEPT_ENCODING, "identity");
        }
        return builder;
    }

    @Override
    public boolean advance()
    {
        if (!isRunning()) {
            return false;
        }

        URI nextUri = currentStatusInfo().getNextUri();
        if (nextUri == null) {
            state.compareAndSet(State.RUNNING, State.FINISHED);
            return false;
        }

        Request request = prepareRequest(HttpUrl.get(nextUri)).build();

        Exception cause = null;
        long start = System.nanoTime();
        long attempts = 0;

        while (true) {
            if (isClientAborted()) {
                return false;
            }

            Duration sinceStart = Duration.nanosSince(start);
            if (attempts > 0 && sinceStart.compareTo(requestTimeoutNanos) > 0) {
                state.compareAndSet(State.RUNNING, State.CLIENT_ERROR);
                throw new RuntimeException(format("Error fetching next (attempts: %s, duration: %s)", attempts, sinceStart), cause);
            }

            if (attempts > 0) {
                // back-off on retry
                try {
                    MILLISECONDS.sleep(attempts * 100);
                }
                catch (InterruptedException e) {
                    try {
                        close();
                    }
                    finally {
                        Thread.currentThread().interrupt();
                    }
                    state.compareAndSet(State.RUNNING, State.CLIENT_ERROR);
                    throw new RuntimeException("StatementClient thread was interrupted");
                }
            }
            attempts++;

            JsonResponse response;
            try {
                response = JsonResponse.execute(QUERY_RESULTS_CODEC, httpClient, request);
            }
            catch (RuntimeException e) {
                cause = e;
                continue;
            }

            if ((response.getStatusCode() == HTTP_OK) && response.hasValue()) {
                processResponse(response.getHeaders(), response.getValue());
                return true;
            }

            if (response.getStatusCode() != HTTP_UNAVAILABLE) {
                state.compareAndSet(State.RUNNING, State.CLIENT_ERROR);
                throw requestFailedException("fetching next", request, response);
            }
        }
    }

    private void processResponse(Headers headers, QueryResults results)
    {
        setCatalog.set(headers.get(PRESTO_SET_CATALOG));
        setSchema.set(headers.get(PRESTO_SET_SCHEMA));

        for (String setSession : headers.values(PRESTO_SET_SESSION)) {
            List keyValue = SESSION_HEADER_SPLITTER.splitToList(setSession);
            if (keyValue.size() != 2) {
                continue;
            }
            setSessionProperties.put(keyValue.get(0), urlDecode(keyValue.get(1)));
        }
        resetSessionProperties.addAll(headers.values(PRESTO_CLEAR_SESSION));

        for (String setRole : headers.values(PRESTO_SET_ROLE)) {
            List keyValue = SESSION_HEADER_SPLITTER.splitToList(setRole);
            if (keyValue.size() != 2) {
                continue;
            }
            setRoles.put(keyValue.get(0), SelectedRole.valueOf(urlDecode(keyValue.get(1))));
        }

        for (String entry : headers.values(PRESTO_ADDED_PREPARE)) {
            List keyValue = SESSION_HEADER_SPLITTER.splitToList(entry);
            if (keyValue.size() != 2) {
                continue;
            }
            addedPreparedStatements.put(urlDecode(keyValue.get(0)), urlDecode(keyValue.get(1)));
        }
        for (String entry : headers.values(PRESTO_DEALLOCATED_PREPARE)) {
            deallocatedPreparedStatements.add(urlDecode(entry));
        }

        String startedTransactionId = headers.get(PRESTO_STARTED_TRANSACTION_ID);
        if (startedTransactionId != null) {
            this.startedTransactionId.set(startedTransactionId);
        }
        if (headers.get(PRESTO_CLEAR_TRANSACTION_ID) != null) {
            clearTransactionId.set(true);
        }

        for (String sessionFunction : headers.values(PRESTO_ADDED_SESSION_FUNCTION)) {
            List keyValue = SESSION_HEADER_SPLITTER.splitToList(sessionFunction);
            if (keyValue.size() != 2) {
                continue;
            }
            addedSessionFunctions.put(urlDecode(keyValue.get(0)), urlDecode(keyValue.get(1)));
        }
        for (String signature : headers.values(PRESTO_REMOVED_SESSION_FUNCTION)) {
            removedSessionFunctions.add(urlDecode(signature));
        }

        currentResults.set(results);
    }

    private RuntimeException requestFailedException(String task, Request request, JsonResponse response)
    {
        if (!response.hasValue()) {
            if (response.getStatusCode() == HTTP_UNAUTHORIZED) {
                return new ClientException("Authentication failed" +
                        Optional.ofNullable(response.getStatusMessage())
                                .map(message -> ": " + message)
                                .orElse(""));
            }
            if (response.getStatusCode() == 429) {
                return new ClientException("Request throttled " +
                        Optional.ofNullable(response.getStatusMessage())
                                .map(message -> ": " + message)
                                .orElse(""), true);
            }
            return new RuntimeException(
                    format("Error %s at %s returned an invalid response: %s [Error: %s]", task, request.url(), response, response.getResponseBody()),
                    response.getException());
        }
        return new RuntimeException(format("Error %s at %s returned HTTP %s", task, request.url(), response.getStatusCode()));
    }

    @Override
    public void cancelLeafStage()
    {
        checkState(!isClientAborted(), "client is closed");

        URI uri = currentStatusInfo().getPartialCancelUri();
        if (uri != null) {
            httpDelete(uri);
        }
    }

    @Override
    public void close()
    {
        // If the query is not done, abort the query.
        if (state.compareAndSet(State.RUNNING, State.CLIENT_ABORTED)) {
            URI uri = currentResults.get().getNextUri();
            if (uri != null) {
                httpDelete(uri);
            }
        }
    }

    private void httpDelete(URI uri)
    {
        Request request = prepareRequest(HttpUrl.get(uri))
                .delete()
                .build();
        httpClient.newCall(request).enqueue(new NullCallback());
    }

    private static String urlEncode(String value)
    {
        try {
            return URLEncoder.encode(value, "UTF-8");
        }
        catch (UnsupportedEncodingException e) {
            throw new AssertionError(e);
        }
    }

    private static String urlDecode(String value)
    {
        try {
            return URLDecoder.decode(value, "UTF-8");
        }
        catch (UnsupportedEncodingException e) {
            throw new AssertionError(e);
        }
    }

    private enum State
    {
        /**
         * submitted to server, not in terminal state (including planning, queued, running, etc)
         */
        RUNNING,
        CLIENT_ERROR,
        CLIENT_ABORTED,
        /**
         * finished on remote Presto server (including failed and successfully completed)
         */
        FINISHED,
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy