Shaka Packager SDK
widevine_key_source.cc
1 // Copyright 2014 Google LLC. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file or at
5 // https://developers.google.com/open-source/licenses/bsd
6 
7 #include <packager/media/base/widevine_key_source.h>
8 
9 #include <functional>
10 #include <iterator>
11 
12 #include <absl/base/internal/endian.h>
13 #include <absl/flags/flag.h>
14 #include <absl/log/check.h>
15 #include <absl/strings/escaping.h>
16 
17 #include <packager/macros/logging.h>
18 #include <packager/media/base/http_key_fetcher.h>
19 #include <packager/media/base/producer_consumer_queue.h>
20 #include <packager/media/base/protection_system_ids.h>
21 #include <packager/media/base/protection_system_specific_info.h>
22 #include <packager/media/base/proto_json_util.h>
23 #include <packager/media/base/pssh_generator_util.h>
24 #include <packager/media/base/rcheck.h>
25 #include <packager/media/base/request_signer.h>
26 #include <packager/media/base/widevine_common_encryption.pb.h>
27 
28 ABSL_FLAG(std::string,
29  video_feature,
30  "",
31  "Specify the optional video feature, e.g. HDR.");
32 
33 namespace shaka {
34 namespace media {
35 namespace {
36 
37 const bool kEnableKeyRotation = true;
38 
39 // Number of times to retry requesting keys in case of a transient error from
40 // the server.
41 const int kNumTransientErrorRetries = 5;
42 const int kFirstRetryDelayMilliseconds = 1000;
43 
44 // Default crypto period count, which is the number of keys to fetch on every
45 // key rotation enabled request.
46 const int kDefaultCryptoPeriodCount = 10;
47 const int kGetKeyTimeoutInSeconds = 5 * 60; // 5 minutes.
48 const int kKeyFetchTimeoutInSeconds = 60; // 1 minute.
49 
50 CommonEncryptionRequest::ProtectionScheme ToCommonEncryptionProtectionScheme(
51  FourCC protection_scheme) {
52  switch (protection_scheme) {
53  case FOURCC_cenc:
54  return CommonEncryptionRequest::CENC;
55  case FOURCC_cbcs:
56  case kAppleSampleAesProtectionScheme:
57  // Treat sample aes as a variant of cbcs.
58  return CommonEncryptionRequest::CBCS;
59  case FOURCC_cbc1:
60  return CommonEncryptionRequest::CBC1;
61  case FOURCC_cens:
62  return CommonEncryptionRequest::CENS;
63  default:
64  LOG(WARNING) << "Ignore unrecognized protection scheme "
65  << FourCCToString(protection_scheme);
66  return CommonEncryptionRequest::UNSPECIFIED;
67  }
68 }
69 
70 ProtectionSystemSpecificInfo ProtectionSystemInfoFromPsshProto(
71  const CommonEncryptionResponse::Track::Pssh& pssh_proto) {
72  PsshBoxBuilder pssh_builder;
73  pssh_builder.set_system_id(kWidevineSystemId, std::size(kWidevineSystemId));
74 
75  if (pssh_proto.has_boxes()) {
76  return {pssh_builder.system_id(),
77  std::vector<uint8_t>(pssh_proto.boxes().begin(),
78  pssh_proto.boxes().end())};
79  } else {
80  pssh_builder.set_pssh_box_version(0);
81  const std::vector<uint8_t> pssh_data(pssh_proto.data().begin(),
82  pssh_proto.data().end());
83  pssh_builder.set_pssh_data(pssh_data);
84  return {pssh_builder.system_id(), pssh_builder.CreateBox()};
85  }
86 }
87 
88 } // namespace
89 
90 WidevineKeySource::WidevineKeySource(const std::string& server_url,
91  ProtectionSystem protection_systems,
92  FourCC protection_scheme)
93  // Widevine PSSH is fetched from Widevine license server.
94  : generate_widevine_protection_system_(
95  // Generate Widevine protection system if there are no other
96  // protection system specified.
97  protection_systems == ProtectionSystem::kNone ||
98  has_flag(protection_systems, ProtectionSystem::kWidevine)),
99  key_fetcher_(new HttpKeyFetcher(kKeyFetchTimeoutInSeconds)),
100  server_url_(server_url),
101  crypto_period_count_(kDefaultCryptoPeriodCount),
102  protection_scheme_(protection_scheme),
103  key_production_thread_(
104  std::bind(&WidevineKeySource::FetchKeysTask, this)) {}
105 
106 WidevineKeySource::~WidevineKeySource() {
107  if (key_pool_)
108  key_pool_->Stop();
109  // Signal the production thread to start key production if it is not
110  // signaled yet so the thread can be joined.
111  if (!start_key_production_.HasBeenNotified())
112  start_key_production_.Notify();
113  key_production_thread_.join();
114 }
115 
116 Status WidevineKeySource::FetchKeys(const std::vector<uint8_t>& content_id,
117  const std::string& policy) {
118  absl::MutexLock scoped_lock(&mutex_);
119  common_encryption_request_.reset(new CommonEncryptionRequest);
120  common_encryption_request_->set_content_id(content_id.data(),
121  content_id.size());
122  common_encryption_request_->set_policy(policy);
123  common_encryption_request_->set_protection_scheme(
124  ToCommonEncryptionProtectionScheme(protection_scheme_));
125  if (enable_entitlement_license_)
126  common_encryption_request_->set_enable_entitlement_license(true);
127 
128  return FetchKeysInternal(!kEnableKeyRotation, 0, false);
129 }
130 
131 Status WidevineKeySource::FetchKeys(EmeInitDataType init_data_type,
132  const std::vector<uint8_t>& init_data) {
133  std::vector<uint8_t> pssh_data;
134  uint32_t asset_id = 0;
135  switch (init_data_type) {
136  case EmeInitDataType::CENC: {
137  const std::vector<uint8_t> widevine_system_id(
138  kWidevineSystemId, kWidevineSystemId + std::size(kWidevineSystemId));
139  std::vector<ProtectionSystemSpecificInfo> protection_systems_info;
141  init_data.data(), init_data.size(), &protection_systems_info)) {
142  return Status(error::PARSER_FAILURE, "Error parsing the PSSH boxes.");
143  }
144  for (const auto& info : protection_systems_info) {
145  std::unique_ptr<PsshBoxBuilder> pssh_builder =
146  PsshBoxBuilder::ParseFromBox(info.psshs.data(), info.psshs.size());
147  if (!pssh_builder)
148  return Status(error::PARSER_FAILURE, "Error parsing the PSSH box.");
149  // Use Widevine PSSH if available otherwise construct a Widevine PSSH
150  // from the first available key ids.
151  if (info.system_id == widevine_system_id) {
152  pssh_data = pssh_builder->pssh_data();
153  break;
154  } else if (pssh_data.empty() && !pssh_builder->key_ids().empty()) {
155  pssh_data =
156  GenerateWidevinePsshDataFromKeyIds(pssh_builder->key_ids());
157  // Continue to see if there is any Widevine PSSH. The KeyId generated
158  // PSSH is only used if a Widevine PSSH could not be found.
159  continue;
160  }
161  }
162  if (pssh_data.empty())
163  return Status(error::INVALID_ARGUMENT, "No supported PSSHs found.");
164  break;
165  }
166  case EmeInitDataType::WEBM: {
167  pssh_data = GenerateWidevinePsshDataFromKeyIds({init_data});
168  break;
169  }
170  case EmeInitDataType::WIDEVINE_CLASSIC:
171  if (init_data.size() < sizeof(asset_id))
172  return Status(error::INVALID_ARGUMENT, "Invalid asset id.");
173  asset_id = absl::big_endian::Load32(init_data.data());
174  break;
175  default:
176  LOG(ERROR) << "Init data type " << static_cast<int>(init_data_type)
177  << " not supported.";
178  return Status(error::INVALID_ARGUMENT, "Unsupported init data type.");
179  }
180  const bool widevine_classic =
181  init_data_type == EmeInitDataType::WIDEVINE_CLASSIC;
182  absl::MutexLock scoped_lock(&mutex_);
183  common_encryption_request_.reset(new CommonEncryptionRequest);
184  if (widevine_classic) {
185  common_encryption_request_->set_asset_id(asset_id);
186  } else {
187  common_encryption_request_->set_pssh_data(pssh_data.data(),
188  pssh_data.size());
189  }
190  return FetchKeysInternal(!kEnableKeyRotation, 0, widevine_classic);
191 }
192 
193 Status WidevineKeySource::GetKey(const std::string& stream_label,
194  EncryptionKey* key) {
195  DCHECK(key);
196  if (encryption_key_map_.find(stream_label) == encryption_key_map_.end()) {
197  return Status(error::INTERNAL_ERROR,
198  "Cannot find key for '" + stream_label + "'.");
199  }
200  *key = *encryption_key_map_[stream_label];
201  return Status::OK;
202 }
203 
204 Status WidevineKeySource::GetKey(const std::vector<uint8_t>& key_id,
205  EncryptionKey* key) {
206  DCHECK(key);
207  for (const auto& pair : encryption_key_map_) {
208  if (pair.second->key_id == key_id) {
209  *key = *pair.second;
210  return Status::OK;
211  }
212  }
213  return Status(error::INTERNAL_ERROR,
214  "Cannot find key with specified key ID");
215 }
216 
218  uint32_t crypto_period_index,
219  int32_t crypto_period_duration_in_seconds,
220  const std::string& stream_label,
221  EncryptionKey* key) {
222  // TODO(kqyang): This is not elegant. Consider refactoring later.
223  {
224  absl::MutexLock scoped_lock(&mutex_);
225  if (!key_production_started_) {
226  crypto_period_duration_in_seconds_ = crypto_period_duration_in_seconds;
227  // Another client may have a slightly smaller starting crypto period
228  // index. Set the initial value to account for that.
229  first_crypto_period_index_ =
230  crypto_period_index ? crypto_period_index - 1 : 0;
231  DCHECK(!key_pool_);
232  const size_t queue_size = crypto_period_count_ * 10;
233  key_pool_.reset(
234  new EncryptionKeyQueue(queue_size, first_crypto_period_index_));
235  start_key_production_.Notify();
236  key_production_started_ = true;
237  } else if (crypto_period_duration_in_seconds_ !=
238  crypto_period_duration_in_seconds) {
239  return Status(error::INVALID_ARGUMENT,
240  "Crypto period duration should not change.");
241  }
242  }
243  return GetKeyInternal(crypto_period_index, stream_label, key);
244 }
245 
246 void WidevineKeySource::set_signer(std::unique_ptr<RequestSigner> signer) {
247  signer_ = std::move(signer);
248 }
249 
251  std::unique_ptr<KeyFetcher> key_fetcher) {
252  key_fetcher_ = std::move(key_fetcher);
253 }
254 
255 Status WidevineKeySource::GetKeyInternal(uint32_t crypto_period_index,
256  const std::string& stream_label,
257  EncryptionKey* key) {
258  DCHECK(key_pool_);
259  DCHECK(key);
260 
261  std::shared_ptr<EncryptionKeyMap> encryption_key_map;
262  Status status = key_pool_->Peek(crypto_period_index, &encryption_key_map,
263  kGetKeyTimeoutInSeconds * 1000);
264  if (!status.ok()) {
265  if (status.error_code() == error::STOPPED) {
266  CHECK(!common_encryption_request_status_.ok());
267  return common_encryption_request_status_;
268  }
269  return status;
270  }
271 
272  if (encryption_key_map->find(stream_label) == encryption_key_map->end()) {
273  return Status(error::INTERNAL_ERROR,
274  "Cannot find key for '" + stream_label + "'.");
275  }
276  *key = *encryption_key_map->at(stream_label);
277  return Status::OK;
278 }
279 
280 void WidevineKeySource::FetchKeysTask() {
281  // Wait until key production is signaled.
282  start_key_production_.WaitForNotification();
283  if (!key_pool_ || key_pool_->Stopped())
284  return;
285 
286  Status status = FetchKeysInternal(kEnableKeyRotation,
287  first_crypto_period_index_,
288  false);
289  while (status.ok()) {
290  first_crypto_period_index_ += crypto_period_count_;
291  status = FetchKeysInternal(kEnableKeyRotation,
292  first_crypto_period_index_,
293  false);
294  }
295  common_encryption_request_status_ = status;
296  key_pool_->Stop();
297 }
298 
299 Status WidevineKeySource::FetchKeysInternal(bool enable_key_rotation,
300  uint32_t first_crypto_period_index,
301  bool widevine_classic) {
302  CommonEncryptionRequest request;
303  FillRequest(enable_key_rotation, first_crypto_period_index, &request);
304 
305  std::string message;
306  Status status = GenerateKeyMessage(request, &message);
307  if (!status.ok())
308  return status;
309  VLOG(1) << "Message: " << message;
310 
311  std::string raw_response;
312  int64_t sleep_duration = kFirstRetryDelayMilliseconds;
313 
314  // Perform client side retries if seeing server transient error to workaround
315  // server limitation.
316  for (int i = 0; i < kNumTransientErrorRetries; ++i) {
317  status = key_fetcher_->FetchKeys(server_url_, message, &raw_response);
318  if (status.ok()) {
319  VLOG(1) << "Retry [" << i << "] Response:" << raw_response;
320 
321  bool transient_error = false;
322  if (ExtractEncryptionKey(enable_key_rotation, widevine_classic,
323  raw_response, &transient_error))
324  return Status::OK;
325 
326  if (!transient_error) {
327  return Status(
328  error::SERVER_ERROR,
329  "Failed to extract encryption key from '" + raw_response + "'.");
330  }
331  } else if (status.error_code() != error::TIME_OUT) {
332  return status;
333  }
334 
335  // Exponential backoff.
336  if (i != kNumTransientErrorRetries - 1) {
337  std::this_thread::sleep_for(std::chrono::milliseconds(sleep_duration));
338  sleep_duration *= 2;
339  }
340  }
341  return Status(error::SERVER_ERROR,
342  "Failed to recover from server internal error.");
343 }
344 
345 void WidevineKeySource::FillRequest(bool enable_key_rotation,
346  uint32_t first_crypto_period_index,
347  CommonEncryptionRequest* request) {
348  DCHECK(common_encryption_request_);
349  DCHECK(request);
350  *request = *common_encryption_request_;
351 
352  request->add_tracks()->set_type("SD");
353  request->add_tracks()->set_type("HD");
354  request->add_tracks()->set_type("UHD1");
355  request->add_tracks()->set_type("UHD2");
356  request->add_tracks()->set_type("AUDIO");
357 
358  request->add_drm_types(ModularDrmType::WIDEVINE);
359 
360  if (enable_key_rotation) {
361  request->set_first_crypto_period_index(first_crypto_period_index);
362  request->set_crypto_period_count(crypto_period_count_);
363  request->set_crypto_period_seconds(crypto_period_duration_in_seconds_);
364  }
365 
366  if (!group_id_.empty())
367  request->set_group_id(group_id_.data(), group_id_.size());
368 
369  std::string video_feature = absl::GetFlag(FLAGS_video_feature);
370  if (!video_feature.empty())
371  request->set_video_feature(video_feature);
372 }
373 
374 Status WidevineKeySource::GenerateKeyMessage(
375  const CommonEncryptionRequest& request,
376  std::string* message) {
377  DCHECK(message);
378 
379  SignedModularDrmRequest signed_request;
380  signed_request.set_request(MessageToJsonString(request));
381 
382  // Sign the request.
383  if (signer_) {
384  std::string signature;
385  if (!signer_->GenerateSignature(signed_request.request(), &signature))
386  return Status(error::INTERNAL_ERROR, "Signature generation failed.");
387 
388  signed_request.set_signature(signature);
389  signed_request.set_signer(signer_->signer_name());
390  }
391 
392  *message = MessageToJsonString(signed_request);
393  return Status::OK;
394 }
395 
396 bool WidevineKeySource::ExtractEncryptionKey(
397  bool enable_key_rotation,
398  bool widevine_classic,
399  const std::string& response,
400  bool* transient_error) {
401  DCHECK(transient_error);
402  *transient_error = false;
403 
404  SignedModularDrmResponse signed_response_proto;
405  if (!JsonStringToMessage(response, &signed_response_proto)) {
406  LOG(ERROR) << "Failed to convert JSON to proto: " << response;
407  return false;
408  }
409 
410  CommonEncryptionResponse response_proto;
411  if (!JsonStringToMessage(signed_response_proto.response(), &response_proto)) {
412  LOG(ERROR) << "Failed to convert JSON to proto: "
413  << signed_response_proto.response();
414  return false;
415  }
416 
417  if (response_proto.status() != CommonEncryptionResponse::OK) {
418  LOG(ERROR) << "Received non-OK license response: " << response;
419  // Server may return INTERNAL_ERROR intermittently, which is a transient
420  // error and the next client request may succeed without problem.
421  *transient_error =
422  (response_proto.status() == CommonEncryptionResponse::INTERNAL_ERROR);
423  return false;
424  }
425 
426  RCHECK(enable_key_rotation
427  ? response_proto.tracks_size() >= crypto_period_count_
428  : response_proto.tracks_size() >= 1);
429 
430  uint32_t current_crypto_period_index = first_crypto_period_index_;
431 
432  std::vector<std::vector<uint8_t>> key_ids;
433  for (const auto& track : response_proto.tracks()) {
434  if (!widevine_classic)
435  key_ids.emplace_back(track.key_id().begin(), track.key_id().end());
436  }
437 
438  EncryptionKeyMap encryption_key_map;
439  for (const auto& track : response_proto.tracks()) {
440  VLOG(2) << "track " << track.ShortDebugString();
441 
442  if (enable_key_rotation) {
443  if (track.crypto_period_index() != current_crypto_period_index) {
444  if (track.crypto_period_index() != current_crypto_period_index + 1) {
445  LOG(ERROR) << "Expecting crypto period index "
446  << current_crypto_period_index << " or "
447  << current_crypto_period_index + 1 << "; Seen "
448  << track.crypto_period_index();
449  return false;
450  }
451  if (!PushToKeyPool(&encryption_key_map))
452  return false;
453  ++current_crypto_period_index;
454  }
455  }
456 
457  const std::string& stream_label = track.type();
458  RCHECK(encryption_key_map.find(stream_label) == encryption_key_map.end());
459 
460  std::unique_ptr<EncryptionKey> encryption_key(new EncryptionKey());
461  encryption_key->key.assign(track.key().begin(), track.key().end());
462 
463  // Get key ID and PSSH data for CENC content only.
464  if (!widevine_classic) {
465  encryption_key->key_id.assign(track.key_id().begin(),
466  track.key_id().end());
467  encryption_key->iv.assign(track.iv().begin(), track.iv().end());
468  encryption_key->key_ids = key_ids;
469 
470  if (generate_widevine_protection_system_) {
471  if (track.pssh_size() != 1) {
472  LOG(ERROR) << "Expecting one and only one pssh, seeing "
473  << track.pssh_size();
474  return false;
475  }
476  encryption_key->key_system_info.push_back(
477  ProtectionSystemInfoFromPsshProto(track.pssh(0)));
478  }
479  }
480  encryption_key_map[stream_label] = std::move(encryption_key);
481  }
482 
483  DCHECK(!encryption_key_map.empty());
484  if (!enable_key_rotation) {
485  // Merge with previously requested keys.
486  for (auto& pair : encryption_key_map)
487  encryption_key_map_[pair.first] = std::move(pair.second);
488  return true;
489  }
490 
491  return PushToKeyPool(&encryption_key_map);
492 }
493 
494 bool WidevineKeySource::PushToKeyPool(
495  EncryptionKeyMap* encryption_key_map) {
496  DCHECK(key_pool_);
497  DCHECK(encryption_key_map);
498  auto encryption_key_map_shared = std::make_shared<EncryptionKeyMap>();
499  encryption_key_map_shared->swap(*encryption_key_map);
500  Status status = key_pool_->Push(encryption_key_map_shared, kInfiniteTimeout);
501  if (!status.ok()) {
502  DCHECK_EQ(error::STOPPED, status.error_code());
503  return false;
504  }
505  return true;
506 }
507 
508 } // namespace media
509 } // namespace shaka
static std::unique_ptr< PsshBoxBuilder > ParseFromBox(const uint8_t *data, size_t data_size)
Status GetCryptoPeriodKey(uint32_t crypto_period_index, int32_t crypto_period_duration_in_seconds, const std::string &stream_label, EncryptionKey *key) override
void set_signer(std::unique_ptr< RequestSigner > signer)
void set_key_fetcher(std::unique_ptr< KeyFetcher > key_fetcher)
WidevineKeySource(const std::string &server_url, ProtectionSystem protection_systems, FourCC protection_scheme)
Status GetKey(const std::string &stream_label, EncryptionKey *key) override
Status FetchKeys(EmeInitDataType init_data_type, const std::vector< uint8_t > &init_data) override
All the methods that are virtual are virtual for mocking.
Definition: crypto_flags.cc:66
static bool ParseBoxes(const uint8_t *data, size_t data_size, std::vector< ProtectionSystemSpecificInfo > *pssh_boxes)