refactor: simplify downcasting logic with std::mem::transmute (#20)

Simplified the downcasting implementations by replacing pointer casting
logic with `std::mem::transmute`, ensuring type safety after matching.
Added tests to validate various downcasting behaviors for both owned and
trait-object error scenarios, improving overall reliability and test
coverage.
This commit is contained in:
Harald Hoyer 2025-03-31 14:55:16 +02:00 committed by GitHub
commit 28eb28e47d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -338,11 +338,8 @@ impl<U: 'static + Display + Debug> ErrorDown for Error<U> {
#[inline]
fn downcast_chain_ref<T: 'static + Display + Debug>(&self) -> Option<&Error<T>> {
if self.is_chain::<T>() {
#[allow(clippy::cast_ptr_alignment)]
unsafe {
#[allow(trivial_casts)]
Some(*(self as *const dyn StdError as *const &Error<T>))
}
// Use transmute when we've verified the types match
unsafe { Some(std::mem::transmute::<&Error<U>, &Error<T>>(self)) }
} else {
None
}
@ -351,11 +348,8 @@ impl<U: 'static + Display + Debug> ErrorDown for Error<U> {
#[inline]
fn downcast_chain_mut<T: 'static + Display + Debug>(&mut self) -> Option<&mut Error<T>> {
if self.is_chain::<T>() {
#[allow(clippy::cast_ptr_alignment)]
unsafe {
#[allow(trivial_casts)]
Some(&mut *(self as *mut dyn StdError as *mut &mut Error<T>))
}
// Use transmute when we've verified the types match
unsafe { Some(std::mem::transmute::<&mut Error<U>, &mut Error<T>>(self)) }
} else {
None
}
@ -363,11 +357,8 @@ impl<U: 'static + Display + Debug> ErrorDown for Error<U> {
#[inline]
fn downcast_inner_ref<T: 'static + StdError>(&self) -> Option<&T> {
if self.is_chain::<T>() {
#[allow(clippy::cast_ptr_alignment)]
unsafe {
#[allow(trivial_casts)]
Some(&(*(self as *const dyn StdError as *const &Error<T>)).kind)
}
// Use transmute when we've verified the types match
unsafe { Some(std::mem::transmute::<&U, &T>(&self.kind)) }
} else {
None
}
@ -376,11 +367,8 @@ impl<U: 'static + Display + Debug> ErrorDown for Error<U> {
#[inline]
fn downcast_inner_mut<T: 'static + StdError>(&mut self) -> Option<&mut T> {
if self.is_chain::<T>() {
#[allow(clippy::cast_ptr_alignment)]
unsafe {
#[allow(trivial_casts)]
Some(&mut (*(self as *mut dyn StdError as *mut &mut Error<T>)).kind)
}
// Use transmute when we've verified the types match
unsafe { Some(std::mem::transmute::<&mut U, &mut T>(&mut self.kind)) }
} else {
None
}
@ -913,4 +901,125 @@ mod tests {
assert!(err.source().is_some());
assert!(err.source().unwrap().is_chain::<io::Error>());
}
// Helper error types for testing
#[derive(Debug)]
struct TestError(String);
impl std::fmt::Display for TestError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl std::error::Error for TestError {}
#[test]
fn test_downcast_chain_operations() {
// Create a test error chain
let original_error = Error::new(
TestError("test message".to_string()),
None,
Some("test location".to_string()),
);
// Test is_chain
assert!(original_error.is_chain::<TestError>());
assert!(!original_error.is_chain::<io::Error>());
// Test downcast_chain_ref
let downcast_ref = original_error.downcast_chain_ref::<TestError>();
assert!(downcast_ref.is_some());
let downcast_kind = downcast_ref.unwrap().kind();
assert_eq!(format!("{}", downcast_kind), "test message");
assert_eq!(
format!("{:?}", downcast_kind),
"TestError(\"test message\")"
);
// Test invalid downcast_chain_ref
let invalid_downcast = original_error.downcast_chain_ref::<io::Error>();
assert!(invalid_downcast.is_none());
// Test downcast_chain_mut
let mut mutable_error = original_error;
let downcast_mut = mutable_error.downcast_chain_mut::<TestError>();
assert!(downcast_mut.is_some());
assert_eq!(downcast_mut.unwrap().kind().0, "test message");
// Test invalid downcast_chain_mut
let invalid_downcast_mut = mutable_error.downcast_chain_mut::<io::Error>();
assert!(invalid_downcast_mut.is_none());
}
#[test]
fn test_downcast_inner_operations() {
// Create a test error
let mut error = Error::new(
TestError("inner test".to_string()),
None,
Some("test location".to_string()),
);
// Test downcast_inner_ref
let inner_ref = error.downcast_inner_ref::<TestError>();
assert!(inner_ref.is_some());
assert_eq!(inner_ref.unwrap().0, "inner test");
// Test invalid downcast_inner_ref
let invalid_inner = error.downcast_inner_ref::<io::Error>();
assert!(invalid_inner.is_none());
// Test downcast_inner_mut
let inner_mut = error.downcast_inner_mut::<TestError>();
assert!(inner_mut.is_some());
assert_eq!(inner_mut.unwrap().0, "inner test");
// Test invalid downcast_inner_mut
let invalid_inner_mut = error.downcast_inner_mut::<io::Error>();
assert!(invalid_inner_mut.is_none());
}
#[test]
fn test_error_down_for_dyn_error() {
// Create a boxed error
let error: Box<dyn std::error::Error + 'static> = Box::new(Error::new(
TestError("dyn test".to_string()),
None,
Some("test location".to_string()),
));
// Test is_chain through trait object
assert!(error.is_chain::<TestError>());
assert!(!error.is_chain::<io::Error>());
// Test downcast_chain_ref through trait object
let chain_ref = error.downcast_chain_ref::<TestError>();
assert!(chain_ref.is_some());
assert_eq!(chain_ref.unwrap().kind().0, "dyn test");
// Test downcast_inner_ref through trait object
let inner_ref = error.downcast_inner_ref::<TestError>();
assert!(inner_ref.is_some());
assert_eq!(inner_ref.unwrap().0, "dyn test");
}
#[test]
fn test_error_down_with_sync_send() {
// Create a boxed error with Send + Sync
let error: Box<dyn std::error::Error + Send + Sync> = Box::new(Error::new(
TestError("sync test".to_string()),
None,
Some("test location".to_string()),
));
// Test operations on Send + Sync error
assert!(error.is_chain::<TestError>());
assert!(error.downcast_chain_ref::<TestError>().is_some());
assert!(error.downcast_inner_ref::<TestError>().is_some());
// Test invalid downcasts
assert!(!error.is_chain::<io::Error>());
assert!(error.downcast_chain_ref::<io::Error>().is_none());
assert!(error.downcast_inner_ref::<io::Error>().is_none());
}
}