magic-nix-cache/gha-cache/src/api.rs

626 lines
17 KiB
Rust
Raw Normal View History

2023-05-08 09:48:11 +00:00
//! GitHub Actions Cache API client.
//!
//! We expose a high-level API that deals with "files."
2023-05-08 16:05:43 +00:00
use std::fmt;
#[cfg(debug_assertions)]
2024-06-12 20:01:21 +00:00
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::{AtomicBool, Ordering};
2023-05-08 18:59:57 +00:00
use std::sync::Arc;
2023-05-08 16:05:43 +00:00
2023-05-08 09:48:11 +00:00
use async_trait::async_trait;
use bytes::{Bytes, BytesMut};
2023-05-08 18:59:57 +00:00
use futures::future;
2023-05-08 09:48:11 +00:00
use rand::{distributions::Alphanumeric, Rng};
use reqwest::{
header::{HeaderMap, HeaderValue, CONTENT_RANGE, CONTENT_TYPE},
Client, StatusCode,
};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use sha2::{Digest, Sha256};
use thiserror::Error;
2023-05-08 18:59:57 +00:00
use tokio::{io::AsyncRead, sync::Semaphore};
use unicode_bom::Bom;
2023-05-08 09:48:11 +00:00
use crate::credentials::Credentials;
use crate::util::read_chunk_async;
/// The API version we implement.
///
/// <https://github.com/actions/toolkit/blob/0d44da2b87f9ed48ae889d15c6cc19667aa37ec0/packages/cache/src/internal/cacheHttpClient.ts>
const API_VERSION: &str = "6.0-preview.1";
/// The User-Agent string for the client.
///
/// We want to be polite :)
const USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"));
/// The default cache version/namespace.
const DEFAULT_VERSION: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"));
/// The chunk size in bytes.
///
/// We greedily read this much from the input stream at a time.
const CHUNK_SIZE: usize = 32 * 1024 * 1024;
2023-05-08 09:48:11 +00:00
2023-05-08 18:59:57 +00:00
/// The number of chunks to upload at the same time.
const MAX_CONCURRENCY: usize = 4;
2023-05-08 18:59:57 +00:00
2023-05-08 09:48:11 +00:00
type Result<T> = std::result::Result<T, Error>;
/// An API error.
#[derive(Error, Debug)]
pub enum Error {
2023-05-08 16:05:43 +00:00
#[error("Failed to initialize the client: {0}")]
2023-05-08 09:48:11 +00:00
InitError(Box<dyn std::error::Error + Send + Sync>),
#[error(
"GitHub Actions Cache throttled Magic Nix Cache. Not trying to use it again on this run."
)]
CircuitBreakerTripped,
2023-05-08 16:05:43 +00:00
#[error("Request error: {0}")]
2023-05-08 09:48:11 +00:00
RequestError(#[from] reqwest::Error), // TODO: Better errors
2023-05-08 16:05:43 +00:00
#[error("Failed to decode response ({status}): {error}")]
2023-05-08 09:48:11 +00:00
DecodeError {
status: StatusCode,
bytes: Bytes,
error: serde_json::Error,
},
2023-05-08 16:05:43 +00:00
#[error("API error ({status}): {info}")]
2023-05-08 09:48:11 +00:00
ApiError {
status: StatusCode,
info: ApiErrorInfo,
},
2023-05-08 16:05:43 +00:00
#[error("I/O error: {0}")]
2023-05-08 09:48:11 +00:00
IoError(#[from] std::io::Error),
#[error("Too many collisions")]
TooManyCollisions,
}
#[derive(Debug)]
2023-05-08 09:48:11 +00:00
pub struct Api {
/// Credentials to access the cache.
credentials: Credentials,
/// The version used for all caches.
///
/// This value should be tied to everything that affects
/// the compatibility of the cached objects.
version: String,
/// The hasher of the version.
version_hasher: Sha256,
/// The HTTP client for authenticated requests.
client: Client,
2023-05-08 18:59:57 +00:00
/// The concurrent upload limit.
concurrency_limit: Arc<Semaphore>,
circuit_breaker_429_tripped: Arc<AtomicBool>,
/// Backend request statistics.
#[cfg(debug_assertions)]
stats: RequestStats,
2023-05-08 09:48:11 +00:00
}
/// A file allocation.
#[derive(Debug, Clone, Copy)]
pub struct FileAllocation(CacheId);
/// The ID of a cache.
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
#[serde(transparent)]
struct CacheId(pub i32);
/// An API error.
#[derive(Debug, Clone)]
pub enum ApiErrorInfo {
/// An error that we couldn't decode.
Unstructured(Bytes),
/// A structured API error.
Structured(StructuredApiError),
}
/// A structured API error.
#[derive(Debug, Clone, Deserialize)]
#[allow(dead_code)]
pub struct StructuredApiError {
/// A human-readable error message.
message: String,
}
/// A cache entry.
///
/// A valid entry looks like:
///
/// ```text
/// ArtifactCacheEntry {
/// cache_key: Some("hello-224".to_string()),
/// scope: Some("refs/heads/main".to_string()),
/// cache_version: Some("gha-cache/0.1.0".to_string()),
/// creation_time: Some("2023-01-01T00:00:00.0000000Z".to_string()),
/// archive_location: Some(
/// "https://[...].blob.core.windows.net/[...]/[...]?sv=2019-07-07&sr=b&sig=[...]".to_string()
/// ),
/// }
/// ```
#[derive(Debug, Clone, Deserialize)]
#[allow(dead_code)]
struct ArtifactCacheEntry {
/// The cache key.
#[serde(rename = "cacheKey")]
cache_key: Option<String>,
/// The scope of the cache.
///
/// It appears to be the branch name.
scope: Option<String>,
/// The version of the cache.
#[serde(rename = "cacheVersion")]
cache_version: Option<String>,
/// The creation timestamp.
#[serde(rename = "creationTime")]
creation_time: Option<String>,
/// The archive location.
#[serde(rename = "archiveLocation")]
archive_location: String,
}
#[derive(Debug, Clone, Serialize)]
struct ReserveCacheRequest<'a> {
/// The cache key.
key: &'a str,
/// The cache version.
///
/// This value should be tied to everything that affects
/// the compatibility of the cached objects.
version: &'a str,
/// The size of the cache, in bytes.
#[serde(rename = "cacheSize")]
#[serde(skip_serializing_if = "Option::is_none")]
cache_size: Option<usize>,
}
#[derive(Debug, Clone, Deserialize)]
struct ReserveCacheResponse {
/// The reserved cache ID.
#[serde(rename = "cacheId")]
cache_id: CacheId,
}
#[derive(Debug, Clone, Serialize)]
struct CommitCacheRequest {
size: usize,
}
#[cfg(debug_assertions)]
#[derive(Default, Debug)]
struct RequestStats {
get: AtomicUsize,
post: AtomicUsize,
patch: AtomicUsize,
}
2023-05-08 09:48:11 +00:00
#[async_trait]
trait ResponseExt {
async fn check(self) -> Result<()>;
async fn check_json<T: DeserializeOwned>(self) -> Result<T>;
}
impl Error {
fn init_error<E>(e: E) -> Self
where
E: std::error::Error + Send + Sync + 'static,
{
Self::InitError(Box::new(e))
}
}
2023-05-08 16:05:43 +00:00
impl fmt::Display for ApiErrorInfo {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Unstructured(bytes) => {
write!(f, "[Unstructured] {}", String::from_utf8_lossy(bytes))
}
Self::Structured(e) => {
write!(f, "{:?}", e)
}
}
}
}
2023-05-08 09:48:11 +00:00
impl Api {
pub fn new(credentials: Credentials) -> Result<Self> {
let mut headers = HeaderMap::new();
let auth_header = {
let mut h = HeaderValue::from_str(&format!("Bearer {}", credentials.runtime_token))
.map_err(Error::init_error)?;
h.set_sensitive(true);
h
};
headers.insert("Authorization", auth_header);
headers.insert(
"Accept",
HeaderValue::from_str(&format!("application/json;api-version={}", API_VERSION))
.map_err(Error::init_error)?,
);
let client = Client::builder()
.user_agent(USER_AGENT)
.default_headers(headers)
.build()
.map_err(Error::init_error)?;
let version_hasher = Sha256::new_with_prefix(DEFAULT_VERSION.as_bytes());
let initial_version = hex::encode(version_hasher.clone().finalize());
Ok(Self {
credentials,
version: initial_version,
version_hasher,
client,
2023-05-08 18:59:57 +00:00
concurrency_limit: Arc::new(Semaphore::new(MAX_CONCURRENCY)),
circuit_breaker_429_tripped: Arc::new(AtomicBool::from(false)),
#[cfg(debug_assertions)]
stats: Default::default(),
2023-05-08 09:48:11 +00:00
})
}
pub fn circuit_breaker_tripped(&self) -> bool {
2024-06-12 19:50:22 +00:00
self.circuit_breaker_429_tripped.load(Ordering::Relaxed)
}
2023-05-08 09:48:11 +00:00
/// Mutates the cache version/namespace.
pub fn mutate_version(&mut self, data: &[u8]) {
self.version_hasher.update(data);
self.version = hex::encode(self.version_hasher.clone().finalize());
}
// Public
/// Allocates a file.
pub async fn allocate_file(&self, key: &str) -> Result<FileAllocation> {
let reservation = self.reserve_cache(key, None).await?;
Ok(FileAllocation(reservation.cache_id))
}
/// Allocates a file with a random suffix.
///
/// This is a hack to allow for easy "overwriting" without
/// deleting the original cache.
pub async fn allocate_file_with_random_suffix(&self, key: &str) -> Result<FileAllocation> {
for _ in 0..5 {
let nonce: String = rand::thread_rng()
.sample_iter(&Alphanumeric)
.take(4)
.map(char::from)
.collect();
let full_key = format!("{}-{}", key, nonce);
match self.allocate_file(&full_key).await {
Ok(allocation) => {
return Ok(allocation);
}
Err(e) => {
if let Error::ApiError {
info: ApiErrorInfo::Structured(structured),
..
} = &e
{
if structured.message.contains("Cache already exists") {
continue;
}
}
return Err(e);
}
}
}
Err(Error::TooManyCollisions)
}
2024-02-29 19:43:25 +00:00
/// Uploads a file. Returns the size of the file.
pub async fn upload_file<S>(&self, allocation: FileAllocation, mut stream: S) -> Result<usize>
2023-05-08 09:48:11 +00:00
where
S: AsyncRead + Unpin + Send,
{
2024-06-12 19:50:22 +00:00
if self.circuit_breaker_tripped() {
return Err(Error::CircuitBreakerTripped);
}
2023-05-08 09:48:11 +00:00
let mut offset = 0;
2023-05-08 18:59:57 +00:00
let mut futures = Vec::new();
2023-05-08 09:48:11 +00:00
loop {
let buf = BytesMut::with_capacity(CHUNK_SIZE);
let chunk = read_chunk_async(&mut stream, buf).await?;
if chunk.is_empty() {
offset += chunk.len();
break;
}
2023-05-19 08:48:52 +00:00
if offset == chunk.len() {
2023-05-19 19:00:36 +00:00
tracing::trace!("Received first chunk for cache {:?}", allocation.0);
2023-05-19 08:48:52 +00:00
}
2023-05-08 09:48:11 +00:00
let chunk_len = chunk.len();
#[cfg(debug_assertions)]
self.stats.patch.fetch_add(1, Ordering::SeqCst);
2023-05-08 18:59:57 +00:00
futures.push({
let client = self.client.clone();
let concurrency_limit = self.concurrency_limit.clone();
let circuit_breaker_429_tripped = self.circuit_breaker_429_tripped.clone();
2023-05-08 18:59:57 +00:00
let url = self.construct_url(&format!("caches/{}", allocation.0 .0));
tokio::task::spawn(async move {
let permit = concurrency_limit.acquire().await.unwrap();
2023-05-19 08:48:52 +00:00
tracing::trace!(
2023-05-08 18:59:57 +00:00
"Starting uploading chunk {}-{}",
offset,
offset + chunk_len - 1
);
let r = client
.patch(url)
.header(CONTENT_TYPE, "application/octet-stream")
.header(
CONTENT_RANGE,
format!("bytes {}-{}/*", offset, offset + chunk.len() - 1),
)
.body(chunk)
.send()
.await?
.check()
.await;
2023-05-19 08:48:52 +00:00
tracing::trace!(
2023-05-08 18:59:57 +00:00
"Finished uploading chunk {}-{}: {:?}",
offset,
offset + chunk_len - 1,
r
);
drop(permit);
circuit_breaker_429_tripped.check_result(&r);
2023-05-08 18:59:57 +00:00
r
})
});
2023-05-08 09:48:11 +00:00
offset += chunk_len;
}
2023-05-08 18:59:57 +00:00
future::join_all(futures)
.await
.into_iter()
.try_for_each(|join_result| join_result.unwrap())?;
2023-05-08 18:59:57 +00:00
2023-05-19 08:48:52 +00:00
tracing::debug!("Received all chunks for cache {:?}", allocation.0);
2023-05-08 09:48:11 +00:00
self.commit_cache(allocation.0, offset).await?;
2024-02-29 19:43:25 +00:00
Ok(offset)
2023-05-08 09:48:11 +00:00
}
/// Downloads a file based on a list of key prefixes.
pub async fn get_file_url(&self, keys: &[&str]) -> Result<Option<String>> {
2024-06-12 19:50:22 +00:00
if self.circuit_breaker_tripped() {
return Err(Error::CircuitBreakerTripped);
}
2023-05-08 09:48:11 +00:00
Ok(self
.get_cache_entry(keys)
.await?
.map(|entry| entry.archive_location))
}
/// Dumps statistics.
///
/// This is for debugging only.
pub fn dump_stats(&self) {
#[cfg(debug_assertions)]
2023-05-19 08:48:52 +00:00
tracing::trace!("Request stats: {:?}", self.stats);
}
2023-05-08 09:48:11 +00:00
// Private
/// Retrieves a cache based on a list of key prefixes.
async fn get_cache_entry(&self, keys: &[&str]) -> Result<Option<ArtifactCacheEntry>> {
2024-06-12 19:50:22 +00:00
if self.circuit_breaker_tripped() {
return Err(Error::CircuitBreakerTripped);
}
#[cfg(debug_assertions)]
self.stats.get.fetch_add(1, Ordering::SeqCst);
2023-05-08 09:48:11 +00:00
let res = self
.client
.get(self.construct_url("cache"))
.query(&[("version", &self.version), ("keys", &keys.join(","))])
.send()
.await?
.check_json()
2024-06-12 20:59:38 +00:00
.await;
self.circuit_breaker_429_tripped.check_result(&res);
2023-05-08 09:48:11 +00:00
match res {
Ok(entry) => Ok(Some(entry)),
Err(Error::DecodeError { status, .. }) if status == StatusCode::NO_CONTENT => Ok(None),
Err(e) => Err(e),
}
}
/// Reserves a new cache.
///
/// The cache key should be unique. A cache cannot be created
/// again if the same (cache_name, cache_version) pair already
/// exists.
async fn reserve_cache(
&self,
key: &str,
cache_size: Option<usize>,
) -> Result<ReserveCacheResponse> {
2024-06-12 19:50:22 +00:00
if self.circuit_breaker_tripped() {
return Err(Error::CircuitBreakerTripped);
}
2023-05-19 08:48:52 +00:00
tracing::debug!("Reserving cache for {}", key);
2023-05-08 09:48:11 +00:00
let req = ReserveCacheRequest {
key,
version: &self.version,
cache_size,
};
#[cfg(debug_assertions)]
self.stats.post.fetch_add(1, Ordering::SeqCst);
2023-05-08 09:48:11 +00:00
let res = self
.client
.post(self.construct_url("caches"))
.json(&req)
.send()
.await?
.check_json()
2024-06-12 20:59:38 +00:00
.await;
2023-05-08 09:48:11 +00:00
self.circuit_breaker_429_tripped.check_result(&res);
2024-06-12 20:59:38 +00:00
res
2023-05-08 09:48:11 +00:00
}
/// Finalizes uploading to a cache.
async fn commit_cache(&self, cache_id: CacheId, size: usize) -> Result<()> {
2024-06-12 19:50:22 +00:00
if self.circuit_breaker_tripped() {
return Err(Error::CircuitBreakerTripped);
}
2023-05-19 08:48:52 +00:00
tracing::debug!("Commiting cache {:?}", cache_id);
2023-05-08 09:48:11 +00:00
let req = CommitCacheRequest { size };
#[cfg(debug_assertions)]
self.stats.post.fetch_add(1, Ordering::SeqCst);
2024-06-12 20:59:38 +00:00
if let Err(e) = self
.client
2023-05-08 09:48:11 +00:00
.post(self.construct_url(&format!("caches/{}", cache_id.0)))
.json(&req)
.send()
.await?
.check()
2024-06-12 20:30:27 +00:00
.await
2024-06-12 20:59:38 +00:00
{
self.circuit_breaker_429_tripped.check_err(&e);
2024-06-12 20:59:38 +00:00
return Err(e);
}
2023-05-08 09:48:11 +00:00
Ok(())
}
fn construct_url(&self, resource: &str) -> String {
format!(
"{}/_apis/artifactcache/{}",
self.credentials.cache_url, resource
)
}
}
#[async_trait]
impl ResponseExt for reqwest::Response {
async fn check(self) -> Result<()> {
let status = self.status();
if !status.is_success() {
return Err(handle_error(self).await);
}
Ok(())
}
async fn check_json<T: DeserializeOwned>(self) -> Result<T> {
let status = self.status();
if !status.is_success() {
return Err(handle_error(self).await);
}
// We don't do `Response::json()` directly to preserve
// the original response payload for troubleshooting.
let bytes = self.bytes().await?;
match serde_json::from_slice(&bytes) {
Ok(decoded) => Ok(decoded),
Err(error) => Err(Error::DecodeError {
status,
error,
bytes,
}),
}
}
}
async fn handle_error(res: reqwest::Response) -> Error {
let status = res.status();
let bytes = match res.bytes().await {
Ok(bytes) => {
let bom = Bom::from(bytes.as_ref());
bytes.slice(bom.len()..)
}
2023-05-08 09:48:11 +00:00
Err(e) => {
return e.into();
}
};
let info = match serde_json::from_slice(&bytes) {
Ok(structured) => ApiErrorInfo::Structured(structured),
Err(e) => {
tracing::info!("failed to decode error: {}", e);
ApiErrorInfo::Unstructured(bytes)
}
2023-05-08 09:48:11 +00:00
};
Error::ApiError { status, info }
}
trait AtomicCircuitBreaker {
fn check_err(&self, e: &Error);
fn check_result<T>(&self, r: &std::result::Result<T, Error>);
}
impl AtomicCircuitBreaker for AtomicBool {
fn check_result<T>(&self, r: &std::result::Result<T, Error>) {
if let Err(ref e) = r {
self.check_err(e)
}
}
fn check_err(&self, e: &Error) {
if let Error::ApiError {
status: reqwest::StatusCode::TOO_MANY_REQUESTS,
info: ref _info,
} = e
{
tracing::info!("Disabling GitHub Actions Cache due to 429: Too Many Requests");
self.store(true, Ordering::Relaxed);
}
}
}