All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.bazaarvoice.ostrich.pool.ServicePool Maven / Gradle / Ivy

package com.bazaarvoice.ostrich.pool;

import com.bazaarvoice.ostrich.HealthCheckResult;
import com.bazaarvoice.ostrich.HealthCheckResults;
import com.bazaarvoice.ostrich.HostDiscovery;
import com.bazaarvoice.ostrich.LoadBalanceAlgorithm;
import com.bazaarvoice.ostrich.PartitionContext;
import com.bazaarvoice.ostrich.PartitionContextBuilder;
import com.bazaarvoice.ostrich.RetryPolicy;
import com.bazaarvoice.ostrich.ServiceCallback;
import com.bazaarvoice.ostrich.ServiceEndPoint;
import com.bazaarvoice.ostrich.ServiceFactory;
import com.bazaarvoice.ostrich.ServicePoolStatistics;
import com.bazaarvoice.ostrich.exceptions.MaxRetriesException;
import com.bazaarvoice.ostrich.exceptions.NoAvailableHostsException;
import com.bazaarvoice.ostrich.exceptions.NoCachedInstancesAvailableException;
import com.bazaarvoice.ostrich.exceptions.NoSuitableHostsException;
import com.bazaarvoice.ostrich.exceptions.OnlyBadHostsException;
import com.bazaarvoice.ostrich.healthcheck.DefaultHealthCheckResults;
import com.bazaarvoice.ostrich.metrics.Metrics;
import com.bazaarvoice.ostrich.partition.PartitionFilter;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Objects;
import com.google.common.base.Predicate;
import com.google.common.base.Predicates;
import com.google.common.base.Stopwatch;
import com.google.common.base.Throwables;
import com.google.common.base.Ticker;
import com.google.common.cache.CacheBuilder;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import com.yammer.metrics.core.Gauge;
import com.yammer.metrics.core.Meter;
import com.yammer.metrics.core.Timer;
import com.yammer.metrics.core.TimerContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.util.Set;
import java.util.concurrent.Future;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;

import static com.google.common.base.Preconditions.checkNotNull;

class ServicePool implements com.bazaarvoice.ostrich.ServicePool {
    private static final Logger LOG = LoggerFactory.getLogger(ServicePool.class);

    // By default check every minute to see if a previously unhealthy end point has become healthy.
    @VisibleForTesting
    static final long HEALTH_CHECK_POLL_INTERVAL_IN_SECONDS = 60;

    private final Ticker _ticker;
    private final HostDiscovery _hostDiscovery;
    private final boolean _cleanupHostDiscoveryOnClose;
    private final HostDiscovery.EndPointListener _hostDiscoveryListener;
    private final ServiceFactory _serviceFactory;
    private final ScheduledExecutorService _healthCheckExecutor;
    private final boolean _shutdownHealthCheckExecutorOnClose;
    private final PartitionFilter _partitionFilter;
    private final LoadBalanceAlgorithm _loadBalanceAlgorithm;
    private final ServicePoolStatistics _servicePoolStatistics;
    private final Set _badEndPoints;
    private final Predicate _badEndPointFilter;
    private final Set _recentlyRemovedEndPoints;
    private final Future _batchHealthChecksFuture;
    private final ServiceCache _serviceCache;
    private final Metrics _metrics;
    private final Timer _callbackExecutionTime;
    private final Timer _healthCheckTime;
    private final Meter _numExecuteSuccesses;
    private final Meter _numExecuteAttemptFailures;

