
cohort.CohortService.kt Maven / Gradle / Ivy
package com.amplitude.experiment.cohort
import com.amplitude.experiment.util.Logger
import com.amplitude.experiment.util.Once
import java.util.concurrent.Executors
import java.util.concurrent.TimeUnit
internal const val DEFAULT_MAX_COHORT_SIZE = 15_000
internal const val DEFAULT_SYNC_INTERVAL_SECONDS = 60L
internal data class CohortServiceConfig(
val maxCohortSize: Int = DEFAULT_MAX_COHORT_SIZE,
val cohortSyncIntervalSeconds: Long = DEFAULT_SYNC_INTERVAL_SECONDS
)
internal interface CohortService {
fun start()
fun stop()
fun refresh(cohortIds: Set = setOf())
fun getCohorts(userId: String): Set
}
internal class CohortServiceImpl(
private val config: CohortServiceConfig,
private val cohortApi: CohortApi,
private val cohortStorage: CohortStorage,
private val cohortIdProvider: CohortIdProvider,
) : CohortService {
private val lock = Once()
private val poller = Executors.newSingleThreadScheduledExecutor()
override fun refresh(cohortIds: Set) {
Logger.d("Refreshing cohorts $cohortIds")
getCohortDescriptions()
.filterCohorts(cohortIds)
.downloadCohorts()
.storeCohorts()
}
override fun start() {
lock.once {
refresh()
poller.scheduleWithFixedDelay(
{ refresh() },
config.cohortSyncIntervalSeconds,
config.cohortSyncIntervalSeconds,
TimeUnit.SECONDS
)
}
}
override fun stop() {
poller.shutdown()
}
override fun getCohorts(userId: String): Set {
return cohortStorage.getCohortsForUser(userId)
}
internal fun getCohortDescriptions(): List {
Logger.d("Getting cohort descriptions.")
return cohortApi.getCohorts(GetCohortsRequest).get().cohorts.apply {
Logger.d("Got cohort descriptions: $this")
}
}
private fun List.filterCohorts(cohortIds: Set): List =
filterCohorts(this, cohortIds)
internal fun filterCohorts(cohortDescriptions: List, cohortIds: Set = setOf()): List {
// Filter for explicit cohort ids, otherwise use the cohort ID provider.
val includedCohortIds = cohortIds.ifEmpty {
val managedCohorts = cohortIdProvider.invoke()
// Delete stored cohorts that are not being managed.
val storedCohorts = cohortStorage.getAllCohortDescriptions()
storedCohorts.keys.forEach { storedCohortId ->
if (!managedCohorts.contains(storedCohortId)) {
cohortStorage.deleteCohort(storedCohortId)
Logger.d("Deleting unmanaged cohort $storedCohortId")
}
}
managedCohorts
}
Logger.d("Filtering cohorts for download: $includedCohortIds")
// Filter out cohorts which are (1) not being targeted (2) too large (3) not updated
return cohortDescriptions.filter { inputDescription ->
val storageDescription = cohortStorage.getCohortDescription(inputDescription.id)
includedCohortIds.contains(inputDescription.id) &&
inputDescription.size < config.maxCohortSize &&
inputDescription.lastComputed > (storageDescription?.lastComputed ?: -1)
}.apply {
Logger.d("Cohorts filtered: $this")
}
}
private fun List.downloadCohorts(): List =
downloadCohorts(this)
internal fun downloadCohorts(cohortDescriptions: List): List {
Logger.d("Downloading cohorts.")
// Make a request to download each cohort
return cohortDescriptions.map { description ->
Logger.d("Downloading cohort ${description.id}")
cohortApi.getCohort(
GetCohortRequest(
cohortId = description.id,
lastComputed = description.lastComputed
)
)
}
// Handle exceptions and get the response
.mapNotNull {
it.handle { response, t ->
Logger.d("Downloaded cohort ${response?.cohort?.id}")
if (response == null || t != null) {
Logger.e("get cohort request failed", t)
null
} else {
response
}
}.join()
}.apply {
Logger.d("Downloaded cohorts.")
}
}
private fun List.storeCohorts() =
storeCohorts(this)
internal fun storeCohorts(getCohortResponses: List) {
Logger.d("Storing cohorts.")
// Store the cohort included in the response
getCohortResponses.forEach {
cohortStorage.putCohort(it.cohort, it.userIds)
}.apply {
Logger.d("Cohorts stored.")
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy