opentitanlib/util/
bigint.rs

1// Copyright lowRISC contributors (OpenTitan project).
2// Licensed under the Apache License, Version 2.0, see LICENSE for details.
3// SPDX-License-Identifier: Apache-2.0
4
5use num_bigint_dig::BigUint;
6use num_traits::Num;
7use std::cmp::Ordering;
8use std::fmt;
9use std::iter;
10use thiserror::Error;
11
12use crate::util::parse_int::ParseInt;
13
14#[derive(Error, Debug, Clone, PartialEq, Eq)]
15pub enum ParseBigIntError {
16    #[error("integer is too large")]
17    Overflow,
18    #[error("integer is too small")]
19    Underflow,
20    #[error(transparent)]
21    ParseBigIntError(#[from] num_bigint_dig::ParseBigIntError),
22}
23
24/// A fixed-size unsigned big integer.
25///
26/// This struct wraps a `BigUint` to facilitate defining new fixed-size unsigned integer types for
27/// better type safety.
28///
29/// An integer stored in this type is fixed-size in the sense that the minimum number of bits
30/// required to represent it, i.e. its bit length, is at most `BIT_LEN`. This size can be specified
31/// using the const parameters `BIT_LEN` and `EXACT_LEN` as follows:
32///   - When `EXACT_LEN` is `false`, the bit length of the integer can be at most `BIT_LEN` bits,
33///     e.g. SHA-256 digests (at most 256 bits) or RSA-3072 signatures (at most 3072 bits),
34///   - When `EXACT_LEN` is `true`, the number of bits required to represent the integer must be
35///     exactly `BIT_LEN` bits, e.g. RSA-3072 moduli (exactly 3072 bits).
36///
37/// Note that while the type encapsulates the size information, the actual check is performed at
38/// runtime when an instance is created (see `check_len()`).
39///
40/// This struct is not meant to be used directly, please see the `fixed_size_bigint` macro which
41/// also generates the required boilerplate code for new types.
42#[derive(Debug, Clone, Eq, PartialEq)]
43pub(crate) struct FixedSizeBigInt<const BIT_LEN: usize, const EXACT_LEN: bool>(BigUint);
44
45impl<const BIT_LEN: usize, const EXACT_LEN: bool> FixedSizeBigInt<BIT_LEN, EXACT_LEN> {
46    const BYTE_LEN: usize = BIT_LEN.saturating_add(u8::BITS as usize - 1) / u8::BITS as usize;
47
48    /// Checks the bit length of the `FixedSizeBigInt`.
49    ///
50    /// Bit length of a `FixedSizeBigInt` can be at most `BIT_LEN` if `EXACT_LEN` is `false`, must
51    /// be exactly `BIT_LEN` otherwise.
52    fn new_from_biguint(biguint: BigUint) -> Result<Self, ParseBigIntError> {
53        match (biguint.bits().cmp(&BIT_LEN), EXACT_LEN) {
54            (Ordering::Greater, _) => Err(ParseBigIntError::Overflow),
55            (Ordering::Equal, _) => Ok(Self(biguint)),
56            (Ordering::Less, true) => Err(ParseBigIntError::Underflow),
57            (Ordering::Less, false) => Ok(Self(biguint)),
58        }
59    }
60
61    /// Creates a `FixedSizeBigInt` from little-endian bytes.
62    pub(crate) fn from_le_bytes(bytes: impl AsRef<[u8]>) -> Result<Self, ParseBigIntError> {
63        Self::new_from_biguint(BigUint::from_bytes_le(bytes.as_ref()))
64    }
65
66    /// Creates a `FixedSizeBigInt` from big-endian bytes.
67    pub(crate) fn from_be_bytes(bytes: impl AsRef<[u8]>) -> Result<Self, ParseBigIntError> {
68        Self::new_from_biguint(BigUint::from_bytes_be(bytes.as_ref()))
69    }
70
71    /// Returns the bit length.
72    ///
73    /// Bit length of `FixedSizeBigInt` is the minimum number of bits required to represent its
74    /// value. The underlying storage may be larger.
75    pub(crate) fn bit_len(&self) -> usize {
76        self.0.bits()
77    }
78
79    /// Returns the byte representation in little-endian order.
80    pub(crate) fn to_le_bytes(&self) -> Vec<u8> {
81        let mut v = self.0.to_bytes_le();
82        assert!(Self::BYTE_LEN >= v.len());
83        // Append since `v` is little-endian.
84        v.resize(Self::BYTE_LEN, 0);
85        v
86    }
87
88    /// Returns the byte representation in big-endian order.
89    pub(crate) fn to_be_bytes(&self) -> Vec<u8> {
90        let mut v = self.0.to_bytes_be();
91        assert!(Self::BYTE_LEN >= v.len());
92        // Prepend since `v` is big-endian.
93        v.splice(0..0, iter::repeat_n(0, Self::BYTE_LEN - v.len()));
94        v
95    }
96
97    /// Returns the underlying `BigUint`.
98    pub(crate) fn as_biguint(&self) -> &BigUint {
99        &self.0
100    }
101}
102
103impl<const BIT_LEN: usize, const EXACT_LEN: bool> ParseInt for FixedSizeBigInt<BIT_LEN, EXACT_LEN> {
104    type FromStrRadixErr = ParseBigIntError;
105
106    fn from_str_radix(src: &str, radix: u32) -> Result<Self, Self::FromStrRadixErr> {
107        Self::new_from_biguint(
108            BigUint::from_str_radix(src, radix).map_err(ParseBigIntError::ParseBigIntError)?,
109        )
110    }
111}
112
113impl<const BIT_LEN: usize, const EXACT_LEN: bool> fmt::Display
114    for FixedSizeBigInt<BIT_LEN, EXACT_LEN>
115{
116    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
117        fmt::Display::fmt(
118            &format_args!("{:#0width$x}", self.0, width = Self::BYTE_LEN * 2 + 2),
119            f,
120        )
121    }
122}
123
124/// Helper macro for the `fixed_size_bigint` macro.
125macro_rules! fixed_size_bigint_impl {
126    ($struct_name:ident, $bit_len:expr, $exact_len:expr) => {
127        #[derive(serde::Serialize, Debug, Clone, Eq, PartialEq)]
128        #[serde(into = "String")]
129        pub struct $struct_name($crate::util::bigint::FixedSizeBigInt<$bit_len, $exact_len>);
130
131        const _: () = {
132            use num_bigint_dig::BigUint;
133            use std::fmt;
134            use std::result::Result;
135
136            use $crate::util::bigint::{FixedSizeBigInt, ParseBigIntError};
137            use $crate::util::parse_int::ParseInt;
138
139            impl $struct_name {
140                pub fn from_le_bytes(bytes: impl AsRef<[u8]>) -> Result<Self, ParseBigIntError> {
141                    Ok($struct_name(
142                        FixedSizeBigInt::<$bit_len, $exact_len>::from_le_bytes(bytes)?,
143                    ))
144                }
145
146                pub fn from_be_bytes(bytes: impl AsRef<[u8]>) -> Result<Self, ParseBigIntError> {
147                    Ok($struct_name(
148                        FixedSizeBigInt::<$bit_len, $exact_len>::from_be_bytes(bytes)?,
149                    ))
150                }
151
152                pub fn bit_len(&self) -> usize {
153                    self.0.bit_len()
154                }
155
156                pub fn to_le_bytes(&self) -> Vec<u8> {
157                    self.0.to_le_bytes()
158                }
159
160                pub fn to_be_bytes(&self) -> Vec<u8> {
161                    self.0.to_be_bytes()
162                }
163
164                pub fn as_biguint(&self) -> &BigUint {
165                    self.0.as_biguint()
166                }
167            }
168
169            impl ParseInt for $struct_name {
170                type FromStrRadixErr = ParseBigIntError;
171
172                fn from_str_radix(src: &str, radix: u32) -> Result<Self, Self::FromStrRadixErr> {
173                    Ok($struct_name(
174                        FixedSizeBigInt::<$bit_len, $exact_len>::from_str_radix(src, radix)?,
175                    ))
176                }
177            }
178
179            impl fmt::Display for $struct_name {
180                fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
181                    fmt::Display::fmt(&self.0, f)
182                }
183            }
184
185            impl From<$struct_name> for String {
186                fn from(s: $struct_name) -> String {
187                    s.0.to_string()
188                }
189            }
190        };
191    };
192}
193
194pub(crate) use fixed_size_bigint_impl;
195
196/// Macro for defining a new fixed-size unsigned big integer type.
197///
198/// Defines a new type that wraps a `FixedSizeBigInt`. This macro is intended to be used within this
199/// crate to define types which can then be exported as needed:
200///
201/// ```
202/// use crate::util::bigint::fixed_size_bigint;
203///
204/// // Define a type for RSA-3072 moduli (exactly 3072 bits long):
205/// fixed_size_bigint!(Rsa3072Modulus, 3072);
206///
207/// // Define a type for SHA-256 digests (at most 256 bits long):
208/// fixed_size_bigint!(Sha256Digest, at_most 256);
209/// ```
210macro_rules! fixed_size_bigint {
211    ($struct_name:ident, $bit_len:expr) => {
212        $crate::util::bigint::fixed_size_bigint_impl!($struct_name, $bit_len, true);
213    };
214    ($struct_name:ident, at_most $bit_len:expr) => {
215        $crate::util::bigint::fixed_size_bigint_impl!($struct_name, $bit_len, false);
216    };
217}
218
219pub(crate) use fixed_size_bigint;
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224
225    fixed_size_bigint!(TestArray, at_most 16);
226    fixed_size_bigint!(TestArrayExact, 16);
227
228    #[test]
229    fn test_from_to_le_bytes() {
230        fn check(slice: &[u8], data: &[u8]) {
231            assert_eq!(TestArray::from_le_bytes(slice).unwrap().to_le_bytes(), data);
232        }
233        check(&[], &[0, 0]);
234        check(&[1], &[1, 0]);
235        check(&[0, 1], &[0, 1]);
236        check(&[1, 0], &[1, 0]);
237
238        assert!(TestArray::from_le_bytes([1, 2, 3]).is_err());
239    }
240
241    #[test]
242    fn test_from_to_le_bytes_exact_len() {
243        fn check(slice: &[u8], data: &[u8]) {
244            assert_eq!(
245                TestArrayExact::from_le_bytes(slice).unwrap().to_le_bytes(),
246                data
247            );
248        }
249        check(&[0, 128], &[0, 128]);
250        check(&[255, 255, 0], &[255, 255]);
251
252        assert!(TestArrayExact::from_le_bytes([1]).is_err());
253        assert!(TestArrayExact::from_le_bytes([255, 127]).is_err());
254        assert!(TestArrayExact::from_le_bytes([0, 0, 1]).is_err());
255    }
256
257    #[test]
258    fn test_from_to_be_bytes() {
259        fn check(slice: &[u8], data: &[u8]) {
260            assert_eq!(TestArray::from_be_bytes(slice).unwrap().to_be_bytes(), data);
261        }
262        check(&[1], &[0, 1]);
263        check(&[1, 0], &[1, 0]);
264        check(&[0, 1], &[0, 1]);
265
266        assert!(TestArray::from_be_bytes([1, 2, 1]).is_err());
267    }
268
269    #[test]
270    fn test_from_to_be_bytes_exact_len() {
271        fn check(slice: &[u8], data: &[u8]) {
272            assert_eq!(
273                TestArrayExact::from_be_bytes(slice).unwrap().to_be_bytes(),
274                data
275            );
276        }
277        check(&[128, 1], &[128, 1]);
278        check(&[0, 255, 255], &[255, 255]);
279
280        assert!(TestArrayExact::from_be_bytes([1]).is_err());
281        assert!(TestArrayExact::from_be_bytes([127, 1]).is_err());
282        assert!(TestArrayExact::from_be_bytes([1, 0, 0]).is_err());
283    }
284
285    #[test]
286    fn test_bit_len() {
287        fn check(slice: &[u8], bit_len: usize) {
288            assert_eq!(TestArray::from_le_bytes(slice).unwrap().bit_len(), bit_len);
289        }
290        check(&[1], 1);
291        check(&[1, 0], 1);
292        check(&[255], 8);
293        check(&[0, 1], 9);
294        check(&[0, 128], 16);
295    }
296
297    #[test]
298    fn test_from_str() {
299        assert_eq!(TestArray::from_str("0x01").unwrap().to_le_bytes(), [1, 0]);
300        assert_eq!(
301            TestArray::from_str("0x00201").unwrap().to_le_bytes(),
302            [1, 2]
303        );
304        assert!(TestArray::from_str("0x030201").is_err());
305    }
306
307    #[test]
308    fn test_from_str_exact_len() {
309        assert_eq!(
310            TestArrayExact::from_str("0x08001").unwrap().to_le_bytes(),
311            [1, 128]
312        );
313
314        assert!(TestArrayExact::from_str("0x01").is_err());
315        assert!(TestArrayExact::from_str("0x0201").is_err());
316        assert!(TestArrayExact::from_str("0x030201").is_err());
317    }
318
319    #[test]
320    fn test_fmt() {
321        let exact = TestArrayExact::from_str("0xabcd").unwrap();
322        assert_eq!(exact.to_string(), "0xabcd");
323
324        let at_most = TestArray::from_str("0xab").unwrap();
325        assert_eq!(at_most.to_string(), "0x00ab");
326
327        let at_most_full = TestArray::from_str("0xabcd").unwrap();
328        assert_eq!(at_most_full.to_string(), "0xabcd");
329    }
330}