Shaka Packager SDK
Loading...
Searching...
No Matches
rsa_key.cc
1// Copyright 2014 Google LLC. All rights reserved.
2//
3// Use of this source code is governed by a BSD-style
4// license that can be found in the LICENSE file or at
5// https://developers.google.com/open-source/licenses/bsd
6//
7// RSA signature details:
8// Algorithm: RSASSA-PSS
9// Hash algorithm: SHA1
10// Mask generation function: mgf1SHA1
11// Salt length: 20 bytes
12// Trailer field: 0xbc
13//
14// RSA encryption details:
15// Algorithm: RSA-OAEP
16// Mask generation function: mgf1SHA1
17// Label (encoding paramter): empty std::string
18
19#include <packager/media/base/rsa_key.h>
20
21#include <memory>
22#include <vector>
23
24#include <absl/log/check.h>
25#include <absl/log/log.h>
26#include <mbedtls/error.h>
27#include <mbedtls/md.h>
28
29namespace {
30
31const size_t kPssSaltLength = 20u;
32
33std::string mbedtls_strerr(int rv) {
34 // There is always a "high level" error.
35 std::string output(mbedtls_high_level_strerr(rv));
36
37 // Some errors have a "low level" error, which is like an inner error code
38 // with a deeper explanation. But on mac and Windows, ostream crashes if you
39 // give it NULL. So we combine them ourselves with a NULL check.
40 const char* low_level_error = mbedtls_low_level_strerr(rv);
41 if (low_level_error) {
42 output += ": ";
43 output += low_level_error;
44 }
45
46 return output;
47}
48
49std::string sha1(const std::string& message) {
50 const mbedtls_md_info_t* md_info = mbedtls_md_info_from_type(MBEDTLS_MD_SHA1);
51 DCHECK(md_info);
52
53 std::string hash(mbedtls_md_get_size(md_info), 0);
54 CHECK_EQ(0,
55 mbedtls_md(md_info, reinterpret_cast<const uint8_t*>(message.data()),
56 message.size(), reinterpret_cast<uint8_t*>(hash.data())));
57
58 return hash;
59}
60
61} // namespace
62
63namespace shaka {
64namespace media {
65
66RsaPrivateKey::RsaPrivateKey() {
67 mbedtls_pk_init(&pk_context_);
68 mbedtls_entropy_init(&entropy_context_);
69 mbedtls_ctr_drbg_init(&prng_context_);
70}
71
72RsaPrivateKey::~RsaPrivateKey() {
73 mbedtls_pk_free(&pk_context_);
74 mbedtls_entropy_free(&entropy_context_);
75 mbedtls_ctr_drbg_free(&prng_context_);
76}
77
78RsaPrivateKey* RsaPrivateKey::Create(const std::string& serialized_key) {
79 std::unique_ptr<RsaPrivateKey> key(new RsaPrivateKey());
80 if (!key->Deserialize(serialized_key)) {
81 return NULL;
82 }
83 return key.release();
84}
85
86bool RsaPrivateKey::Deserialize(const std::string& serialized_key) {
87 const mbedtls_pk_info_t* pk_info = mbedtls_pk_info_from_type(MBEDTLS_PK_RSA);
88 DCHECK(pk_info);
89
90 CHECK_EQ(mbedtls_ctr_drbg_seed(&prng_context_, mbedtls_entropy_func,
91 &entropy_context_, /* custom= */ NULL,
92 /* custom_len= */ 0),
93 0);
94
95 int rv = mbedtls_pk_parse_key(
96 &pk_context_, reinterpret_cast<const uint8_t*>(serialized_key.data()),
97 serialized_key.size(),
98 /* password= */ NULL,
99 /* password_len= */ 0, mbedtls_ctr_drbg_random, &prng_context_);
100 if (rv != 0) {
101 LOG(ERROR) << "RSA private key failed to load: " << mbedtls_strerr(rv);
102 return false;
103 }
104
105 // Set the padding mode and digest mode.
106 mbedtls_rsa_context* rsa_context = mbedtls_pk_rsa(pk_context_);
107 rv = mbedtls_rsa_set_padding(rsa_context, MBEDTLS_RSA_PKCS_V21,
108 MBEDTLS_MD_SHA1);
109 if (rv != 0) {
110 LOG(ERROR) << "RSA private key failed to set padding: "
111 << mbedtls_strerr(rv);
112 return false;
113 }
114
115 return true;
116}
117
118bool RsaPrivateKey::Decrypt(const std::string& encrypted_message,
119 std::string* decrypted_message) {
120 DCHECK(decrypted_message);
121
122 mbedtls_rsa_context* rsa_context = mbedtls_pk_rsa(pk_context_);
123
124 size_t rsa_size = mbedtls_rsa_get_len(rsa_context);
125 if (encrypted_message.size() != rsa_size) {
126 LOG(ERROR) << "Encrypted RSA message has the wrong size (expected "
127 << rsa_size << ", actual " << encrypted_message.size() << ").";
128 return false;
129 }
130 decrypted_message->resize(encrypted_message.size());
131
132 size_t decrypted_size = 0;
133 int rv = mbedtls_rsa_rsaes_oaep_decrypt(
134 rsa_context, mbedtls_ctr_drbg_random, &prng_context_,
135 /* label= */ NULL,
136 /* label_len= */ 0, &decrypted_size,
137 reinterpret_cast<const uint8_t*>(encrypted_message.data()),
138 reinterpret_cast<uint8_t*>(decrypted_message->data()),
139 decrypted_message->size());
140
141 if (rv != 0) {
142 LOG(ERROR) << "RSA private decrypt failure: " << mbedtls_strerr(rv);
143 return false;
144 }
145 decrypted_message->resize(decrypted_size);
146 return true;
147}
148
149bool RsaPrivateKey::GenerateSignature(const std::string& message,
150 std::string* signature) {
151 DCHECK(signature);
152 if (message.empty()) {
153 LOG(ERROR) << "Message to be signed is empty.";
154 return false;
155 }
156
157 mbedtls_rsa_context* rsa_context = mbedtls_pk_rsa(pk_context_);
158
159 size_t rsa_size = mbedtls_rsa_get_len(rsa_context);
160 signature->resize(rsa_size);
161
162 std::string hash = sha1(message);
163 int rv = mbedtls_rsa_rsassa_pss_sign_ext(
164 rsa_context, mbedtls_ctr_drbg_random, &prng_context_, MBEDTLS_MD_SHA1,
165 static_cast<unsigned int>(hash.size()),
166 reinterpret_cast<const uint8_t*>(hash.data()), kPssSaltLength,
167 reinterpret_cast<uint8_t*>(signature->data()));
168
169 if (rv != 0) {
170 LOG(ERROR) << "RSA sign failure: " << mbedtls_strerr(rv);
171 return false;
172 }
173 return true;
174}
175
176RsaPublicKey::RsaPublicKey() {
177 mbedtls_pk_init(&pk_context_);
178 mbedtls_entropy_init(&entropy_context_);
179 mbedtls_ctr_drbg_init(&prng_context_);
180}
181
182RsaPublicKey::~RsaPublicKey() {
183 mbedtls_pk_free(&pk_context_);
184 mbedtls_entropy_free(&entropy_context_);
185 mbedtls_ctr_drbg_free(&prng_context_);
186}
187
188RsaPublicKey* RsaPublicKey::Create(const std::string& serialized_key) {
189 std::unique_ptr<RsaPublicKey> key(new RsaPublicKey());
190 if (!key->Deserialize(serialized_key)) {
191 return NULL;
192 }
193 return key.release();
194}
195
196bool RsaPublicKey::Deserialize(const std::string& serialized_key) {
197 const mbedtls_pk_info_t* pk_info = mbedtls_pk_info_from_type(MBEDTLS_PK_RSA);
198 DCHECK(pk_info);
199
200 CHECK_EQ(mbedtls_ctr_drbg_seed(&prng_context_, mbedtls_entropy_func,
201 &entropy_context_, /* custom= */ NULL,
202 /* custom_len= */ 0),
203 0);
204
205 int rv = mbedtls_pk_parse_public_key(
206 &pk_context_, reinterpret_cast<const uint8_t*>(serialized_key.data()),
207 serialized_key.size());
208 if (rv != 0) {
209 LOG(ERROR) << "RSA public key failed to load: " << mbedtls_strerr(rv);
210 return false;
211 }
212
213 // Set the padding mode and digest mode.
214 mbedtls_rsa_context* rsa_context = mbedtls_pk_rsa(pk_context_);
215 rv = mbedtls_rsa_set_padding(rsa_context, MBEDTLS_RSA_PKCS_V21,
216 MBEDTLS_MD_SHA1);
217 if (rv != 0) {
218 LOG(ERROR) << "RSA public key failed to set padding: "
219 << mbedtls_strerr(rv);
220 return false;
221 }
222
223 return true;
224}
225
226bool RsaPublicKey::Encrypt(const std::string& clear_message,
227 std::string* encrypted_message) {
228 DCHECK(encrypted_message);
229 if (clear_message.empty()) {
230 LOG(ERROR) << "Message to be encrypted is empty.";
231 return false;
232 }
233
234 mbedtls_rsa_context* rsa_context = mbedtls_pk_rsa(pk_context_);
235
236 size_t rsa_size = mbedtls_rsa_get_len(rsa_context);
237 encrypted_message->resize(rsa_size);
238
239 int rv = mbedtls_rsa_rsaes_oaep_encrypt(
240 rsa_context, mbedtls_ctr_drbg_random, &prng_context_,
241 /* label= */ NULL,
242 /* label_len= */ 0, clear_message.size(),
243 reinterpret_cast<const uint8_t*>(clear_message.data()),
244 reinterpret_cast<uint8_t*>(encrypted_message->data()));
245
246 if (rv != 0) {
247 LOG(ERROR) << "RSA public encrypt failure: " << mbedtls_strerr(rv);
248 return false;
249 }
250 return true;
251}
252
253bool RsaPublicKey::VerifySignature(const std::string& message,
254 const std::string& signature) {
255 if (message.empty()) {
256 LOG(ERROR) << "Signed message is empty.";
257 return false;
258 }
259
260 mbedtls_rsa_context* rsa_context = mbedtls_pk_rsa(pk_context_);
261
262 size_t rsa_size = mbedtls_rsa_get_len(rsa_context);
263 if (signature.size() != rsa_size) {
264 LOG(ERROR) << "Message signature is of the wrong size (expected "
265 << rsa_size << ", actual " << signature.size() << ").";
266 return false;
267 }
268
269 // Verify the signature.
270 std::string hash = sha1(message);
271 int rv = mbedtls_rsa_rsassa_pss_verify_ext(
272 rsa_context, MBEDTLS_MD_SHA1, static_cast<unsigned int>(hash.size()),
273 reinterpret_cast<const uint8_t*>(hash.data()), MBEDTLS_MD_SHA1,
274 kPssSaltLength, reinterpret_cast<const uint8_t*>(signature.data()));
275
276 if (rv != 0) {
277 LOG(ERROR) << "RSA signature verification failed: " << mbedtls_strerr(rv);
278 return false;
279 }
280 return true;
281}
282
283} // namespace media
284} // namespace shaka
Rsa private key, used for message signing and decryption.
Definition rsa_key.h:25
bool Decrypt(const std::string &encrypted_message, std::string *decrypted_message)
Definition rsa_key.cc:118
bool GenerateSignature(const std::string &message, std::string *signature)
Definition rsa_key.cc:149
static RsaPrivateKey * Create(const std::string &serialized_key)
Definition rsa_key.cc:78
Rsa public key, used for signature verification and encryption.
Definition rsa_key.h:57
bool VerifySignature(const std::string &message, const std::string &signature)
Definition rsa_key.cc:253
static RsaPublicKey * Create(const std::string &serialized_key)
Definition rsa_key.cc:188
bool Encrypt(const std::string &clear_message, std::string *encrypted_message)
Definition rsa_key.cc:226
All the methods that are virtual are virtual for mocking.