Shaka Packager SDK
rsa_key.cc
1 // Copyright 2014 Google LLC. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file or at
5 // https://developers.google.com/open-source/licenses/bsd
6 //
7 // RSA signature details:
8 // Algorithm: RSASSA-PSS
9 // Hash algorithm: SHA1
10 // Mask generation function: mgf1SHA1
11 // Salt length: 20 bytes
12 // Trailer field: 0xbc
13 //
14 // RSA encryption details:
15 // Algorithm: RSA-OAEP
16 // Mask generation function: mgf1SHA1
17 // Label (encoding paramter): empty std::string
18 
19 #include <packager/media/base/rsa_key.h>
20 
21 #include <memory>
22 #include <vector>
23 
24 #include <absl/log/check.h>
25 #include <absl/log/log.h>
26 #include <mbedtls/error.h>
27 #include <mbedtls/md.h>
28 
29 namespace {
30 
31 const size_t kPssSaltLength = 20u;
32 
33 std::string mbedtls_strerr(int rv) {
34  // There is always a "high level" error.
35  std::string output(mbedtls_high_level_strerr(rv));
36 
37  // Some errors have a "low level" error, which is like an inner error code
38  // with a deeper explanation. But on mac and Windows, ostream crashes if you
39  // give it NULL. So we combine them ourselves with a NULL check.
40  const char* low_level_error = mbedtls_low_level_strerr(rv);
41  if (low_level_error) {
42  output += ": ";
43  output += low_level_error;
44  }
45 
46  return output;
47 }
48 
49 std::string sha1(const std::string& message) {
50  const mbedtls_md_info_t* md_info = mbedtls_md_info_from_type(MBEDTLS_MD_SHA1);
51  DCHECK(md_info);
52 
53  std::string hash(mbedtls_md_get_size(md_info), 0);
54  CHECK_EQ(0,
55  mbedtls_md(md_info, reinterpret_cast<const uint8_t*>(message.data()),
56  message.size(), reinterpret_cast<uint8_t*>(hash.data())));
57 
58  return hash;
59 }
60 
61 } // namespace
62 
63 namespace shaka {
64 namespace media {
65 
66 RsaPrivateKey::RsaPrivateKey() {
67  mbedtls_pk_init(&pk_context_);
68  mbedtls_entropy_init(&entropy_context_);
69  mbedtls_ctr_drbg_init(&prng_context_);
70 }
71 
72 RsaPrivateKey::~RsaPrivateKey() {
73  mbedtls_pk_free(&pk_context_);
74  mbedtls_entropy_free(&entropy_context_);
75  mbedtls_ctr_drbg_free(&prng_context_);
76 }
77 
78 RsaPrivateKey* RsaPrivateKey::Create(const std::string& serialized_key) {
79  std::unique_ptr<RsaPrivateKey> key(new RsaPrivateKey());
80  if (!key->Deserialize(serialized_key)) {
81  return NULL;
82  }
83  return key.release();
84 }
85 
86 bool RsaPrivateKey::Deserialize(const std::string& serialized_key) {
87  const mbedtls_pk_info_t* pk_info = mbedtls_pk_info_from_type(MBEDTLS_PK_RSA);
88  DCHECK(pk_info);
89 
90  CHECK_EQ(mbedtls_ctr_drbg_seed(&prng_context_, mbedtls_entropy_func,
91  &entropy_context_, /* custom= */ NULL,
92  /* custom_len= */ 0),
93  0);
94 
95  int rv = mbedtls_pk_parse_key(
96  &pk_context_, reinterpret_cast<const uint8_t*>(serialized_key.data()),
97  serialized_key.size(),
98  /* password= */ NULL,
99  /* password_len= */ 0, mbedtls_ctr_drbg_random, &prng_context_);
100  if (rv != 0) {
101  LOG(ERROR) << "RSA private key failed to load: " << mbedtls_strerr(rv);
102  return false;
103  }
104 
105  // Set the padding mode and digest mode.
106  mbedtls_rsa_context* rsa_context = mbedtls_pk_rsa(pk_context_);
107  rv = mbedtls_rsa_set_padding(rsa_context, MBEDTLS_RSA_PKCS_V21,
108  MBEDTLS_MD_SHA1);
109  if (rv != 0) {
110  LOG(ERROR) << "RSA private key failed to set padding: "
111  << mbedtls_strerr(rv);
112  return false;
113  }
114 
115  return true;
116 }
117 
118 bool RsaPrivateKey::Decrypt(const std::string& encrypted_message,
119  std::string* decrypted_message) {
120  DCHECK(decrypted_message);
121 
122  mbedtls_rsa_context* rsa_context = mbedtls_pk_rsa(pk_context_);
123 
124  size_t rsa_size = mbedtls_rsa_get_len(rsa_context);
125  if (encrypted_message.size() != rsa_size) {
126  LOG(ERROR) << "Encrypted RSA message has the wrong size (expected "
127  << rsa_size << ", actual " << encrypted_message.size() << ").";
128  return false;
129  }
130  decrypted_message->resize(encrypted_message.size());
131 
132  size_t decrypted_size = 0;
133  int rv = mbedtls_rsa_rsaes_oaep_decrypt(
134  rsa_context, mbedtls_ctr_drbg_random, &prng_context_,
135  /* label= */ NULL,
136  /* label_len= */ 0, &decrypted_size,
137  reinterpret_cast<const uint8_t*>(encrypted_message.data()),
138  reinterpret_cast<uint8_t*>(decrypted_message->data()),
139  decrypted_message->size());
140 
141  if (rv != 0) {
142  LOG(ERROR) << "RSA private decrypt failure: " << mbedtls_strerr(rv);
143  return false;
144  }
145  decrypted_message->resize(decrypted_size);
146  return true;
147 }
148 
149 bool RsaPrivateKey::GenerateSignature(const std::string& message,
150  std::string* signature) {
151  DCHECK(signature);
152  if (message.empty()) {
153  LOG(ERROR) << "Message to be signed is empty.";
154  return false;
155  }
156 
157  mbedtls_rsa_context* rsa_context = mbedtls_pk_rsa(pk_context_);
158 
159  size_t rsa_size = mbedtls_rsa_get_len(rsa_context);
160  signature->resize(rsa_size);
161 
162  std::string hash = sha1(message);
163  int rv = mbedtls_rsa_rsassa_pss_sign_ext(
164  rsa_context, mbedtls_ctr_drbg_random, &prng_context_, MBEDTLS_MD_SHA1,
165  static_cast<unsigned int>(hash.size()),
166  reinterpret_cast<const uint8_t*>(hash.data()), kPssSaltLength,
167  reinterpret_cast<uint8_t*>(signature->data()));
168 
169  if (rv != 0) {
170  LOG(ERROR) << "RSA sign failure: " << mbedtls_strerr(rv);
171  return false;
172  }
173  return true;
174 }
175 
176 RsaPublicKey::RsaPublicKey() {
177  mbedtls_pk_init(&pk_context_);
178  mbedtls_entropy_init(&entropy_context_);
179  mbedtls_ctr_drbg_init(&prng_context_);
180 }
181 
182 RsaPublicKey::~RsaPublicKey() {
183  mbedtls_pk_free(&pk_context_);
184  mbedtls_entropy_free(&entropy_context_);
185  mbedtls_ctr_drbg_free(&prng_context_);
186 }
187 
188 RsaPublicKey* RsaPublicKey::Create(const std::string& serialized_key) {
189  std::unique_ptr<RsaPublicKey> key(new RsaPublicKey());
190  if (!key->Deserialize(serialized_key)) {
191  return NULL;
192  }
193  return key.release();
194 }
195 
196 bool RsaPublicKey::Deserialize(const std::string& serialized_key) {
197  const mbedtls_pk_info_t* pk_info = mbedtls_pk_info_from_type(MBEDTLS_PK_RSA);
198  DCHECK(pk_info);
199 
200  CHECK_EQ(mbedtls_ctr_drbg_seed(&prng_context_, mbedtls_entropy_func,
201  &entropy_context_, /* custom= */ NULL,
202  /* custom_len= */ 0),
203  0);
204 
205  int rv = mbedtls_pk_parse_public_key(
206  &pk_context_, reinterpret_cast<const uint8_t*>(serialized_key.data()),
207  serialized_key.size());
208  if (rv != 0) {
209  LOG(ERROR) << "RSA public key failed to load: " << mbedtls_strerr(rv);
210  return false;
211  }
212 
213  // Set the padding mode and digest mode.
214  mbedtls_rsa_context* rsa_context = mbedtls_pk_rsa(pk_context_);
215  rv = mbedtls_rsa_set_padding(rsa_context, MBEDTLS_RSA_PKCS_V21,
216  MBEDTLS_MD_SHA1);
217  if (rv != 0) {
218  LOG(ERROR) << "RSA public key failed to set padding: "
219  << mbedtls_strerr(rv);
220  return false;
221  }
222 
223  return true;
224 }
225 
226 bool RsaPublicKey::Encrypt(const std::string& clear_message,
227  std::string* encrypted_message) {
228  DCHECK(encrypted_message);
229  if (clear_message.empty()) {
230  LOG(ERROR) << "Message to be encrypted is empty.";
231  return false;
232  }
233 
234  mbedtls_rsa_context* rsa_context = mbedtls_pk_rsa(pk_context_);
235 
236  size_t rsa_size = mbedtls_rsa_get_len(rsa_context);
237  encrypted_message->resize(rsa_size);
238 
239  int rv = mbedtls_rsa_rsaes_oaep_encrypt(
240  rsa_context, mbedtls_ctr_drbg_random, &prng_context_,
241  /* label= */ NULL,
242  /* label_len= */ 0, clear_message.size(),
243  reinterpret_cast<const uint8_t*>(clear_message.data()),
244  reinterpret_cast<uint8_t*>(encrypted_message->data()));
245 
246  if (rv != 0) {
247  LOG(ERROR) << "RSA public encrypt failure: " << mbedtls_strerr(rv);
248  return false;
249  }
250  return true;
251 }
252 
253 bool RsaPublicKey::VerifySignature(const std::string& message,
254  const std::string& signature) {
255  if (message.empty()) {
256  LOG(ERROR) << "Signed message is empty.";
257  return false;
258  }
259 
260  mbedtls_rsa_context* rsa_context = mbedtls_pk_rsa(pk_context_);
261 
262  size_t rsa_size = mbedtls_rsa_get_len(rsa_context);
263  if (signature.size() != rsa_size) {
264  LOG(ERROR) << "Message signature is of the wrong size (expected "
265  << rsa_size << ", actual " << signature.size() << ").";
266  return false;
267  }
268 
269  // Verify the signature.
270  std::string hash = sha1(message);
271  int rv = mbedtls_rsa_rsassa_pss_verify_ext(
272  rsa_context, MBEDTLS_MD_SHA1, static_cast<unsigned int>(hash.size()),
273  reinterpret_cast<const uint8_t*>(hash.data()), MBEDTLS_MD_SHA1,
274  kPssSaltLength, reinterpret_cast<const uint8_t*>(signature.data()));
275 
276  if (rv != 0) {
277  LOG(ERROR) << "RSA signature verification failed: " << mbedtls_strerr(rv);
278  return false;
279  }
280  return true;
281 }
282 
283 } // namespace media
284 } // namespace shaka
Rsa private key, used for message signing and decryption.
Definition: rsa_key.h:25
bool Decrypt(const std::string &encrypted_message, std::string *decrypted_message)
Definition: rsa_key.cc:118
bool GenerateSignature(const std::string &message, std::string *signature)
Definition: rsa_key.cc:149
static RsaPrivateKey * Create(const std::string &serialized_key)
Definition: rsa_key.cc:78
Rsa public key, used for signature verification and encryption.
Definition: rsa_key.h:57
bool VerifySignature(const std::string &message, const std::string &signature)
Definition: rsa_key.cc:253
static RsaPublicKey * Create(const std::string &serialized_key)
Definition: rsa_key.cc:188
bool Encrypt(const std::string &clear_message, std::string *encrypted_message)
Definition: rsa_key.cc:226
All the methods that are virtual are virtual for mocking.
Definition: crypto_flags.cc:66