inout/
reserved.rs

1use crate::{InOutBuf, errors::OutIsTooSmallError};
2use core::{marker::PhantomData, slice};
3
4#[cfg(feature = "block-padding")]
5use {
6    crate::{InOut, errors::PadError},
7    block_padding::Padding,
8    hybrid_array::{Array, ArraySize},
9};
10
11/// Custom slice type which references one immutable (input) slice and one
12/// mutable (output) slice. Input and output slices are either the same or
13/// do not overlap. Length of the output slice is always equal or bigger than
14/// length of the input slice.
15pub struct InOutBufReserved<'inp, 'out, T> {
16    in_ptr: *const T,
17    out_ptr: *mut T,
18    in_len: usize,
19    out_len: usize,
20    _pd: PhantomData<(&'inp T, &'out mut T)>,
21}
22
23impl<'a, T> InOutBufReserved<'a, 'a, T> {
24    /// Crate [`InOutBufReserved`] from a single mutable slice.
25    pub fn from_mut_slice(buf: &'a mut [T], msg_len: usize) -> Result<Self, OutIsTooSmallError> {
26        if msg_len > buf.len() {
27            return Err(OutIsTooSmallError);
28        }
29        let p = buf.as_mut_ptr();
30        let out_len = buf.len();
31        Ok(Self {
32            in_ptr: p,
33            out_ptr: p,
34            in_len: msg_len,
35            out_len,
36            _pd: PhantomData,
37        })
38    }
39}
40
41impl<T> InOutBufReserved<'_, '_, T> {
42    /// Create [`InOutBufReserved`] from raw input and output pointers.
43    ///
44    /// # Safety
45    /// Behavior is undefined if any of the following conditions are violated:
46    /// - `in_ptr` must point to a properly initialized value of type `T` and
47    ///   must be valid for reads for `in_len * mem::size_of::<T>()` many bytes.
48    /// - `out_ptr` must point to a properly initialized value of type `T` and
49    ///   must be valid for both reads and writes for `out_len * mem::size_of::<T>()`
50    ///   many bytes.
51    /// - `in_ptr` and `out_ptr` must be either equal or non-overlapping.
52    /// - If `in_ptr` and `out_ptr` are equal, then the memory referenced by
53    ///   them must not be accessed through any other pointer (not derived from
54    ///   the return value) for the duration of lifetime 'a. Both read and write
55    ///   accesses are forbidden.
56    /// - If `in_ptr` and `out_ptr` are not equal, then the memory referenced by
57    ///   `out_ptr` must not be accessed through any other pointer (not derived from
58    ///   the return value) for the duration of lifetime 'a. Both read and write
59    ///   accesses are forbidden. The memory referenced by `in_ptr` must not be
60    ///   mutated for the duration of lifetime `'a`, except inside an `UnsafeCell`.
61    /// - The total size `in_len * mem::size_of::<T>()` and
62    ///   `out_len * mem::size_of::<T>()`  must be no larger than `isize::MAX`.
63    #[inline(always)]
64    pub unsafe fn from_raw(
65        in_ptr: *const T,
66        in_len: usize,
67        out_ptr: *mut T,
68        out_len: usize,
69    ) -> Self {
70        Self {
71            in_ptr,
72            out_ptr,
73            in_len,
74            out_len,
75            _pd: PhantomData,
76        }
77    }
78
79    /// Get raw input and output pointers.
80    #[inline(always)]
81    pub fn into_raw(self) -> (*const T, *mut T) {
82        (self.in_ptr, self.out_ptr)
83    }
84
85    /// Get input buffer length.
86    #[inline(always)]
87    pub fn get_in_len(&self) -> usize {
88        self.in_len
89    }
90
91    /// Get output buffer length.
92    #[inline(always)]
93    pub fn get_out_len(&self) -> usize {
94        self.out_len
95    }
96
97    /// Split buffer into `InOutBuf` with input length and mutable slice pointing to
98    /// the remaining reserved suffix.
99    pub fn split_reserved(&mut self) -> (InOutBuf<'_, '_, T>, &mut [T]) {
100        let in_len = self.get_in_len();
101        let out_len = self.get_out_len();
102        let in_ptr = self.get_in().as_ptr();
103        let out_ptr = self.get_out().as_mut_ptr();
104        // This never underflows because the type ensures that `out_len` is
105        // bigger or equal to `in_len`.
106        let tail_len = out_len - in_len;
107        unsafe {
108            let body = InOutBuf::from_raw(in_ptr, out_ptr, in_len);
109            let tail = slice::from_raw_parts_mut(out_ptr.add(in_len), tail_len);
110            (body, tail)
111        }
112    }
113}
114
115impl<'inp, 'out, T> InOutBufReserved<'inp, 'out, T> {
116    /// Crate [`InOutBufReserved`] from two separate slices.
117    pub fn from_slices(
118        in_buf: &'inp [T],
119        out_buf: &'out mut [T],
120    ) -> Result<Self, OutIsTooSmallError> {
121        if in_buf.len() > out_buf.len() {
122            return Err(OutIsTooSmallError);
123        }
124        Ok(Self {
125            in_ptr: in_buf.as_ptr(),
126            out_ptr: out_buf.as_mut_ptr(),
127            in_len: in_buf.len(),
128            out_len: out_buf.len(),
129            _pd: PhantomData,
130        })
131    }
132
133    /// Get input slice.
134    #[inline(always)]
135    pub fn get_in(&self) -> &[T] {
136        unsafe { slice::from_raw_parts(self.in_ptr, self.in_len) }
137    }
138
139    /// Get output slice.
140    #[inline(always)]
141    pub fn get_out(&mut self) -> &mut [T] {
142        unsafe { slice::from_raw_parts_mut(self.out_ptr, self.out_len) }
143    }
144
145    /// Consume `self` and get output slice with lifetime `'out`.
146    #[inline(always)]
147    pub fn into_out(self) -> &'out mut [T] {
148        unsafe { slice::from_raw_parts_mut(self.out_ptr, self.out_len) }
149    }
150}
151
152#[cfg(feature = "block-padding")]
153impl<'inp, 'out> InOutBufReserved<'inp, 'out, u8> {
154    /// Transform buffer into [`PaddedInOutBuf`] using padding algorithm `P`.
155    #[inline(always)]
156    pub fn into_padded_blocks<P, BS>(self) -> Result<PaddedInOutBuf<'inp, 'out, BS>, PadError>
157    where
158        P: Padding,
159        BS: ArraySize,
160    {
161        let bs = BS::USIZE;
162        let blocks_len = self.in_len / bs;
163
164        use block_padding::PaddedData;
165        let (blocks, tail_block) = match P::pad_detached(self.get_in()) {
166            PaddedData::Pad { blocks, tail_block } => (blocks, Some(tail_block)),
167            PaddedData::NoPad { blocks } => (blocks, None),
168            PaddedData::Error => return Err(PadError),
169        };
170
171        assert_eq!(blocks.len(), blocks_len);
172
173        let out_len = self.out_len;
174        let (in_ptr, out_ptr) = self.into_raw();
175
176        let blocks = unsafe {
177            InOutBuf::from_raw(
178                in_ptr.cast::<Array<u8, BS>>(),
179                out_ptr.cast::<Array<u8, BS>>(),
180                blocks_len,
181            )
182        };
183
184        let Some(tail_block) = tail_block else {
185            let tail_inout = None;
186            return Ok(PaddedInOutBuf { blocks, tail_inout });
187        };
188
189        let blocks_byte_len = blocks_len * bs;
190        let reserve_len = out_len - blocks_byte_len;
191        if reserve_len < tail_block.len() {
192            return Err(PadError);
193        }
194        // SAFETY: we checked that the out buffer has enough bytes in reserve
195        let tail_out: &mut Array<u8, BS> = unsafe {
196            let tail_out_ptr = out_ptr.add(blocks_byte_len);
197            &mut *(tail_out_ptr.cast())
198        };
199
200        let tail_inout = Some((tail_block, tail_out));
201
202        Ok(PaddedInOutBuf { blocks, tail_inout })
203    }
204}
205
206/// Variant of [`InOutBuf`] with optional padded tail block.
207#[cfg(feature = "block-padding")]
208#[allow(clippy::type_complexity)]
209pub struct PaddedInOutBuf<'inp, 'out, BS: ArraySize> {
210    blocks: InOutBuf<'inp, 'out, Array<u8, BS>>,
211    tail_inout: Option<(Array<u8, BS>, &'out mut Array<u8, BS>)>,
212}
213
214#[cfg(feature = "block-padding")]
215impl<'out, BS: ArraySize> PaddedInOutBuf<'_, 'out, BS> {
216    /// Get full blocks.
217    #[inline(always)]
218    pub fn get_blocks(&mut self) -> InOutBuf<'_, '_, Array<u8, BS>> {
219        self.blocks.reborrow()
220    }
221
222    /// Get padded tail block.
223    ///
224    /// Most padding implementations always return `Some`.
225    #[inline(always)]
226    pub fn get_tail_block(&mut self) -> Option<InOut<'_, '_, Array<u8, BS>>> {
227        self.tail_inout.as_mut().map(|(in_block, out_block)| {
228            let in_block = &*in_block;
229            let out_block = &mut **out_block;
230            InOut::from((in_block, out_block))
231        })
232    }
233
234    /// Convert buffer into output slice.
235    #[inline(always)]
236    pub fn into_out(self) -> &'out [u8] {
237        let total_blocks = if self.tail_inout.is_some() {
238            self.blocks.len() + 1
239        } else {
240            self.blocks.len()
241        };
242        let res_len = BS::USIZE * total_blocks;
243        let (_, out_ptr) = self.blocks.into_raw();
244        // SAFETY: `res_len` is always valid for the output buffer since
245        // it's checked during type construction
246        unsafe { slice::from_raw_parts(out_ptr as *const u8, res_len) }
247    }
248}