refactor: improve type ergonomics

This commit is contained in:
Lucille L. Blumire 2025-04-17 16:03:42 +01:00
parent 0768b0ad67
commit 2ff169da9f
No known key found for this signature in database
GPG key ID: D168492023622329
11 changed files with 22 additions and 25 deletions

View file

@ -160,7 +160,7 @@ fn main() -> Result<()> {
let section_table = pe.get_section_table()?; let section_table = pe.get_section_table()?;
for section in section_table.iter() { for section in &section_table {
debug!(section_name = ?section.name()?); debug!(section_name = ?section.name()?);
} }

View file

@ -260,7 +260,7 @@ fn decode_tdx_mrs(
Some(mrs_array) => { Some(mrs_array) => {
let result = mrs_array let result = mrs_array
.into_iter() .into_iter()
.map(|strings| decode_and_combine_mrs(strings, bytes_length)) .map(|strings| decode_and_combine_mrs(&strings, bytes_length))
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
Ok(Some(result)) Ok(Some(result))
} }
@ -269,12 +269,12 @@ fn decode_tdx_mrs(
// Helper function to decode and combine MRs // Helper function to decode and combine MRs
fn decode_and_combine_mrs( fn decode_and_combine_mrs(
strings: [String; 5], strings: &[String; 5],
bytes_length: usize, bytes_length: usize,
) -> Result<Bytes, hex::FromHexError> { ) -> Result<Bytes, hex::FromHexError> {
let mut buffer = BytesMut::with_capacity(bytes_length * 5); let mut buffer = BytesMut::with_capacity(bytes_length * 5);
for s in &strings { for s in strings {
if s.len() > (bytes_length * 2) { if s.len() > (bytes_length * 2) {
return Err(hex::FromHexError::InvalidStringLength); return Err(hex::FromHexError::InvalidStringLength);
} }

View file

@ -96,7 +96,7 @@ impl Error {
impl From<reqwest::Error> for Error { impl From<reqwest::Error> for Error {
fn from(value: reqwest::Error) -> Self { fn from(value: reqwest::Error) -> Self {
Self::Http { Self::Http {
status_code: value.status().map(|v| v.as_u16()).unwrap_or(0), status_code: value.status().map_or(0, |v| v.as_u16()),
message: value.to_string(), message: value.to_string(),
} }
} }

View file

@ -53,7 +53,7 @@ impl BatchProcessor {
// Fetch proofs for the current batch across different TEE types // Fetch proofs for the current batch across different TEE types
let mut proofs = Vec::new(); let mut proofs = Vec::new();
for tee_type in self.config.args.tee_types.iter() { for tee_type in self.config.args.tee_types.iter().copied() {
match self match self
.proof_fetcher .proof_fetcher
.get_proofs(token, batch_number, tee_type) .get_proofs(token, batch_number, tee_type)

View file

@ -43,14 +43,14 @@ impl ProcessorFactory {
/// Create a new processor based on the provided configuration /// Create a new processor based on the provided configuration
pub fn create(config: VerifierConfig) -> Result<(ProcessorType, VerifierMode)> { pub fn create(config: VerifierConfig) -> Result<(ProcessorType, VerifierMode)> {
let mode = if let Some((start, end)) = config.args.batch_range { let mode = if let Some((start, end)) = config.args.batch_range {
let processor = OneShotProcessor::new(config.clone(), start, end)?; let processor = OneShotProcessor::new(config, start, end)?;
let mode = VerifierMode::OneShot { let mode = VerifierMode::OneShot {
start_batch: start, start_batch: start,
end_batch: end, end_batch: end,
}; };
(ProcessorType::OneShot(processor), mode) (ProcessorType::OneShot(processor), mode)
} else if let Some(start) = config.args.continuous { } else if let Some(start) = config.args.continuous {
let processor = ContinuousProcessor::new(config.clone(), start)?; let processor = ContinuousProcessor::new(config, start)?;
let mode = VerifierMode::Continuous { start_batch: start }; let mode = VerifierMode::Continuous { start_batch: start };
(ProcessorType::Continuous(processor), mode) (ProcessorType::Continuous(processor), mode)
} else { } else {

View file

@ -36,7 +36,7 @@ impl ProofFetcher {
&self, &self,
token: &CancellationToken, token: &CancellationToken,
batch_number: L1BatchNumber, batch_number: L1BatchNumber,
tee_type: &TeeType, tee_type: TeeType,
) -> Result<Vec<Proof>> { ) -> Result<Vec<Proof>> {
let mut proofs_request = GetProofsRequest::new(batch_number, tee_type); let mut proofs_request = GetProofsRequest::new(batch_number, tee_type);
let mut backoff = Duration::from_secs(1); let mut backoff = Duration::from_secs(1);

View file

@ -17,7 +17,7 @@ pub struct GetProofsRequest {
impl GetProofsRequest { impl GetProofsRequest {
/// Create a new request for the given batch number /// Create a new request for the given batch number
pub fn new(batch_number: L1BatchNumber, tee_type: &TeeType) -> Self { pub fn new(batch_number: L1BatchNumber, tee_type: TeeType) -> Self {
GetProofsRequest { GetProofsRequest {
jsonrpc: "2.0".to_string(), jsonrpc: "2.0".to_string(),
id: 1, id: 1,

View file

@ -48,7 +48,7 @@ impl<C: JsonRpcClient> BatchVerifier<C> {
let mut total_proofs_count: u32 = 0; let mut total_proofs_count: u32 = 0;
let mut verified_proofs_count: u32 = 0; let mut verified_proofs_count: u32 = 0;
for proof in proofs.into_iter() { for proof in proofs {
if token.is_cancelled() { if token.is_cancelled() {
tracing::warn!("Stop signal received during batch verification"); tracing::warn!("Stop signal received during batch verification");
return Ok(BatchVerificationResult { return Ok(BatchVerificationResult {

View file

@ -24,7 +24,7 @@ impl PolicyEnforcer {
match &quote.report { match &quote.report {
Report::SgxEnclave(report_body) => { Report::SgxEnclave(report_body) => {
// Validate TCB level // Validate TCB level
Self::validate_tcb_level(&attestation_policy.sgx_allowed_tcb_levels, tcblevel)?; Self::validate_tcb_level(attestation_policy.sgx_allowed_tcb_levels, tcblevel)?;
// Validate SGX Advisories // Validate SGX Advisories
for advisory in &quote_verification_result.advisories { for advisory in &quote_verification_result.advisories {
@ -50,7 +50,7 @@ impl PolicyEnforcer {
} }
Report::TD10(report_body) => { Report::TD10(report_body) => {
// Validate TCB level // Validate TCB level
Self::validate_tcb_level(&attestation_policy.tdx_allowed_tcb_levels, tcblevel)?; Self::validate_tcb_level(attestation_policy.tdx_allowed_tcb_levels, tcblevel)?;
// Validate TDX Advisories // Validate TDX Advisories
for advisory in &quote_verification_result.advisories { for advisory in &quote_verification_result.advisories {
@ -74,7 +74,7 @@ impl PolicyEnforcer {
} }
Report::TD15(report_body) => { Report::TD15(report_body) => {
// Validate TCB level // Validate TCB level
Self::validate_tcb_level(&attestation_policy.tdx_allowed_tcb_levels, tcblevel)?; Self::validate_tcb_level(attestation_policy.tdx_allowed_tcb_levels, tcblevel)?;
// Validate TDX Advisories // Validate TDX Advisories
for advisory in &quote_verification_result.advisories { for advisory in &quote_verification_result.advisories {
@ -101,10 +101,7 @@ impl PolicyEnforcer {
} }
/// Helper method to validate TCB levels /// Helper method to validate TCB levels
fn validate_tcb_level( fn validate_tcb_level(allowed_levels: EnumSet<TcbLevel>, actual_level: TcbLevel) -> Result<()> {
allowed_levels: &EnumSet<TcbLevel>,
actual_level: TcbLevel,
) -> Result<()> {
if !allowed_levels.contains(actual_level) { if !allowed_levels.contains(actual_level) {
let error_msg = format!( let error_msg = format!(
"Quote verification failed: TCB level mismatch (expected one of: {allowed_levels:?}, actual: {actual_level})", "Quote verification failed: TCB level mismatch (expected one of: {allowed_levels:?}, actual: {actual_level})",
@ -116,7 +113,7 @@ impl PolicyEnforcer {
/// Helper method to build combined TDX measurement register /// Helper method to build combined TDX measurement register
fn build_tdx_mr<const N: usize>(parts: [&[u8]; N]) -> Vec<u8> { fn build_tdx_mr<const N: usize>(parts: [&[u8]; N]) -> Vec<u8> {
parts.into_iter().flatten().cloned().collect() parts.into_iter().flatten().copied().collect()
} }
/// Check if a policy value matches the actual value /// Check if a policy value matches the actual value

View file

@ -193,7 +193,7 @@ impl ApiClient {
} }
/// Checks if a V4-only parameter is provided with a V3 API version. /// Checks if a V4-only parameter is provided with a V3 API version.
pub(super) fn check_v4_only_param<T>( pub(super) fn check_v4_only_param<T: Copy>(
&self, &self,
param_value: Option<T>, param_value: Option<T>,
param_name: &str, param_name: &str,

View file

@ -24,7 +24,7 @@ fn extract_header_value(
.ok_or_else(|| QuoteError::Unexpected(format!("Missing required header: {header_name}")))? .ok_or_else(|| QuoteError::Unexpected(format!("Missing required header: {header_name}")))?
.to_str() .to_str()
.map_err(|e| QuoteError::Unexpected(format!("Invalid header value: {e}"))) .map_err(|e| QuoteError::Unexpected(format!("Invalid header value: {e}")))
.map(|val| val.to_string()) .map(str::to_string)
} }
/// Fetch collateral data from Intel's Provisioning Certification Service /// Fetch collateral data from Intel's Provisioning Certification Service
@ -74,14 +74,14 @@ pub(crate) fn get_collateral(quote: &[u8]) -> Result<Collateral, QuoteError> {
let (collateral, pck_crl, pck_issuer_chain) = result; let (collateral, pck_crl, pck_issuer_chain) = result;
// Convert QuoteCollateralV3 to Collateral // Convert QuoteCollateralV3 to Collateral
convert_to_collateral(collateral, pck_crl, pck_issuer_chain) convert_to_collateral(collateral, &pck_crl, &pck_issuer_chain)
} }
// Helper function to convert QuoteCollateralV3 to Collateral // Helper function to convert QuoteCollateralV3 to Collateral
fn convert_to_collateral( fn convert_to_collateral(
collateral: QuoteCollateralV3, collateral: QuoteCollateralV3,
pck_crl: String, pck_crl: &str,
pck_issuer_chain: Bytes, pck_issuer_chain: &[u8],
) -> Result<Collateral, QuoteError> { ) -> Result<Collateral, QuoteError> {
let QuoteCollateralV3 { let QuoteCollateralV3 {
tcb_info_issuer_chain, tcb_info_issuer_chain,
@ -119,7 +119,7 @@ fn convert_to_collateral(
root_ca_crl: Box::new([]), root_ca_crl: Box::new([]),
// Converted values // Converted values
pck_crl_issuer_chain: pck_issuer_chain.as_ref().into(), pck_crl_issuer_chain: pck_issuer_chain.into(),
pck_crl: pck_crl.as_bytes().into(), pck_crl: pck_crl.as_bytes().into(),
tcb_info_issuer_chain: to_bytes_with_nul(tcb_info_issuer_chain, "tcb_info_issuer_chain")?, tcb_info_issuer_chain: to_bytes_with_nul(tcb_info_issuer_chain, "tcb_info_issuer_chain")?,
tcb_info: to_bytes_with_nul(tcb_info_json, "tcb_info")?, tcb_info: to_bytes_with_nul(tcb_info_json, "tcb_info")?,