All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.bitrise.gradle.cache.BitriseBuildCacheService.kt Maven / Gradle / Ivy

/**
 * Copyright (C)2022 Bitrise
 * All rights reserved.
 */
package io.bitrise.gradle.cache

import build.bazel.remote.execution.v2.DigestFunction
import build.bazel.remote.execution.v2.ServerCapabilities
import build.bazel.remote.execution.v2.SymlinkAbsolutePathStrategy
import build.bazel.remote.execution.v2.getCapabilitiesRequest
import build.bazel.remote.execution.v2.requestMetadata
import build.bazel.remote.execution.v2.toolDetails
import com.google.bytestream.ByteStreamProto.WriteRequest
import com.google.bytestream.readRequest
import com.google.bytestream.writeRequest
import com.google.protobuf.ByteString
import io.bitrise.gradle.cache.connection.CapabilitiesClient
import io.bitrise.gradle.cache.connection.ClientBalancer
import io.bitrise.gradle.cache.connection.KVStorageClient
import io.grpc.CallCredentials
import io.grpc.CallOptions
import io.grpc.ManagedChannel
import io.grpc.ManagedChannelBuilder
import io.grpc.Metadata
import io.grpc.Status
import io.grpc.netty.GrpcSslContexts
import io.grpc.netty.NettyChannelBuilder
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.Runnable
import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.cancel
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.withContext
import kotlinx.coroutines.withTimeout
import org.gradle.caching.BuildCacheEntryReader
import org.gradle.caching.BuildCacheEntryWriter
import org.gradle.caching.BuildCacheException
import org.gradle.caching.BuildCacheKey
import org.gradle.caching.BuildCacheService
import java.io.ByteArrayOutputStream
import java.io.File
import java.text.SimpleDateFormat
import java.util.*
import java.util.concurrent.Executor
import java.util.concurrent.ExecutorService
import java.util.concurrent.Executors
import java.util.logging.Level
import java.util.logging.Logger
import kotlin.math.floor
import kotlin.math.max

