
cohort.CohortDownloadApi.kt Maven / Gradle / Ivy
package com.amplitude.experiment.cohort
import com.amplitude.experiment.util.Logger
import com.amplitude.experiment.util.get
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
import okhttp3.HttpUrl.Companion.toHttpUrl
import okhttp3.OkHttpClient
import okio.IOException
import org.apache.commons.csv.CSVFormat
import org.apache.commons.csv.CSVParser
import java.lang.Thread.sleep
import java.util.Base64
import java.util.concurrent.Semaphore
import java.util.concurrent.TimeUnit
/*
* Based on the Behavioral Cohort API:
* https://www.docs.developers.amplitude.com/analytics/apis/behavioral-cohorts-api/
*/
// TODO make configurable to support EU datacenter
private const val CDN_COHORT_SYNC_URL = "https://cohort.lab.amplitude.com/"
@Serializable
private data class SerialCohortDescription(
@SerialName("lastComputed") val lastComputed: Long,
@SerialName("published") val published: Boolean,
@SerialName("archived") val archived: Boolean,
@SerialName("appId") val appId: Int,
@SerialName("lastMod") val lastMod: Long,
@SerialName("type") val type: String,
@SerialName("id") val id: String,
@SerialName("size") val size: Int,
)
@Serializable
private data class SerialSingleCohortDescription(
@SerialName("cohort_id") val cohortId: String,
@SerialName("app_id") val appId: Int = 0,
@SerialName("org_id") val orgId: Int = 0,
@SerialName("name") val name: String? = null,
@SerialName("size") val size: Int = Int.MAX_VALUE,
@SerialName("description") val description: String? = null,
@SerialName("last_computed") val lastComputed: Long = 0,
)
@Serializable
private data class GetCohortDescriptionsResponse(
@SerialName("cohorts") val cohorts: List,
)
@Serializable
private data class GetCohortMembersResponse(
@SerialName("cohort") val cohort: SerialCohortDescription,
@SerialName("user_ids") val userIds: List,
)
internal interface CohortDownloadApi {
fun getCohortDescriptions(cohortIds: Set): List
fun getCohortMembers(cohortDescription: CohortDescription): Set
}
internal class DirectCohortDownloadApiV3(
apiKey: String,
secretKey: String,
httpClient: OkHttpClient,
) : CohortDownloadApi {
private val httpClient: OkHttpClient = httpClient.newBuilder()
.readTimeout(5, TimeUnit.MINUTES)
.build()
private val serverUrl = CDN_COHORT_SYNC_URL.toHttpUrl()
private val semaphore = Semaphore(5, true)
private val basicAuth = Base64.getEncoder().encodeToString("$apiKey:$secretKey".toByteArray(Charsets.UTF_8))
override fun getCohortDescriptions(cohortIds: Set): List {
return semaphore.limit {
val response = httpClient.get(
serverUrl = serverUrl,
path = "api/3/cohorts",
headers = mapOf("Authorization" to "Basic $basicAuth"),
queries = mapOf("cohortIds" to cohortIds.sorted().joinToString()),
)
response.cohorts.map { CohortDescription(id = it.id, lastComputed = it.lastComputed, size = it.size) }
}
}
override fun getCohortMembers(cohortDescription: CohortDescription): Set {
return semaphore.limit {
val response = httpClient.get(
serverUrl = serverUrl,
path = "api/3/cohorts/${cohortDescription.id}",
headers = mapOf("Authorization" to "Basic $basicAuth"),
queries = mapOf(
"lastComputed" to "${cohortDescription.lastComputed}",
"refreshCohort" to "false",
"amp_ids" to "false",
),
)
response.userIds.filterNotNull().toSet()
}
}
}
@Serializable
data class GetCohortAsyncResponse(
@SerialName("cohort_id")
val cohortId: String,
@SerialName("request_id")
val requestId: String,
)
internal class DirectCohortDownloadApiV5(
apiKey: String,
secretKey: String,
httpClient: OkHttpClient,
) : CohortDownloadApi {
private val httpClient: OkHttpClient = httpClient.newBuilder()
.readTimeout(5, TimeUnit.MINUTES)
.build()
private val cdnServerUrl = CDN_COHORT_SYNC_URL.toHttpUrl()
private val semaphore = Semaphore(5, true)
private val basicAuth = Base64.getEncoder().encodeToString("$apiKey:$secretKey".toByteArray(Charsets.UTF_8))
private val csvFormat = CSVFormat.RFC4180.builder().apply {
setHeader()
}.build()
override fun getCohortDescriptions(cohortIds: Set): List {
return semaphore.limit {
val result = mutableListOf()
for (cohortId in cohortIds) {
val response = httpClient.get(
serverUrl = cdnServerUrl,
path = "api/3/cohorts/info/$cohortId",
headers = mapOf("Authorization" to "Basic $basicAuth"),
)
result += CohortDescription(
id = response.cohortId,
lastComputed = response.lastComputed,
size = response.size
)
}
result
}
}
override fun getCohortMembers(cohortDescription: CohortDescription): Set {
return semaphore.limit {
Logger.d("getCohortMembers: start - $cohortDescription")
val initialResponse = httpClient.get(
serverUrl = cdnServerUrl,
path = "api/5/cohorts/request/${cohortDescription.id}",
headers = mapOf("Authorization" to "Basic $basicAuth"),
queries = mapOf("lastComputed" to cohortDescription.lastComputed.toString())
)
Logger.d("getCohortMembers: requestId=${initialResponse.requestId}")
// Poll until the cohort is ready for download
while (true) {
val statusResponse = httpClient.get(
serverUrl = cdnServerUrl,
path = "api/5/cohorts/request-status/${initialResponse.requestId}",
headers = mapOf("Authorization" to "Basic $basicAuth"),
)
Logger.d("getCohortMembers: status=${statusResponse.code}")
if (statusResponse.code == 200) {
break
} else if (statusResponse.code != 202) {
throw IOException("Cohort status request resulted in error response ${statusResponse.code}")
}
sleep(1000)
}
return httpClient.get(
serverUrl = cdnServerUrl,
path = "api/5/cohorts/request/${initialResponse.requestId}/file",
headers = mapOf("Authorization" to "Basic $basicAuth"),
) { response ->
val csv = CSVParser.parse(response.body?.byteStream(), Charsets.UTF_8, csvFormat)
csv.map { it.get("user_id") }.filterNot { it.isNullOrEmpty() }.toSet()
.also { Logger.d("getCohortMembers: end - resultSize=${it.size}") }
}
}
}
}
private inline fun Semaphore.limit(block: () -> T): T {
acquire()
val result: T = try {
block.invoke()
} finally {
release()
}
return result
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy