8 #include "sw/device/silicon_creator/lib/sigverify/sphincsplus/wots.h"
11 #include "sw/device/silicon_creator/lib/drivers/hmac.h"
12 #include "sw/device/silicon_creator/lib/error.h"
13 #include "sw/device/silicon_creator/lib/sigverify/sphincsplus/address.h"
14 #include "sw/device/silicon_creator/lib/sigverify/sphincsplus/params.h"
15 #include "sw/device/silicon_creator/lib/sigverify/sphincsplus/sha2.h"
16 #include "sw/device/silicon_creator/lib/sigverify/sphincsplus/thash.h"
17 #include "sw/device/silicon_creator/lib/sigverify/sphincsplus/utils.h"
21 static_assert(
sizeof(uint8_t) <= kSpxWotsLogW,
22 "Base-w integers must fit in a `uint8_t`.");
38 static void gen_chain(
const uint32_t *in, uint8_t start,
const spx_ctx_t *ctx,
45 spx_addr_hash_set(addr, start);
46 for (uint8_t i = start; i + 1 < kSpxWotsW; i++) {
49 hmac_sha256_update((
unsigned char *)addr->addr, kSpxSha256AddrBytes);
50 hmac_sha256_update_words(out, kSpxNWords);
51 hmac_sha256_process();
53 spx_addr_hash_set(addr, i + 1);
54 hmac_sha256_final_truncated(out, kSpxNWords);
73 static_assert(8 % kSpxWotsLogW == 0,
"log2(w) must be a divisor of 8.");
74 static void base_w(
const uint8_t *input,
const size_t out_len,
79 for (
size_t out_idx = 0; out_idx < out_len; out_idx++) {
81 total = input[in_idx];
86 output[out_idx] = (total >> bits) & (kSpxWotsW - 1);
106 static_assert(kSpxWotsLen1 * (kSpxWotsW - 1) <= UINT32_MAX,
107 "WOTS checksum may not fit in a 32-bit integer.");
108 static void wots_checksum(
const uint8_t *msg_base_w, uint8_t *csum_base_w) {
111 for (
size_t i = 0; i < kSpxWotsLen1; i++) {
112 csum += kSpxWotsW - 1 - msg_base_w[i];
117 size_t csum_nbits = kSpxWotsLen2 * kSpxWotsLogW;
118 csum <<= ((32 - (csum_nbits % 32)) % 32);
121 csum = __builtin_bswap32(csum);
122 base_w((
unsigned char *)&csum, kSpxWotsLen2, csum_base_w);
133 static void chain_lengths(
const uint32_t *msg, uint8_t *lengths) {
134 base_w((
unsigned char *)msg, kSpxWotsLen1, lengths);
135 wots_checksum(lengths, &lengths[kSpxWotsLen1]);
138 static_assert(kSpxWotsLen - 1 <= UINT8_MAX,
139 "Maximum chain value must fit into a `uint8_t`");
140 void wots_pk_from_sig(
const uint32_t *sig,
const uint32_t *msg,
142 uint8_t lengths[kSpxWotsLen];
143 chain_lengths(msg, lengths);
145 for (uint8_t i = 0; i < kSpxWotsLen; i++) {
146 spx_addr_chain_set(addr, i);
147 size_t word_offset = i * kSpxNWords;
148 gen_chain(sig + word_offset, lengths[i], ctx, addr, pk + word_offset);