All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cohort.CohortLoader.kt Maven / Gradle / Ivy

@file:OptIn(ExperimentalApi::class)

package com.amplitude.experiment.cohort

import com.amplitude.experiment.ExperimentalApi
import com.amplitude.experiment.LocalEvaluationMetrics
import com.amplitude.experiment.util.LocalEvaluationMetricsWrapper
import com.amplitude.experiment.util.Logger
import com.amplitude.experiment.util.daemonFactory
import com.amplitude.experiment.util.wrapMetrics
import java.util.concurrent.CompletableFuture
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.LinkedBlockingQueue
import java.util.concurrent.ThreadPoolExecutor
import java.util.concurrent.TimeUnit

internal class CohortLoader(
    private val maxCohortSize: Int,
    private val cohortDownloadApi: CohortDownloadApi,
    private val cohortStorage: CohortStorage,
    // TODO: Remove once dedicated API is implemented.
    private val directCohortDownloadApi: DirectCohortDownloadApiV5? = null,
    private val metrics: LocalEvaluationMetrics = LocalEvaluationMetricsWrapper(),
) {

    private val jobs = ConcurrentHashMap>()
    private val cachedJobs = ConcurrentHashMap>()
    private val executor = ThreadPoolExecutor(
        32,
        32,
        60,
        TimeUnit.SECONDS,
        LinkedBlockingQueue(),
        daemonFactory,
    ).apply {
        allowCoreThreadTimeOut(true)
    }

    fun loadCohort(cohortId: String): CompletableFuture<*> {
        return jobs.getOrPut(cohortId) {
            CompletableFuture.runAsync({
                Logger.d("Loading cohort $cohortId")
                val cohortDescription = getCohortDescription(cohortId)
                if (shouldDownloadCohort(cohortDescription)) {
                    val cohortMembers = downloadCohort(cohortDescription)
                    cohortStorage.putCohort(cohortDescription, cohortMembers)
                }
            }, executor).whenComplete { _, _ -> jobs.remove(cohortId) }
        }
    }

    fun loadCachedCohort(cohortId: String): CompletableFuture<*> {
        return cachedJobs.getOrPut(cohortId) {
            CompletableFuture.runAsync({
                Logger.d("Loading cohort from cache $cohortId")
                // Cached cohorts should be refreshed, so set last computed to 0.
                val cohortDescription = getCohortDescription(cohortId).copy(lastComputed = 0)
                if (shouldDownloadCohort(cohortDescription)) {
                    val cohortMembers = downloadCachedCohort(cohortDescription)
                    cohortStorage.putCohort(cohortDescription, cohortMembers)
                }
            }, executor).whenComplete { _, _ -> cachedJobs.remove(cohortId) }
        }
    }

    private fun getCohortDescription(cohortIds: String): CohortDescription {
        return wrapMetrics(
            metric = metrics::onCohortDescriptionsFetch,
            failure = metrics::onCohortDescriptionsFetchFailure,
        ) {
            cohortDownloadApi.getCohortDescription(cohortIds)
        }
    }

    private fun shouldDownloadCohort(cohortDescription: CohortDescription): Boolean {
        val storageDescription = cohortStorage.getCohortDescription(cohortDescription.id)
        return cohortDescription.size <= maxCohortSize &&
            cohortDescription.lastComputed > (storageDescription?.lastComputed ?: -1)
    }

    private fun downloadCohort(cohortDescription: CohortDescription): Set {
        return wrapMetrics(
            metric = metrics::onCohortDownload,
            failure = metrics::onCohortDownloadFailure,
        ) {
            try {
                cohortDownloadApi.getCohortMembers(cohortDescription)
            } catch (e: CachedCohortDownloadException) {
                metrics.onCohortDownloadFailureCachedFallback(e.cause)
                e.members
            }
        }
    }

    private fun downloadCachedCohort(cohortDescription: CohortDescription): Set {
        return wrapMetrics(
            metric = metrics::onCohortDownload,
            failure = metrics::onCohortDownloadFailure,
        ) {
            directCohortDownloadApi?.getCachedCohortMembers(cohortDescription.id, cohortDescription.groupType)
                ?: cohortDownloadApi.getCohortMembers(cohortDescription)
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy