1use core::{
56 cell::UnsafeCell,
57 marker::PhantomData,
58 ops::Deref,
59 sync::atomic::{AtomicU8, Ordering},
60};
61
62use crate::{
63 Algorithm, ByteArray, Encrypted, STATE_DECRYPTED, STATE_DECRYPTING, STATE_UNENCRYPTED,
64 StringLiteral,
65 drop_strategy::{DropStrategy, Zeroize},
66};
67
68pub struct ReEncrypt<const KEY_LEN: usize>;
71
72impl<const KEY_LEN: usize> DropStrategy for ReEncrypt<KEY_LEN> {
73 type Extra = [u8; KEY_LEN];
74
75 fn drop(data: &mut [u8], key: &[u8; KEY_LEN]) {
76 let mut s = [0u8; 256];
78 let mut j: u8 = 0;
79
80 let mut i = 0usize;
82 while i < 256 {
83 s[i] = i as u8;
84 i += 1;
85 }
86
87 let mut i = 0usize;
89 while i < 256 {
90 j = j.wrapping_add(s[i]).wrapping_add(key[i % KEY_LEN]);
91 s.swap(i, j as usize);
92 i += 1;
93 }
94
95 let mut i: u8 = 0;
97 j = 0;
98 let mut idx = 0usize;
99 let n = data.len();
100 while idx < n {
101 i = i.wrapping_add(1);
102 j = j.wrapping_add(s[i as usize]);
103 s.swap(i as usize, j as usize);
104 let k = s[(s[i as usize].wrapping_add(s[j as usize])) as usize];
105 data[idx] ^= k;
106 idx += 1;
107 }
108 }
109}
110
111pub struct Rc4<const KEY_LEN: usize, D: DropStrategy = Zeroize>(PhantomData<D>);
118
119impl<const KEY_LEN: usize, D: DropStrategy<Extra = [u8; KEY_LEN]>> Algorithm for Rc4<KEY_LEN, D> {
120 type Drop = D;
121 type Extra = [u8; KEY_LEN];
122}
123
124impl<const KEY_LEN: usize, D: DropStrategy<Extra = [u8; KEY_LEN]>, M, const N: usize>
125 Encrypted<Rc4<KEY_LEN, D>, M, N>
126{
127 pub const fn new(mut buffer: [u8; N], key: [u8; KEY_LEN]) -> Self {
138 let mut s = [0u8; 256];
141 let mut j: u8 = 0;
142
143 let mut i = 0usize;
145 while i < 256 {
146 s[i] = i as u8;
147 i += 1;
148 }
149
150 let mut i = 0usize;
152 while i < 256 {
153 let key_byte = key[i % KEY_LEN];
154 j = j.wrapping_add(s[i]).wrapping_add(key_byte);
155 let temp = s[i];
157 s[i] = s[j as usize];
158 s[j as usize] = temp;
159 i += 1;
160 }
161
162 let mut i: u8 = 0;
164 j = 0;
165 let mut idx = 0usize;
166 while idx < N {
167 i = i.wrapping_add(1);
168 j = j.wrapping_add(s[i as usize]);
169 let temp = s[i as usize];
171 s[i as usize] = s[j as usize];
172 s[j as usize] = temp;
173 let k = s[(s[i as usize].wrapping_add(s[j as usize])) as usize];
175 buffer[idx] ^= k;
176 idx += 1;
177 }
178
179 Encrypted {
180 buffer: UnsafeCell::new(buffer),
181 decryption_state: AtomicU8::new(STATE_UNENCRYPTED),
182 extra: key,
183 _phantom: PhantomData,
184 }
185 }
186}
187
188impl<const KEY_LEN: usize, D: DropStrategy<Extra = [u8; KEY_LEN]>, const N: usize> Deref
189 for Encrypted<Rc4<KEY_LEN, D>, ByteArray, N>
190{
191 type Target = [u8; N];
192
193 fn deref(&self) -> &Self::Target {
194 if self.decryption_state.load(Ordering::Acquire) == STATE_DECRYPTED {
196 return unsafe { &*self.buffer.get() };
198 }
199
200 match self.decryption_state.compare_exchange(
202 STATE_UNENCRYPTED,
203 STATE_DECRYPTING,
204 Ordering::AcqRel,
205 Ordering::Acquire,
206 ) {
207 Ok(_) => {
208 let data = unsafe { &mut *self.buffer.get() };
211 let key = &self.extra;
213 let mut s = [0u8; 256];
214 let mut j: u8 = 0;
215
216 let mut i = 0usize;
218 while i < 256 {
219 s[i] = i as u8;
220 i += 1;
221 }
222
223 let mut i = 0usize;
225 while i < 256 {
226 j = j.wrapping_add(s[i]).wrapping_add(key[i % KEY_LEN]);
227 s.swap(i, j as usize);
228 i += 1;
229 }
230
231 let mut i: u8 = 0;
233 j = 0;
234 let mut idx = 0usize;
235 while idx < N {
236 i = i.wrapping_add(1);
237 j = j.wrapping_add(s[i as usize]);
238 s.swap(i as usize, j as usize);
239 let k = s[(s[i as usize].wrapping_add(s[j as usize])) as usize];
240 data[idx] ^= k;
241 idx += 1;
242 }
243
244 self.decryption_state.store(STATE_DECRYPTED, Ordering::Release);
247 }
248 Err(_) => {
249 while self.decryption_state.load(Ordering::Acquire) != STATE_DECRYPTED {
252 core::hint::spin_loop();
253 }
254 }
255 }
256
257 unsafe { &*self.buffer.get() }
261 }
262}
263
264impl<const KEY_LEN: usize, D: DropStrategy<Extra = [u8; KEY_LEN]>, const N: usize> Deref
265 for Encrypted<Rc4<KEY_LEN, D>, StringLiteral, N>
266{
267 type Target = str;
268
269 fn deref(&self) -> &Self::Target {
270 if self.decryption_state.load(Ordering::Acquire) == STATE_DECRYPTED {
272 let bytes = unsafe { &*self.buffer.get() };
274 return unsafe { core::str::from_utf8_unchecked(bytes) };
278 }
279
280 match self.decryption_state.compare_exchange(
282 STATE_UNENCRYPTED,
283 STATE_DECRYPTING,
284 Ordering::AcqRel,
285 Ordering::Acquire,
286 ) {
287 Ok(_) => {
288 let data = unsafe { &mut *self.buffer.get() };
291 let key = &self.extra;
293 let mut s = [0u8; 256];
294 let mut j: u8 = 0;
295
296 let mut i = 0usize;
298 while i < 256 {
299 s[i] = i as u8;
300 i += 1;
301 }
302
303 let mut i = 0usize;
305 while i < 256 {
306 j = j.wrapping_add(s[i]).wrapping_add(key[i % KEY_LEN]);
307 s.swap(i, j as usize);
308 i += 1;
309 }
310
311 let mut i: u8 = 0;
313 j = 0;
314 let mut idx = 0usize;
315 while idx < N {
316 i = i.wrapping_add(1);
317 j = j.wrapping_add(s[i as usize]);
318 s.swap(i as usize, j as usize);
319 let k = s[(s[i as usize].wrapping_add(s[j as usize])) as usize];
320 data[idx] ^= k;
321 idx += 1;
322 }
323
324 self.decryption_state.store(STATE_DECRYPTED, Ordering::Release);
327 }
328 Err(_) => {
329 while self.decryption_state.load(Ordering::Acquire) != STATE_DECRYPTED {
332 core::hint::spin_loop();
333 }
334 }
335 }
336
337 let bytes = unsafe { &*self.buffer.get() };
341
342 unsafe { core::str::from_utf8_unchecked(bytes) }
346 }
347}
348
349#[cfg(test)]
350mod tests {
351 use super::*;
352 use crate::{
353 ByteArray, StringLiteral,
354 drop_strategy::{NoOp, Zeroize},
355 rc4::Rc4,
356 };
357
358 use alloc::vec;
359 use alloc::vec::Vec;
360 use core::sync::atomic::AtomicUsize;
361 use std::sync::Arc;
362 use std::thread;
363
364 const RC4_KEY: [u8; 5] = *b"mykey";
366 const RC4_KEY2: [u8; 16] = *b"sixteen-byte-key";
367
368 const CONST_ENCRYPTED: Encrypted<Rc4<5, Zeroize<[u8; 5]>>, ByteArray, 5> =
369 Encrypted::<Rc4<5, Zeroize<[u8; 5]>>, ByteArray, 5>::new(*b"hello", RC4_KEY);
370
371 const CONST_ENCRYPTED_STR: Encrypted<Rc4<5, Zeroize<[u8; 5]>>, StringLiteral, 5> =
372 Encrypted::<Rc4<5, Zeroize<[u8; 5]>>, StringLiteral, 5>::new(*b"hello", RC4_KEY);
373
374 const CONST_ENCRYPTED_16: Encrypted<Rc4<16, Zeroize<[u8; 16]>>, ByteArray, 8> =
375 Encrypted::<Rc4<16, Zeroize<[u8; 16]>>, ByteArray, 8>::new(*b"longdata", RC4_KEY2);
376
377 #[test]
378 fn test_rc4_buffer_is_encrypted_before_deref() {
379 let encrypted = CONST_ENCRYPTED;
380
381 let raw = unsafe { &*encrypted.buffer.get() };
383 assert_ne!(raw, b"hello", "buffer must NOT be plaintext before deref");
385 assert_eq!(encrypted.extra, RC4_KEY, "key should be stored in extra");
387 }
388
389 #[test]
390 fn test_rc4_bytearray_deref_decrypts() {
391 let encrypted = CONST_ENCRYPTED;
392
393 let plain: &[u8; 5] = &*encrypted;
395 assert_eq!(plain, b"hello");
396 }
397
398 #[test]
399 fn test_rc4_string_deref_decrypts() {
400 let encrypted = CONST_ENCRYPTED_STR;
401
402 let plain: &str = &*encrypted;
404 assert_eq!(plain, "hello");
405 }
406
407 #[test]
408 fn test_rc4_multiple_derefs_are_idempotent() {
409 let encrypted = CONST_ENCRYPTED;
410
411 let first: &[u8; 5] = &*encrypted;
412 let second: &[u8; 5] = &*encrypted;
413 assert_eq!(first, b"hello");
414 assert_eq!(second, b"hello");
415 }
416
417 #[test]
418 fn test_rc4_different_key_length() {
419 let encrypted = CONST_ENCRYPTED_16;
420
421 let plain: &[u8; 8] = &*encrypted;
422 assert_eq!(plain, b"longdata");
423 }
424
425 #[test]
426 fn test_rc4_encrypted_is_sync() {
427 const fn assert_sync<T: Sync>() {}
428 const fn check() {
429 assert_sync::<Encrypted<Rc4<5, Zeroize<[u8; 5]>>, ByteArray, 8>>();
430 assert_sync::<Encrypted<Rc4<16, Zeroize<[u8; 16]>>, StringLiteral, 10>>();
431 assert_sync::<Encrypted<Rc4<32, NoOp<[u8; 32]>>, ByteArray, 16>>();
432 }
433 check();
434 }
435
436 #[test]
437 fn test_rc4_concurrent_deref_same_value() {
438 const SHARED: Encrypted<Rc4<5, Zeroize<[u8; 5]>>, StringLiteral, 5> =
439 Encrypted::<Rc4<5, Zeroize<[u8; 5]>>, StringLiteral, 5>::new(*b"hello", RC4_KEY);
440
441 let shared = Arc::new(SHARED);
442 let mut handles: Vec<thread::JoinHandle<()>> = vec![];
443
444 for _ in 0..10 {
445 let shared_clone = Arc::clone(&shared);
446 let handle = thread::spawn(move || {
447 let decrypted: &str = &*shared_clone;
448 assert_eq!(decrypted, "hello");
449 });
450 handles.push(handle);
451 }
452
453 for handle in handles {
454 handle.join().unwrap();
455 }
456 }
457
458 #[test]
459 fn test_rc4_concurrent_deref_bytearray() {
460 const SHARED: Encrypted<Rc4<16, Zeroize<[u8; 16]>>, ByteArray, 4> =
461 Encrypted::<Rc4<16, Zeroize<[u8; 16]>>, ByteArray, 4>::new([1, 2, 3, 4], RC4_KEY2);
462
463 let shared = Arc::new(SHARED);
464 let mut handles: Vec<thread::JoinHandle<()>> = vec![];
465
466 for _ in 0..20 {
467 let shared_clone = Arc::clone(&shared);
468 let handle = thread::spawn(move || {
469 let decrypted: &[u8; 4] = &*shared_clone;
470 assert_eq!(decrypted, &[1, 2, 3, 4]);
471 });
472 handles.push(handle);
473 }
474
475 for handle in handles {
476 handle.join().unwrap();
477 }
478 }
479
480 #[test]
481 fn test_rc4_concurrent_deref_race_condition() {
482 const SHARED: Encrypted<Rc4<5, Zeroize<[u8; 5]>>, StringLiteral, 8> =
483 Encrypted::<Rc4<5, Zeroize<[u8; 5]>>, StringLiteral, 8>::new(*b"racetest", RC4_KEY);
484
485 let shared = Arc::new(SHARED);
486 let results = Arc::new(AtomicUsize::new(0));
487 let mut handles: Vec<thread::JoinHandle<()>> = vec![];
488
489 for _ in 0..50 {
490 let shared_clone = Arc::clone(&shared);
491 let results_clone = Arc::clone(&results);
492 let handle = thread::spawn(move || {
493 let decrypted: &str = &*shared_clone;
494 if decrypted == "racetest" {
495 results_clone.fetch_add(1, core::sync::atomic::Ordering::Relaxed);
496 }
497 });
498 handles.push(handle);
499 }
500
501 for handle in handles {
502 handle.join().unwrap();
503 }
504
505 let success_count = results.load(core::sync::atomic::Ordering::Relaxed);
506 assert_eq!(success_count, 50, "all threads should see correct plaintext");
507 }
508
509 #[test]
510 fn test_rc4_single_byte() {
511 const ENCRYPTED: Encrypted<Rc4<5, Zeroize<[u8; 5]>>, ByteArray, 1> =
512 Encrypted::<Rc4<5, Zeroize<[u8; 5]>>, ByteArray, 1>::new([42], RC4_KEY);
513
514 let plain: &[u8; 1] = &*ENCRYPTED;
515 assert_eq!(plain, &[42]);
516 }
517
518 #[test]
519 fn test_rc4_all_zeros() {
520 const ENCRYPTED: Encrypted<Rc4<5, Zeroize<[u8; 5]>>, ByteArray, 4> =
521 Encrypted::<Rc4<5, Zeroize<[u8; 5]>>, ByteArray, 4>::new([0, 0, 0, 0], RC4_KEY);
522
523 let plain: &[u8; 4] = &*ENCRYPTED;
524 assert_eq!(plain, &[0, 0, 0, 0]);
525 }
526
527 #[test]
528 fn test_rc4_reencrypt_drop() {
529 use crate::rc4::ReEncrypt;
530
531 const SHARED: Encrypted<Rc4<5, ReEncrypt<5>>, StringLiteral, 5> =
532 Encrypted::<Rc4<5, ReEncrypt<5>>, StringLiteral, 5>::new(*b"hello", RC4_KEY);
533
534 let shared = Arc::new(SHARED);
535 let mut handles: Vec<thread::JoinHandle<()>> = vec![];
536
537 for _ in 0..10 {
538 let shared_clone = Arc::clone(&shared);
539 let handle = thread::spawn(move || {
540 let decrypted: &str = &*shared_clone;
541 assert_eq!(decrypted, "hello");
542 });
543 handles.push(handle);
544 }
545
546 for handle in handles {
547 handle.join().unwrap();
548 }
549
550 }
554}