diff --git a/src/lib.rs b/src/lib.rs index b052a53..faa60d5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -214,11 +214,7 @@ impl ChainError { /// return the root cause of the error chain, if any exists pub fn root_cause(&self) -> Option<&(dyn Error + 'static)> { - let mut cause = self as &(dyn Error + 'static); - while let Some(c) = cause.source() { - cause = c; - } - Some(cause) + self.iter().last() } /** find the first error cause of type U, if any exists @@ -270,17 +266,7 @@ impl ChainError { ~~~ **/ pub fn find_cause(&self) -> Option<&U> { - let mut cause = self as &(dyn Error + 'static); - loop { - if cause.is::() { - return cause.downcast_ref::(); - } - - match cause.source() { - Some(c) => cause = c, - None => return None, - } - } + self.iter().filter_map(Error::downcast_ref::()).next() } /** find the first error cause of type ChainError, if any exists @@ -299,17 +285,9 @@ impl ChainError { **/ pub fn find_chain_cause(&self) -> Option<&ChainError> { - let mut cause = self as &(dyn Error + 'static); - loop { - if cause.is::>() { - return cause.downcast_ref::>(); - } - - match cause.source() { - Some(c) => cause = c, - None => return None, - } - } + self.iter() + .filter_map(Error::downcast_ref::>()) + .next() } /** return a reference to T of `ChainError` @@ -374,6 +352,26 @@ impl ChainError { pub fn kind(&self) -> &T { &self.kind } + + pub fn iter(&self) -> impl Iterator { + ErrorIter { + current: Some(self), + } + } +} + +struct ErrorIter<'a> { + current: Option<&'a (dyn Error + 'static)>, +} + +impl<'a> Iterator for ErrorIter<'a> { + type Item = &'a (dyn Error + 'static); + + fn next(&mut self) -> Option { + let current = self.current; + self.current = self.current.and_then(Error::source); + current + } } /** convenience trait to hide the `ChainError` implementation internals diff --git a/tests/test_iter.rs b/tests/test_iter.rs new file mode 100644 index 0000000..c351160 --- /dev/null +++ b/tests/test_iter.rs @@ -0,0 +1,66 @@ +use chainerror::*; +use std::error::Error; +use std::fmt::Write; +use std::io; + +#[test] +fn test_iter() -> Result<(), Box> { + let err = io::Error::from(io::ErrorKind::NotFound); + let err = cherr!(err, "1"); + let err = cherr!(err, "2"); + let err = cherr!(err, "3"); + let err = cherr!(err, "4"); + let err = cherr!(err, "5"); + let err = cherr!(err, "6"); + + let mut res = String::new(); + + for e in err.iter() { + write!(res, "{}", e.to_string())?; + } + assert_eq!(res, "654321entity not found"); + + let io_error: Option<&io::Error> = err + .iter() + .filter_map(Error::downcast_ref::) + .next(); + + assert_eq!(io_error.unwrap().kind(), io::ErrorKind::NotFound); + + Ok(()) +} + +#[test] +fn test_find_cause() -> Result<(), Box> { + let err = io::Error::from(io::ErrorKind::NotFound); + let err = cherr!(err, "1"); + let err = cherr!(err, "2"); + let err = cherr!(err, "3"); + let err = cherr!(err, "4"); + let err = cherr!(err, "5"); + let err = cherr!(err, "6"); + + let io_error: Option<&io::Error> = err.find_cause::(); + + assert_eq!(io_error.unwrap().kind(), io::ErrorKind::NotFound); + + Ok(()) +} + +#[test] +fn test_root_cause() -> Result<(), Box> { + let err = io::Error::from(io::ErrorKind::NotFound); + let err = cherr!(err, "1"); + let err = cherr!(err, "2"); + let err = cherr!(err, "3"); + let err = cherr!(err, "4"); + let err = cherr!(err, "5"); + let err = cherr!(err, "6"); + + let err: Option<&(dyn std::error::Error + 'static)> = err.root_cause(); + let io_error: Option<&io::Error> = err.and_then(Error::downcast_ref::); + + assert_eq!(io_error.unwrap().kind(), io::ErrorKind::NotFound); + + Ok(()) +}