Shaka Packager SDK
aes_decryptor.cc
1 // Copyright 2016 Google LLC. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file or at
5 // https://developers.google.com/open-source/licenses/bsd
6 
7 #include <packager/media/base/aes_decryptor.h>
8 
9 #include <algorithm>
10 
11 #include <absl/log/check.h>
12 #include <absl/log/log.h>
13 
14 #include <packager/macros/crypto.h>
15 
16 namespace shaka {
17 namespace media {
18 
19 AesCbcDecryptor::AesCbcDecryptor(CbcPaddingScheme padding_scheme)
20  : AesCbcDecryptor(padding_scheme, kDontUseConstantIv) {}
21 
22 AesCbcDecryptor::AesCbcDecryptor(CbcPaddingScheme padding_scheme,
23  ConstantIvFlag constant_iv_flag)
24  : AesCryptor(constant_iv_flag), padding_scheme_(padding_scheme) {
25  if (padding_scheme_ != kNoPadding) {
26  CHECK_EQ(constant_iv_flag, kUseConstantIv)
27  << "non-constant iv (cipher block chain across calls) only makes sense "
28  "if the padding_scheme is kNoPadding.";
29  }
30 }
31 
32 AesCbcDecryptor::~AesCbcDecryptor() {}
33 
34 bool AesCbcDecryptor::InitializeWithIv(const std::vector<uint8_t>& key,
35  const std::vector<uint8_t>& iv) {
36  if (!SetupCipher(key.size(), kCbcMode)) {
37  return false;
38  }
39 
40  if (mbedtls_cipher_setkey(&cipher_ctx_, key.data(),
41  static_cast<int>(8 * key.size()),
42  MBEDTLS_DECRYPT) != 0) {
43  LOG(ERROR) << "Failed to set CBC decryption key";
44  return false;
45  }
46 
47  return SetIv(iv);
48 }
49 
50 size_t AesCbcDecryptor::RequiredOutputSize(size_t plaintext_size) {
51  return plaintext_size;
52 }
53 
54 bool AesCbcDecryptor::CryptInternal(const uint8_t* ciphertext,
55  size_t ciphertext_size,
56  uint8_t* plaintext,
57  size_t* plaintext_size) {
58  DCHECK(plaintext_size);
59  // Plaintext size is the same as ciphertext size except for pkcs5 padding.
60  // Will update later if using pkcs5 padding. For pkcs5 padding, we still
61  // need at least |ciphertext_size| bytes for intermediate operation.
62  if (*plaintext_size < ciphertext_size) {
63  LOG(ERROR) << "Expecting output size of at least " << ciphertext_size
64  << " bytes.";
65  return false;
66  }
67  *plaintext_size = ciphertext_size;
68 
69  // If the ciphertext size is 0, this can be a no-op decrypt, so long as the
70  // padding mode isn't PKCS5.
71  if (ciphertext_size == 0) {
72  if (padding_scheme_ == kPkcs5Padding) {
73  LOG(ERROR) << "Expected ciphertext to be at least " << AES_BLOCK_SIZE
74  << " bytes with Pkcs5 padding.";
75  return false;
76  }
77  return true;
78  }
79  DCHECK(plaintext);
80 
81  const size_t residual_block_size = ciphertext_size % AES_BLOCK_SIZE;
82  const size_t cbc_size = ciphertext_size - residual_block_size;
83  if (residual_block_size == 0) {
84  CbcDecryptBlocks(ciphertext, ciphertext_size, plaintext,
85  internal_iv_.data());
86  if (padding_scheme_ != kPkcs5Padding)
87  return true;
88 
89  // Strip off PKCS5 padding bytes.
90  const uint8_t num_padding_bytes = plaintext[ciphertext_size - 1];
91  if (num_padding_bytes > AES_BLOCK_SIZE) {
92  LOG(ERROR) << "Padding length is too large : "
93  << static_cast<int>(num_padding_bytes);
94  return false;
95  }
96  *plaintext_size -= num_padding_bytes;
97  return true;
98  } else if (padding_scheme_ == kNoPadding) {
99  if (cbc_size > 0) {
100  CbcDecryptBlocks(ciphertext, cbc_size, plaintext, internal_iv_.data());
101  }
102  // The residual block is not encrypted.
103  memcpy(plaintext + cbc_size, ciphertext + cbc_size, residual_block_size);
104  return true;
105  } else if (padding_scheme_ != kCtsPadding) {
106  LOG(ERROR) << "Expecting cipher text size to be multiple of "
107  << AES_BLOCK_SIZE << ", got " << ciphertext_size;
108  return false;
109  }
110 
111  DCHECK_EQ(padding_scheme_, kCtsPadding);
112  if (ciphertext_size < AES_BLOCK_SIZE) {
113  // Don't have a full block, leave unencrypted.
114  memcpy(plaintext, ciphertext, ciphertext_size);
115  return true;
116  }
117 
118  // AES-CBC decrypt everything up to the next-to-last full block.
119  if (cbc_size > AES_BLOCK_SIZE) {
120  CbcDecryptBlocks(ciphertext, cbc_size - AES_BLOCK_SIZE, plaintext,
121  internal_iv_.data());
122  }
123 
124  const uint8_t* next_to_last_ciphertext_block =
125  ciphertext + ciphertext_size - residual_block_size - AES_BLOCK_SIZE;
126  uint8_t* next_to_last_plaintext_block =
127  plaintext + ciphertext_size - residual_block_size - AES_BLOCK_SIZE;
128 
129  // Determine what the last IV should be so that we can "skip ahead" in the
130  // CBC decryption.
131  std::vector<uint8_t> last_iv(
132  ciphertext + ciphertext_size - residual_block_size,
133  ciphertext + ciphertext_size);
134  last_iv.resize(AES_BLOCK_SIZE, 0);
135 
136  // Decrypt the next-to-last block using the IV determined above. This decrypts
137  // the residual block bits.
138  CbcDecryptBlocks(next_to_last_ciphertext_block, AES_BLOCK_SIZE,
139  next_to_last_plaintext_block, last_iv.data());
140 
141  // Swap back the residual block bits and the next-to-last block.
142  if (plaintext == ciphertext) {
143  std::swap_ranges(next_to_last_plaintext_block,
144  next_to_last_plaintext_block + residual_block_size,
145  next_to_last_plaintext_block + AES_BLOCK_SIZE);
146  } else {
147  memcpy(next_to_last_plaintext_block + AES_BLOCK_SIZE,
148  next_to_last_plaintext_block, residual_block_size);
149  memcpy(next_to_last_plaintext_block,
150  next_to_last_ciphertext_block + AES_BLOCK_SIZE, residual_block_size);
151  }
152 
153  // Decrypt the next-to-last full block.
154  CbcDecryptBlocks(next_to_last_plaintext_block, AES_BLOCK_SIZE,
155  next_to_last_plaintext_block, internal_iv_.data());
156  return true;
157 }
158 
159 void AesCbcDecryptor::SetIvInternal() {
160  internal_iv_ = iv();
161  internal_iv_.resize(AES_BLOCK_SIZE, 0);
162 }
163 
164 void AesCbcDecryptor::CbcDecryptBlocks(const uint8_t* ciphertext,
165  size_t ciphertext_size,
166  uint8_t* plaintext,
167  uint8_t* iv) {
168  CHECK_EQ(ciphertext_size % AES_BLOCK_SIZE, 0u);
169  CHECK_GT(ciphertext_size, 0u);
170 
171  // Copy the final block of ciphertext before decryption, since we could be
172  // decrypting in-place.
173  const uint8_t* last_block = ciphertext + ciphertext_size - AES_BLOCK_SIZE;
174  std::vector<uint8_t> next_iv(last_block, last_block + AES_BLOCK_SIZE);
175 
176  size_t output_size = 0;
177  CHECK_EQ(mbedtls_cipher_crypt(&cipher_ctx_, iv, AES_BLOCK_SIZE, ciphertext,
178  ciphertext_size, plaintext, &output_size),
179  0);
180  DCHECK_EQ(output_size % AES_BLOCK_SIZE, 0u);
181 
182  memcpy(iv, next_iv.data(), next_iv.size());
183 }
184 
185 } // namespace media
186 } // namespace shaka
Class which implements AES-CBC (Cipher block chaining) decryption.
Definition: aes_decryptor.h:25
bool InitializeWithIv(const std::vector< uint8_t > &key, const std::vector< uint8_t > &iv) override
AesCbcDecryptor(CbcPaddingScheme padding_scheme)
const std::vector< uint8_t > & iv() const
Definition: aes_cryptor.h:85
bool SetIv(const std::vector< uint8_t > &iv)
Definition: aes_cryptor.cc:70
All the methods that are virtual are virtual for mocking.
Definition: crypto_flags.cc:66