Software APIs
rsa_padding.c
1 // Copyright lowRISC contributors (OpenTitan project).
2 // Licensed under the Apache License, Version 2.0, see LICENSE for details.
3 // SPDX-License-Identifier: Apache-2.0
4 
5 #include "sw/device/lib/crypto/impl/rsa/rsa_padding.h"
6 
10 #include "sw/device/lib/crypto/drivers/entropy.h"
11 #include "sw/device/lib/crypto/drivers/kmac.h"
13 
14 // Module ID for status codes.
15 #define MODULE_ID MAKE_MODULE_ID('r', 'p', 'a')
16 
17 /**
18  * Digest identifiers for different hash functions (little-endian).
19  *
20  * See Note 1 in RFC 8017.
21  */
22 static const uint8_t kSha256DigestIdentifier[] = {
23  0x20, 0x04, 0x00, 0x05, 0x01, 0x02, 0x04, 0x03, 0x65, 0x01,
24  0x48, 0x86, 0x60, 0x09, 0x06, 0x0d, 0x30, 0x31, 0x30,
25 };
26 static const uint8_t kSha384DigestIdentifier[] = {
27  0x30, 0x04, 0x00, 0x05, 0x02, 0x02, 0x04, 0x03, 0x65, 0x01,
28  0x48, 0x86, 0x60, 0x09, 0x06, 0x0d, 0x30, 0x41, 0x30,
29 };
30 static const uint8_t kSha512DigestIdentifier[] = {
31  0x40, 0x04, 0x00, 0x05, 0x03, 0x02, 0x04, 0x03, 0x65, 0x01,
32  0x48, 0x86, 0x60, 0x09, 0x06, 0x0d, 0x30, 0x51, 0x30,
33 };
34 /*
35  * SHA-3 digest identifiers adapted from the SHA-2 identifers based on the
36  * algorithm identifiers on
37  * https://csrc.nist.gov/projects/computer-security-objects-register/algorithm-registration
38  */
39 static const uint8_t kSha3_224DigestIdentifier[] = {
40  0x1c, 0x04, 0x00, 0x05, 0x07, 0x02, 0x04, 0x03, 0x65, 0x01,
41  0x48, 0x86, 0x60, 0x09, 0x06, 0x0d, 0x30, 0x2d, 0x30,
42 };
43 static const uint8_t kSha3_256DigestIdentifier[] = {
44  0x20, 0x04, 0x00, 0x05, 0x08, 0x02, 0x04, 0x03, 0x65, 0x01,
45  0x48, 0x86, 0x60, 0x09, 0x06, 0x0d, 0x30, 0x31, 0x30,
46 };
47 static const uint8_t kSha3_384DigestIdentifier[] = {
48  0x30, 0x04, 0x00, 0x05, 0x09, 0x02, 0x04, 0x03, 0x65, 0x01,
49  0x48, 0x86, 0x60, 0x09, 0x06, 0x0d, 0x30, 0x41, 0x30,
50 };
51 static const uint8_t kSha3_512DigestIdentifier[] = {
52  0x40, 0x04, 0x00, 0x05, 0x0a, 0x02, 0x04, 0x03, 0x65, 0x01,
53  0x48, 0x86, 0x60, 0x09, 0x06, 0x0d, 0x30, 0x51, 0x30,
54 };
55 
56 /**
57  * Get the length of the DER encoding for the given hash function's digests.
58  *
59  * See RFC 8017, Appendix B.1. The encoding consists of the digest algorithm
60  * identifier and then the digest itself.
61  *
62  * @param hash_mode Hash function to use.
63  * @param[out] len Byte-length of the DER encoding of the digest.
64  * @param OTCRYPTO_BAD_ARGS if the hash function is not valid, otherwise OK.
65  */
67 static status_t digest_info_length_get(const otcrypto_hash_mode_t hash_mode,
68  size_t *len) {
69  switch (hash_mode) {
70  case kOtcryptoHashModeSha256:
71  *len = sizeof(kSha256DigestIdentifier) + kSha256DigestBytes;
72  return OTCRYPTO_OK;
73  case kOtcryptoHashModeSha384:
74  *len = sizeof(kSha384DigestIdentifier) + kSha384DigestBytes;
75  return OTCRYPTO_OK;
76  case kOtcryptoHashModeSha512:
77  *len = sizeof(kSha512DigestIdentifier) + kSha512DigestBytes;
78  return OTCRYPTO_OK;
79  case kOtcryptoHashModeSha3_224:
80  *len = sizeof(kSha3_224DigestIdentifier) + kSha3_224DigestBytes;
81  return OTCRYPTO_OK;
82  case kOtcryptoHashModeSha3_256:
83  *len = sizeof(kSha3_256DigestIdentifier) + kSha3_256DigestBytes;
84  return OTCRYPTO_OK;
85  case kOtcryptoHashModeSha3_384:
86  *len = sizeof(kSha3_384DigestIdentifier) + kSha3_384DigestBytes;
87  return OTCRYPTO_OK;
88  case kOtcryptoHashModeSha3_512:
89  *len = sizeof(kSha512DigestIdentifier) + kSha3_512DigestBytes;
90  return OTCRYPTO_OK;
91  default:
92  // Unsupported or unrecognized hash function.
93  return OTCRYPTO_BAD_ARGS;
94  };
95 
96  // Unreachable.
97  HARDENED_TRAP();
98  return OTCRYPTO_FATAL_ERR;
99 }
100 
101 /**
102  * Get the DER encoding for the hash function's digests.
103  *
104  * See RFC 8017, Appendix B.1.
105  *
106  * The caller must ensure that enough space is allocated for the encoding; use
107  * `digest_info_length()` to check before calling this function. Only certain
108  * hash functions are supported.
109  *
110  * Writes the encoding in little-endian, which is reversed compared to the RFC.
111  *
112  * @param message_digest Message digest to encode.
113  * @param[out] encoding DER encoding of the digest.
114  * @return OTCRYPTO_BAD_ARGS if the hash function is not valid, otherwise OK.
115  */
117 static status_t digest_info_write(const otcrypto_hash_digest_t message_digest,
118  uint32_t *encoding) {
119  switch (message_digest.mode) {
120  case kOtcryptoHashModeSha256:
121  if (message_digest.len != kSha256DigestWords) {
122  return OTCRYPTO_BAD_ARGS;
123  }
124  memcpy(encoding + kSha256DigestWords, &kSha256DigestIdentifier,
125  sizeof(kSha256DigestIdentifier));
126  break;
127  case kOtcryptoHashModeSha384:
128  if (message_digest.len != kSha384DigestWords) {
129  return OTCRYPTO_BAD_ARGS;
130  }
131  memcpy(encoding + kSha384DigestWords, &kSha384DigestIdentifier,
132  sizeof(kSha384DigestIdentifier));
133  break;
134  case kOtcryptoHashModeSha512:
135  if (message_digest.len != kSha512DigestWords) {
136  return OTCRYPTO_BAD_ARGS;
137  }
138  memcpy(encoding + kSha512DigestWords, &kSha512DigestIdentifier,
139  sizeof(kSha512DigestIdentifier));
140  break;
141  case kOtcryptoHashModeSha3_224:
142  if (message_digest.len != kSha3_224DigestWords) {
143  return OTCRYPTO_BAD_ARGS;
144  }
145  memcpy(encoding + kSha3_224DigestWords, &kSha3_224DigestIdentifier,
146  sizeof(kSha3_224DigestIdentifier));
147  break;
148  case kOtcryptoHashModeSha3_256:
149  if (message_digest.len != kSha3_256DigestWords) {
150  return OTCRYPTO_BAD_ARGS;
151  }
152  memcpy(encoding + kSha3_256DigestWords, &kSha3_256DigestIdentifier,
153  sizeof(kSha3_256DigestIdentifier));
154  break;
155  case kOtcryptoHashModeSha3_384:
156  if (message_digest.len != kSha3_384DigestWords) {
157  return OTCRYPTO_BAD_ARGS;
158  }
159  memcpy(encoding + kSha3_384DigestWords, &kSha3_384DigestIdentifier,
160  sizeof(kSha3_384DigestIdentifier));
161  break;
162  case kOtcryptoHashModeSha3_512:
163  if (message_digest.len != kSha3_512DigestWords) {
164  return OTCRYPTO_BAD_ARGS;
165  }
166  memcpy(encoding + kSha3_512DigestWords, &kSha3_512DigestIdentifier,
167  sizeof(kSha3_512DigestIdentifier));
168  break;
169  default:
170  // Unsupported or unrecognized hash function.
171  return OTCRYPTO_BAD_ARGS;
172  };
173 
174  // Copy the digest into the encoding, reversing the order of bytes.
175  for (size_t i = 0; i < ceil_div(message_digest.len, 2); i++) {
176  encoding[i] =
177  __builtin_bswap32(message_digest.data[message_digest.len - 1 - i]);
178  encoding[message_digest.len - 1 - i] =
179  __builtin_bswap32(message_digest.data[i]);
180  }
181 
182  return OTCRYPTO_OK;
183 }
184 
185 status_t rsa_padding_pkcs1v15_encode(
186  const otcrypto_hash_digest_t message_digest, size_t encoded_message_len,
187  uint32_t *encoded_message) {
188  // Initialize all bits of the encoded message to 1.
189  size_t encoded_message_bytelen = encoded_message_len * sizeof(uint32_t);
190  memset(encoded_message, 0xff, encoded_message_bytelen);
191 
192  // Get a byte-sized pointer to the encoded message data.
193  unsigned char *buf = (unsigned char *)encoded_message;
194 
195  // Set the last byte to 0x00 and the second-to-last byte to 0x01.
196  buf[encoded_message_bytelen - 1] = 0x00;
197  buf[encoded_message_bytelen - 2] = 0x01;
198 
199  // Get the length of the digest info (called T in the RFC).
200  size_t tlen;
201  HARDENED_TRY(digest_info_length_get(message_digest.mode, &tlen));
202 
203  if (tlen + 3 + 8 >= encoded_message_bytelen) {
204  // Invalid encoded message length/hash function combination; the RFC
205  // specifies that the 0xff padding must be at least 8 octets.
206  return OTCRYPTO_BAD_ARGS;
207  }
208  // Write the digest info to the start of the buffer.
209  HARDENED_TRY(digest_info_write(message_digest, encoded_message));
210 
211  // Set one byte to 0 just after the digest info.
212  buf[tlen] = 0x00;
213 
214  return OTCRYPTO_OK;
215 }
216 
217 status_t rsa_padding_pkcs1v15_verify(
218  const otcrypto_hash_digest_t message_digest,
219  const uint32_t *encoded_message, const size_t encoded_message_len,
220  hardened_bool_t *result) {
221  // Re-encode the message.
222  uint32_t expected_encoded_message[encoded_message_len];
223  HARDENED_TRY(rsa_padding_pkcs1v15_encode(message_digest, encoded_message_len,
224  expected_encoded_message));
225  // Compare with the expected value.
226  *result = hardened_memeq(encoded_message, expected_encoded_message,
227  ARRAYSIZE(expected_encoded_message));
228  return OTCRYPTO_OK;
229 }
230 
231 /**
232  * Get the output size in words for the given hash function.
233  *
234  * Returns an error if the hash mode is unsupported, unrecognized, or does not
235  * have a fixed length.
236  *
237  * @param hash_mode Hash function.
238  * @param[out] num_words Output length in 32-bit words.
239  * @return Result of the operation (OK or error).
240  */
242 static status_t digest_wordlen_get(otcrypto_hash_mode_t hash_mode,
243  size_t *num_words) {
244  *num_words = 0;
245  switch (hash_mode) {
246  case kOtcryptoHashModeSha3_224:
247  *num_words = 224 / 32;
248  break;
249  case kOtcryptoHashModeSha256:
251  case kOtcryptoHashModeSha3_256:
252  *num_words = 256 / 32;
253  break;
254  case kOtcryptoHashModeSha384:
256  case kOtcryptoHashModeSha3_384:
257  *num_words = 384 / 32;
258  break;
259  case kOtcryptoHashModeSha512:
261  case kOtcryptoHashModeSha3_512:
262  *num_words = 512 / 32;
263  break;
264  default:
265  return OTCRYPTO_BAD_ARGS;
266  }
267  HARDENED_CHECK_GT(num_words, 0);
268 
269  return OTCRYPTO_OK;
270 }
271 
272 /**
273  * Mask generation function MGF1 (RFC 8017, appendix B.2.1).
274  *
275  * The `mask` parameter is 32-bit aligned because this makes it more secure and
276  * efficient to operate and compare with the mask. However, the mask length is
277  * not necessarily a multiple of the word size. This routine guarantees that
278  * any extra bytes at the end of the mask will be initialized, but does not
279  * make any guarantees about their values.
280  *
281  * @param hash_mode Hash function to use.
282  * @param seed Seed data.
283  * @param seed_len Length of seed data in bytes.
284  * @param mask_len Intended byte-length of the mask.
285  * @param[out] mask Destination buffer for mask.
286  * @return Result of the operation (OK or error).
287  */
289 static status_t mgf1(otcrypto_hash_mode_t hash_mode, const uint8_t *seed,
290  size_t seed_len, size_t mask_len, uint32_t *mask) {
291  // Check that the number of iterations won't overflow the counter.
292  size_t digest_wordlen;
293  HARDENED_TRY(digest_wordlen_get(hash_mode, &digest_wordlen));
294  size_t digest_bytelen = digest_wordlen * sizeof(uint32_t);
295  size_t num_iterations = ceil_div(mask_len, digest_bytelen);
296  if (num_iterations > UINT32_MAX) {
297  return OTCRYPTO_BAD_ARGS;
298  }
299 
300  // First, process the iterations in which the entire digest will fit in the
301  // `mask` buffer.
302  uint8_t hash_input[seed_len + sizeof(uint32_t)];
303  memcpy(hash_input, seed, seed_len);
304  for (uint32_t i = 0; i < num_iterations - 1; i++) {
305  uint32_t ctr = __builtin_bswap32(i);
306  memcpy(hash_input + seed_len, &ctr, sizeof(uint32_t));
307  otcrypto_hash_digest_t digest = {
308  .data = mask, .len = digest_wordlen, .mode = hash_mode};
309  HARDENED_TRY(otcrypto_hash(
311  .data = hash_input,
312  .len = sizeof(hash_input),
313  },
314  digest));
315  mask += digest_wordlen;
316  mask_len -= digest_bytelen;
317  }
318  HARDENED_CHECK_LE(mask_len, digest_bytelen);
319 
320  // Last iteration is special; use an intermediate buffer in case the digest
321  // is longer than the remaining mask buffer.
322  uint32_t ctr = __builtin_bswap32(num_iterations - 1);
323  memcpy(hash_input + seed_len, &ctr, sizeof(uint32_t));
324  uint32_t digest_data[digest_wordlen];
325  otcrypto_hash_digest_t digest = {
326  .data = digest_data, .len = digest_wordlen, .mode = hash_mode};
327  HARDENED_TRY(
328  otcrypto_hash((otcrypto_const_byte_buf_t){.data = hash_input,
329  .len = sizeof(hash_input)},
330  digest));
331  hardened_memcpy(mask, digest_data, ceil_div(mask_len, sizeof(uint32_t)));
332  return OTCRYPTO_OK;
333 }
334 
335 /**
336  * Reverse the byte-order of a word array in-place.
337  *
338  * @param input_len Length of input in 32-bit words.
339  * @param[in,out] input Input array, modified in-place.
340  */
341 static inline void reverse_bytes(size_t input_len, uint32_t *input) {
342  for (size_t i = 0; i < (input_len + 1) / 2; i++) {
343  size_t j = input_len - 1 - i;
344  uint32_t tmp = input[j];
345  input[j] = __builtin_bswap32(input[i]);
346  input[i] = __builtin_bswap32(tmp);
347  }
348 }
349 
350 /**
351  * Helper function to construct the "H" value for PSS encoding.
352  *
353  * As described in RFC 8017, H = Hash(0x0000000000000000 || digest || salt).
354  * This value needs to be computed for both encryption and decryption. The hash
355  * function should match the hash function from the message digest, so the
356  * caller is responsible for ensuring that there is enough space in `h` to hold
357  * another digest of the same type.
358  *
359  * @param message_digest Message digest to encode.
360  * @param salt Salt value.
361  * @param salt_len Length of the salt in 32-bit words.
362  * @param[out] h Resulting digest, H.
363  */
365 static status_t pss_construct_h(const otcrypto_hash_digest_t message_digest,
366  const uint32_t *salt, size_t salt_len,
367  uint32_t *h) {
368  // Create a buffer for M' = (0x0000000000000000 || digest || salt).
369  size_t m_prime_wordlen = 2 + message_digest.len + salt_len;
370  uint32_t m_prime[m_prime_wordlen];
371  m_prime[0] = 0;
372  m_prime[1] = 0;
373  uint32_t *digest_dst = &m_prime[2];
374  uint32_t *salt_dst = digest_dst + message_digest.len;
375  hardened_memcpy(digest_dst, message_digest.data, message_digest.len);
376  if (salt_len > 0) {
377  hardened_memcpy(salt_dst, salt, salt_len);
378  }
379 
380  // Construct H = Hash(M').
381  otcrypto_hash_digest_t h_buffer = {
382  .data = h, .len = message_digest.len, .mode = message_digest.mode};
383  return otcrypto_hash(
384  (otcrypto_const_byte_buf_t){.data = (unsigned char *)m_prime,
385  .len = sizeof(m_prime)},
386  h_buffer);
387 }
388 
389 status_t rsa_padding_pss_encode(const otcrypto_hash_digest_t message_digest,
390  const uint32_t *salt, size_t salt_len,
391  size_t encoded_message_len,
392  uint32_t *encoded_message) {
393  // Check that the message length is long enough.
394  size_t digest_bytelen = message_digest.len * sizeof(uint32_t);
395  size_t salt_bytelen = salt_len * sizeof(uint32_t);
396  size_t encoded_message_bytelen = encoded_message_len * sizeof(uint32_t);
397  if (encoded_message_bytelen < salt_bytelen + digest_bytelen + 2) {
398  return OTCRYPTO_BAD_ARGS;
399  }
400 
401  // Construct H = Hash(0x0000000000000000 || digest || salt).
402  uint32_t h[message_digest.len];
403  HARDENED_TRY(pss_construct_h(message_digest, salt, salt_len, h));
404 
405  // Construct DB = 00...00 || 0x01 || salt.
406  size_t db_bytelen = encoded_message_bytelen - digest_bytelen - 1;
407  uint32_t db[ceil_div(db_bytelen, sizeof(uint32_t))];
408  memset(db, 0, sizeof(db));
409  unsigned char *db_bytes = (unsigned char *)db;
410  db_bytes[db_bytelen - 1 - salt_bytelen] = 0x01;
411  if (salt_bytelen > 0) {
412  memcpy(db_bytes + (db_bytelen - salt_bytelen), salt, salt_bytelen);
413  }
414 
415  // Compute the mask.
416  uint32_t mask[ARRAYSIZE(db)];
417  HARDENED_TRY(mgf1(message_digest.mode, (unsigned char *)h, sizeof(h),
418  db_bytelen, mask));
419 
420  // Compute maskedDB = DB ^ mask.
421  for (size_t i = 0; i < ARRAYSIZE(db); i++) {
422  db[i] ^= mask[i];
423  }
424 
425  // Set the most significant bit of the first byte of maskedDB to 0. This
426  // ensures the encoded message is less than the modulus. Corresponds to RFC
427  // 8017, section 9.1.1, step 11 (where emBits is modLen - 1).
428  db_bytes[0] &= 0x7f;
429 
430  // Compute the final encoded message and reverse the byte-order.
431  // EM = maskedDB || H || 0xbc
432  unsigned char *encoded_message_bytes = (unsigned char *)encoded_message;
433  hardened_memcpy(encoded_message, db, ARRAYSIZE(db));
434  memcpy(encoded_message_bytes + db_bytelen, h, sizeof(h));
435  encoded_message_bytes[encoded_message_bytelen - 1] = 0xbc;
436  reverse_bytes(encoded_message_len, encoded_message);
437  return OTCRYPTO_OK;
438 }
439 
440 status_t rsa_padding_pss_verify(const otcrypto_hash_digest_t message_digest,
441  uint32_t *encoded_message,
442  size_t encoded_message_len,
443  hardened_bool_t *result) {
444  // Initialize the result to false.
445  *result = kHardenedBoolFalse;
446 
447  // Check that the message length is long enough.
448  size_t digest_bytelen = message_digest.len * sizeof(uint32_t);
449  size_t salt_bytelen = digest_bytelen;
450  size_t encoded_message_bytelen = encoded_message_len * sizeof(uint32_t);
451  if (encoded_message_bytelen < salt_bytelen + digest_bytelen + 2) {
452  return OTCRYPTO_BAD_ARGS;
453  }
454 
455  // Reverse the byte-order.
456  reverse_bytes(encoded_message_len, encoded_message);
457 
458  // Check the last byte.
459  unsigned char *encoded_message_bytes = (unsigned char *)encoded_message;
460  if (encoded_message_bytes[encoded_message_bytelen - 1] != 0xbc) {
461  *result = kHardenedBoolFalse;
462  return OTCRYPTO_OK;
463  }
464 
465  // Extract the masked "DB" value. Zero the last bytes if needed.
466  size_t db_bytelen = encoded_message_bytelen - digest_bytelen - 1;
467  uint32_t db[ceil_div(db_bytelen, sizeof(uint32_t))];
468  memcpy(db, encoded_message_bytes, db_bytelen);
469  if (sizeof(db) > db_bytelen) {
470  memset(((unsigned char *)db) + db_bytelen, 0, sizeof(db) - db_bytelen);
471  }
472 
473  // Extract H.
474  uint32_t h[message_digest.len];
475  memcpy(h, encoded_message_bytes + db_bytelen, sizeof(h));
476 
477  // Compute the mask = MFG(H, emLen - hLen - 1). Zero the last bytes if
478  // needed.
479  uint32_t mask[ARRAYSIZE(db)];
480  HARDENED_TRY(mgf1(message_digest.mode, (unsigned char *)h, sizeof(h),
481  db_bytelen, mask));
482  if (sizeof(mask) > db_bytelen) {
483  memset(((unsigned char *)mask) + db_bytelen, 0, sizeof(mask) - db_bytelen);
484  }
485 
486  // Unmask the "DB" value.
487  for (size_t i = 0; i < ARRAYSIZE(db); i++) {
488  db[i] ^= mask[i];
489  }
490 
491  // Set the most significant bit of the first byte of maskedDB to 0.
492  // Corresponds to RFC 8017, section 9.1.2 step 9 (emBits is modLen - 1).
493  unsigned char *db_bytes = (unsigned char *)db;
494  db_bytes[0] &= 0x7f;
495 
496  // Check that DB starts with all zeroes followed by a single 1 byte. Copy in
497  // enough trailing bytes to fill the last word, so that we can use
498  // `hardened_memeq` here.
499  size_t padding_bytelen = db_bytelen - salt_bytelen;
500  uint32_t exp_padding[ceil_div(padding_bytelen, sizeof(uint32_t))];
501  unsigned char *exp_padding_bytes = (unsigned char *)exp_padding;
502  memset(exp_padding, 0, padding_bytelen - 1);
503  exp_padding_bytes[padding_bytelen - 1] = 0x01;
504  memcpy(exp_padding_bytes + padding_bytelen, db_bytes + padding_bytelen,
505  sizeof(exp_padding) - padding_bytelen);
506  hardened_bool_t padding_eq =
507  hardened_memeq(db, exp_padding, ARRAYSIZE(exp_padding));
508  if (padding_eq != kHardenedBoolTrue) {
509  *result = kHardenedBoolFalse;
510  return OTCRYPTO_OK;
511  }
512 
513  // Extract the salt.
514  uint32_t salt[message_digest.len];
515  memcpy(salt, db_bytes + db_bytelen - salt_bytelen, sizeof(salt));
516 
517  // Construct the expected value of H and compare.
518  uint32_t exp_h[message_digest.len];
519  HARDENED_TRY(pss_construct_h(message_digest, salt, ARRAYSIZE(salt), exp_h));
520  *result = hardened_memeq(h, exp_h, ARRAYSIZE(exp_h));
521  return OTCRYPTO_OK;
522 }
523 
524 status_t rsa_padding_oaep_max_message_bytelen(
525  const otcrypto_hash_mode_t hash_mode, size_t rsa_wordlen,
526  size_t *max_message_bytelen) {
527  // Get the hash digest length for the given hash function (and check that it
528  // is one of the supported hash functions).
529  size_t digest_wordlen = 0;
530  HARDENED_TRY(digest_wordlen_get(hash_mode, &digest_wordlen));
531 
532  size_t digest_bytelen = digest_wordlen * sizeof(uint32_t);
533  size_t rsa_bytelen = rsa_wordlen * sizeof(uint32_t);
534  if (2 * digest_bytelen + 2 > rsa_bytelen) {
535  // This case would cause underflow if we continue; return an error.
536  return OTCRYPTO_BAD_ARGS;
537  }
538 
539  *max_message_bytelen = rsa_bytelen - 2 * digest_bytelen - 2;
540  return OTCRYPTO_OK;
541 }
542 
543 status_t rsa_padding_oaep_encode(const otcrypto_hash_mode_t hash_mode,
544  const uint8_t *message, size_t message_bytelen,
545  const uint8_t *label, size_t label_bytelen,
546  size_t encoded_message_len,
547  uint32_t *encoded_message) {
548  // Check that the message is not too long (RFC 8017, section 7.1.1, step 1a).
549  size_t max_message_bytelen = 0;
550  HARDENED_TRY(rsa_padding_oaep_max_message_bytelen(
551  hash_mode, encoded_message_len, &max_message_bytelen));
552  if (message_bytelen > max_message_bytelen) {
553  return OTCRYPTO_BAD_ARGS;
554  }
555 
556  // Get the hash digest length for the given hash function (and check that it
557  // is one of the supported hash functions).
558  size_t digest_wordlen = 0;
559  HARDENED_TRY(digest_wordlen_get(hash_mode, &digest_wordlen));
560 
561  // Hash the label (step 2a).
562  otcrypto_const_byte_buf_t label_buf = {
563  .data = label,
564  .len = label_bytelen,
565  };
566  uint32_t lhash_data[digest_wordlen];
567  otcrypto_hash_digest_t lhash = {
568  .data = lhash_data,
569  .len = ARRAYSIZE(lhash_data),
570  .mode = hash_mode,
571  };
572  HARDENED_TRY(otcrypto_hash(label_buf, lhash));
573 
574  // Generate a random string the same length as a hash digest (step 2d).
575  uint32_t seed[digest_wordlen];
576  HARDENED_TRY(entropy_complex_check());
577  HARDENED_TRY(entropy_csrng_instantiate(
578  /*disable_trng_input=*/kHardenedBoolFalse, &kEntropyEmptySeed));
579  HARDENED_TRY(entropy_csrng_generate(&kEntropyEmptySeed, seed, ARRAYSIZE(seed),
580  /*fips_check=*/kHardenedBoolTrue));
581  HARDENED_TRY(entropy_csrng_uninstantiate());
582 
583  // Generate dbMask = MGF(seed, k - hLen - 1) (step 2e).
584  size_t digest_bytelen = digest_wordlen * sizeof(uint32_t);
585  size_t encoded_message_bytelen = encoded_message_len * sizeof(uint32_t);
586  size_t db_bytelen = encoded_message_bytelen - digest_bytelen - 1;
587  size_t db_wordlen = ceil_div(db_bytelen, sizeof(uint32_t));
588  uint32_t db[db_wordlen];
589  HARDENED_TRY(
590  mgf1(hash_mode, (unsigned char *)seed, sizeof(seed), db_bytelen, db));
591 
592  // Construct maskedDB = dbMask XOR (lhash || PS || 0x01 || M), where PS is
593  // all-zero (step 2f). By computing the mask first, we can simply XOR with
594  // lhash, 0x01, and M, skipping PS because XOR with zero is the identity
595  // function.
596  for (size_t i = 0; i < ARRAYSIZE(lhash_data); i++) {
597  db[i] ^= lhash_data[i];
598  }
599  size_t message_start_idx = db_bytelen - message_bytelen;
600  unsigned char *db_bytes = (unsigned char *)db;
601  db_bytes[message_start_idx - 1] ^= 0x01;
602  for (size_t i = 0; i < message_bytelen; i++) {
603  db_bytes[message_start_idx + i] ^= message[i];
604  }
605 
606  // Compute seedMask = MGF(maskedDB, hLen) (step 2g).
607  uint32_t seed_mask[digest_wordlen];
608  HARDENED_TRY(mgf1(hash_mode, (unsigned char *)db, db_bytelen, digest_bytelen,
609  seed_mask));
610 
611  // Construct maskedSeed = seed XOR seedMask (step 2h).
612  for (size_t i = 0; i < ARRAYSIZE(seed); i++) {
613  seed[i] ^= seed_mask[i];
614  }
615 
616  // Construct EM = 0x00 || maskedSeed || maskedDB (step 2i).
617  unsigned char *encoded_message_bytes = (unsigned char *)encoded_message;
618  encoded_message_bytes[0] = 0x00;
619  memcpy(encoded_message_bytes + 1, seed, sizeof(seed));
620  memcpy(encoded_message_bytes + 1 + sizeof(seed), db, sizeof(db));
621 
622  // Reverse the byte-order.
623  reverse_bytes(encoded_message_len, encoded_message);
624  return OTCRYPTO_OK;
625 }
626 
627 status_t rsa_padding_oaep_decode(const otcrypto_hash_mode_t hash_mode,
628  const uint8_t *label, size_t label_bytelen,
629  uint32_t *encoded_message,
630  size_t encoded_message_len, uint8_t *message,
631  size_t *message_bytelen) {
632  // Reverse the byte-order.
633  reverse_bytes(encoded_message_len, encoded_message);
634  *message_bytelen = 0;
635 
636  // Get the hash digest length for the given hash function (and check that it
637  // is one of the supported hash functions).
638  size_t digest_wordlen = 0;
639  HARDENED_TRY(digest_wordlen_get(hash_mode, &digest_wordlen));
640 
641  // Extract maskedSeed from the encoded message (RFC 8017, section 7.1.2, step
642  // 3b).
643  uint32_t seed[digest_wordlen];
644  unsigned char *encoded_message_bytes = (unsigned char *)encoded_message;
645  memcpy(seed, encoded_message_bytes + 1, sizeof(seed));
646 
647  // Extract maskedDB from the encoded message (RFC 8017, section 7.1.2, step
648  // 3b).
649  size_t digest_bytelen = digest_wordlen * sizeof(uint32_t);
650  size_t encoded_message_bytelen = encoded_message_len * sizeof(uint32_t);
651  size_t db_bytelen = encoded_message_bytelen - digest_bytelen - 1;
652  size_t db_wordlen = ceil_div(db_bytelen, sizeof(uint32_t));
653  uint32_t db[db_wordlen];
654  memcpy(db, encoded_message_bytes + 1 + sizeof(seed), db_bytelen);
655 
656  // Compute seedMask = MGF(maskedDB, hLen) (step 3c).
657  uint32_t seed_mask[digest_wordlen];
658  HARDENED_TRY(mgf1(hash_mode, (unsigned char *)db, db_bytelen, digest_bytelen,
659  seed_mask));
660 
661  // Construct seed = maskedSeed XOR seedMask (step 3d).
662  for (size_t i = 0; i < ARRAYSIZE(seed); i++) {
663  seed[i] ^= seed_mask[i];
664  }
665 
666  // Generate dbMask = MGF(seed, k - hLen - 1) (step 3e).
667  uint32_t db_mask[db_wordlen];
668  HARDENED_TRY(mgf1(hash_mode, (unsigned char *)seed, sizeof(seed), db_bytelen,
669  db_mask));
670 
671  // Zero trailing bytes of DB and dbMask if needed.
672  size_t num_trailing_bytes = sizeof(db) - db_bytelen;
673  if (num_trailing_bytes > 0) {
674  memset(((unsigned char *)db) + db_bytelen, 0, num_trailing_bytes);
675  memset(((unsigned char *)db_mask) + db_bytelen, 0, num_trailing_bytes);
676  }
677 
678  // Construct DB = dbMask XOR maskedDB.
679  for (size_t i = 0; i < ARRAYSIZE(db); i++) {
680  db[i] ^= db_mask[i];
681  }
682 
683  // Hash the label (step 3a).
684  otcrypto_const_byte_buf_t label_buf = {
685  .data = label,
686  .len = label_bytelen,
687  };
688  uint32_t lhash_data[digest_wordlen];
689  otcrypto_hash_digest_t lhash = {
690  .data = lhash_data,
691  .len = digest_wordlen,
692  .mode = hash_mode,
693  };
694  HARDENED_TRY(otcrypto_hash(label_buf, lhash));
695 
696  // Note: as we compare parts of the encoded message to their expected values,
697  // we must be careful that the attacker cannot differentiate error codes or
698  // get partial information about the encoded message. See the note in RCC
699  // 8017, section 7.1.2. This implementation currently protects against
700  // revealing this information through error codes or timing, but does not yet
701  // defend against power side channels.
702 
703  // Locate the start of the message in DB = lhash || 0x00..0x00 || 0x01 || M
704  // by searching for the 0x01 byte in constant time.
705  unsigned char *db_bytes = (unsigned char *)db;
706  uint32_t message_start_idx = 0;
707  ct_bool32_t decode_failure = 0;
708  for (size_t i = digest_bytelen; i < db_bytelen; i++) {
709  uint32_t byte = 0;
710  memcpy(&byte, db_bytes + i, 1);
711  ct_bool32_t is_one = ct_seq32(byte, 0x01);
712  ct_bool32_t is_before_message = ct_seqz32(message_start_idx);
713  ct_bool32_t is_message_start = is_one & is_before_message;
714  message_start_idx = ct_cmov32(is_message_start, i + 1, message_start_idx);
715  ct_bool32_t is_zero = ct_seqz32(byte);
716  ct_bool32_t padding_failure = is_before_message & (~is_zero) & (~is_one);
717  decode_failure |= padding_failure;
718  }
719  HARDENED_CHECK_LE(message_start_idx, db_bytelen);
720 
721  // If we never found a message start index, we should fail. However, don't
722  // fail yet to avoid leaking timing information.
723  ct_bool32_t message_start_not_found = ct_seqz32(message_start_idx);
724  decode_failure |= message_start_not_found;
725 
726  // Check that the first part of DB is equal to lhash.
727  hardened_bool_t lhash_matches =
728  hardened_memeq(lhash_data, db, digest_wordlen);
729  ct_bool32_t lhash_match = ct_seq32(lhash_matches, kHardenedBoolTrue);
730  ct_bool32_t lhash_mismatch = ~lhash_match;
731  decode_failure |= lhash_mismatch;
732 
733  // Check that the leading byte is 0.
734  uint32_t leading_byte = 0;
735  memcpy(&leading_byte, encoded_message_bytes, 1);
736  ct_bool32_t leading_byte_nonzero = ~ct_seqz32(leading_byte);
737  decode_failure |= leading_byte_nonzero;
738 
739  // Now, decode_failure is all-zero if the decode succeeded and all-one if the
740  // decode failed.
741  if (launder32(decode_failure) != 0) {
742  return OTCRYPTO_BAD_ARGS;
743  }
744  HARDENED_CHECK_EQ(decode_failure, 0);
745 
746  // TODO: re-check the padding as an FI hardening measure?
747 
748  // If we get here, then the encoded message has a proper format and it is
749  // safe to copy the message into the output buffer.
750  *message_bytelen = db_bytelen - message_start_idx;
751  if (*message_bytelen > 0) {
752  memcpy(message, db_bytes + message_start_idx, *message_bytelen);
753  }
754  return OTCRYPTO_OK;
755 }