    ServicePool(Ticker ticker, HostDiscovery hostDiscovery, boolean cleanupHostDiscoveryOnClose,
                ServiceFactory serviceFactory, ServiceCachingPolicy cachingPolicy,
                PartitionFilter partitionFilter, LoadBalanceAlgorithm loadBalanceAlgorithm,
                ScheduledExecutorService healthCheckExecutor, boolean shutdownHealthCheckExecutorOnClose) {
        _ticker = checkNotNull(ticker);
        _hostDiscovery = checkNotNull(hostDiscovery);
        _cleanupHostDiscoveryOnClose = cleanupHostDiscoveryOnClose;
        _serviceFactory = checkNotNull(serviceFactory);
        _healthCheckExecutor = checkNotNull(healthCheckExecutor);
        _shutdownHealthCheckExecutorOnClose = shutdownHealthCheckExecutorOnClose;
        _badEndPoints = Sets.newSetFromMap(Maps.newConcurrentMap());
        _badEndPointFilter = Predicates.not(Predicates.in(_badEndPoints));
        _recentlyRemovedEndPoints = Sets.newSetFromMap(CacheBuilder.newBuilder()
                .ticker(_ticker)
                .expireAfterWrite(10, TimeUnit.MINUTES)  // TODO: Make this a constant
                .build()
                .asMap());
        checkNotNull(cachingPolicy);
        _serviceCache = new ServiceCache(cachingPolicy, serviceFactory);
        _partitionFilter = checkNotNull(partitionFilter);
        _loadBalanceAlgorithm = checkNotNull(loadBalanceAlgorithm);

        _servicePoolStatistics = new ServicePoolStatistics() {
            @Override
            public int getNumIdleCachedInstances(ServiceEndPoint endPoint) {
                return _serviceCache.getNumIdleInstances(endPoint);
            }

            @Override
            public int getNumActiveInstances(ServiceEndPoint endPoint) {
                return _serviceCache.getNumActiveInstances(endPoint);
            }
        };

        // Watch end points as they are removed from host discovery so that we can remove them from our set of bad
        // end points as well.  This will prevent the bad end points set from growing in an unbounded fashion.
        // There is a minor race condition that could happen here, but it's not anything to be concerned about.  The
        // HostDiscovery component could lose its connection to its backing data store and then immediately regain it
        // right afterwards.  If that happens it could remove all of its end points only to re-add them right back again
        // and we will "forget" that an end point was bad and try to use it again.  This isn't fatal though because
        // we'll just rediscover that it's a bad end point again in the future.  Also in the future it might be useful
        // to measure how long an end point has been considered bad and potentially take action for end points that are
        // bad for long periods of time.
        _hostDiscoveryListener = new HostDiscovery.EndPointListener() {
            @Override
            public void onEndPointAdded(ServiceEndPoint endPoint) {
                addEndPoint(endPoint);
            }

            @Override
            public void onEndPointRemoved(ServiceEndPoint endPoint) {
                removeEndPoint(endPoint);
            }
        };
        _hostDiscovery.addListener(_hostDiscoveryListener);

        // Periodically wake up and check any bad end points to see if they're now healthy.
        _batchHealthChecksFuture = _healthCheckExecutor.scheduleAtFixedRate(new BatchHealthChecks(),
                HEALTH_CHECK_POLL_INTERVAL_IN_SECONDS, HEALTH_CHECK_POLL_INTERVAL_IN_SECONDS, TimeUnit.SECONDS);

        String serviceName = _serviceFactory.getServiceName();
        _metrics = Metrics.forInstance(this, serviceName);
        _callbackExecutionTime = _metrics.newTimer(serviceName, "callback-execution-time", TimeUnit.MILLISECONDS,
                TimeUnit.SECONDS);
        _healthCheckTime = _metrics.newTimer(serviceName, "health-check-time", TimeUnit.MILLISECONDS, TimeUnit.SECONDS);
        _numExecuteSuccesses = _metrics.newMeter(serviceName, "num-execute-successes", "successes", TimeUnit.SECONDS);
        _numExecuteAttemptFailures = _metrics.newMeter(serviceName, "num-execute-attempt-failures", "failures",
                TimeUnit.SECONDS);
        _metrics.newGauge(serviceName, "num-valid-end-points", new Gauge() {
            @Override
            public Integer value() {
                return getNumValidEndPoints();
            }
        });
        _metrics.newGauge(serviceName, "num-bad-end-points", new Gauge() {
            @Override
            public Integer value() {
                return getNumBadEndPoints();
            }
        });
    }

    @Override
    public void close() {
        _batchHealthChecksFuture.cancel(true);

        _hostDiscovery.removeListener(_hostDiscoveryListener);
        if (_cleanupHostDiscoveryOnClose) {
            try {
                _hostDiscovery.close();
            } catch (IOException e) {
                // NOP
            }
        }

        _metrics.close();

        if (_shutdownHealthCheckExecutorOnClose) {
            _healthCheckExecutor.shutdownNow();
        }
    }

    @Override
    public  R execute(RetryPolicy retry, ServiceCallback callback) {
        return execute(PartitionContextBuilder.empty(), retry, callback);
    }

    @Override
    public  R execute(PartitionContext partitionContext, RetryPolicy retry, ServiceCallback callback) {
        Stopwatch sw = new Stopwatch(_ticker).start();
        int numAttempts = 0;
        Exception lastException;

        do {
            ServiceEndPoint endPoint = chooseEndPoint(getValidEndPoints(), partitionContext);

            try {
                R result = executeOnEndPoint(endPoint, callback);
                _numExecuteSuccesses.mark();
                return result;
            } catch (Exception e) {
                _numExecuteAttemptFailures.mark();

                // Don't retry if exception is too severe.
                if (!isRetriableException(e)) {
                    throw Throwables.propagate(e);
                }

                LOG.info("Retriable exception from end point id: " + endPoint.getId(), e);
                lastException = e;
            }
        } while (retry.allowRetry(++numAttempts, sw.elapsedMillis()));

        throw new MaxRetriesException(lastException);
    }

