cipher\stream/
wrapper.rs

1use crate::StreamCipherCounter;
2
3use super::{
4    OverflowError, SeekNum, StreamCipher, StreamCipherCore, StreamCipherSeek, StreamCipherSeekCore,
5    errors::StreamCipherError,
6};
7use block_buffer::{BlockSizes, ReadBuffer};
8use common::{
9    Iv, IvSizeUser, Key, KeyInit, KeyIvInit, KeySizeUser, array::Array, typenum::Unsigned,
10};
11use core::fmt;
12use inout::InOutBuf;
13#[cfg(feature = "zeroize")]
14use zeroize::ZeroizeOnDrop;
15
16/// Buffering wrapper around a [`StreamCipherCore`] implementation.
17///
18/// It handles data buffering and implements the slice-based traits.
19pub struct StreamCipherCoreWrapper<T>
20where
21    T: StreamCipherCore,
22    T::BlockSize: BlockSizes,
23{
24    core: T,
25    buffer: ReadBuffer<T::BlockSize>,
26}
27
28impl<T> Clone for StreamCipherCoreWrapper<T>
29where
30    T: StreamCipherCore + Clone,
31    T::BlockSize: BlockSizes,
32{
33    #[inline]
34    fn clone(&self) -> Self {
35        Self {
36            core: self.core.clone(),
37            buffer: self.buffer.clone(),
38        }
39    }
40}
41
42impl<T> fmt::Debug for StreamCipherCoreWrapper<T>
43where
44    T: StreamCipherCore + fmt::Debug,
45    T::BlockSize: BlockSizes,
46{
47    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
48        f.debug_struct("StreamCipherCoreWrapper")
49            .finish_non_exhaustive()
50    }
51}
52
53impl<T> StreamCipherCoreWrapper<T>
54where
55    T: StreamCipherCore,
56    T::BlockSize: BlockSizes,
57{
58    /// Initialize from a [`StreamCipherCore`] instance.
59    pub fn from_core(core: T) -> Self {
60        Self {
61            core,
62            buffer: Default::default(),
63        }
64    }
65
66    /// Get reference to the wrapped [`StreamCipherCore`] instance.
67    pub fn get_core(&self) -> &T {
68        &self.core
69    }
70}
71
72impl<T> StreamCipher for StreamCipherCoreWrapper<T>
73where
74    T: StreamCipherCore,
75    T::BlockSize: BlockSizes,
76{
77    #[inline]
78    fn check_remaining(&self, data_len: usize) -> Result<(), StreamCipherError> {
79        let Some(rem_blocks) = self.core.remaining_blocks() else {
80            return Ok(());
81        };
82        let Some(data_len) = data_len.checked_sub(self.buffer.remaining()) else {
83            return Ok(());
84        };
85        let req_blocks = data_len.div_ceil(T::BlockSize::USIZE);
86        if req_blocks > rem_blocks {
87            Err(StreamCipherError)
88        } else {
89            Ok(())
90        }
91    }
92
93    #[inline]
94    fn unchecked_apply_keystream_inout(&mut self, data: InOutBuf<'_, '_, u8>) {
95        let head_ks = self.buffer.read_cached(data.len());
96
97        let (mut head, data) = data.split_at(head_ks.len());
98        let (blocks, mut tail) = data.into_chunks();
99
100        head.xor_in2out(head_ks);
101        self.core.apply_keystream_blocks_inout(blocks);
102
103        self.buffer.write_block(
104            tail.len(),
105            |b| self.core.write_keystream_block(b),
106            |tail_ks| {
107                tail.xor_in2out(tail_ks);
108            },
109        );
110    }
111
112    #[inline]
113    fn unchecked_write_keystream(&mut self, data: &mut [u8]) {
114        let head_ks = self.buffer.read_cached(data.len());
115
116        let (head, data) = data.split_at_mut(head_ks.len());
117        let (blocks, tail) = Array::slice_as_chunks_mut(data);
118
119        head.copy_from_slice(head_ks);
120        self.core.write_keystream_blocks(blocks);
121
122        self.buffer.write_block(
123            tail.len(),
124            |b| self.core.write_keystream_block(b),
125            |tail_ks| tail.copy_from_slice(tail_ks),
126        );
127    }
128}
129
130impl<T> StreamCipherSeek for StreamCipherCoreWrapper<T>
131where
132    T: StreamCipherSeekCore,
133    T::BlockSize: BlockSizes,
134{
135    #[allow(clippy::unwrap_in_result)]
136    fn try_current_pos<SN: SeekNum>(&self) -> Result<SN, OverflowError> {
137        let pos = u8::try_from(self.buffer.get_pos())
138            .expect("buffer position is always smaller than 256");
139        SN::from_block_byte(self.core.get_block_pos(), pos, T::BlockSize::U8)
140    }
141
142    fn try_seek<SN: SeekNum>(&mut self, new_pos: SN) -> Result<(), StreamCipherError> {
143        let (block_pos, byte_pos) = new_pos.into_block_byte::<T::Counter>(T::BlockSize::U8)?;
144        if byte_pos != 0 && block_pos.is_max() {
145            return Err(StreamCipherError);
146        }
147        // For correct implementations of `SeekNum` the compiler should be able to
148        // eliminate this assert
149        assert!(byte_pos < T::BlockSize::U8);
150
151        self.core.set_block_pos(block_pos);
152
153        self.buffer.reset();
154
155        self.buffer.write_block(
156            usize::from(byte_pos),
157            |b| self.core.write_keystream_block(b),
158            |_| {},
159        );
160        Ok(())
161    }
162}
163
164// Note: ideally we would only implement the InitInner trait and everything
165// else would be handled by blanket impls, but, unfortunately, it will
166// not work properly without mutually exclusive traits, see:
167// https://github.com/rust-lang/rfcs/issues/1053
168
169impl<T> KeySizeUser for StreamCipherCoreWrapper<T>
170where
171    T: KeySizeUser + StreamCipherCore,
172    T::BlockSize: BlockSizes,
173{
174    type KeySize = T::KeySize;
175}
176
177impl<T> IvSizeUser for StreamCipherCoreWrapper<T>
178where
179    T: IvSizeUser + StreamCipherCore,
180    T::BlockSize: BlockSizes,
181{
182    type IvSize = T::IvSize;
183}
184
185impl<T> KeyIvInit for StreamCipherCoreWrapper<T>
186where
187    T: KeyIvInit + StreamCipherCore,
188    T::BlockSize: BlockSizes,
189{
190    #[inline]
191    fn new(key: &Key<Self>, iv: &Iv<Self>) -> Self {
192        Self {
193            core: T::new(key, iv),
194            buffer: Default::default(),
195        }
196    }
197}
198
199impl<T> KeyInit for StreamCipherCoreWrapper<T>
200where
201    T: KeyInit + StreamCipherCore,
202    T::BlockSize: BlockSizes,
203{
204    #[inline]
205    fn new(key: &Key<Self>) -> Self {
206        Self {
207            core: T::new(key),
208            buffer: Default::default(),
209        }
210    }
211}
212
213#[cfg(feature = "zeroize")]
214impl<T> ZeroizeOnDrop for StreamCipherCoreWrapper<T>
215where
216    T: StreamCipherCore + ZeroizeOnDrop,
217    T::BlockSize: BlockSizes,
218{
219}
220
221// Assert that `ReadBuffer` implements `ZeroizeOnDrop`
222#[cfg(feature = "zeroize")]
223const _: () = {
224    #[allow(dead_code, trivial_casts)]
225    fn check_buffer<BS: BlockSizes>(v: &ReadBuffer<BS>) {
226        let _ = v as &dyn ZeroizeOnDrop;
227    }
228};