Merge pull request #3 from DeterminateSystems/async-push

Push to attic from the post-build-hook
This commit is contained in:
Eelco Dolstra 2023-12-15 17:01:19 +01:00 committed by GitHub
commit 684aa7a2c4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 144 additions and 106 deletions

5
Cargo.lock generated
View file

@ -128,7 +128,7 @@ dependencies = [
[[package]] [[package]]
name = "attic" name = "attic"
version = "0.2.0" version = "0.2.0"
source = "git+ssh://git@github.com/DeterminateSystems/attic-priv?branch=main#217cbe932ceb81f504621fead261edc792eb5e2c" source = "git+ssh://git@github.com/DeterminateSystems/attic-priv?branch=main#ce89f7b954b01b2bf6403c7d136c002730c550ce"
dependencies = [ dependencies = [
"async-stream", "async-stream",
"base64 0.21.2", "base64 0.21.2",
@ -155,7 +155,7 @@ dependencies = [
[[package]] [[package]]
name = "attic-client" name = "attic-client"
version = "0.2.0" version = "0.2.0"
source = "git+ssh://git@github.com/DeterminateSystems/attic-priv?branch=main#217cbe932ceb81f504621fead261edc792eb5e2c" source = "git+ssh://git@github.com/DeterminateSystems/attic-priv?branch=main#ce89f7b954b01b2bf6403c7d136c002730c550ce"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"async-channel", "async-channel",
@ -1314,6 +1314,7 @@ dependencies = [
name = "magic-nix-cache" name = "magic-nix-cache"
version = "0.1.1" version = "0.1.1"
dependencies = [ dependencies = [
"anyhow",
"attic", "attic",
"attic-client", "attic-client",
"axum", "axum",

View file

@ -29,6 +29,7 @@ attic = { git = "ssh://git@github.com/DeterminateSystems/attic-priv", branch = "
attic-client = { git = "ssh://git@github.com/DeterminateSystems/attic-priv", branch = "main" } attic-client = { git = "ssh://git@github.com/DeterminateSystems/attic-priv", branch = "main" }
#attic-client = { path = "../../attic-priv/client" } #attic-client = { path = "../../attic-priv/client" }
indicatif = "0.17" indicatif = "0.17"
anyhow = "1.0.71"
[dependencies.tokio] [dependencies.tokio]
version = "1.28.0" version = "1.28.0"

View file

@ -6,7 +6,7 @@ use std::net::SocketAddr;
use axum::{extract::Extension, http::uri::Uri, routing::post, Json, Router}; use axum::{extract::Extension, http::uri::Uri, routing::post, Json, Router};
use axum_macros::debug_handler; use axum_macros::debug_handler;
use serde::Serialize; use serde::{Deserialize, Serialize};
use super::State; use super::State;
use crate::error::Result; use crate::error::Result;
@ -28,6 +28,7 @@ pub fn get_router() -> Router {
Router::new() Router::new()
.route("/api/workflow-start", post(workflow_start)) .route("/api/workflow-start", post(workflow_start))
.route("/api/workflow-finish", post(workflow_finish)) .route("/api/workflow-finish", post(workflow_finish))
.route("/api/enqueue-paths", post(enqueue_paths))
} }
/// Record existing paths. /// Record existing paths.
@ -62,20 +63,14 @@ async fn workflow_finish(
upload_paths(new_paths.clone(), &store_uri).await?; upload_paths(new_paths.clone(), &store_uri).await?;
} }
if let Some(attic_state) = &state.flakehub_state {
tracing::info!("Pushing {} new paths to Attic", new_paths.len());
let new_paths = new_paths
.iter()
.map(|path| state.store.follow_store_path(path).unwrap())
.collect();
crate::flakehub::push(attic_state, state.store.clone(), new_paths).await?;
}
let sender = state.shutdown_sender.lock().await.take().unwrap(); let sender = state.shutdown_sender.lock().await.take().unwrap();
sender.send(()).unwrap(); sender.send(()).unwrap();
// Wait for the Attic push workers to finish.
if let Some(attic_state) = state.flakehub_state.write().await.take() {
attic_state.push_session.wait().await.unwrap();
}
let reply = WorkflowFinishResponse { let reply = WorkflowFinishResponse {
num_original_paths: original_paths.len(), num_original_paths: original_paths.len(),
num_final_paths: final_paths.len(), num_final_paths: final_paths.len(),
@ -101,3 +96,31 @@ fn make_store_uri(self_endpoint: &SocketAddr) -> String {
.unwrap() .unwrap()
.to_string() .to_string()
} }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EnqueuePathsRequest {
pub store_paths: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EnqueuePathsResponse {}
/// Schedule paths in the local Nix store for uploading.
async fn enqueue_paths(
Extension(state): Extension<State>,
Json(req): Json<EnqueuePathsRequest>,
) -> Result<Json<EnqueuePathsResponse>> {
tracing::info!("Enqueueing {:?}", req.store_paths);
let store_paths: Vec<_> = req
.store_paths
.iter()
.map(|path| state.store.follow_store_path(path).unwrap())
.collect();
if let Some(flakehub_state) = &*state.flakehub_state.read().await {
crate::flakehub::enqueue_paths(flakehub_state, store_paths).await?;
}
Ok(Json(EnqueuePathsResponse {}))
}

View file

@ -27,6 +27,9 @@ pub enum Error {
#[error("GHA cache is disabled")] #[error("GHA cache is disabled")]
GHADisabled, GHADisabled,
#[error("FlakeHub cache error: {0}")]
FlakeHub(anyhow::Error),
} }
impl IntoResponse for Error { impl IntoResponse for Error {

View file

@ -1,7 +1,8 @@
use crate::error::Result; use crate::error::{Error, Result};
use attic::api::v1::cache_config::{CreateCacheRequest, KeypairConfig}; use attic::api::v1::cache_config::{CreateCacheRequest, KeypairConfig};
use attic::cache::CacheSliceIdentifier; use attic::cache::CacheSliceIdentifier;
use attic::nix_store::{NixStore, StorePath}; use attic::nix_store::{NixStore, StorePath};
use attic_client::push::{PushSession, PushSessionConfig};
use attic_client::{ use attic_client::{
api::{ApiClient, ApiError}, api::{ApiClient, ApiError},
config::ServerConfig, config::ServerConfig,
@ -19,17 +20,16 @@ const JWT_PREFIX: &str = "flakehub1_";
const USER_AGENT: &str = "magic-nix-cache"; const USER_AGENT: &str = "magic-nix-cache";
pub struct State { pub struct State {
cache: CacheSliceIdentifier,
pub substituter: String, pub substituter: String,
api: ApiClient, pub push_session: PushSession,
} }
pub async fn init_cache( pub async fn init_cache(
flakehub_api_server: &str, flakehub_api_server: &str,
flakehub_api_server_netrc: &Path, flakehub_api_server_netrc: &Path,
flakehub_cache_server: &str, flakehub_cache_server: &str,
store: Arc<NixStore>,
) -> Result<State> { ) -> Result<State> {
// Parse netrc to get the credentials for api.flakehub.com. // Parse netrc to get the credentials for api.flakehub.com.
let netrc = { let netrc = {
@ -225,15 +225,7 @@ pub async fn init_cache(
tracing::info!("Created cache {} on {}.", cache_name, flakehub_cache_server); tracing::info!("Created cache {} on {}.", cache_name, flakehub_cache_server);
} }
Ok(State { let cache_config = api.get_cache_config(&cache).await.unwrap();
cache,
substituter: flakehub_cache_server.to_owned(),
api,
})
}
pub async fn push(state: &State, store: Arc<NixStore>, store_paths: Vec<StorePath>) -> Result<()> {
let cache_config = state.api.get_cache_config(&state.cache).await.unwrap();
let push_config = PushConfig { let push_config = PushConfig {
num_workers: 5, // FIXME: use number of CPUs? num_workers: 5, // FIXME: use number of CPUs?
@ -242,37 +234,30 @@ pub async fn push(state: &State, store: Arc<NixStore>, store_paths: Vec<StorePat
let mp = indicatif::MultiProgress::new(); let mp = indicatif::MultiProgress::new();
let pusher = Pusher::new( let push_session = Pusher::new(
store.clone(), store.clone(),
state.api.clone(), api.clone(),
state.cache.to_owned(), cache.to_owned(),
cache_config, cache_config,
mp, mp,
push_config, push_config,
); )
.into_push_session(PushSessionConfig {
no_closure: false,
ignore_upstream_cache_filter: false,
});
let plan = pusher.plan(store_paths, false, false).await.unwrap(); Ok(State {
substituter: flakehub_cache_server.to_owned(),
push_session,
})
}
for (_, path_info) in plan.store_path_map { pub async fn enqueue_paths(state: &State, store_paths: Vec<StorePath>) -> Result<()> {
pusher.queue(path_info).await.unwrap(); state
} .push_session
.queue_many(store_paths)
let results = pusher.wait().await; .map_err(Error::FlakeHub)?;
for (path, res) in &results {
if let Err(err) = res {
tracing::error!(
"Upload of {} failed: {}",
store.get_full_path(path).display(),
err
);
}
}
tracing::info!(
"Uploaded {} paths.",
results.iter().filter(|(_path, res)| res.is_ok()).count()
);
Ok(()) Ok(())
} }

View file

@ -2,7 +2,6 @@
asm_sub_register, asm_sub_register,
deprecated, deprecated,
missing_abi, missing_abi,
unsafe_code,
unused_macros, unused_macros,
unused_must_use, unused_must_use,
unused_unsafe unused_unsafe
@ -21,21 +20,17 @@ mod telemetry;
mod util; mod util;
use std::collections::HashSet; use std::collections::HashSet;
use std::fs::{self, create_dir_all, File, OpenOptions}; use std::fs::{self, create_dir_all, OpenOptions};
use std::io::Write; use std::io::Write;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::os::fd::OwnedFd; use std::os::fd::FromRawFd;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use std::sync::Arc; use std::sync::Arc;
use ::attic::nix_store::NixStore; use ::attic::nix_store::NixStore;
use axum::{extract::Extension, routing::get, Router}; use axum::{extract::Extension, routing::get, Router};
use clap::Parser; use clap::Parser;
use daemonize::Daemonize; use tokio::sync::{oneshot, Mutex, RwLock};
use tokio::{
runtime::Runtime,
sync::{oneshot, Mutex, RwLock},
};
use tracing_subscriber::filter::EnvFilter; use tracing_subscriber::filter::EnvFilter;
use gha_cache::{Api, Credentials}; use gha_cache::{Api, Credentials};
@ -82,12 +77,6 @@ struct Args {
)] )]
diagnostic_endpoint: String, diagnostic_endpoint: String,
/// Daemonize the server.
///
/// This is for use in the GitHub Action only.
#[arg(long, hide = true)]
daemon_dir: Option<PathBuf>,
/// The FlakeHub API server. /// The FlakeHub API server.
#[arg(long)] #[arg(long)]
flakehub_api_server: Option<String>, flakehub_api_server: Option<String>,
@ -111,6 +100,10 @@ struct Args {
/// Whether to use the FlakeHub binary cache. /// Whether to use the FlakeHub binary cache.
#[arg(long)] #[arg(long)]
use_flakehub: bool, use_flakehub: bool,
/// File descriptor on which to send startup notification.
#[arg(long)]
notify_fd: Option<i32>,
} }
/// The global server state. /// The global server state.
@ -142,10 +135,10 @@ struct StateInner {
store: Arc<NixStore>, store: Arc<NixStore>,
/// FlakeHub cache state. /// FlakeHub cache state.
flakehub_state: Option<flakehub::State>, flakehub_state: RwLock<Option<flakehub::State>>,
} }
fn main() { async fn main_cli() {
init_logging(); init_logging();
let args = Args::parse(); let args = Args::parse();
@ -169,18 +162,16 @@ fn main() {
.flakehub_api_server_netrc .flakehub_api_server_netrc
.expect("--flakehub-api-server-netrc is required"); .expect("--flakehub-api-server-netrc is required");
let rt = Runtime::new().unwrap(); match flakehub::init_cache(
&args
match rt.block_on(async { .flakehub_api_server
flakehub::init_cache( .expect("--flakehub-api-server is required"),
&args &flakehub_api_server_netrc,
.flakehub_api_server &flakehub_cache_server,
.expect("--flakehub-api-server is required"), store.clone(),
&flakehub_api_server_netrc, )
&flakehub_cache_server, .await
) {
.await
}) {
Ok(state) => { Ok(state) => {
nix_conf nix_conf
.write_all( .write_all(
@ -236,7 +227,15 @@ fn main() {
}; };
nix_conf nix_conf
.write_all("fallback = true\n".as_bytes()) .write_all(
format!(
"fallback = true\npost-build-hook = {}\n",
std::env::current_exe()
.expect("Getting the path of magic-nix-cache")
.display()
)
.as_bytes(),
)
.expect("Writing to nix.conf"); .expect("Writing to nix.conf");
drop(nix_conf); drop(nix_conf);
@ -260,7 +259,7 @@ fn main() {
self_endpoint: args.listen.to_owned(), self_endpoint: args.listen.to_owned(),
metrics: telemetry::TelemetryReport::new(), metrics: telemetry::TelemetryReport::new(),
store, store,
flakehub_state, flakehub_state: RwLock::new(flakehub_state),
}); });
let app = Router::new() let app = Router::new()
@ -275,35 +274,60 @@ fn main() {
let app = app.layer(Extension(state.clone())); let app = app.layer(Extension(state.clone()));
if args.daemon_dir.is_some() { tracing::info!("Listening on {}", args.listen);
let dir = args.daemon_dir.as_ref().unwrap();
let logfile: OwnedFd = File::create(dir.join("daemon.log")).unwrap().into();
let daemon = Daemonize::new()
.pid_file(dir.join("daemon.pid"))
.stdout(File::from(logfile.try_clone().unwrap()))
.stderr(File::from(logfile));
tracing::info!("Forking into the background"); if let Some(notify_fd) = args.notify_fd {
daemon.start().expect("Failed to fork into the background"); let mut f = unsafe { std::fs::File::from_raw_fd(notify_fd) };
write!(&mut f, "INIT\n").unwrap();
} }
let rt = Runtime::new().unwrap(); let ret = axum::Server::bind(&args.listen)
rt.block_on(async move { .serve(app.into_make_service())
tracing::info!("Listening on {}", args.listen); .with_graceful_shutdown(async move {
let ret = axum::Server::bind(&args.listen) shutdown_receiver.await.ok();
.serve(app.into_make_service()) tracing::info!("Shutting down");
.with_graceful_shutdown(async move { })
shutdown_receiver.await.ok(); .await;
tracing::info!("Shutting down");
if let Some(diagnostic_endpoint) = diagnostic_endpoint { if let Some(diagnostic_endpoint) = diagnostic_endpoint {
state.metrics.send(diagnostic_endpoint).await; state.metrics.send(diagnostic_endpoint).await;
} }
})
.await;
ret.unwrap() ret.unwrap()
}); }
async fn post_build_hook(out_paths: &str) {
let store_paths: Vec<_> = out_paths.lines().map(str::to_owned).collect();
let request = api::EnqueuePathsRequest { store_paths };
let response = reqwest::Client::new()
.post(format!(
"http://{}/api/enqueue-paths",
std::env::var("INPUT_LISTEN").unwrap_or_else(|_| "127.0.0.1:37515".to_owned())
))
.header("Content-Type", "application/json")
.body(serde_json::to_string(&request).unwrap())
.send()
.await
.unwrap();
if !response.status().is_success() {
eprintln!(
"magic-nix-cache server failed to enqueue the push request: {}",
response.status()
);
} else {
response.json::<api::EnqueuePathsResponse>().await.unwrap();
}
}
#[tokio::main]
async fn main() {
match std::env::var("OUT_PATHS") {
Ok(out_paths) => post_build_hook(&out_paths).await,
Err(_) => main_cli().await,
}
} }
fn init_logging() { fn init_logging() {
@ -318,6 +342,7 @@ fn init_logging() {
}); });
tracing_subscriber::fmt() tracing_subscriber::fmt()
.with_writer(std::io::stderr)
.pretty() .pretty()
.with_env_filter(filter) .with_env_filter(filter)
.init(); .init();