    @Override
    public int getNumValidEndPoints() {
        return Iterables.size(_hostDiscovery.getHosts()) - _badEndPoints.size();
    }

    @Override
    public int getNumBadEndPoints() {
        return _badEndPoints.size();
    }

    /**
     * Determine the set of all {@link ServiceEndPoint}s.
     * 

* NOTE: This method is package private specifically so that {@link AsyncServicePool} can call it. */ Iterable getAllEndPoints() { Iterable hosts = _hostDiscovery.getHosts(); if (Iterables.isEmpty(hosts)) { // There were no service end points available, we have no choice but to stop trying and just exit. throw new NoAvailableHostsException(); } return hosts; } /** * Determine the set of usable {@link ServiceEndPoint}s. *

* NOTE: This method is package private specifically so that {@link AsyncServicePool} can call it. */ Iterable getValidEndPoints() { Iterable goodHosts = Iterables.filter(getAllEndPoints(), _badEndPointFilter); if (Iterables.isEmpty(goodHosts)) { // All available hosts are bad, so we must give up. throw new OnlyBadHostsException(); } return goodHosts; } private ServiceEndPoint chooseEndPoint(Iterable endPoints, PartitionContext partitionContext) { endPoints = _partitionFilter.filter(endPoints, partitionContext); if (endPoints == null || Iterables.isEmpty(endPoints)) { throw new NoSuitableHostsException(); } ServiceEndPoint endPoint = _loadBalanceAlgorithm.choose(endPoints, _servicePoolStatistics); if (endPoint == null) { throw new NoSuitableHostsException(); } return endPoint; } /** * Execute a callback on a specific end point. *

* NOTE: This method is package private specifically so that {@link AsyncServicePool} can call it. */ R executeOnEndPoint(ServiceEndPoint endPoint, ServiceCallback callback) throws Exception { S service = null; try { service = _serviceCache.checkOut(endPoint); TimerContext timer = _callbackExecutionTime.time(); try { return callback.call(service); } finally { timer.stop(); } } catch (NoCachedInstancesAvailableException e) { LOG.debug("Service cache exhausted. End point ID: " + endPoint.getId(), e); // Don't mark an end point as bad just because there are no cached end points for it. throw e; } catch (Exception e) { if (_serviceFactory.isRetriableException(e)) { // This is a known and supported exception indicating that something went wrong somewhere in the service // layer while trying to communicate with the end point. These errors are often transient, so we // enqueue a health check for the end point and mark it as unavailable for the time being. markEndPointAsBad(endPoint); LOG.debug("Bad end point discovered. End point ID: " + endPoint.getId(), e); } throw e; } finally { if (service != null) { try { _serviceCache.checkIn(endPoint, service); } catch (Exception e) { // This should never happen, but log just in case. LOG.warn("Error returning end point to cache. End point ID: " + endPoint.getId(), e); } } } } /** * Check if an exception is retriable. *

* NOTE: This method is package private specifically so that {@link AsyncServicePool} can call it. */ boolean isRetriableException(Exception exception) { return _serviceFactory.isRetriableException(exception); } /** * NOTE: This method is package private specifically so that {@link AsyncServicePool} can call it. * @return The name of the service for this pool. */ String getServiceName() { return _serviceFactory.getServiceName(); } @VisibleForTesting HostDiscovery getHostDiscovery() { return _hostDiscovery; } @VisibleForTesting PartitionFilter getPartitionFilter() { return _partitionFilter; } @VisibleForTesting LoadBalanceAlgorithm getLoadBalanceAlgorithm() { return _loadBalanceAlgorithm; } @VisibleForTesting ServicePoolStatistics getServicePoolStatistics() { return _servicePoolStatistics; } @VisibleForTesting Set getBadEndPoints() { return ImmutableSet.copyOf(_badEndPoints); } @Override public HealthCheckResults checkForHealthyEndPoint() { Set endPoints; DefaultHealthCheckResults aggregate = new DefaultHealthCheckResults(); try { // Take a snapshot of the current end points. endPoints = Sets.newHashSet(getValidEndPoints()); } catch (Exception e) { // No valid end points means no healthy end points. return aggregate; } while (!endPoints.isEmpty()) { ServiceEndPoint endPoint; try { // Prefer end points in the order the load balancer recommends. endPoint = chooseEndPoint(endPoints, PartitionContextBuilder.empty()); } catch (Exception e) { // Load balancer didn't like our end points, so just go sequentially. endPoint = endPoints.iterator().next(); } HealthCheckResult result = checkHealth(endPoint); aggregate.addHealthCheckResult(result); if (!result.isHealthy()) { Exception exception = ((FailedHealthCheckResult) result).getException(); if (exception == null || isRetriableException(exception)) { LOG.debug("Unhealthy end point discovered. End point ID: " + endPoint.getId()); endPoints.remove(endPoint); markEndPointAsBad(endPoint); continue; } } break; } return aggregate; } private synchronized void addEndPoint(ServiceEndPoint endPoint) { _recentlyRemovedEndPoints.remove(endPoint); _badEndPoints.remove(endPoint); LOG.debug("End point added to service pool. End point ID: " + endPoint.getId()); } private synchronized void removeEndPoint(ServiceEndPoint endPoint) { // Mark this end point as recently removed. We do this in order to keep a positive set of removed // end points so that we avoid a potential race condition where someone was using this end point while // we noticed it disappeared from host discovery. In that case there is the potential that they // would add it to the bad end points set after we've already processed the removal, thus leading to a // memory leak in the bad end points set. Having this time-limited view of the recently removed // end points ensures that this memory leak doesn't happen. _recentlyRemovedEndPoints.add(endPoint); _badEndPoints.remove(endPoint); _serviceCache.evict(endPoint); LOG.debug("End point removed from service pool. End point ID: " + endPoint.getId()); } private synchronized void markEndPointAsBad(ServiceEndPoint endPoint) { if (_recentlyRemovedEndPoints.contains(endPoint)) { // Nothing to do, we've already removed this end point return; } _serviceCache.evict(endPoint); // Only schedule a health check if this is the first time we've seen this end point as bad... if (_badEndPoints.add(endPoint)) { _healthCheckExecutor.submit(new HealthCheck(endPoint)); } } @VisibleForTesting HealthCheckResult checkHealth(ServiceEndPoint endPoint) { // We have to be very careful to not allow any exceptions to make it out of of this method, if they do then // subsequent scheduled invocations of the Runnable may not happen, and we could stop checking health checks // completely. So we intentionally handle all possible exceptions here. Stopwatch sw = new Stopwatch(_ticker).start(); try { return _serviceFactory.isHealthy(endPoint) ? new SuccessfulHealthCheckResult(endPoint.getId(), sw.stop().elapsedTime(TimeUnit.NANOSECONDS)) : new FailedHealthCheckResult(endPoint.getId(), sw.stop().elapsedTime(TimeUnit.NANOSECONDS)); } catch (Exception e) { return new FailedHealthCheckResult(endPoint.getId(), sw.stop().elapsedTime(TimeUnit.NANOSECONDS), e); } finally { _healthCheckTime.update(sw.elapsedTime(TimeUnit.NANOSECONDS), TimeUnit.NANOSECONDS); } } @VisibleForTesting final class HealthCheck implements Runnable { private final ServiceEndPoint _endPoint; public HealthCheck(ServiceEndPoint endPoint) { _endPoint = endPoint; } @Override public void run() { HealthCheckResult result = checkHealth(_endPoint); if (result.isHealthy()) { _badEndPoints.remove(_endPoint); } } } @VisibleForTesting final class BatchHealthChecks implements Runnable { @Override public void run() { for (ServiceEndPoint endPoint : _badEndPoints) { HealthCheckResult result = checkHealth(endPoint); if (result.isHealthy()) { _badEndPoints.remove(endPoint); } // If we were interrupted during checking the health (but weren't blocked so an InterruptedException // couldn't be thrown), then we should exit now. if (Thread.currentThread().isInterrupted()) { break; } } } } private static final class SuccessfulHealthCheckResult implements HealthCheckResult { private final String _endPointId; private final long _responseTimeInNanos; public SuccessfulHealthCheckResult(String endPointId, long responseTimeInNanos) { _endPointId = endPointId; _responseTimeInNanos = responseTimeInNanos; } @Override public boolean isHealthy() { return true; } @Override public String getEndPointId() { return _endPointId; } @Override public long getResponseTime(TimeUnit unit) { return unit.convert(_responseTimeInNanos, TimeUnit.NANOSECONDS); } @Override public String toString() { return Objects.toStringHelper(this) .add("endPointId", _endPointId) .toString(); } } private static final class FailedHealthCheckResult implements HealthCheckResult { private final String _endPointId; private final long _responseTimeInNanos; private final Exception _exception; public FailedHealthCheckResult(String endPointId, long responseTimeInNanos, Exception exception) { _endPointId = endPointId; _responseTimeInNanos = responseTimeInNanos; _exception = exception; } public FailedHealthCheckResult(String endPointId, long responseTimeInNanos) { this(endPointId, responseTimeInNanos, null); } @Override public boolean isHealthy() { return false; } @Override public String getEndPointId() { return _endPointId; } @Override public long getResponseTime(TimeUnit unit) { return unit.convert(_responseTimeInNanos, TimeUnit.NANOSECONDS); } public Exception getException() { return _exception; } @Override public String toString() { return Objects.toStringHelper(this) .add("endPointId", _endPointId) .add("exception", _exception) .toString(); } } }