class BitriseBuildCacheService internal constructor(
    private val endpoint: String,
    private val authToken: String,
    private val chunkSize: Int,
    timeout: Long,
    private val debug: Boolean,
    private val orgSlug: String,
    private val retryCount: Int,
    private val numChannels: Int = max(2, floor(Runtime.getRuntime().availableProcessors() / 6.0).toInt()),
    private val maxConcurrencyPerChannel: Int = Runtime.getRuntime().availableProcessors(),
    private val pool: ExecutorService = Executors.newFixedThreadPool(maxConcurrencyPerChannel),
    private val blobValidationLevel: ValidationLevel,
    private val tlsCertPath: String? = null,
    private val overrideAuthority: String? = null,
) : BuildCacheService {

    private val mainCoroutineScope = CoroutineScope(SupervisorJob() + Dispatchers.IO)

    private val dateTimeFormatter by lazy { SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS") }

    private val bazelReqMetadataKey = "build.bazel.remote.execution.v2.requestmetadata-bin"
    private val invocationId by lazy {
        // For syncing the invocation ID among cache plugin instances, we lock for the holder object
        // For syncing with the analytics plugin, we lock on the analytics plugin's holder object
        synchronized(InvocationIdHolder) {
            if (InvocationIdHolder.invocationId != "") {
                log("Reusing invocation ID: ${InvocationIdHolder.invocationId}")
                return@lazy InvocationIdHolder.invocationId
            }

            runCatching {
                analyticsInvocationId?.let {
                    log("Got invocation ID from analytics: $it")
                    return@lazy it
                }
            }.onFailure { log("Failed to get invocation ID from analytics: $it") }

            debugln("Generating new invocation ID")
            InvocationIdHolder.invocationId = UUID.randomUUID().toString()
            InvocationIdHolder.invocationId
        }
    }

    private val analyticsInvocationId: String?
        get() {
            Class.forName("io.bitrise.gradle.analytics.services.InvocationIdHolder")?.let { clazz ->
                clazz.kotlin.objectInstance?.let { instance ->
                    // Note: both the cache and the analytics plugin locks on the analytics plugin's holder's object
                    synchronized(clazz.kotlin.objectInstance as Any) {
                        clazz.declaredFields.find { it.name == "invocationId" }?.let { field ->
                            field.isAccessible = true
                            field.get(clazz.kotlin.objectInstance)?.let { value ->
                                if (value is String && value != "") {
                                    return value
                                }
                            }
                        }
                    }
                }
            }

            return null
        }

    private val requestMetadata: Metadata by lazy {
        Metadata().apply {
            put(Metadata.Key.of("authorization", Metadata.ASCII_STRING_MARSHALLER), "Bearer $authToken")

            if (orgSlug.isNotBlank()) {
                put(Metadata.Key.of("x-org-id", Metadata.ASCII_STRING_MARSHALLER), orgSlug)
            }

            put(Metadata.Key.of("x-flare-builduser", Metadata.ASCII_STRING_MARSHALLER), System.getProperty("user.name"))
            System.getenv("BITRISE_APP_SLUG")?.let {
                put(Metadata.Key.of("x-app-id", Metadata.ASCII_STRING_MARSHALLER), it)
            }
            System.getenv("BITRISE_BUILD_SLUG")?.let {
                put(Metadata.Key.of("x-flare-build-id", Metadata.ASCII_STRING_MARSHALLER), it)
            }
            System.getenv("BITRISE_STEP_EXECUTION_ID")?.let {
                put(Metadata.Key.of("x-flare-step-id", Metadata.ASCII_STRING_MARSHALLER), it)
            }
            put(
                Metadata.Key.of(bazelReqMetadataKey, Metadata.BINARY_BYTE_MARSHALLER),
                requestMetadata {
                    toolInvocationId = invocationId
                    toolDetails = toolDetails {
                        toolName = "gradle"
                    }
                }.toByteArray(),
            )
            if (blobValidationLevel != ValidationLevel.NONE) {
                put(
                    Metadata.Key.of(METADATA_KEY_BLOB_VALIDATION_LEVEL, Metadata.ASCII_STRING_MARSHALLER),
                    blobValidationLevel.lvl,
                )
            }
            put(Metadata.Key.of("x-flare-ac-validation-mode", Metadata.ASCII_STRING_MARSHALLER), "fast")

            debugln("Request metadata invocationId: $invocationId")
            debugln(
                "Request metadata: ${
                    this.keys().minus(listOf("authorization", bazelReqMetadataKey))
                        .map { it to this.get(Metadata.Key.of(it, Metadata.ASCII_STRING_MARSHALLER)) }
                        .joinToString(";")
                }",
            )
        }
    }

    private val clientCallOptions
        get() = CallOptions.DEFAULT
            .withCallCredentials(
                object : CallCredentials() {
                    override fun applyRequestMetadata(
                        requestInfo: RequestInfo?,
                        appExecutor: Executor?,
                        applier: MetadataApplier?,
                    ) {
                        debugln("Applying request metadata")
                        runCatching {
                            applier?.apply(requestMetadata)
                        }.onFailure {
                            log("Failed to apply request metadata: $it")
                        }
                    }

                    @Deprecated("Deprecated in Java")
                    override fun thisUsesUnstableApi() {
                    }
                },
            )

    private var writeEnabled: Boolean? = null
    private val writeEnabledMutex = Object()

    private val instanceName = "" // TODO: clean up instanceName from this class
    private val timeoutMillis = timeout * 1000
    private val getCapsTimeout = 10 * 1000L

    private val capabilitiesClient: CapabilitiesClient by lazy {
        debugln("Creating capabilities client")
        CapabilitiesClient(channelFactory(), clientCallOptions)
    }
    private lateinit var clients: ClientBalancer

    init {
        try {
            initGrpcLogging()
            initClients()
            mainCoroutineScope.launch {
                synchronized(writeEnabledMutex) {
                    runBlocking {
                        checkCapabilities()
                    }
                }
            }
            log("🤖 Bitrise remote cache enabled")
            debugln("Channels: $numChannels, concurrency per channel: $maxConcurrencyPerChannel")
        } catch (e: Exception) {
            log("fatal: Failed to initialize remote cache: $e")
        }
    }

    private inline val blobValidationEnabled: Boolean
        get() {
            return this.blobValidationLevel == ValidationLevel.ERROR || this.blobValidationLevel == ValidationLevel.WARNING
        }

    private fun handleBlobValidationMetadata(responseMetadata: Metadata): String? {
        if (!blobValidationEnabled) return null
        blobValidation(
            responseMetadata.get(
                Metadata.Key.of(
                    METADATA_KEY_BLOB_VALIDATION_WARNING,
                    Metadata.ASCII_STRING_MARSHALLER,
                ),
            ),
        )
        return responseMetadata.get(Metadata.Key.of(METADATA_KEY_BLOB_VALIDATION_SHA256, Metadata.ASCII_STRING_MARSHALLER))
    }

    private fun blobValidation(warning: String?) {
        when (blobValidationLevel) {
            ValidationLevel.NONE -> {}
            ValidationLevel.WARNING -> {
                if (warning.isNullOrEmpty()) {
                    return
                }
                log("warn: $warning")
            }

            ValidationLevel.ERROR -> {
                // on mode=error, the remote should have already returned an error, meaning it didnt bother to
                // return the error in metadata. So in effect this is not ever hit when errors come back, but note it
                // WILL be used by internal validation use cases!
                if (warning.isNullOrEmpty()) {
                    return
                }
                throw BuildCacheException(warning)
            }
        }
    }

    private fun initGrpcLogging() {
        Logger.getLogger("io.grpc").level = if (debug) Level.ALL else Level.WARNING
    }

    private fun initClients() {
        clients =
            ClientBalancer(numChannels, maxConcurrencyPerChannel, this::channelFactory, this::clientFactory, debug)
    }

    private fun clientFactory(channel: ManagedChannel): KVStorageClient {
        return KVStorageClient(channel, clientCallOptions)
    }

    private fun channelFactory(): ManagedChannel {
        try {
            debugln("Creating channel to $endpoint")

            val customTls = !tlsCertPath.isNullOrEmpty()
            var plainText: Boolean
            val target: String = when {
                endpoint.startsWith("grpc://") -> {
                    plainText = true
                    endpoint.removePrefix("grpc://")
                }

                endpoint.startsWith("grpcs://") -> {
                    plainText = false
                    endpoint.removePrefix("grpcs://")
                }

                endpoint.endsWith(":443") -> {
                    plainText = false
                    endpoint
                }

                else -> {
                    plainText = false
                    endpoint
                }
            }
            if (customTls) {
                plainText = false
            }
            val builder: ManagedChannelBuilder<*> = if (customTls) {
                NettyChannelBuilder
                    .forTarget(target)
                    .sslContext(
                        GrpcSslContexts
                            .forClient()
                            .trustManager(File(tlsCertPath!!))
                            .build(),
                    )
            } else {
                ManagedChannelBuilder.forTarget(target)
            }
            builder.enableRetry()
                .maxRetryAttempts(retryCount)
                .executor { command: Runnable ->
                    if (!pool.isShutdown && !pool.isTerminated) {
                        pool.submit(command).get()
                    }
                }
            if (plainText) {
                builder.usePlaintext()
            } else {
                if (!overrideAuthority.isNullOrEmpty()) {
                    builder.overrideAuthority(overrideAuthority)
                }
                builder.useTransportSecurity()
            }
            return builder.build()
        } catch (t: Throwable) {
            throw BuildCacheException("Failed to initialize cache connection: $t", t)
        }
    }

    private suspend fun checkCapabilities() {
        val startT = currT
        val baseErr = "Failed to initialize build cache connection to $endpoint"
        try {
            withContext(Dispatchers.IO) {
                withTimeout(getCapsTimeout) {
                    val c = capabilitiesClient.getCapabilities(
                        getCapabilitiesRequest {
                            this.instanceName = [email protected]
                        },
                    )
                    if (!checkServerCapabilities(c)) {
                        throw BuildCacheException("The remote server has missing or unsupported cache capabilities.")
                    }
                    debugln("Connected to remote server successfully; write-enabled: $writeEnabled")
                    // don't need this client/channel really so just close it
                    capabilitiesClient.close()
                    debugln("Checking capabilities took ${currT - startT} ms")
                }
            }
        } catch (t: Throwable) {
            if (t is BuildCacheException) {
                throw t
            } else {
                throw BuildCacheException("$baseErr: $t", t)
            }
        }
    }

    private fun checkServerCapabilities(capabilities: ServerCapabilities): Boolean {
        if (!capabilities.hasCacheCapabilities()) return false
        writeEnabled = capabilities.cacheCapabilities.actionCacheUpdateCapabilities.updateEnabled
        return capabilities.cacheCapabilities.digestFunctionList.any {
            it == DigestFunction.Value.SHA256
        } && capabilities.cacheCapabilities.symlinkAbsolutePathStrategy == SymlinkAbsolutePathStrategy.Value.ALLOWED
    }

    @ExperimentalCoroutinesApi
    override fun store(key: BuildCacheKey, writer: BuildCacheEntryWriter) {
        synchronized(writeEnabledMutex) {
            if (writeEnabled == null) {
                mainCoroutineScope.launch {
                    runBlocking {
                        checkCapabilities()
                    }
                }
            }
            if (writeEnabled != true) {
                throw BuildCacheException(
                    "pushes to this cache are not supported.\n" +
                        "\tAre you using a read only token? If so, set push = false in your gradle config",
                )
            }
        }

        try {
            runBlocking {
                // grab the gradle cache entry from the writer
                val content = ByteArrayOutputStream()
                writer.writeTo(content)
                val contentBytes = content.toByteArray()
                writeBlob(key, contentBytes)
            }
        } catch (t: Throwable) {
            throw BuildCacheException("Failed to `store` cache entry: $t", t)
        }
    }

    /**
     * note that the Blob does NOT contain `contents` bytes
     */
    private suspend fun writeBlob(key: BuildCacheKey, contentBytes: ByteArray) =
        withContext(Dispatchers.IO) {
            var sha256 = ""
            if (blobValidationEnabled) {
                sha256 = contentBytes.sha256String()
                debugln("writeBlob start: ${key.hashCode} is ${contentBytes.size / 1024.0} KiB\nsha256=$sha256")
            } else {
                debugln("writeBlob start: ${key.hashCode} is ${contentBytes.size / 1024.0} KiB")
            }
            val requests =
                createWriteRequests(instanceName, key, contentBytes, [email protected])
            val start = currT
            try {
                clients.next.use { client ->
                    withRetryAndTimeout(
                        retryCount,
                        timeoutMillis,
                        {
                            var preparedClient = client
                            if (blobValidationEnabled) {
                                preparedClient = client.withMetadata(
                                    mapOf(
                                        METADATA_KEY_BLOB_VALIDATION_LEVEL to listOf(blobValidationLevel.lvl!!),
                                        METADATA_KEY_BLOB_VALIDATION_SHA256 to listOf(sha256),
                                    ),
                                )
                            }

                            val resp = preparedClient.put(
                                flow {
                                    for (r in requests) {
                                        debugln("sending ${r.data.size()} bytes to remote for ${r.resourceName}")
                                        emit(r)
                                    }
                                },
                                ::handleBlobValidationMetadata,
                            )

                            if (resp.committedSize != contentBytes.size.toLong()) {
                                throw Throwable("didnt commit expected number of bytes to remote (EOF?). Got ${resp.committedSize}, wanted ${contentBytes.size}")
                            }
                        },
                        // do on each err:
                        { t, attemptCount, attemptElapsed ->
                            log(
                                "Attempt #$attemptCount failed to send ${contentBytes.size} byte blob ${key.hashCode}\n" +
                                    "  err: '${t.message}' duration: $attemptElapsed ms, attempts remaining: " +
                                    "${retryCount + 1 - attemptCount}, request messages: ${requests.size}",
                            )
                        },
                    )
                }.join()
            } catch (ret: RetriesExhaustedException) {
                debugln("writeBlob failed: ${key.hashCode} failed to send after ${(currT - start) / 1000.0} seconds and $retryCount retries")
                throw BuildCacheException(
                    "Failed to write blob ${key.hashCode} and exhausted " +
                        "retry limit of $retryCount. Inner Error: ${ret.cause?.message}",
                )
            }
            debugln("writeBlob success: ${key.hashCode} finished after ${currT - start} ms")
        }

    override fun load(key: BuildCacheKey, reader: BuildCacheEntryReader): Boolean = runBlocking {
        runCatching {
            reader.readFrom(readBlob(key).inputStream())
        }
            .mapToFoundOrThrow { BuildCacheException("Failed to load: $it") }
    }

    private suspend fun readBlob(key: BuildCacheKey): ByteArray = withContext(Dispatchers.IO) {
        val resource = resourceNameForKey(instanceName, key)
        var result = ByteArray(0)
        val start = currT
        debugln("readBlob start: ${key.hashCode} resourceName: $resource")
        val req = readRequest {
            resourceName = resource
            readLimit = 0
        }
        var expectedIntegrity: String? = null
        try {
            clients.next.use { client ->
                withRetryAndTimeout(
                    retryCount,
                    timeoutMillis,
                    {
                        client.get(req) {
                            expectedIntegrity = handleBlobValidationMetadata(it)
                        }.collect {
                            if (!it.isInitialized) {
                                debugln("warn: bad response from remote server")
                            } else {
                                result += it.data.toByteArray()
                            }
                        }
                    },
                )
                // on err:
                { t, attemptCount, attemptElapsed ->
                    if (t.isNotFound()) {
                        debugln("blob ${key.hashCode} not found after $attemptElapsed ms")
                        throw t
                    }
                    log(
                        "Attempt #$attemptCount failed to Get() blob ${key.hashCode}\n" +
                            "  err: '${t.message}', duration: $attemptElapsed ms, attempts remaining: " +
                            "${retryCount + 1 - attemptCount}",
                    )
                }
            }
        } catch (ret: RetriesExhaustedException) {
            throw ret.cause ?: ret
        }
        if (blobValidationEnabled) {
            // warnings/errors from remote are already handled above
            val fetchTime = currT - start
            val hashStart = currT
            val actualIntegrity = result.sha256String()
            val hashTime = currT - hashStart
            debugln("fetched ${key.hashCode} from remote in $fetchTime ms\nsha256=$actualIntegrity\nhashed in $hashTime ms")
            if (expectedIntegrity != actualIntegrity) {
                val msg =
                    "Local blob integrity validation failed: wanted $expectedIntegrity, got $actualIntegrity for blob ${key.hashCode}"
                when (blobValidationLevel) {
                    ValidationLevel.NONE -> {}
                    ValidationLevel.WARNING -> {
                        log("warn: $msg;\nHandling this as a NOT_FOUND because blobValidationLevel=warn")
                        // throw a not found to continue gracefully
                        throw Status.NOT_FOUND.asException()
                    }

                    ValidationLevel.ERROR -> {
                        log("err: $msg")
                        // throw a cache error; gradle will disable the cache but shouldn't die.
                        throw BuildCacheException("err: $msg")
                    }
                }
            }
        } else {
            debugln("fetched ${key.hashCode} from remote in ${currT - start} ms")
        }
        result
    }

    override fun close() {
        debugln("CLOSING REMOTE CONNECTION")
        InvocationIdHolder.invocationId = ""
        clients.close()
        pool.shutdownNow()
        capabilitiesClient.close()
        mainCoroutineScope.cancel()
    }

    private fun debugln(msg: String) {
        if (debug) log("DEBUG: $msg")
    }

    private fun log(msg: String) {
        println("[Bitrise Build Cache] [${dateTimeFormatter.format(Date())}] $msg")
    }
}

internal fun resourceNameForKey(instanceName: String, key: BuildCacheKey): String {
    return "${instanceName.ifEmpty { "gradle" }}/${key.hashCode}"
}

internal fun createWriteRequests(
    instanceName: String,
    key: BuildCacheKey,
    contentBytes: ByteArray,
    chunkSize: Int,
): Array {
    val reqs = mutableListOf()
    val blobSize = contentBytes.size
    var chunked = 0
    var seek: Int

    while (chunked < blobSize) {
        seek = minOf(chunkSize, blobSize - chunked)
        reqs.add(
            writeRequest {
                resourceName = if (chunked == 0) resourceNameForKey(instanceName, key) else ""
                writeOffset = chunked.toLong()
                finishWrite = chunked == blobSize - seek
                data = ByteString.copyFrom(contentBytes.slice(chunked until chunked + seek).toByteArray())
            },
        )
        chunked += seek
    }
    return reqs.toTypedArray()
}

// To be used by the analytics plugin and among multiple cache instances (composite builds).
object InvocationIdHolder {
    var invocationId: String = ""
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy