All Downloads are FREE. Search and download functionalities are using the official Maven repository.

software.amazon.awssdk.services.sagemakerruntime.DefaultSageMakerRuntimeAsyncClient Maven / Gradle / Ivy

Go to download

The AWS Java SDK for SageMaker Runtime module holds the client classes that are used for communicating with SageMaker Runtime.

There is a newer version: 2.29.39
Show newest version
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * 
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance with
 * the License. A copy of the License is located at
 * 
 * http://aws.amazon.com/apache2.0
 * 
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
 * CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */

package software.amazon.awssdk.services.sagemakerruntime;

import static software.amazon.awssdk.utils.FunctionalUtils.runAndLogError;

import java.util.Collections;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.annotations.Generated;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.awscore.client.handler.AwsAsyncClientHandler;
import software.amazon.awssdk.awscore.eventstream.EventStreamAsyncResponseTransformer;
import software.amazon.awssdk.awscore.eventstream.EventStreamTaggedUnionPojoSupplier;
import software.amazon.awssdk.awscore.eventstream.RestEventStreamAsyncResponseTransformer;
import software.amazon.awssdk.awscore.exception.AwsServiceException;
import software.amazon.awssdk.awscore.internal.AwsProtocolMetadata;
import software.amazon.awssdk.awscore.internal.AwsServiceProtocol;
import software.amazon.awssdk.core.RequestOverrideConfiguration;
import software.amazon.awssdk.core.SdkPlugin;
import software.amazon.awssdk.core.SdkPojoBuilder;
import software.amazon.awssdk.core.SdkRequest;
import software.amazon.awssdk.core.SdkResponse;
import software.amazon.awssdk.core.client.config.SdkAdvancedAsyncClientOption;
import software.amazon.awssdk.core.client.config.SdkClientConfiguration;
import software.amazon.awssdk.core.client.config.SdkClientOption;
import software.amazon.awssdk.core.client.handler.AsyncClientHandler;
import software.amazon.awssdk.core.client.handler.AttachHttpMetadataResponseHandler;
import software.amazon.awssdk.core.client.handler.ClientExecutionParams;
import software.amazon.awssdk.core.http.HttpResponseHandler;
import software.amazon.awssdk.core.metrics.CoreMetric;
import software.amazon.awssdk.core.protocol.VoidSdkResponse;
import software.amazon.awssdk.metrics.MetricCollector;
import software.amazon.awssdk.metrics.MetricPublisher;
import software.amazon.awssdk.metrics.NoOpMetricCollector;
import software.amazon.awssdk.protocols.core.ExceptionMetadata;
import software.amazon.awssdk.protocols.json.AwsJsonProtocol;
import software.amazon.awssdk.protocols.json.AwsJsonProtocolFactory;
import software.amazon.awssdk.protocols.json.BaseAwsJsonProtocolFactory;
import software.amazon.awssdk.protocols.json.JsonOperationMetadata;
import software.amazon.awssdk.services.sagemakerruntime.internal.SageMakerRuntimeServiceClientConfigurationBuilder;
import software.amazon.awssdk.services.sagemakerruntime.model.InternalDependencyException;
import software.amazon.awssdk.services.sagemakerruntime.model.InternalFailureException;
import software.amazon.awssdk.services.sagemakerruntime.model.InternalStreamFailureException;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointAsyncRequest;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointAsyncResponse;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointRequest;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointResponse;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointWithResponseStreamRequest;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointWithResponseStreamResponse;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointWithResponseStreamResponseHandler;
import software.amazon.awssdk.services.sagemakerruntime.model.ModelErrorException;
import software.amazon.awssdk.services.sagemakerruntime.model.ModelNotReadyException;
import software.amazon.awssdk.services.sagemakerruntime.model.ModelStreamErrorException;
import software.amazon.awssdk.services.sagemakerruntime.model.ResponseStream;
import software.amazon.awssdk.services.sagemakerruntime.model.SageMakerRuntimeException;
import software.amazon.awssdk.services.sagemakerruntime.model.ServiceUnavailableException;
import software.amazon.awssdk.services.sagemakerruntime.model.ValidationErrorException;
import software.amazon.awssdk.services.sagemakerruntime.transform.InvokeEndpointAsyncRequestMarshaller;
import software.amazon.awssdk.services.sagemakerruntime.transform.InvokeEndpointRequestMarshaller;
import software.amazon.awssdk.services.sagemakerruntime.transform.InvokeEndpointWithResponseStreamRequestMarshaller;
import software.amazon.awssdk.utils.CompletableFutureUtils;

