openzeppelin_relayer/models/
secret_string.rsuse std::{fmt, sync::Mutex};
use secrets::SecretVec;
use serde::{Deserialize, Serialize};
use zeroize::Zeroizing;
pub struct SecretString(Mutex<SecretVec<u8>>);
impl Clone for SecretString {
fn clone(&self) -> Self {
let secret_vec = self.with_secret_vec(|secret_vec| secret_vec.clone());
Self(Mutex::new(secret_vec))
}
}
impl SecretString {
pub fn new(s: &str) -> Self {
let bytes = Zeroizing::new(s.as_bytes().to_vec());
let secret_vec = SecretVec::new(bytes.len(), |buffer| {
buffer.copy_from_slice(&bytes);
});
Self(Mutex::new(secret_vec))
}
fn with_secret_vec<F, R>(&self, f: F) -> R
where
F: FnOnce(&SecretVec<u8>) -> R,
{
let guard = match self.0.lock() {
Ok(guard) => guard,
Err(poisoned) => poisoned.into_inner(),
};
f(&guard)
}
pub fn as_str<F, R>(&self, f: F) -> R
where
F: FnOnce(&str) -> R,
{
self.with_secret_vec(|secret_vec| {
let bytes = secret_vec.borrow();
let s = unsafe { std::str::from_utf8_unchecked(&bytes) };
f(s)
})
}
pub fn to_str(&self) -> Zeroizing<String> {
self.with_secret_vec(|secret_vec| {
let bytes = secret_vec.borrow();
let s = unsafe { std::str::from_utf8_unchecked(&bytes) };
Zeroizing::new(s.to_string())
})
}
pub fn is_empty(&self) -> bool {
self.with_secret_vec(|secret_vec| secret_vec.is_empty())
}
pub fn has_minimum_length(&self, min_length: usize) -> bool {
self.with_secret_vec(|secret_vec| {
let bytes = secret_vec.borrow();
bytes.len() >= min_length
})
}
}
impl Serialize for SecretString {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str("REDACTED")
}
}
impl<'de> Deserialize<'de> for SecretString {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = Zeroizing::new(String::deserialize(deserializer)?);
Ok(SecretString::new(&s))
}
}
impl PartialEq for SecretString {
fn eq(&self, other: &Self) -> bool {
self.with_secret_vec(|self_vec| {
other.with_secret_vec(|other_vec| {
let self_bytes = self_vec.borrow();
let other_bytes = other_vec.borrow();
self_bytes.len() == other_bytes.len()
&& subtle::ConstantTimeEq::ct_eq(&*self_bytes, &*other_bytes).into()
})
})
}
}
impl fmt::Debug for SecretString {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "SecretString(REDACTED)")
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::{Arc, Barrier};
use std::thread;
#[test]
fn test_new_creates_valid_secret_string() {
let secret = SecretString::new("test_secret_value");
secret.as_str(|s| {
assert_eq!(s, "test_secret_value");
});
}
#[test]
fn test_empty_string_is_handled_correctly() {
let empty = SecretString::new("");
assert!(empty.is_empty());
empty.as_str(|s| {
assert_eq!(s, "");
});
}
#[test]
fn test_to_str_creates_correct_zeroizing_copy() {
let secret = SecretString::new("temporary_copy");
let copy = secret.to_str();
assert_eq!(&*copy, "temporary_copy");
}
#[test]
fn test_is_empty_returns_correct_value() {
let empty = SecretString::new("");
let non_empty = SecretString::new("not empty");
assert!(empty.is_empty());
assert!(!non_empty.is_empty());
}
#[test]
fn test_serialization_redacts_content() {
let secret = SecretString::new("should_not_appear_in_serialized_form");
let serialized = serde_json::to_string(&secret).unwrap();
assert_eq!(serialized, "\"REDACTED\"");
assert!(!serialized.contains("should_not_appear_in_serialized_form"));
}
#[test]
fn test_deserialization_creates_valid_secret_string() {
let json_str = "\"deserialized_secret\"";
let deserialized: SecretString = serde_json::from_str(json_str).unwrap();
deserialized.as_str(|s| {
assert_eq!(s, "deserialized_secret");
});
}
#[test]
fn test_equality_comparison_works_correctly() {
let secret1 = SecretString::new("same_value");
let secret2 = SecretString::new("same_value");
let secret3 = SecretString::new("different_value");
assert_eq!(secret1, secret2);
assert_ne!(secret1, secret3);
}
#[test]
fn test_debug_output_redacts_content() {
let secret = SecretString::new("should_not_appear_in_debug");
let debug_str = format!("{:?}", secret);
assert_eq!(debug_str, "SecretString(REDACTED)");
assert!(!debug_str.contains("should_not_appear_in_debug"));
}
#[test]
fn test_thread_safety() {
let secret = SecretString::new("shared_across_threads");
let num_threads = 10;
let barrier = Arc::new(Barrier::new(num_threads));
let mut handles = vec![];
for i in 0..num_threads {
let thread_secret = secret.clone();
let thread_barrier = barrier.clone();
let handle = thread::spawn(move || {
thread_barrier.wait();
thread_secret.as_str(|s| {
assert_eq!(s, "shared_across_threads");
});
assert!(!thread_secret.is_empty());
let copy = thread_secret.to_str();
assert_eq!(&*copy, "shared_across_threads");
i
});
handles.push(handle);
}
let mut completed_threads = vec![];
for handle in handles {
completed_threads.push(handle.join().unwrap());
}
completed_threads.sort();
assert_eq!(completed_threads, (0..num_threads).collect::<Vec<_>>());
}
#[test]
fn test_unicode_handling() {
let unicode_string = "こんにちは世界!";
let secret = SecretString::new(unicode_string);
secret.as_str(|s| {
assert_eq!(s, unicode_string);
assert_eq!(s.chars().count(), 8); });
}
#[test]
fn test_special_characters_handling() {
let special_chars = "!@#$%^&*()_+{}|:<>?~`-=[]\\;',./";
let secret = SecretString::new(special_chars);
secret.as_str(|s| {
assert_eq!(s, special_chars);
});
}
#[test]
fn test_very_long_string() {
let long_string = "a".repeat(100_000);
let secret = SecretString::new(&long_string);
secret.as_str(|s| {
assert_eq!(s.len(), 100_000);
assert_eq!(s, long_string);
});
assert_eq!(secret.0.lock().unwrap().len(), 100_000);
}
#[test]
fn test_has_minimum_length() {
let empty = SecretString::new("");
let short = SecretString::new("abc");
let medium = SecretString::new("abcdefghij"); let long = SecretString::new("abcdefghijklmnopqrst"); assert!(empty.has_minimum_length(0));
assert!(short.has_minimum_length(0));
assert!(medium.has_minimum_length(0));
assert!(long.has_minimum_length(0));
assert!(!empty.has_minimum_length(1));
assert!(short.has_minimum_length(1));
assert!(medium.has_minimum_length(1));
assert!(long.has_minimum_length(1));
assert!(empty.has_minimum_length(0));
assert!(short.has_minimum_length(3));
assert!(medium.has_minimum_length(10));
assert!(long.has_minimum_length(20));
assert!(!empty.has_minimum_length(1));
assert!(!short.has_minimum_length(4));
assert!(!medium.has_minimum_length(11));
assert!(!long.has_minimum_length(21));
assert!(!short.has_minimum_length(100));
assert!(!medium.has_minimum_length(100));
assert!(!long.has_minimum_length(100));
}
}