5 #include "sw/device/lib/crypto/impl/rsa/rsa_padding.h"
10 #include "sw/device/lib/crypto/drivers/entropy.h"
11 #include "sw/device/lib/crypto/drivers/kmac.h"
15 #define MODULE_ID MAKE_MODULE_ID('r', 'p', 'a')
22 static const uint8_t kSha256DigestIdentifier[] = {
23 0x20, 0x04, 0x00, 0x05, 0x01, 0x02, 0x04, 0x03, 0x65, 0x01,
24 0x48, 0x86, 0x60, 0x09, 0x06, 0x0d, 0x30, 0x31, 0x30,
26 static const uint8_t kSha384DigestIdentifier[] = {
27 0x30, 0x04, 0x00, 0x05, 0x02, 0x02, 0x04, 0x03, 0x65, 0x01,
28 0x48, 0x86, 0x60, 0x09, 0x06, 0x0d, 0x30, 0x41, 0x30,
30 static const uint8_t kSha512DigestIdentifier[] = {
31 0x40, 0x04, 0x00, 0x05, 0x03, 0x02, 0x04, 0x03, 0x65, 0x01,
32 0x48, 0x86, 0x60, 0x09, 0x06, 0x0d, 0x30, 0x51, 0x30,
39 static const uint8_t kSha3_224DigestIdentifier[] = {
40 0x1c, 0x04, 0x00, 0x05, 0x07, 0x02, 0x04, 0x03, 0x65, 0x01,
41 0x48, 0x86, 0x60, 0x09, 0x06, 0x0d, 0x30, 0x2d, 0x30,
43 static const uint8_t kSha3_256DigestIdentifier[] = {
44 0x20, 0x04, 0x00, 0x05, 0x08, 0x02, 0x04, 0x03, 0x65, 0x01,
45 0x48, 0x86, 0x60, 0x09, 0x06, 0x0d, 0x30, 0x31, 0x30,
47 static const uint8_t kSha3_384DigestIdentifier[] = {
48 0x30, 0x04, 0x00, 0x05, 0x09, 0x02, 0x04, 0x03, 0x65, 0x01,
49 0x48, 0x86, 0x60, 0x09, 0x06, 0x0d, 0x30, 0x41, 0x30,
51 static const uint8_t kSha3_512DigestIdentifier[] = {
52 0x40, 0x04, 0x00, 0x05, 0x0a, 0x02, 0x04, 0x03, 0x65, 0x01,
53 0x48, 0x86, 0x60, 0x09, 0x06, 0x0d, 0x30, 0x51, 0x30,
70 case kOtcryptoHashModeSha256:
71 *len =
sizeof(kSha256DigestIdentifier) + kSha256DigestBytes;
73 case kOtcryptoHashModeSha384:
74 *len =
sizeof(kSha384DigestIdentifier) + kSha384DigestBytes;
76 case kOtcryptoHashModeSha512:
77 *len =
sizeof(kSha512DigestIdentifier) + kSha512DigestBytes;
79 case kOtcryptoHashModeSha3_224:
80 *len =
sizeof(kSha3_224DigestIdentifier) + kSha3_224DigestBytes;
82 case kOtcryptoHashModeSha3_256:
83 *len =
sizeof(kSha3_256DigestIdentifier) + kSha3_256DigestBytes;
85 case kOtcryptoHashModeSha3_384:
86 *len =
sizeof(kSha3_384DigestIdentifier) + kSha3_384DigestBytes;
88 case kOtcryptoHashModeSha3_512:
89 *len =
sizeof(kSha512DigestIdentifier) + kSha3_512DigestBytes;
93 return OTCRYPTO_BAD_ARGS;
98 return OTCRYPTO_FATAL_ERR;
118 uint32_t *encoding) {
119 switch (message_digest.mode) {
120 case kOtcryptoHashModeSha256:
121 if (message_digest.len != kSha256DigestWords) {
122 return OTCRYPTO_BAD_ARGS;
124 memcpy(encoding + kSha256DigestWords, &kSha256DigestIdentifier,
125 sizeof(kSha256DigestIdentifier));
127 case kOtcryptoHashModeSha384:
128 if (message_digest.len != kSha384DigestWords) {
129 return OTCRYPTO_BAD_ARGS;
131 memcpy(encoding + kSha384DigestWords, &kSha384DigestIdentifier,
132 sizeof(kSha384DigestIdentifier));
134 case kOtcryptoHashModeSha512:
135 if (message_digest.len != kSha512DigestWords) {
136 return OTCRYPTO_BAD_ARGS;
138 memcpy(encoding + kSha512DigestWords, &kSha512DigestIdentifier,
139 sizeof(kSha512DigestIdentifier));
141 case kOtcryptoHashModeSha3_224:
142 if (message_digest.len != kSha3_224DigestWords) {
143 return OTCRYPTO_BAD_ARGS;
145 memcpy(encoding + kSha3_224DigestWords, &kSha3_224DigestIdentifier,
146 sizeof(kSha3_224DigestIdentifier));
148 case kOtcryptoHashModeSha3_256:
149 if (message_digest.len != kSha3_256DigestWords) {
150 return OTCRYPTO_BAD_ARGS;
152 memcpy(encoding + kSha3_256DigestWords, &kSha3_256DigestIdentifier,
153 sizeof(kSha3_256DigestIdentifier));
155 case kOtcryptoHashModeSha3_384:
156 if (message_digest.len != kSha3_384DigestWords) {
157 return OTCRYPTO_BAD_ARGS;
159 memcpy(encoding + kSha3_384DigestWords, &kSha3_384DigestIdentifier,
160 sizeof(kSha3_384DigestIdentifier));
162 case kOtcryptoHashModeSha3_512:
163 if (message_digest.len != kSha3_512DigestWords) {
164 return OTCRYPTO_BAD_ARGS;
166 memcpy(encoding + kSha3_512DigestWords, &kSha3_512DigestIdentifier,
167 sizeof(kSha3_512DigestIdentifier));
171 return OTCRYPTO_BAD_ARGS;
175 for (
size_t i = 0; i <
ceil_div(message_digest.len, 2); i++) {
177 __builtin_bswap32(message_digest.data[message_digest.len - 1 - i]);
178 encoding[message_digest.len - 1 - i] =
179 __builtin_bswap32(message_digest.data[i]);
185 status_t rsa_padding_pkcs1v15_encode(
187 uint32_t *encoded_message) {
189 size_t encoded_message_bytelen = encoded_message_len *
sizeof(uint32_t);
190 memset(encoded_message, 0xff, encoded_message_bytelen);
193 unsigned char *buf = (
unsigned char *)encoded_message;
196 buf[encoded_message_bytelen - 1] = 0x00;
197 buf[encoded_message_bytelen - 2] = 0x01;
201 HARDENED_TRY(digest_info_length_get(message_digest.mode, &tlen));
203 if (tlen + 3 + 8 >= encoded_message_bytelen) {
206 return OTCRYPTO_BAD_ARGS;
209 HARDENED_TRY(digest_info_write(message_digest, encoded_message));
217 status_t rsa_padding_pkcs1v15_verify(
219 const uint32_t *encoded_message,
const size_t encoded_message_len,
222 uint32_t expected_encoded_message[encoded_message_len];
223 HARDENED_TRY(rsa_padding_pkcs1v15_encode(message_digest, encoded_message_len,
224 expected_encoded_message));
226 *result =
hardened_memeq(encoded_message, expected_encoded_message,
246 case kOtcryptoHashModeSha3_224:
247 *num_words = 224 / 32;
249 case kOtcryptoHashModeSha256:
251 case kOtcryptoHashModeSha3_256:
252 *num_words = 256 / 32;
254 case kOtcryptoHashModeSha384:
256 case kOtcryptoHashModeSha3_384:
257 *num_words = 384 / 32;
259 case kOtcryptoHashModeSha512:
261 case kOtcryptoHashModeSha3_512:
262 *num_words = 512 / 32;
265 return OTCRYPTO_BAD_ARGS;
267 HARDENED_CHECK_GT(num_words, 0);
290 size_t seed_len,
size_t mask_len, uint32_t *mask) {
292 size_t digest_wordlen;
293 HARDENED_TRY(digest_wordlen_get(hash_mode, &digest_wordlen));
294 size_t digest_bytelen = digest_wordlen *
sizeof(uint32_t);
295 size_t num_iterations =
ceil_div(mask_len, digest_bytelen);
296 if (num_iterations > UINT32_MAX) {
297 return OTCRYPTO_BAD_ARGS;
302 uint8_t hash_input[seed_len +
sizeof(uint32_t)];
303 memcpy(hash_input, seed, seed_len);
304 for (uint32_t i = 0; i < num_iterations - 1; i++) {
305 uint32_t ctr = __builtin_bswap32(i);
306 memcpy(hash_input + seed_len, &ctr,
sizeof(uint32_t));
308 .data = mask, .len = digest_wordlen, .mode = hash_mode};
312 .len =
sizeof(hash_input),
315 mask += digest_wordlen;
316 mask_len -= digest_bytelen;
318 HARDENED_CHECK_LE(mask_len, digest_bytelen);
322 uint32_t ctr = __builtin_bswap32(num_iterations - 1);
323 memcpy(hash_input + seed_len, &ctr,
sizeof(uint32_t));
324 uint32_t digest_data[digest_wordlen];
326 .data = digest_data, .len = digest_wordlen, .mode = hash_mode};
329 .len =
sizeof(hash_input)},
341 static inline void reverse_bytes(
size_t input_len, uint32_t *input) {
342 for (
size_t i = 0; i < (input_len + 1) / 2; i++) {
343 size_t j = input_len - 1 - i;
344 uint32_t tmp = input[j];
345 input[j] = __builtin_bswap32(input[i]);
346 input[i] = __builtin_bswap32(tmp);
366 const uint32_t *salt,
size_t salt_len,
369 size_t m_prime_wordlen = 2 + message_digest.len + salt_len;
370 uint32_t m_prime[m_prime_wordlen];
373 uint32_t *digest_dst = &m_prime[2];
374 uint32_t *salt_dst = digest_dst + message_digest.len;
382 .data = h, .len = message_digest.len, .mode = message_digest.mode};
385 .len =
sizeof(m_prime)},
390 const uint32_t *salt,
size_t salt_len,
391 size_t encoded_message_len,
392 uint32_t *encoded_message) {
394 size_t digest_bytelen = message_digest.len *
sizeof(uint32_t);
395 size_t salt_bytelen = salt_len *
sizeof(uint32_t);
396 size_t encoded_message_bytelen = encoded_message_len *
sizeof(uint32_t);
397 if (encoded_message_bytelen < salt_bytelen + digest_bytelen + 2) {
398 return OTCRYPTO_BAD_ARGS;
402 uint32_t h[message_digest.len];
403 HARDENED_TRY(pss_construct_h(message_digest, salt, salt_len, h));
406 size_t db_bytelen = encoded_message_bytelen - digest_bytelen - 1;
407 uint32_t db[
ceil_div(db_bytelen,
sizeof(uint32_t))];
408 memset(db, 0,
sizeof(db));
409 unsigned char *db_bytes = (
unsigned char *)db;
410 db_bytes[db_bytelen - 1 - salt_bytelen] = 0x01;
411 if (salt_bytelen > 0) {
412 memcpy(db_bytes + (db_bytelen - salt_bytelen), salt, salt_bytelen);
417 HARDENED_TRY(mgf1(message_digest.mode, (
unsigned char *)h,
sizeof(h),
421 for (
size_t i = 0; i <
ARRAYSIZE(db); i++) {
432 unsigned char *encoded_message_bytes = (
unsigned char *)encoded_message;
434 memcpy(encoded_message_bytes + db_bytelen, h,
sizeof(h));
435 encoded_message_bytes[encoded_message_bytelen - 1] = 0xbc;
436 reverse_bytes(encoded_message_len, encoded_message);
441 uint32_t *encoded_message,
442 size_t encoded_message_len,
448 size_t digest_bytelen = message_digest.len *
sizeof(uint32_t);
449 size_t salt_bytelen = digest_bytelen;
450 size_t encoded_message_bytelen = encoded_message_len *
sizeof(uint32_t);
451 if (encoded_message_bytelen < salt_bytelen + digest_bytelen + 2) {
452 return OTCRYPTO_BAD_ARGS;
456 reverse_bytes(encoded_message_len, encoded_message);
459 unsigned char *encoded_message_bytes = (
unsigned char *)encoded_message;
460 if (encoded_message_bytes[encoded_message_bytelen - 1] != 0xbc) {
466 size_t db_bytelen = encoded_message_bytelen - digest_bytelen - 1;
467 uint32_t db[
ceil_div(db_bytelen,
sizeof(uint32_t))];
468 memcpy(db, encoded_message_bytes, db_bytelen);
469 if (
sizeof(db) > db_bytelen) {
470 memset(((
unsigned char *)db) + db_bytelen, 0,
sizeof(db) - db_bytelen);
474 uint32_t h[message_digest.len];
475 memcpy(h, encoded_message_bytes + db_bytelen,
sizeof(h));
480 HARDENED_TRY(mgf1(message_digest.mode, (
unsigned char *)h,
sizeof(h),
482 if (
sizeof(mask) > db_bytelen) {
483 memset(((
unsigned char *)mask) + db_bytelen, 0,
sizeof(mask) - db_bytelen);
487 for (
size_t i = 0; i <
ARRAYSIZE(db); i++) {
493 unsigned char *db_bytes = (
unsigned char *)db;
499 size_t padding_bytelen = db_bytelen - salt_bytelen;
500 uint32_t exp_padding[
ceil_div(padding_bytelen,
sizeof(uint32_t))];
501 unsigned char *exp_padding_bytes = (
unsigned char *)exp_padding;
502 memset(exp_padding, 0, padding_bytelen - 1);
503 exp_padding_bytes[padding_bytelen - 1] = 0x01;
504 memcpy(exp_padding_bytes + padding_bytelen, db_bytes + padding_bytelen,
505 sizeof(exp_padding) - padding_bytelen);
514 uint32_t salt[message_digest.len];
515 memcpy(salt, db_bytes + db_bytelen - salt_bytelen,
sizeof(salt));
518 uint32_t exp_h[message_digest.len];
519 HARDENED_TRY(pss_construct_h(message_digest, salt,
ARRAYSIZE(salt), exp_h));
524 status_t rsa_padding_oaep_max_message_bytelen(
526 size_t *max_message_bytelen) {
529 size_t digest_wordlen = 0;
530 HARDENED_TRY(digest_wordlen_get(hash_mode, &digest_wordlen));
532 size_t digest_bytelen = digest_wordlen *
sizeof(uint32_t);
533 size_t rsa_bytelen = rsa_wordlen *
sizeof(uint32_t);
534 if (2 * digest_bytelen + 2 > rsa_bytelen) {
536 return OTCRYPTO_BAD_ARGS;
539 *max_message_bytelen = rsa_bytelen - 2 * digest_bytelen - 2;
544 const uint8_t *message,
size_t message_bytelen,
545 const uint8_t *label,
size_t label_bytelen,
546 size_t encoded_message_len,
547 uint32_t *encoded_message) {
549 size_t max_message_bytelen = 0;
550 HARDENED_TRY(rsa_padding_oaep_max_message_bytelen(
551 hash_mode, encoded_message_len, &max_message_bytelen));
552 if (message_bytelen > max_message_bytelen) {
553 return OTCRYPTO_BAD_ARGS;
558 size_t digest_wordlen = 0;
559 HARDENED_TRY(digest_wordlen_get(hash_mode, &digest_wordlen));
564 .len = label_bytelen,
566 uint32_t lhash_data[digest_wordlen];
575 uint32_t seed[digest_wordlen];
576 HARDENED_TRY(entropy_complex_check());
577 HARDENED_TRY(entropy_csrng_instantiate(
579 HARDENED_TRY(entropy_csrng_generate(&kEntropyEmptySeed, seed,
ARRAYSIZE(seed),
581 HARDENED_TRY(entropy_csrng_uninstantiate());
584 size_t digest_bytelen = digest_wordlen *
sizeof(uint32_t);
585 size_t encoded_message_bytelen = encoded_message_len *
sizeof(uint32_t);
586 size_t db_bytelen = encoded_message_bytelen - digest_bytelen - 1;
587 size_t db_wordlen =
ceil_div(db_bytelen,
sizeof(uint32_t));
588 uint32_t db[db_wordlen];
590 mgf1(hash_mode, (
unsigned char *)seed,
sizeof(seed), db_bytelen, db));
596 for (
size_t i = 0; i <
ARRAYSIZE(lhash_data); i++) {
597 db[i] ^= lhash_data[i];
599 size_t message_start_idx = db_bytelen - message_bytelen;
600 unsigned char *db_bytes = (
unsigned char *)db;
601 db_bytes[message_start_idx - 1] ^= 0x01;
602 for (
size_t i = 0; i < message_bytelen; i++) {
603 db_bytes[message_start_idx + i] ^= message[i];
607 uint32_t seed_mask[digest_wordlen];
608 HARDENED_TRY(mgf1(hash_mode, (
unsigned char *)db, db_bytelen, digest_bytelen,
612 for (
size_t i = 0; i <
ARRAYSIZE(seed); i++) {
613 seed[i] ^= seed_mask[i];
617 unsigned char *encoded_message_bytes = (
unsigned char *)encoded_message;
618 encoded_message_bytes[0] = 0x00;
619 memcpy(encoded_message_bytes + 1, seed,
sizeof(seed));
620 memcpy(encoded_message_bytes + 1 +
sizeof(seed), db,
sizeof(db));
623 reverse_bytes(encoded_message_len, encoded_message);
628 const uint8_t *label,
size_t label_bytelen,
629 uint32_t *encoded_message,
630 size_t encoded_message_len, uint8_t *message,
631 size_t *message_bytelen) {
633 reverse_bytes(encoded_message_len, encoded_message);
634 *message_bytelen = 0;
638 size_t digest_wordlen = 0;
639 HARDENED_TRY(digest_wordlen_get(hash_mode, &digest_wordlen));
643 uint32_t seed[digest_wordlen];
644 unsigned char *encoded_message_bytes = (
unsigned char *)encoded_message;
645 memcpy(seed, encoded_message_bytes + 1,
sizeof(seed));
649 size_t digest_bytelen = digest_wordlen *
sizeof(uint32_t);
650 size_t encoded_message_bytelen = encoded_message_len *
sizeof(uint32_t);
651 size_t db_bytelen = encoded_message_bytelen - digest_bytelen - 1;
652 size_t db_wordlen =
ceil_div(db_bytelen,
sizeof(uint32_t));
653 uint32_t db[db_wordlen];
654 memcpy(db, encoded_message_bytes + 1 +
sizeof(seed), db_bytelen);
657 uint32_t seed_mask[digest_wordlen];
658 HARDENED_TRY(mgf1(hash_mode, (
unsigned char *)db, db_bytelen, digest_bytelen,
662 for (
size_t i = 0; i <
ARRAYSIZE(seed); i++) {
663 seed[i] ^= seed_mask[i];
667 uint32_t db_mask[db_wordlen];
668 HARDENED_TRY(mgf1(hash_mode, (
unsigned char *)seed,
sizeof(seed), db_bytelen,
672 size_t num_trailing_bytes =
sizeof(db) - db_bytelen;
673 if (num_trailing_bytes > 0) {
674 memset(((
unsigned char *)db) + db_bytelen, 0, num_trailing_bytes);
675 memset(((
unsigned char *)db_mask) + db_bytelen, 0, num_trailing_bytes);
679 for (
size_t i = 0; i <
ARRAYSIZE(db); i++) {
686 .len = label_bytelen,
688 uint32_t lhash_data[digest_wordlen];
691 .len = digest_wordlen,
705 unsigned char *db_bytes = (
unsigned char *)db;
706 uint32_t message_start_idx = 0;
708 for (
size_t i = digest_bytelen; i < db_bytelen; i++) {
710 memcpy(&
byte, db_bytes + i, 1);
713 ct_bool32_t is_message_start = is_one & is_before_message;
714 message_start_idx =
ct_cmov32(is_message_start, i + 1, message_start_idx);
716 ct_bool32_t padding_failure = is_before_message & (~is_zero) & (~is_one);
717 decode_failure |= padding_failure;
719 HARDENED_CHECK_LE(message_start_idx, db_bytelen);
724 decode_failure |= message_start_not_found;
731 decode_failure |= lhash_mismatch;
734 uint32_t leading_byte = 0;
735 memcpy(&leading_byte, encoded_message_bytes, 1);
737 decode_failure |= leading_byte_nonzero;
741 if (launder32(decode_failure) != 0) {
742 return OTCRYPTO_BAD_ARGS;
750 *message_bytelen = db_bytelen - message_start_idx;
751 if (*message_bytelen > 0) {
752 memcpy(message, db_bytes + message_start_idx, *message_bytelen);