diff --git a/bin/tee-vault-unseal/src/main.rs b/bin/tee-vault-unseal/src/main.rs index 4b62483..b30b977 100644 --- a/bin/tee-vault-unseal/src/main.rs +++ b/bin/tee-vault-unseal/src/main.rs @@ -12,7 +12,7 @@ mod unseal; use actix_web::rt::time::sleep; use actix_web::web::Data; use actix_web::{web, App, HttpServer}; -use anyhow::{Context, Result}; +use anyhow::{bail, Context, Result}; use awc::Client; use clap::Parser; use init::post_init; @@ -97,8 +97,10 @@ struct Args { /// port to listen on #[arg(long, env = "PORT", default_value = "8443")] port: u16, + /// the sha256 of the `vault_auth_tee` plugin, with precedence over the file #[arg(long, env = "VAULT_AUTH_TEE_SHA256")] - vault_auth_tee_sha: String, + vault_auth_tee_sha: Option, + /// the file containing the sha256 of the `vault_auth_tee` plugin #[arg(long, env = "VAULT_AUTH_TEE_SHA256_FILE")] vault_auth_tee_sha_file: Option, #[arg(long, env = "VAULT_AUTH_TEE_VERSION")] @@ -123,7 +125,7 @@ async fn main() -> Result<()> { ); tracing::subscriber::set_global_default(subscriber).unwrap(); - let mut args = Args::parse(); + let args = Args::parse(); info!("Starting up"); @@ -146,20 +148,23 @@ async fn main() -> Result<()> { let server_state = get_vault_status(&args.attestation.vault_addr, conn.client()).await; - // If sha file given, override env variable with contents - if let Some(sha_file) = args.vault_auth_tee_sha_file { + let vault_auth_tee_sha = if let Some(vault_auth_tee_sha) = args.vault_auth_tee_sha { + vault_auth_tee_sha + } else if let Some(sha_file) = args.vault_auth_tee_sha_file { let mut file = std::fs::File::open(sha_file)?; let mut contents = String::new(); file.read_to_string(&mut contents)?; - args.vault_auth_tee_sha = contents.trim_end().into(); - } + contents.trim_end().into() + } else { + bail!("Neither `VAULT_AUTH_TEE_SHA256_FILE` nor `VAULT_AUTH_TEE_SHA256` set!"); + }; info!("Starting HTTPS server at port {}", args.port); let server_config = Arc::new(UnsealServerConfig { vault_url: args.attestation.vault_addr, report_data: Box::from(report_data), allowed_tcb_levels: Some(args.allowed_tcb_levels), - vault_auth_tee_sha: args.vault_auth_tee_sha, + vault_auth_tee_sha, vault_auth_tee_version: args.vault_auth_tee_version, ca_cert_file: args.ca_cert_file, });