software.amazon.awssdk.services.sagemakerruntime.DefaultSageMakerRuntimeAsyncClient Maven / Gradle / Ivy
Show all versions of sagemakerruntime Show documentation
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance with
* the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
* CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package software.amazon.awssdk.services.sagemakerruntime;
import static software.amazon.awssdk.utils.FunctionalUtils.runAndLogError;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.annotations.Generated;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.awscore.client.handler.AwsAsyncClientHandler;
import software.amazon.awssdk.awscore.eventstream.EventStreamAsyncResponseTransformer;
import software.amazon.awssdk.awscore.eventstream.EventStreamTaggedUnionPojoSupplier;
import software.amazon.awssdk.awscore.eventstream.RestEventStreamAsyncResponseTransformer;
import software.amazon.awssdk.awscore.exception.AwsServiceException;
import software.amazon.awssdk.awscore.internal.AwsProtocolMetadata;
import software.amazon.awssdk.awscore.internal.AwsServiceProtocol;
import software.amazon.awssdk.core.RequestOverrideConfiguration;
import software.amazon.awssdk.core.SdkPlugin;
import software.amazon.awssdk.core.SdkPojoBuilder;
import software.amazon.awssdk.core.SdkRequest;
import software.amazon.awssdk.core.SdkResponse;
import software.amazon.awssdk.core.client.config.SdkAdvancedAsyncClientOption;
import software.amazon.awssdk.core.client.config.SdkClientConfiguration;
import software.amazon.awssdk.core.client.config.SdkClientOption;
import software.amazon.awssdk.core.client.handler.AsyncClientHandler;
import software.amazon.awssdk.core.client.handler.AttachHttpMetadataResponseHandler;
import software.amazon.awssdk.core.client.handler.ClientExecutionParams;
import software.amazon.awssdk.core.http.HttpResponseHandler;
import software.amazon.awssdk.core.metrics.CoreMetric;
import software.amazon.awssdk.core.protocol.VoidSdkResponse;
import software.amazon.awssdk.metrics.MetricCollector;
import software.amazon.awssdk.metrics.MetricPublisher;
import software.amazon.awssdk.metrics.NoOpMetricCollector;
import software.amazon.awssdk.protocols.core.ExceptionMetadata;
import software.amazon.awssdk.protocols.json.AwsJsonProtocol;
import software.amazon.awssdk.protocols.json.AwsJsonProtocolFactory;
import software.amazon.awssdk.protocols.json.BaseAwsJsonProtocolFactory;
import software.amazon.awssdk.protocols.json.JsonOperationMetadata;
import software.amazon.awssdk.services.sagemakerruntime.internal.SageMakerRuntimeServiceClientConfigurationBuilder;
import software.amazon.awssdk.services.sagemakerruntime.model.InternalDependencyException;
import software.amazon.awssdk.services.sagemakerruntime.model.InternalFailureException;
import software.amazon.awssdk.services.sagemakerruntime.model.InternalStreamFailureException;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointAsyncRequest;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointAsyncResponse;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointRequest;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointResponse;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointWithResponseStreamRequest;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointWithResponseStreamResponse;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointWithResponseStreamResponseHandler;
import software.amazon.awssdk.services.sagemakerruntime.model.ModelErrorException;
import software.amazon.awssdk.services.sagemakerruntime.model.ModelNotReadyException;
import software.amazon.awssdk.services.sagemakerruntime.model.ModelStreamErrorException;
import software.amazon.awssdk.services.sagemakerruntime.model.ResponseStream;
import software.amazon.awssdk.services.sagemakerruntime.model.SageMakerRuntimeException;
import software.amazon.awssdk.services.sagemakerruntime.model.ServiceUnavailableException;
import software.amazon.awssdk.services.sagemakerruntime.model.ValidationErrorException;
import software.amazon.awssdk.services.sagemakerruntime.transform.InvokeEndpointAsyncRequestMarshaller;
import software.amazon.awssdk.services.sagemakerruntime.transform.InvokeEndpointRequestMarshaller;
import software.amazon.awssdk.services.sagemakerruntime.transform.InvokeEndpointWithResponseStreamRequestMarshaller;
import software.amazon.awssdk.utils.CompletableFutureUtils;
/**
* Internal implementation of {@link SageMakerRuntimeAsyncClient}.
*
* @see SageMakerRuntimeAsyncClient#builder()
*/
@Generated("software.amazon.awssdk:codegen")
@SdkInternalApi
final class DefaultSageMakerRuntimeAsyncClient implements SageMakerRuntimeAsyncClient {
private static final Logger log = LoggerFactory.getLogger(DefaultSageMakerRuntimeAsyncClient.class);
private static final AwsProtocolMetadata protocolMetadata = AwsProtocolMetadata.builder()
.serviceProtocol(AwsServiceProtocol.REST_JSON).build();
private final AsyncClientHandler clientHandler;
private final AwsJsonProtocolFactory protocolFactory;
private final SdkClientConfiguration clientConfiguration;
private final Executor executor;
protected DefaultSageMakerRuntimeAsyncClient(SdkClientConfiguration clientConfiguration) {
this.clientHandler = new AwsAsyncClientHandler(clientConfiguration);
this.clientConfiguration = clientConfiguration;
this.protocolFactory = init(AwsJsonProtocolFactory.builder()).build();
this.executor = clientConfiguration.option(SdkAdvancedAsyncClientOption.FUTURE_COMPLETION_EXECUTOR);
}
/**
*
* After you deploy a model into production using Amazon SageMaker hosting services, your client applications use
* this API to get inferences from the model hosted at the specified endpoint.
*
*
* For an overview of Amazon SageMaker, see How It Works.
*
*
* Amazon SageMaker strips all POST headers except those supported by the API. Amazon SageMaker might add additional
* headers. You should not rely on the behavior of headers outside those enumerated in the request syntax.
*
*
* Calls to InvokeEndpoint
are authenticated by using Amazon Web Services Signature Version 4. For
* information, see Authenticating
* Requests (Amazon Web Services Signature Version 4) in the Amazon S3 API Reference.
*
*
* A customer's model containers must respond to requests within 60 seconds. The model itself can have a maximum
* processing time of 60 seconds before responding to invocations. If your model is going to take 50-60 seconds of
* processing time, the SDK socket timeout should be set to be 70 seconds.
*
*
*
* Endpoints are scoped to an individual account, and are not public. The URL does not contain the account ID, but
* Amazon SageMaker determines the account ID from the authentication token that is supplied by the caller.
*
*
*
* @param invokeEndpointRequest
* @return A Java Future containing the result of the InvokeEndpoint operation returned by the service.
* The CompletableFuture returned by this method can be completed exceptionally with the following
* exceptions.
*
* - InternalFailureException An internal failure occurred.
* - ServiceUnavailableException The service is unavailable. Try your call again.
* - ValidationErrorException Inspect your request and try again.
* - ModelErrorException Model (owned by the customer in the container) returned 4xx or 5xx error code.
* - InternalDependencyException Your request caused an exception with an internal dependency. Contact
* customer support.
* - ModelNotReadyException Either a serverless endpoint variant's resources are still being provisioned,
* or a multi-model endpoint is still downloading or loading the target model. Wait and try your request
* again.
* - SdkException Base class for all exceptions that can be thrown by the SDK (both service and client).
* Can be used for catch all scenarios.
* - SdkClientException If any client side error occurs such as an IO related failure, failure to get
* credentials, etc.
* - SageMakerRuntimeException Base class for all service exceptions. Unknown exceptions will be thrown as
* an instance of this type.
*
* @sample SageMakerRuntimeAsyncClient.InvokeEndpoint
* @see AWS API Documentation
*/
@Override
public CompletableFuture invokeEndpoint(InvokeEndpointRequest invokeEndpointRequest) {
SdkClientConfiguration clientConfiguration = updateSdkClientConfiguration(invokeEndpointRequest, this.clientConfiguration);
List metricPublishers = resolveMetricPublishers(clientConfiguration, invokeEndpointRequest
.overrideConfiguration().orElse(null));
MetricCollector apiCallMetricCollector = metricPublishers.isEmpty() ? NoOpMetricCollector.create() : MetricCollector
.create("ApiCall");
try {
apiCallMetricCollector.reportMetric(CoreMetric.SERVICE_ID, "SageMaker Runtime");
apiCallMetricCollector.reportMetric(CoreMetric.OPERATION_NAME, "InvokeEndpoint");
JsonOperationMetadata operationMetadata = JsonOperationMetadata.builder().hasStreamingSuccessResponse(false)
.isPayloadJson(false).build();
HttpResponseHandler responseHandler = protocolFactory.createResponseHandler(
operationMetadata, InvokeEndpointResponse::builder);
HttpResponseHandler errorResponseHandler = createErrorResponseHandler(protocolFactory,
operationMetadata);
CompletableFuture executeFuture = clientHandler
.execute(new ClientExecutionParams()
.withOperationName("InvokeEndpoint").withProtocolMetadata(protocolMetadata)
.withMarshaller(new InvokeEndpointRequestMarshaller(protocolFactory))
.withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler)
.withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector)
.withInput(invokeEndpointRequest));
CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> {
metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect()));
});
executeFuture = CompletableFutureUtils.forwardExceptionTo(whenCompleted, executeFuture);
return executeFuture;
} catch (Throwable t) {
metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect()));
return CompletableFutureUtils.failedFuture(t);
}
}
/**
*
* After you deploy a model into production using Amazon SageMaker hosting services, your client applications use
* this API to get inferences from the model hosted at the specified endpoint in an asynchronous manner.
*
*
* Inference requests sent to this API are enqueued for asynchronous processing. The processing of the inference
* request may or may not complete before you receive a response from this API. The response from this API will not
* contain the result of the inference request but contain information about where you can locate it.
*
*
* Amazon SageMaker strips all POST headers except those supported by the API. Amazon SageMaker might add additional
* headers. You should not rely on the behavior of headers outside those enumerated in the request syntax.
*
*
* Calls to InvokeEndpointAsync
are authenticated by using Amazon Web Services Signature Version 4. For
* information, see Authenticating
* Requests (Amazon Web Services Signature Version 4) in the Amazon S3 API Reference.
*
*
* @param invokeEndpointAsyncRequest
* @return A Java Future containing the result of the InvokeEndpointAsync operation returned by the service.
* The CompletableFuture returned by this method can be completed exceptionally with the following
* exceptions.
*
* - InternalFailureException An internal failure occurred.
* - ServiceUnavailableException The service is unavailable. Try your call again.
* - ValidationErrorException Inspect your request and try again.
* - SdkException Base class for all exceptions that can be thrown by the SDK (both service and client).
* Can be used for catch all scenarios.
* - SdkClientException If any client side error occurs such as an IO related failure, failure to get
* credentials, etc.
* - SageMakerRuntimeException Base class for all service exceptions. Unknown exceptions will be thrown as
* an instance of this type.
*
* @sample SageMakerRuntimeAsyncClient.InvokeEndpointAsync
* @see AWS API Documentation
*/
@Override
public CompletableFuture invokeEndpointAsync(
InvokeEndpointAsyncRequest invokeEndpointAsyncRequest) {
SdkClientConfiguration clientConfiguration = updateSdkClientConfiguration(invokeEndpointAsyncRequest,
this.clientConfiguration);
List metricPublishers = resolveMetricPublishers(clientConfiguration, invokeEndpointAsyncRequest
.overrideConfiguration().orElse(null));
MetricCollector apiCallMetricCollector = metricPublishers.isEmpty() ? NoOpMetricCollector.create() : MetricCollector
.create("ApiCall");
try {
apiCallMetricCollector.reportMetric(CoreMetric.SERVICE_ID, "SageMaker Runtime");
apiCallMetricCollector.reportMetric(CoreMetric.OPERATION_NAME, "InvokeEndpointAsync");
JsonOperationMetadata operationMetadata = JsonOperationMetadata.builder().hasStreamingSuccessResponse(false)
.isPayloadJson(true).build();
HttpResponseHandler responseHandler = protocolFactory.createResponseHandler(
operationMetadata, InvokeEndpointAsyncResponse::builder);
HttpResponseHandler errorResponseHandler = createErrorResponseHandler(protocolFactory,
operationMetadata);
CompletableFuture executeFuture = clientHandler
.execute(new ClientExecutionParams()
.withOperationName("InvokeEndpointAsync").withProtocolMetadata(protocolMetadata)
.withMarshaller(new InvokeEndpointAsyncRequestMarshaller(protocolFactory))
.withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler)
.withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector)
.withInput(invokeEndpointAsyncRequest));
CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> {
metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect()));
});
executeFuture = CompletableFutureUtils.forwardExceptionTo(whenCompleted, executeFuture);
return executeFuture;
} catch (Throwable t) {
metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect()));
return CompletableFutureUtils.failedFuture(t);
}
}
/**
*
* Invokes a model at the specified endpoint to return the inference response as a stream. The inference stream
* provides the response payload incrementally as a series of parts. Before you can get an inference stream, you
* must have access to a model that's deployed using Amazon SageMaker hosting services, and the container for that
* model must support inference streaming.
*
*
* For more information that can help you use this API, see the following sections in the Amazon SageMaker
* Developer Guide:
*
*
* -
*
* For information about how to add streaming support to a model, see How Containers Serve Requests.
*
*
* -
*
* For information about how to process the streaming response, see Invoke real-time
* endpoints.
*
*
*
*
* Amazon SageMaker strips all POST headers except those supported by the API. Amazon SageMaker might add additional
* headers. You should not rely on the behavior of headers outside those enumerated in the request syntax.
*
*
* Calls to InvokeEndpointWithResponseStream
are authenticated by using Amazon Web Services Signature
* Version 4. For information, see Authenticating
* Requests (Amazon Web Services Signature Version 4) in the Amazon S3 API Reference.
*
*
* @param invokeEndpointWithResponseStreamRequest
* @return A Java Future containing the result of the InvokeEndpointWithResponseStream operation returned by the
* service.
* The CompletableFuture returned by this method can be completed exceptionally with the following
* exceptions.
*
* - InternalFailureException An internal failure occurred.
* - ServiceUnavailableException The service is unavailable. Try your call again.
* - ValidationErrorException Inspect your request and try again.
* - ModelErrorException Model (owned by the customer in the container) returned 4xx or 5xx error code.
* - ModelStreamErrorException An error occurred while streaming the response body. This error can have
* the following error codes:
*
* - ModelInvocationTimeExceeded
* -
*
* The model failed to finish sending the response within the timeout period allowed by Amazon SageMaker.
*
*
* - StreamBroken
* -
*
* The Transmission Control Protocol (TCP) connection between the client and the model was reset or closed.
*
*
* - InternalStreamFailureException The stream processing failed because of an unknown error, exception or
* failure. Try your request again.
* - SdkException Base class for all exceptions that can be thrown by the SDK (both service and client).
* Can be used for catch all scenarios.
* - SdkClientException If any client side error occurs such as an IO related failure, failure to get
* credentials, etc.
* - SageMakerRuntimeException Base class for all service exceptions. Unknown exceptions will be thrown as
* an instance of this type.
*
* @sample SageMakerRuntimeAsyncClient.InvokeEndpointWithResponseStream
* @see AWS API Documentation
*/
@Override
public CompletableFuture invokeEndpointWithResponseStream(
InvokeEndpointWithResponseStreamRequest invokeEndpointWithResponseStreamRequest,
InvokeEndpointWithResponseStreamResponseHandler asyncResponseHandler) {
SdkClientConfiguration clientConfiguration = updateSdkClientConfiguration(invokeEndpointWithResponseStreamRequest,
this.clientConfiguration);
List metricPublishers = resolveMetricPublishers(clientConfiguration,
invokeEndpointWithResponseStreamRequest.overrideConfiguration().orElse(null));
MetricCollector apiCallMetricCollector = metricPublishers.isEmpty() ? NoOpMetricCollector.create() : MetricCollector
.create("ApiCall");
try {
apiCallMetricCollector.reportMetric(CoreMetric.SERVICE_ID, "SageMaker Runtime");
apiCallMetricCollector.reportMetric(CoreMetric.OPERATION_NAME, "InvokeEndpointWithResponseStream");
JsonOperationMetadata operationMetadata = JsonOperationMetadata.builder().hasStreamingSuccessResponse(false)
.isPayloadJson(true).build();
HttpResponseHandler responseHandler = new AttachHttpMetadataResponseHandler(
protocolFactory.createResponseHandler(operationMetadata, InvokeEndpointWithResponseStreamResponse::builder));
HttpResponseHandler voidResponseHandler = protocolFactory.createResponseHandler(JsonOperationMetadata
.builder().isPayloadJson(false).hasStreamingSuccessResponse(true).build(), VoidSdkResponse::builder);
HttpResponseHandler extends ResponseStream> eventResponseHandler = protocolFactory.createResponseHandler(
JsonOperationMetadata.builder().isPayloadJson(true).hasStreamingSuccessResponse(false).build(),
EventStreamTaggedUnionPojoSupplier.builder()
.putSdkPojoSupplier("PayloadPart", ResponseStream::payloadPartBuilder)
.defaultSdkPojoSupplier(() -> new SdkPojoBuilder(ResponseStream.UNKNOWN)).build());
HttpResponseHandler errorResponseHandler = createErrorResponseHandler(protocolFactory,
operationMetadata);
CompletableFuture future = new CompletableFuture<>();
EventStreamAsyncResponseTransformer asyncResponseTransformer = EventStreamAsyncResponseTransformer
. builder()
.eventStreamResponseHandler(asyncResponseHandler).eventResponseHandler(eventResponseHandler)
.initialResponseHandler(responseHandler).exceptionResponseHandler(errorResponseHandler).future(future)
.executor(executor).serviceName(serviceName()).build();
RestEventStreamAsyncResponseTransformer restAsyncResponseTransformer = RestEventStreamAsyncResponseTransformer
. builder()
.eventStreamAsyncResponseTransformer(asyncResponseTransformer)
.eventStreamResponseHandler(asyncResponseHandler).build();
CompletableFuture executeFuture = clientHandler
.execute(
new ClientExecutionParams()
.withOperationName("InvokeEndpointWithResponseStream").withProtocolMetadata(protocolMetadata)
.withMarshaller(new InvokeEndpointWithResponseStreamRequestMarshaller(protocolFactory))
.withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler)
.withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector)
.withInput(invokeEndpointWithResponseStreamRequest), restAsyncResponseTransformer);
CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> {
if (e != null) {
try {
asyncResponseHandler.exceptionOccurred(e);
} finally {
future.completeExceptionally(e);
}
}
metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect()));
});
executeFuture = CompletableFutureUtils.forwardExceptionTo(whenCompleted, executeFuture);
return CompletableFutureUtils.forwardExceptionTo(future, executeFuture);
} catch (Throwable t) {
runAndLogError(log, "Exception thrown in exceptionOccurred callback, ignoring",
() -> asyncResponseHandler.exceptionOccurred(t));
metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect()));
return CompletableFutureUtils.failedFuture(t);
}
}
@Override
public final SageMakerRuntimeServiceClientConfiguration serviceClientConfiguration() {
return new SageMakerRuntimeServiceClientConfigurationBuilder(this.clientConfiguration.toBuilder()).build();
}
@Override
public final String serviceName() {
return SERVICE_NAME;
}
private > T init(T builder) {
return builder
.clientConfiguration(clientConfiguration)
.defaultServiceExceptionSupplier(SageMakerRuntimeException::builder)
.protocol(AwsJsonProtocol.REST_JSON)
.protocolVersion("1.1")
.registerModeledException(
ExceptionMetadata.builder().errorCode("ModelStreamError")
.exceptionBuilderSupplier(ModelStreamErrorException::builder).build())
.registerModeledException(
ExceptionMetadata.builder().errorCode("ModelError")
.exceptionBuilderSupplier(ModelErrorException::builder).httpStatusCode(424).build())
.registerModeledException(
ExceptionMetadata.builder().errorCode("InternalStreamFailure")
.exceptionBuilderSupplier(InternalStreamFailureException::builder).build())
.registerModeledException(
ExceptionMetadata.builder().errorCode("ValidationError")
.exceptionBuilderSupplier(ValidationErrorException::builder).httpStatusCode(400).build())
.registerModeledException(
ExceptionMetadata.builder().errorCode("ServiceUnavailable")
.exceptionBuilderSupplier(ServiceUnavailableException::builder).httpStatusCode(503).build())
.registerModeledException(
ExceptionMetadata.builder().errorCode("InternalFailure")
.exceptionBuilderSupplier(InternalFailureException::builder).httpStatusCode(500).build())
.registerModeledException(
ExceptionMetadata.builder().errorCode("InternalDependencyException")
.exceptionBuilderSupplier(InternalDependencyException::builder).httpStatusCode(530).build())
.registerModeledException(
ExceptionMetadata.builder().errorCode("ModelNotReadyException")
.exceptionBuilderSupplier(ModelNotReadyException::builder).httpStatusCode(429).build());
}
private static List resolveMetricPublishers(SdkClientConfiguration clientConfiguration,
RequestOverrideConfiguration requestOverrideConfiguration) {
List publishers = null;
if (requestOverrideConfiguration != null) {
publishers = requestOverrideConfiguration.metricPublishers();
}
if (publishers == null || publishers.isEmpty()) {
publishers = clientConfiguration.option(SdkClientOption.METRIC_PUBLISHERS);
}
if (publishers == null) {
publishers = Collections.emptyList();
}
return publishers;
}
private SdkClientConfiguration updateSdkClientConfiguration(SdkRequest request, SdkClientConfiguration clientConfiguration) {
List plugins = request.overrideConfiguration().map(c -> c.plugins()).orElse(Collections.emptyList());
if (plugins.isEmpty()) {
return clientConfiguration;
}
SdkClientConfiguration.Builder configuration = clientConfiguration.toBuilder();
SageMakerRuntimeServiceClientConfigurationBuilder serviceConfigBuilder = new SageMakerRuntimeServiceClientConfigurationBuilder(
configuration);
for (SdkPlugin plugin : plugins) {
plugin.configureClient(serviceConfigBuilder);
}
return configuration.build();
}
private HttpResponseHandler createErrorResponseHandler(BaseAwsJsonProtocolFactory protocolFactory,
JsonOperationMetadata operationMetadata) {
return protocolFactory.createErrorResponseHandler(operationMetadata);
}
@Override
public void close() {
clientHandler.close();
}
}