/**
 * Internal implementation of {@link SageMakerRuntimeAsyncClient}.
 *
 * @see SageMakerRuntimeAsyncClient#builder()
 */
@Generated("software.amazon.awssdk:codegen")
@SdkInternalApi
final class DefaultSageMakerRuntimeAsyncClient implements SageMakerRuntimeAsyncClient {
    private static final Logger log = LoggerFactory.getLogger(DefaultSageMakerRuntimeAsyncClient.class);

    private static final AwsProtocolMetadata protocolMetadata = AwsProtocolMetadata.builder()
            .serviceProtocol(AwsServiceProtocol.REST_JSON).build();

    private final AsyncClientHandler clientHandler;

    private final AwsJsonProtocolFactory protocolFactory;

    private final SdkClientConfiguration clientConfiguration;

    private final Executor executor;

    protected DefaultSageMakerRuntimeAsyncClient(SdkClientConfiguration clientConfiguration) {
        this.clientHandler = new AwsAsyncClientHandler(clientConfiguration);
        this.clientConfiguration = clientConfiguration;
        this.protocolFactory = init(AwsJsonProtocolFactory.builder()).build();
        this.executor = clientConfiguration.option(SdkAdvancedAsyncClientOption.FUTURE_COMPLETION_EXECUTOR);
    }

