1use core::{
45 cell::UnsafeCell,
46 marker::PhantomData,
47 ops::Deref,
48 sync::atomic::{AtomicU8, Ordering},
49};
50
51use crate::{
52 Algorithm, ByteArray, Encrypted, STATE_DECRYPTED, STATE_DECRYPTING, STATE_UNENCRYPTED,
53 StringLiteral,
54 drop_strategy::{DropStrategy, Zeroize},
55};
56
57pub struct ReEncrypt<const KEY: u8>;
58
59impl<const KEY: u8> DropStrategy for ReEncrypt<KEY> {
60 type Extra = ();
61 fn drop(data: &mut [u8], _extra: &()) {
62 for byte in data {
63 *byte ^= KEY;
64 }
65 }
66}
67
68pub struct Xor<const KEY: u8, D: DropStrategy = Zeroize>(PhantomData<D>);
71
72impl<const KEY: u8, D: DropStrategy<Extra = ()>> Algorithm for Xor<KEY, D> {
73 type Drop = D;
74 type Extra = ();
75}
76
77impl<const KEY: u8, D: DropStrategy<Extra = ()>, M, const N: usize> Encrypted<Xor<KEY, D>, M, N> {
78 pub const fn new(mut buffer: [u8; N]) -> Self {
79 let mut i = 0;
81 while i < N {
82 buffer[i] ^= KEY;
83 i += 1;
84 }
85
86 Encrypted {
87 buffer: UnsafeCell::new(buffer),
88 decryption_state: AtomicU8::new(STATE_UNENCRYPTED),
89 extra: (),
90 _phantom: PhantomData,
91 }
92 }
93}
94
95impl<const KEY: u8, D: DropStrategy<Extra = ()>, const N: usize> Deref
96 for Encrypted<Xor<KEY, D>, ByteArray, N>
97{
98 type Target = [u8; N];
99
100 fn deref(&self) -> &Self::Target {
101 if self.decryption_state.load(Ordering::Acquire) == STATE_DECRYPTED {
103 return unsafe { &*self.buffer.get() };
105 }
106
107 match self.decryption_state.compare_exchange(
109 STATE_UNENCRYPTED,
110 STATE_DECRYPTING,
111 Ordering::AcqRel,
112 Ordering::Acquire,
113 ) {
114 Ok(_) => {
115 let data = unsafe { &mut *self.buffer.get() };
118 for byte in data.iter_mut() {
119 *byte ^= KEY;
120 }
121
122 self.decryption_state.store(STATE_DECRYPTED, Ordering::Release);
125 }
126 Err(_) => {
127 while self.decryption_state.load(Ordering::Acquire) != STATE_DECRYPTED {
130 core::hint::spin_loop();
131 }
132 }
133 }
134
135 unsafe { &*self.buffer.get() }
139 }
140}
141
142impl<const KEY: u8, D: DropStrategy<Extra = ()>, const N: usize> Deref
143 for Encrypted<Xor<KEY, D>, StringLiteral, N>
144{
145 type Target = str;
146
147 fn deref(&self) -> &Self::Target {
148 if self.decryption_state.load(Ordering::Acquire) == STATE_DECRYPTED {
150 let bytes = unsafe { &*self.buffer.get() };
152 return unsafe { core::str::from_utf8_unchecked(bytes) };
154 }
155
156 match self.decryption_state.compare_exchange(
158 STATE_UNENCRYPTED,
159 STATE_DECRYPTING,
160 Ordering::AcqRel,
161 Ordering::Acquire,
162 ) {
163 Ok(_) => {
164 let data = unsafe { &mut *self.buffer.get() };
167 for byte in data.iter_mut() {
168 *byte ^= KEY;
169 }
170
171 self.decryption_state.store(STATE_DECRYPTED, Ordering::Release);
174 }
175 Err(_) => {
176 while self.decryption_state.load(Ordering::Acquire) != STATE_DECRYPTED {
179 core::hint::spin_loop();
180 }
181 }
182 }
183
184 let bytes = unsafe { &*self.buffer.get() };
188
189 unsafe { core::str::from_utf8_unchecked(bytes) }
191 }
192}
193
194#[cfg(test)]
195mod tests {
196 use super::*;
197 use crate::{
198 ByteArray, StringLiteral,
199 align::{Aligned8, Aligned16},
200 drop_strategy::{NoOp, Zeroize},
201 xor::Xor,
202 };
203
204 use alloc::vec;
205 use alloc::vec::Vec;
206 use core::{mem::size_of, sync::atomic::AtomicUsize};
207 use std::sync::Arc;
208 use std::thread;
209
210 #[test]
211 fn test_size() {
212 assert_eq!(17, size_of::<Encrypted<Xor<0xAA, Zeroize>, ByteArray, 16>>());
213 assert_eq!(17, size_of::<Encrypted<Xor<0xAA, NoOp>, ByteArray, 16>>());
214 assert_eq!(17, size_of::<Encrypted<Xor<0xAA, ReEncrypt<0xAA>>, ByteArray, 16>>());
215
216 assert_eq!(24, size_of::<Aligned8<Encrypted<Xor<0xAA, ReEncrypt<0xAA>>, ByteArray, 16>>>());
218 assert_eq!(
219 32,
220 size_of::<Aligned16<Encrypted<Xor<0xAA, ReEncrypt<0xAA>>, ByteArray, 16>>>()
221 );
222 }
223
224 const CONST_ENCRYPTED: Encrypted<Xor<0xAA, Zeroize>, ByteArray, 5> =
225 Encrypted::<Xor<0xAA, Zeroize>, ByteArray, 5>::new(*b"hello");
226
227 const CONST_ENCRYPTED_STR: Encrypted<Xor<0xFF, Zeroize>, StringLiteral, 3> =
228 Encrypted::<Xor<0xFF, Zeroize>, StringLiteral, 3>::new(*b"abc");
229
230 const CONST_ENCRYPTED_SINGLE: Encrypted<Xor<0xFF, Zeroize>, ByteArray, 1> =
231 Encrypted::<Xor<0xFF, Zeroize>, ByteArray, 1>::new([42]);
232
233 const CONST_ENCRYPTED_ZEROS: Encrypted<Xor<0xAA, Zeroize>, ByteArray, 4> =
234 Encrypted::<Xor<0xAA, Zeroize>, ByteArray, 4>::new([0, 0, 0, 0]);
235
236 const CONST_ENCRYPTED_NOOP_KEY: Encrypted<Xor<0x00, Zeroize>, ByteArray, 3> =
237 Encrypted::<Xor<0x00, Zeroize>, ByteArray, 3>::new(*b"abc");
238
239 #[test]
240 fn test_new_in_const_context() {
241 let plain: &[u8; 5] = &*CONST_ENCRYPTED;
242 assert_eq!(plain, b"hello");
243 }
244
245 #[test]
246 fn test_buffer_is_encrypted_before_deref() {
247 let encrypted = CONST_ENCRYPTED;
249
250 let raw = unsafe { &*encrypted.buffer.get() };
252 let expected = [b'h' ^ 0xAA, b'e' ^ 0xAA, b'l' ^ 0xAA, b'l' ^ 0xAA, b'o' ^ 0xAA];
253 assert_eq!(raw, &expected, "buffer should be XOR-encrypted before deref");
254 assert_ne!(raw, b"hello", "buffer must NOT be plaintext before deref");
255 }
256
257 #[test]
258 fn test_string_buffer_is_encrypted_before_deref() {
259 let encrypted = CONST_ENCRYPTED_STR;
260
261 let raw = unsafe { &*encrypted.buffer.get() };
262 let expected = [b'a' ^ 0xFF, b'b' ^ 0xFF, b'c' ^ 0xFF];
263 assert_eq!(raw, &expected, "string buffer should be XOR-encrypted before deref");
264 assert_ne!(raw, b"abc");
265 }
266
267 #[test]
268 fn test_bytearray_deref_decrypts() {
269 let encrypted = CONST_ENCRYPTED;
270
271 let plain: &[u8; 5] = &*encrypted;
273 assert_eq!(plain, b"hello");
274 }
275
276 #[test]
277 fn test_bytearray_deref_single_byte() {
278 let pre_deref = CONST_ENCRYPTED_SINGLE;
279 let raw = unsafe { &*pre_deref.buffer.get() };
280 assert_eq!(raw, &[42 ^ 0xFF]);
281
282 let encrypted = CONST_ENCRYPTED_SINGLE;
283 let plain: &[u8; 1] = &*encrypted;
284 assert_eq!(plain, &[42]);
285 }
286
287 #[test]
288 fn test_bytearray_deref_all_zeros() {
289 let pre_deref = CONST_ENCRYPTED_ZEROS;
290 let raw = unsafe { &*pre_deref.buffer.get() };
291 assert_eq!(raw, &[0xAA, 0xAA, 0xAA, 0xAA]);
292
293 let encrypted = CONST_ENCRYPTED_ZEROS;
294 let plain: &[u8; 4] = &*encrypted;
295 assert_eq!(plain, &[0, 0, 0, 0]);
296 }
297
298 #[test]
299 fn test_bytearray_deref_key_zero_is_identity() {
300 let pre_deref = CONST_ENCRYPTED_NOOP_KEY;
302 let raw = unsafe { &*pre_deref.buffer.get() };
303 assert_eq!(raw, b"abc", "key 0x00 should leave buffer unchanged");
304
305 let encrypted = CONST_ENCRYPTED_NOOP_KEY;
306 let plain: &[u8; 3] = &*encrypted;
307 assert_eq!(plain, b"abc");
308 }
309
310 #[test]
311 fn test_bytearray_multiple_derefs_are_idempotent() {
312 let encrypted = CONST_ENCRYPTED;
313
314 let first: &[u8; 5] = &*encrypted;
315 let second: &[u8; 5] = &*encrypted;
316 assert_eq!(first, b"hello");
317 assert_eq!(second, b"hello");
318 }
319
320 #[test]
321 fn test_encrypted_is_sync() {
322 const fn assert_sync<T: Sync>() {}
323 const fn check() {
324 assert_sync::<Encrypted<Xor<0xAA, Zeroize>, ByteArray, 5>>();
325 assert_sync::<Encrypted<Xor<0xBB, ReEncrypt<0xBB>>, StringLiteral, 5>>();
326 assert_sync::<Encrypted<Xor<0xCC, NoOp>, ByteArray, 8>>();
327 }
328 check();
329 }
330
331 #[test]
332 fn test_concurrent_deref_same_value() {
333 const SHARED: Encrypted<Xor<0xAA, Zeroize>, StringLiteral, 5> =
334 Encrypted::<Xor<0xAA, Zeroize>, StringLiteral, 5>::new(*b"hello");
335
336 let shared = Arc::new(SHARED);
337 let mut handles: Vec<thread::JoinHandle<()>> = vec![];
338
339 for _ in 0..10 {
340 let shared_clone = Arc::clone(&shared);
341 let handle = thread::spawn(move || {
342 let decrypted: &str = &*shared_clone;
343 assert_eq!(decrypted, "hello");
344 });
345 handles.push(handle);
346 }
347
348 for handle in handles {
349 handle.join().unwrap();
350 }
351 }
352
353 #[test]
354 fn test_concurrent_deref_bytearray() {
355 const SHARED: Encrypted<Xor<0xFF, Zeroize>, ByteArray, 4> =
356 Encrypted::<Xor<0xFF, Zeroize>, ByteArray, 4>::new([1, 2, 3, 4]);
357
358 let shared = Arc::new(SHARED);
359 let mut handles: Vec<thread::JoinHandle<()>> = vec![];
360
361 for _ in 0..20 {
362 let shared_clone = Arc::clone(&shared);
363 let handle = thread::spawn(move || {
364 let decrypted: &[u8; 4] = &*shared_clone;
365 assert_eq!(decrypted, &[1, 2, 3, 4]);
366 });
367 handles.push(handle);
368 }
369
370 for handle in handles {
371 handle.join().unwrap();
372 }
373 }
374
375 #[test]
376 fn test_concurrent_deref_reencrypt() {
377 const SHARED: Encrypted<Xor<0xBB, ReEncrypt<0xBB>>, StringLiteral, 6> =
378 Encrypted::<Xor<0xBB, ReEncrypt<0xBB>>, StringLiteral, 6>::new(*b"secret");
379
380 let shared = Arc::new(SHARED);
381 let mut handles: Vec<thread::JoinHandle<()>> = vec![];
382
383 for _ in 0..15 {
384 let shared_clone = Arc::clone(&shared);
385 let handle = thread::spawn(move || {
386 let decrypted: &str = &*shared_clone;
387 assert_eq!(decrypted, "secret");
388 });
389 handles.push(handle);
390 }
391
392 for handle in handles {
393 handle.join().unwrap();
394 }
395 }
396
397 #[test]
398 fn test_concurrent_deref_race_condition() {
399 const SHARED: Encrypted<Xor<0x42, Zeroize>, StringLiteral, 8> =
400 Encrypted::<Xor<0x42, Zeroize>, StringLiteral, 8>::new(*b"racetest");
401
402 let shared = Arc::new(SHARED);
403 let results = Arc::new(AtomicUsize::new(0));
404 let mut handles: Vec<thread::JoinHandle<()>> = vec![];
405
406 for _ in 0..50 {
407 let shared_clone = Arc::clone(&shared);
408 let results_clone = Arc::clone(&results);
409 let handle = thread::spawn(move || {
410 let decrypted: &str = &*shared_clone;
411 if decrypted == "racetest" {
412 results_clone.fetch_add(1, core::sync::atomic::Ordering::Relaxed);
413 }
414 });
415 handles.push(handle);
416 }
417
418 for handle in handles {
419 handle.join().unwrap();
420 }
421
422 let success_count = results.load(core::sync::atomic::Ordering::Relaxed);
423 assert_eq!(success_count, 50, "all threads should see correct plaintext");
424 }
425
426 #[test]
427 fn test_concurrent_multiple_values() {
428 const SECRET1: Encrypted<Xor<0xAA, Zeroize>, StringLiteral, 5> =
429 Encrypted::<Xor<0xAA, Zeroize>, StringLiteral, 5>::new(*b"hello");
430 const SECRET2: Encrypted<Xor<0xFF, Zeroize>, ByteArray, 4> =
431 Encrypted::<Xor<0xFF, Zeroize>, ByteArray, 4>::new([1, 2, 3, 4]);
432
433 let secret1 = Arc::new(SECRET1);
434 let secret2 = Arc::new(SECRET2);
435 let mut handles: Vec<thread::JoinHandle<()>> = vec![];
436
437 for i in 0..20 {
438 if i % 2 == 0 {
439 let secret_clone = Arc::clone(&secret1);
440 let handle = thread::spawn(move || {
441 let decrypted: &str = &*secret_clone;
442 assert_eq!(decrypted, "hello");
443 });
444 handles.push(handle);
445 } else {
446 let secret_clone = Arc::clone(&secret2);
447 let handle = thread::spawn(move || {
448 let decrypted: &[u8; 4] = &*secret_clone;
449 assert_eq!(decrypted, &[1, 2, 3, 4]);
450 });
451 handles.push(handle);
452 }
453 }
454
455 for handle in handles {
456 handle.join().unwrap();
457 }
458 }
459}