Shaka Packager SDK
encryptor.cc
1 // Copyright 2015 Google LLC. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file or at
5 // https://developers.google.com/open-source/licenses/bsd
6 
7 #include <packager/media/formats/webm/encryptor.h>
8 
9 #include <absl/log/check.h>
10 
11 #include <packager/media/base/buffer_writer.h>
12 #include <packager/media/base/media_sample.h>
13 #include <packager/media/formats/webm/webm_constants.h>
14 
15 namespace shaka {
16 namespace media {
17 namespace webm {
18 namespace {
19 void WriteEncryptedFrameHeader(const DecryptConfig* decrypt_config,
20  BufferWriter* header_buffer) {
21  if (decrypt_config) {
22  const size_t iv_size = decrypt_config->iv().size();
23  DCHECK_EQ(iv_size, kWebMIvSize);
24  if (!decrypt_config->subsamples().empty()) {
25  const auto& subsamples = decrypt_config->subsamples();
26  // Use partitioned subsample encryption: | signal_byte(3) | iv
27  // | num_partitions | partition_offset * n | enc_data |
28  DCHECK_LT(subsamples.size(), kWebMMaxSubsamples);
29  const size_t num_partitions =
30  2 * subsamples.size() - 1 -
31  (subsamples.back().cipher_bytes == 0 ? 1 : 0);
32  const size_t header_size = kWebMSignalByteSize + iv_size +
33  kWebMNumPartitionsSize +
34  (kWebMPartitionOffsetSize * num_partitions);
35 
36  const uint8_t signal_byte = kWebMEncryptedSignal | kWebMPartitionedSignal;
37  header_buffer->AppendInt(signal_byte);
38  header_buffer->AppendVector(decrypt_config->iv());
39  header_buffer->AppendInt(static_cast<uint8_t>(num_partitions));
40 
41  uint32_t partition_offset = 0;
42  for (size_t i = 0; i < subsamples.size() - 1; ++i) {
43  partition_offset += subsamples[i].clear_bytes;
44  header_buffer->AppendInt(partition_offset);
45  partition_offset += subsamples[i].cipher_bytes;
46  header_buffer->AppendInt(partition_offset);
47  }
48  // Add another partition between the clear bytes and cipher bytes if
49  // cipher bytes is not zero.
50  if (subsamples.back().cipher_bytes != 0) {
51  partition_offset += subsamples.back().clear_bytes;
52  header_buffer->AppendInt(partition_offset);
53  }
54 
55  DCHECK_EQ(header_size, header_buffer->Size());
56  } else {
57  // Use whole-frame encryption: | signal_byte(1) | iv | enc_data |
58  const uint8_t signal_byte = kWebMEncryptedSignal;
59  header_buffer->AppendInt(signal_byte);
60  header_buffer->AppendVector(decrypt_config->iv());
61  }
62  } else {
63  // Clear sample: | signal_byte(0) | data |
64  const uint8_t signal_byte = 0x00;
65  header_buffer->AppendInt(signal_byte);
66  }
67 }
68 } // namespace
69 
70 Status UpdateTrackForEncryption(const std::vector<uint8_t>& key_id,
71  mkvmuxer::Track* track) {
72  DCHECK_EQ(track->content_encoding_entries_size(), 0u);
73 
74  if (!track->AddContentEncoding()) {
75  return Status(error::INTERNAL_ERROR,
76  "Could not add ContentEncoding to track.");
77  }
78 
79  mkvmuxer::ContentEncoding* const encoding =
80  track->GetContentEncodingByIndex(0);
81  if (!encoding) {
82  return Status(error::INTERNAL_ERROR,
83  "Could not add ContentEncoding to track.");
84  }
85 
86  mkvmuxer::ContentEncAESSettings* const aes = encoding->enc_aes_settings();
87  if (!aes) {
88  return Status(error::INTERNAL_ERROR,
89  "Error getting ContentEncAESSettings.");
90  }
91  if (aes->cipher_mode() != mkvmuxer::ContentEncAESSettings::kCTR) {
92  return Status(error::INTERNAL_ERROR, "Cipher Mode is not CTR.");
93  }
94 
95  if (!encoding->SetEncryptionID(key_id.data(), key_id.size())) {
96  return Status(error::INTERNAL_ERROR, "Error setting encryption ID.");
97  }
98  return Status::OK;
99 }
100 
101 void UpdateFrameForEncryption(MediaSample* sample) {
102  DCHECK(sample);
103  BufferWriter header_buffer;
104  WriteEncryptedFrameHeader(sample->decrypt_config(), &header_buffer);
105 
106  const size_t sample_size = header_buffer.Size() + sample->data_size();
107  std::shared_ptr<uint8_t> new_sample_data(new uint8_t[sample_size],
108  std::default_delete<uint8_t[]>());
109  memcpy(new_sample_data.get(), header_buffer.Buffer(), header_buffer.Size());
110  memcpy(&new_sample_data.get()[header_buffer.Size()], sample->data(),
111  sample->data_size());
112  sample->TransferData(std::move(new_sample_data), sample_size);
113 }
114 
115 } // namespace webm
116 } // namespace media
117 } // namespace shaka
All the methods that are virtual are virtual for mocking.
Definition: crypto_flags.cc:66