    /**
     * 

* After you deploy a model into production using Amazon SageMaker hosting services, your client applications use * this API to get inferences from the model hosted at the specified endpoint. *

*

* For an overview of Amazon SageMaker, see How It Works. *

*

* Amazon SageMaker strips all POST headers except those supported by the API. Amazon SageMaker might add additional * headers. You should not rely on the behavior of headers outside those enumerated in the request syntax. *

*

* Calls to InvokeEndpoint are authenticated by using Amazon Web Services Signature Version 4. For * information, see Authenticating * Requests (Amazon Web Services Signature Version 4) in the Amazon S3 API Reference. *

*

* A customer's model containers must respond to requests within 60 seconds. The model itself can have a maximum * processing time of 60 seconds before responding to invocations. If your model is going to take 50-60 seconds of * processing time, the SDK socket timeout should be set to be 70 seconds. *

* *

* Endpoints are scoped to an individual account, and are not public. The URL does not contain the account ID, but * Amazon SageMaker determines the account ID from the authentication token that is supplied by the caller. *

*
* * @param invokeEndpointRequest * @return A Java Future containing the result of the InvokeEndpoint operation returned by the service.
* The CompletableFuture returned by this method can be completed exceptionally with the following * exceptions. *
    *
  • InternalFailureException An internal failure occurred.
  • *
  • ServiceUnavailableException The service is unavailable. Try your call again.
  • *
  • ValidationErrorException Inspect your request and try again.
  • *
  • ModelErrorException Model (owned by the customer in the container) returned 4xx or 5xx error code.
  • *
  • InternalDependencyException Your request caused an exception with an internal dependency. Contact * customer support.
  • *
  • ModelNotReadyException Either a serverless endpoint variant's resources are still being provisioned, * or a multi-model endpoint is still downloading or loading the target model. Wait and try your request * again.
  • *
  • SdkException Base class for all exceptions that can be thrown by the SDK (both service and client). * Can be used for catch all scenarios.
  • *
  • SdkClientException If any client side error occurs such as an IO related failure, failure to get * credentials, etc.
  • *
  • SageMakerRuntimeException Base class for all service exceptions. Unknown exceptions will be thrown as * an instance of this type.
  • *
* @sample SageMakerRuntimeAsyncClient.InvokeEndpoint * @see AWS API Documentation */ @Override public CompletableFuture invokeEndpoint(InvokeEndpointRequest invokeEndpointRequest) { SdkClientConfiguration clientConfiguration = updateSdkClientConfiguration(invokeEndpointRequest, this.clientConfiguration); List metricPublishers = resolveMetricPublishers(clientConfiguration, invokeEndpointRequest .overrideConfiguration().orElse(null)); MetricCollector apiCallMetricCollector = metricPublishers.isEmpty() ? NoOpMetricCollector.create() : MetricCollector .create("ApiCall"); try { apiCallMetricCollector.reportMetric(CoreMetric.SERVICE_ID, "SageMaker Runtime"); apiCallMetricCollector.reportMetric(CoreMetric.OPERATION_NAME, "InvokeEndpoint"); JsonOperationMetadata operationMetadata = JsonOperationMetadata.builder().hasStreamingSuccessResponse(false) .isPayloadJson(false).build(); HttpResponseHandler responseHandler = protocolFactory.createResponseHandler( operationMetadata, InvokeEndpointResponse::builder); HttpResponseHandler errorResponseHandler = createErrorResponseHandler(protocolFactory, operationMetadata); CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() .withOperationName("InvokeEndpoint").withProtocolMetadata(protocolMetadata) .withMarshaller(new InvokeEndpointRequestMarshaller(protocolFactory)) .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) .withInput(invokeEndpointRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); }); executeFuture = CompletableFutureUtils.forwardExceptionTo(whenCompleted, executeFuture); return executeFuture; } catch (Throwable t) { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); return CompletableFutureUtils.failedFuture(t); } } /** *

* After you deploy a model into production using Amazon SageMaker hosting services, your client applications use * this API to get inferences from the model hosted at the specified endpoint in an asynchronous manner. *

*

* Inference requests sent to this API are enqueued for asynchronous processing. The processing of the inference * request may or may not complete before you receive a response from this API. The response from this API will not * contain the result of the inference request but contain information about where you can locate it. *

*

* Amazon SageMaker strips all POST headers except those supported by the API. Amazon SageMaker might add additional * headers. You should not rely on the behavior of headers outside those enumerated in the request syntax. *

*

* Calls to InvokeEndpointAsync are authenticated by using Amazon Web Services Signature Version 4. For * information, see Authenticating * Requests (Amazon Web Services Signature Version 4) in the Amazon S3 API Reference. *

* * @param invokeEndpointAsyncRequest * @return A Java Future containing the result of the InvokeEndpointAsync operation returned by the service.
* The CompletableFuture returned by this method can be completed exceptionally with the following * exceptions. *
    *
  • InternalFailureException An internal failure occurred.
  • *
  • ServiceUnavailableException The service is unavailable. Try your call again.
  • *
  • ValidationErrorException Inspect your request and try again.
  • *
  • SdkException Base class for all exceptions that can be thrown by the SDK (both service and client). * Can be used for catch all scenarios.
  • *
  • SdkClientException If any client side error occurs such as an IO related failure, failure to get * credentials, etc.
  • *
  • SageMakerRuntimeException Base class for all service exceptions. Unknown exceptions will be thrown as * an instance of this type.
  • *
* @sample SageMakerRuntimeAsyncClient.InvokeEndpointAsync * @see AWS API Documentation */ @Override public CompletableFuture invokeEndpointAsync( InvokeEndpointAsyncRequest invokeEndpointAsyncRequest) { SdkClientConfiguration clientConfiguration = updateSdkClientConfiguration(invokeEndpointAsyncRequest, this.clientConfiguration); List metricPublishers = resolveMetricPublishers(clientConfiguration, invokeEndpointAsyncRequest .overrideConfiguration().orElse(null)); MetricCollector apiCallMetricCollector = metricPublishers.isEmpty() ? NoOpMetricCollector.create() : MetricCollector .create("ApiCall"); try { apiCallMetricCollector.reportMetric(CoreMetric.SERVICE_ID, "SageMaker Runtime"); apiCallMetricCollector.reportMetric(CoreMetric.OPERATION_NAME, "InvokeEndpointAsync"); JsonOperationMetadata operationMetadata = JsonOperationMetadata.builder().hasStreamingSuccessResponse(false) .isPayloadJson(true).build(); HttpResponseHandler responseHandler = protocolFactory.createResponseHandler( operationMetadata, InvokeEndpointAsyncResponse::builder); HttpResponseHandler errorResponseHandler = createErrorResponseHandler(protocolFactory, operationMetadata); CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() .withOperationName("InvokeEndpointAsync").withProtocolMetadata(protocolMetadata) .withMarshaller(new InvokeEndpointAsyncRequestMarshaller(protocolFactory)) .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) .withInput(invokeEndpointAsyncRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); }); executeFuture = CompletableFutureUtils.forwardExceptionTo(whenCompleted, executeFuture); return executeFuture; } catch (Throwable t) { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); return CompletableFutureUtils.failedFuture(t); } } /** *

* Invokes a model at the specified endpoint to return the inference response as a stream. The inference stream * provides the response payload incrementally as a series of parts. Before you can get an inference stream, you * must have access to a model that's deployed using Amazon SageMaker hosting services, and the container for that * model must support inference streaming. *

*

* For more information that can help you use this API, see the following sections in the Amazon SageMaker * Developer Guide: *

* *

* Amazon SageMaker strips all POST headers except those supported by the API. Amazon SageMaker might add additional * headers. You should not rely on the behavior of headers outside those enumerated in the request syntax. *

*

* Calls to InvokeEndpointWithResponseStream are authenticated by using Amazon Web Services Signature * Version 4. For information, see Authenticating * Requests (Amazon Web Services Signature Version 4) in the Amazon S3 API Reference. *

* * @param invokeEndpointWithResponseStreamRequest * @return A Java Future containing the result of the InvokeEndpointWithResponseStream operation returned by the * service.
* The CompletableFuture returned by this method can be completed exceptionally with the following * exceptions. *
    *
  • InternalFailureException An internal failure occurred.
  • *
  • ServiceUnavailableException The service is unavailable. Try your call again.
  • *
  • ValidationErrorException Inspect your request and try again.
  • *
  • ModelErrorException Model (owned by the customer in the container) returned 4xx or 5xx error code.
  • *
  • ModelStreamErrorException An error occurred while streaming the response body. This error can have * the following error codes:

    *
    *
    ModelInvocationTimeExceeded
    *
    *

    * The model failed to finish sending the response within the timeout period allowed by Amazon SageMaker. *

    *
    *
    StreamBroken
    *
    *

    * The Transmission Control Protocol (TCP) connection between the client and the model was reset or closed. *

    *
  • *
  • InternalStreamFailureException The stream processing failed because of an unknown error, exception or * failure. Try your request again.
  • *
  • SdkException Base class for all exceptions that can be thrown by the SDK (both service and client). * Can be used for catch all scenarios.
  • *
  • SdkClientException If any client side error occurs such as an IO related failure, failure to get * credentials, etc.
  • *
  • SageMakerRuntimeException Base class for all service exceptions. Unknown exceptions will be thrown as * an instance of this type.
  • *
* @sample SageMakerRuntimeAsyncClient.InvokeEndpointWithResponseStream * @see AWS API Documentation */ @Override public CompletableFuture invokeEndpointWithResponseStream( InvokeEndpointWithResponseStreamRequest invokeEndpointWithResponseStreamRequest, InvokeEndpointWithResponseStreamResponseHandler asyncResponseHandler) { SdkClientConfiguration clientConfiguration = updateSdkClientConfiguration(invokeEndpointWithResponseStreamRequest, this.clientConfiguration); List metricPublishers = resolveMetricPublishers(clientConfiguration, invokeEndpointWithResponseStreamRequest.overrideConfiguration().orElse(null)); MetricCollector apiCallMetricCollector = metricPublishers.isEmpty() ? NoOpMetricCollector.create() : MetricCollector .create("ApiCall"); try { apiCallMetricCollector.reportMetric(CoreMetric.SERVICE_ID, "SageMaker Runtime"); apiCallMetricCollector.reportMetric(CoreMetric.OPERATION_NAME, "InvokeEndpointWithResponseStream"); JsonOperationMetadata operationMetadata = JsonOperationMetadata.builder().hasStreamingSuccessResponse(false) .isPayloadJson(true).build(); HttpResponseHandler responseHandler = new AttachHttpMetadataResponseHandler( protocolFactory.createResponseHandler(operationMetadata, InvokeEndpointWithResponseStreamResponse::builder)); HttpResponseHandler voidResponseHandler = protocolFactory.createResponseHandler(JsonOperationMetadata .builder().isPayloadJson(false).hasStreamingSuccessResponse(true).build(), VoidSdkResponse::builder); HttpResponseHandler eventResponseHandler = protocolFactory.createResponseHandler( JsonOperationMetadata.builder().isPayloadJson(true).hasStreamingSuccessResponse(false).build(), EventStreamTaggedUnionPojoSupplier.builder() .putSdkPojoSupplier("PayloadPart", ResponseStream::payloadPartBuilder) .defaultSdkPojoSupplier(() -> new SdkPojoBuilder(ResponseStream.UNKNOWN)).build()); HttpResponseHandler errorResponseHandler = createErrorResponseHandler(protocolFactory, operationMetadata); CompletableFuture future = new CompletableFuture<>(); EventStreamAsyncResponseTransformer asyncResponseTransformer = EventStreamAsyncResponseTransformer . builder() .eventStreamResponseHandler(asyncResponseHandler).eventResponseHandler(eventResponseHandler) .initialResponseHandler(responseHandler).exceptionResponseHandler(errorResponseHandler).future(future) .executor(executor).serviceName(serviceName()).build(); RestEventStreamAsyncResponseTransformer restAsyncResponseTransformer = RestEventStreamAsyncResponseTransformer . builder() .eventStreamAsyncResponseTransformer(asyncResponseTransformer) .eventStreamResponseHandler(asyncResponseHandler).build(); CompletableFuture executeFuture = clientHandler .execute( new ClientExecutionParams() .withOperationName("InvokeEndpointWithResponseStream").withProtocolMetadata(protocolMetadata) .withMarshaller(new InvokeEndpointWithResponseStreamRequestMarshaller(protocolFactory)) .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) .withInput(invokeEndpointWithResponseStreamRequest), restAsyncResponseTransformer); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { if (e != null) { try { asyncResponseHandler.exceptionOccurred(e); } finally { future.completeExceptionally(e); } } metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); }); executeFuture = CompletableFutureUtils.forwardExceptionTo(whenCompleted, executeFuture); return CompletableFutureUtils.forwardExceptionTo(future, executeFuture); } catch (Throwable t) { runAndLogError(log, "Exception thrown in exceptionOccurred callback, ignoring", () -> asyncResponseHandler.exceptionOccurred(t)); metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); return CompletableFutureUtils.failedFuture(t); } } @Override public final SageMakerRuntimeServiceClientConfiguration serviceClientConfiguration() { return new SageMakerRuntimeServiceClientConfigurationBuilder(this.clientConfiguration.toBuilder()).build(); } @Override public final String serviceName() { return SERVICE_NAME; } private > T init(T builder) { return builder .clientConfiguration(clientConfiguration) .defaultServiceExceptionSupplier(SageMakerRuntimeException::builder) .protocol(AwsJsonProtocol.REST_JSON) .protocolVersion("1.1") .registerModeledException( ExceptionMetadata.builder().errorCode("ModelStreamError") .exceptionBuilderSupplier(ModelStreamErrorException::builder).build()) .registerModeledException( ExceptionMetadata.builder().errorCode("ModelError") .exceptionBuilderSupplier(ModelErrorException::builder).httpStatusCode(424).build()) .registerModeledException( ExceptionMetadata.builder().errorCode("InternalStreamFailure") .exceptionBuilderSupplier(InternalStreamFailureException::builder).build()) .registerModeledException( ExceptionMetadata.builder().errorCode("ValidationError") .exceptionBuilderSupplier(ValidationErrorException::builder).httpStatusCode(400).build()) .registerModeledException( ExceptionMetadata.builder().errorCode("ServiceUnavailable") .exceptionBuilderSupplier(ServiceUnavailableException::builder).httpStatusCode(503).build()) .registerModeledException( ExceptionMetadata.builder().errorCode("InternalFailure") .exceptionBuilderSupplier(InternalFailureException::builder).httpStatusCode(500).build()) .registerModeledException( ExceptionMetadata.builder().errorCode("InternalDependencyException") .exceptionBuilderSupplier(InternalDependencyException::builder).httpStatusCode(530).build()) .registerModeledException( ExceptionMetadata.builder().errorCode("ModelNotReadyException") .exceptionBuilderSupplier(ModelNotReadyException::builder).httpStatusCode(429).build()); } private static List resolveMetricPublishers(SdkClientConfiguration clientConfiguration, RequestOverrideConfiguration requestOverrideConfiguration) { List publishers = null; if (requestOverrideConfiguration != null) { publishers = requestOverrideConfiguration.metricPublishers(); } if (publishers == null || publishers.isEmpty()) { publishers = clientConfiguration.option(SdkClientOption.METRIC_PUBLISHERS); } if (publishers == null) { publishers = Collections.emptyList(); } return publishers; } private SdkClientConfiguration updateSdkClientConfiguration(SdkRequest request, SdkClientConfiguration clientConfiguration) { List plugins = request.overrideConfiguration().map(c -> c.plugins()).orElse(Collections.emptyList()); if (plugins.isEmpty()) { return clientConfiguration; } SdkClientConfiguration.Builder configuration = clientConfiguration.toBuilder(); SageMakerRuntimeServiceClientConfigurationBuilder serviceConfigBuilder = new SageMakerRuntimeServiceClientConfigurationBuilder( configuration); for (SdkPlugin plugin : plugins) { plugin.configureClient(serviceConfigBuilder); } return configuration.build(); } private HttpResponseHandler createErrorResponseHandler(BaseAwsJsonProtocolFactory protocolFactory, JsonOperationMetadata operationMetadata) { return protocolFactory.createErrorResponseHandler(operationMetadata); } @Override public void close() { clientHandler.close(); } }




© 2015 - 2025 Weber Informatics LLC | Privacy Policy