Software APIs
wots.c
1 // Copyright lowRISC contributors (OpenTitan project).
2 // Licensed under the Apache License, Version 2.0, see LICENSE for details.
3 // SPDX-License-Identifier: Apache-2.0
4 //
5 // Derived from code in the SPHINCS+ reference implementation (CC0 license):
6 // https://github.com/sphincs/sphincsplus/blob/ed15dd78658f63288c7492c00260d86154b84637/ref/wots.h
7 
8 #include "sw/device/silicon_creator/lib/sigverify/sphincsplus/wots.h"
9 
11 #include "sw/device/silicon_creator/lib/drivers/hmac.h"
12 #include "sw/device/silicon_creator/lib/error.h"
13 #include "sw/device/silicon_creator/lib/sigverify/sphincsplus/address.h"
14 #include "sw/device/silicon_creator/lib/sigverify/sphincsplus/params.h"
15 #include "sw/device/silicon_creator/lib/sigverify/sphincsplus/sha2.h"
16 #include "sw/device/silicon_creator/lib/sigverify/sphincsplus/thash.h"
17 #include "sw/device/silicon_creator/lib/sigverify/sphincsplus/utils.h"
18 
19 // Throughout this file, we need to assume that integers in base-w will fit
20 // into a single byte.
21 static_assert(sizeof(uint8_t) <= kSpxWotsLogW,
22  "Base-w integers must fit in a `uint8_t`.");
23 /**
24  * Computes the chaining function.
25  *
26  * Interprets `in` as the value of the chain at index `start`. `addr` must
27  * contain the address of the chain.
28  *
29  * The chain `hash` value that is incremented at each step is stored in a
30  * single byte, so the caller must ensure that `start + steps <= UINT8_MAX`.
31  *
32  * @param in Input buffer (`kSpxN` bytes).
33  * @param start Start index.
34  * @param steps Number of steps.
35  * @param addr Hypertree address.
36  * @param[out] Output buffer (`kSpxNWords` words).
37  */
38 static void gen_chain(const uint32_t *in, uint8_t start, const spx_ctx_t *ctx,
39  spx_addr_t *addr, uint32_t *out) {
40  // Initialize out with the value at position `start`.
41  memcpy(out, in, kSpxN);
42 
43  // Iterate `kSpxWotsW - 1` calls to the hash function. This loop is
44  // performance-critical.
45  spx_addr_hash_set(addr, start);
46  for (uint8_t i = start; i + 1 < kSpxWotsW; i++) {
47  // This loop body is essentially just `thash`, inlined for performance.
48  hmac_sha256_restore(&ctx->state_seeded);
49  hmac_sha256_update((unsigned char *)addr->addr, kSpxSha256AddrBytes);
50  hmac_sha256_update_words(out, kSpxNWords);
51  hmac_sha256_process();
52  // Update the address while HMAC is processing for performance reasons.
53  spx_addr_hash_set(addr, i + 1);
54  hmac_sha256_final_truncated(out, kSpxNWords);
55  }
56 }
57 
58 /**
59  * Interprets an array of bytes as integers in base w.
60  *
61  * The NIST submission describes this operation in detail (section 2.5):
62  * https://sphincs.org/data/sphincs+-r3.1-specification.pdf
63  *
64  * The caller is responsible for ensuring that `input` has at least
65  * `kSpxWotsLogW * out_len` bits available.
66  *
67  * This implementation assumes log2(w) is a divisor of 8 (1, 2, 4, or 8).
68  *
69  * @param input Input buffer.
70  * @param out_len Length of output buffer.
71  * @param[out] output Resulting array of integers.
72  */
73 static_assert(8 % kSpxWotsLogW == 0, "log2(w) must be a divisor of 8.");
74 static void base_w(const uint8_t *input, const size_t out_len,
75  uint8_t *output) {
76  size_t bits = 0;
77  size_t in_idx = 0;
78  uint8_t total;
79  for (size_t out_idx = 0; out_idx < out_len; out_idx++) {
80  if (bits == 0) {
81  total = input[in_idx];
82  in_idx++;
83  bits += 8;
84  }
85  bits -= kSpxWotsLogW;
86  output[out_idx] = (total >> bits) & (kSpxWotsW - 1);
87  }
88 }
89 
90 /**
91  * Computes the WOTS+ checksum over a message (in base-w).
92  *
93  * The length of the checksum is `kSpxWotsLen2` integers in base-w; the caller
94  * must ensure that `csum_base_w` has at least this length.
95  *
96  * This implementation uses a 32-bit integer to store the checksum, which
97  * assumes that the maximum checksum value (len1 * (w - 1)) fits in that range.
98  *
99  * See section 3.1 of the NIST submission for explanation about the WOTS
100  * parameters here (e.g. `kSpxWotsLen2`):
101  * https://sphincs.org/data/sphincs+-r3.1-specification.pdf
102  *
103  * @param msg_base_w Message in base-w.
104  * @param[out] csum_base_w Resulting checksum in base-w.
105  */
106 static_assert(kSpxWotsLen1 * (kSpxWotsW - 1) <= UINT32_MAX,
107  "WOTS checksum may not fit in a 32-bit integer.");
108 static void wots_checksum(const uint8_t *msg_base_w, uint8_t *csum_base_w) {
109  // Compute checksum.
110  uint32_t csum = 0;
111  for (size_t i = 0; i < kSpxWotsLen1; i++) {
112  csum += kSpxWotsW - 1 - msg_base_w[i];
113  }
114 
115  // Make sure any expected empty zero bits are the least significant bits by
116  // shifting csum left.
117  size_t csum_nbits = kSpxWotsLen2 * kSpxWotsLogW;
118  csum <<= ((32 - (csum_nbits % 32)) % 32);
119 
120  // Convert checksum to big-endian bytes and then to base-w.
121  csum = __builtin_bswap32(csum);
122  base_w((unsigned char *)&csum, kSpxWotsLen2, csum_base_w);
123 }
124 
125 /**
126  * Derive the matching chain lengths from a message.
127  *
128  * The `lengths` buffer should be at least `kSpxWotsLen` words long.
129  *
130  * @param msg Input message.
131  * @param[out] lengths Resulting chain lengths.
132  */
133 static void chain_lengths(const uint32_t *msg, uint8_t *lengths) {
134  base_w((unsigned char *)msg, kSpxWotsLen1, lengths);
135  wots_checksum(lengths, &lengths[kSpxWotsLen1]);
136 }
137 
138 static_assert(kSpxWotsLen - 1 <= UINT8_MAX,
139  "Maximum chain value must fit into a `uint8_t`");
140 void wots_pk_from_sig(const uint32_t *sig, const uint32_t *msg,
141  const spx_ctx_t *ctx, spx_addr_t *addr, uint32_t *pk) {
142  uint8_t lengths[kSpxWotsLen];
143  chain_lengths(msg, lengths);
144 
145  for (uint8_t i = 0; i < kSpxWotsLen; i++) {
146  spx_addr_chain_set(addr, i);
147  size_t word_offset = i * kSpxNWords;
148  gen_chain(sig + word_offset, lengths[i], ctx, addr, pk + word_offset);
149  }
150 }