Shaka Packager SDK
Loading...
Searching...
No Matches
widevine_key_source.cc
1// Copyright 2014 Google LLC. All rights reserved.
2//
3// Use of this source code is governed by a BSD-style
4// license that can be found in the LICENSE file or at
5// https://developers.google.com/open-source/licenses/bsd
6
7#include <packager/media/base/widevine_key_source.h>
8
9#include <functional>
10#include <iterator>
11
12#include <absl/base/internal/endian.h>
13#include <absl/flags/flag.h>
14#include <absl/log/check.h>
15#include <absl/strings/escaping.h>
16
17#include <packager/macros/logging.h>
18#include <packager/media/base/http_key_fetcher.h>
19#include <packager/media/base/producer_consumer_queue.h>
20#include <packager/media/base/protection_system_ids.h>
21#include <packager/media/base/protection_system_specific_info.h>
22#include <packager/media/base/proto_json_util.h>
23#include <packager/media/base/pssh_generator_util.h>
24#include <packager/media/base/rcheck.h>
25#include <packager/media/base/request_signer.h>
26#include <packager/media/base/widevine_common_encryption.pb.h>
27
28ABSL_FLAG(std::string,
29 video_feature,
30 "",
31 "Specify the optional video feature, e.g. HDR.");
32
33namespace shaka {
34namespace media {
35namespace {
36
37const bool kEnableKeyRotation = true;
38
39// Number of times to retry requesting keys in case of a transient error from
40// the server.
41const int kNumTransientErrorRetries = 5;
42const int kFirstRetryDelayMilliseconds = 1000;
43
44// Default crypto period count, which is the number of keys to fetch on every
45// key rotation enabled request.
46const int kDefaultCryptoPeriodCount = 10;
47const int kGetKeyTimeoutInSeconds = 5 * 60; // 5 minutes.
48const int kKeyFetchTimeoutInSeconds = 60; // 1 minute.
49
50CommonEncryptionRequest::ProtectionScheme ToCommonEncryptionProtectionScheme(
51 FourCC protection_scheme) {
52 switch (protection_scheme) {
53 case FOURCC_cenc:
54 return CommonEncryptionRequest::CENC;
55 case FOURCC_cbcs:
56 case kAppleSampleAesProtectionScheme:
57 // Treat sample aes as a variant of cbcs.
58 return CommonEncryptionRequest::CBCS;
59 case FOURCC_cbc1:
60 return CommonEncryptionRequest::CBC1;
61 case FOURCC_cens:
62 return CommonEncryptionRequest::CENS;
63 default:
64 LOG(WARNING) << "Ignore unrecognized protection scheme "
65 << FourCCToString(protection_scheme);
66 return CommonEncryptionRequest::UNSPECIFIED;
67 }
68}
69
70ProtectionSystemSpecificInfo ProtectionSystemInfoFromPsshProto(
71 const CommonEncryptionResponse::Track::Pssh& pssh_proto) {
72 PsshBoxBuilder pssh_builder;
73 pssh_builder.set_system_id(kWidevineSystemId, std::size(kWidevineSystemId));
74
75 if (pssh_proto.has_boxes()) {
76 return {pssh_builder.system_id(),
77 std::vector<uint8_t>(pssh_proto.boxes().begin(),
78 pssh_proto.boxes().end())};
79 } else {
80 pssh_builder.set_pssh_box_version(0);
81 const std::vector<uint8_t> pssh_data(pssh_proto.data().begin(),
82 pssh_proto.data().end());
83 pssh_builder.set_pssh_data(pssh_data);
84 return {pssh_builder.system_id(), pssh_builder.CreateBox()};
85 }
86}
87
88} // namespace
89
90WidevineKeySource::WidevineKeySource(const std::string& server_url,
91 ProtectionSystem protection_systems,
92 FourCC protection_scheme)
93 // Widevine PSSH is fetched from Widevine license server.
94 : generate_widevine_protection_system_(
95 // Generate Widevine protection system if there are no other
96 // protection system specified.
97 protection_systems == ProtectionSystem::kNone ||
98 has_flag(protection_systems, ProtectionSystem::kWidevine)),
99 key_fetcher_(new HttpKeyFetcher(kKeyFetchTimeoutInSeconds)),
100 server_url_(server_url),
101 crypto_period_count_(kDefaultCryptoPeriodCount),
102 protection_scheme_(protection_scheme),
103 key_production_thread_(
104 std::bind(&WidevineKeySource::FetchKeysTask, this)) {}
105
106WidevineKeySource::~WidevineKeySource() {
107 if (key_pool_)
108 key_pool_->Stop();
109 // Signal the production thread to start key production if it is not
110 // signaled yet so the thread can be joined.
111 if (!start_key_production_.HasBeenNotified())
112 start_key_production_.Notify();
113 key_production_thread_.join();
114}
115
116Status WidevineKeySource::FetchKeys(const std::vector<uint8_t>& content_id,
117 const std::string& policy) {
118 absl::MutexLock scoped_lock(mutex_);
119 common_encryption_request_.reset(new CommonEncryptionRequest);
120 common_encryption_request_->set_content_id(content_id.data(),
121 content_id.size());
122 common_encryption_request_->set_policy(policy);
123 common_encryption_request_->set_protection_scheme(
124 ToCommonEncryptionProtectionScheme(protection_scheme_));
125 if (enable_entitlement_license_)
126 common_encryption_request_->set_enable_entitlement_license(true);
127
128 return FetchKeysInternal(!kEnableKeyRotation, 0, false);
129}
130
131Status WidevineKeySource::FetchKeys(EmeInitDataType init_data_type,
132 const std::vector<uint8_t>& init_data) {
133 std::vector<uint8_t> pssh_data;
134 uint32_t asset_id = 0;
135 switch (init_data_type) {
136 case EmeInitDataType::CENC: {
137 const std::vector<uint8_t> widevine_system_id(
138 kWidevineSystemId, kWidevineSystemId + std::size(kWidevineSystemId));
139 std::vector<ProtectionSystemSpecificInfo> protection_systems_info;
141 init_data.data(), init_data.size(), &protection_systems_info)) {
142 return Status(error::PARSER_FAILURE, "Error parsing the PSSH boxes.");
143 }
144 for (const auto& info : protection_systems_info) {
145 std::unique_ptr<PsshBoxBuilder> pssh_builder =
146 PsshBoxBuilder::ParseFromBox(info.psshs.data(), info.psshs.size());
147 if (!pssh_builder)
148 return Status(error::PARSER_FAILURE, "Error parsing the PSSH box.");
149 // Use Widevine PSSH if available otherwise construct a Widevine PSSH
150 // from the first available key ids.
151 if (info.system_id == widevine_system_id) {
152 pssh_data = pssh_builder->pssh_data();
153 break;
154 } else if (pssh_data.empty() && !pssh_builder->key_ids().empty()) {
155 pssh_data =
156 GenerateWidevinePsshDataFromKeyIds(pssh_builder->key_ids());
157 // Continue to see if there is any Widevine PSSH. The KeyId generated
158 // PSSH is only used if a Widevine PSSH could not be found.
159 continue;
160 }
161 }
162 if (pssh_data.empty())
163 return Status(error::INVALID_ARGUMENT, "No supported PSSHs found.");
164 break;
165 }
166 case EmeInitDataType::WEBM: {
167 pssh_data = GenerateWidevinePsshDataFromKeyIds({init_data});
168 break;
169 }
170 case EmeInitDataType::WIDEVINE_CLASSIC:
171 if (init_data.size() < sizeof(asset_id))
172 return Status(error::INVALID_ARGUMENT, "Invalid asset id.");
173 asset_id = absl::big_endian::Load32(init_data.data());
174 break;
175 default:
176 LOG(ERROR) << "Init data type " << static_cast<int>(init_data_type)
177 << " not supported.";
178 return Status(error::INVALID_ARGUMENT, "Unsupported init data type.");
179 }
180 const bool widevine_classic =
181 init_data_type == EmeInitDataType::WIDEVINE_CLASSIC;
182 absl::MutexLock scoped_lock(mutex_);
183 common_encryption_request_.reset(new CommonEncryptionRequest);
184 if (widevine_classic) {
185 common_encryption_request_->set_asset_id(asset_id);
186 } else {
187 common_encryption_request_->set_pssh_data(pssh_data.data(),
188 pssh_data.size());
189 }
190 return FetchKeysInternal(!kEnableKeyRotation, 0, widevine_classic);
191}
192
193Status WidevineKeySource::GetKey(const std::string& stream_label,
194 EncryptionKey* key) {
195 DCHECK(key);
196 if (encryption_key_map_.find(stream_label) == encryption_key_map_.end()) {
197 return Status(error::INTERNAL_ERROR,
198 "Cannot find key for '" + stream_label + "'.");
199 }
200 *key = *encryption_key_map_[stream_label];
201 return Status::OK;
202}
203
204Status WidevineKeySource::GetKey(const std::vector<uint8_t>& key_id,
205 EncryptionKey* key) {
206 DCHECK(key);
207 for (const auto& pair : encryption_key_map_) {
208 if (pair.second->key_id == key_id) {
209 *key = *pair.second;
210 return Status::OK;
211 }
212 }
213 return Status(error::INTERNAL_ERROR, "Cannot find key with specified key ID");
214}
215
217 uint32_t crypto_period_index,
218 int32_t crypto_period_duration_in_seconds,
219 const std::string& stream_label,
220 EncryptionKey* key) {
221 // TODO(kqyang): This is not elegant. Consider refactoring later.
222 {
223 absl::MutexLock scoped_lock(mutex_);
224 if (!key_production_started_) {
225 crypto_period_duration_in_seconds_ = crypto_period_duration_in_seconds;
226 // Another client may have a slightly smaller starting crypto period
227 // index. Set the initial value to account for that.
228 first_crypto_period_index_ =
229 crypto_period_index ? crypto_period_index - 1 : 0;
230 DCHECK(!key_pool_);
231 const size_t queue_size = crypto_period_count_ * 10;
232 key_pool_.reset(
233 new EncryptionKeyQueue(queue_size, first_crypto_period_index_));
234 start_key_production_.Notify();
235 key_production_started_ = true;
236 } else if (crypto_period_duration_in_seconds_ !=
237 crypto_period_duration_in_seconds) {
238 return Status(error::INVALID_ARGUMENT,
239 "Crypto period duration should not change.");
240 }
241 }
242 return GetKeyInternal(crypto_period_index, stream_label, key);
243}
244
245void WidevineKeySource::set_signer(std::unique_ptr<RequestSigner> signer) {
246 signer_ = std::move(signer);
247}
248
250 std::unique_ptr<KeyFetcher> key_fetcher) {
251 key_fetcher_ = std::move(key_fetcher);
252}
253
254Status WidevineKeySource::GetKeyInternal(uint32_t crypto_period_index,
255 const std::string& stream_label,
256 EncryptionKey* key) {
257 DCHECK(key_pool_);
258 DCHECK(key);
259
260 std::shared_ptr<EncryptionKeyMap> encryption_key_map;
261 Status status = key_pool_->Peek(crypto_period_index, &encryption_key_map,
262 kGetKeyTimeoutInSeconds * 1000);
263 if (!status.ok()) {
264 if (status.error_code() == error::STOPPED) {
265 CHECK(!common_encryption_request_status_.ok());
266 return common_encryption_request_status_;
267 }
268 return status;
269 }
270
271 if (encryption_key_map->find(stream_label) == encryption_key_map->end()) {
272 return Status(error::INTERNAL_ERROR,
273 "Cannot find key for '" + stream_label + "'.");
274 }
275 *key = *encryption_key_map->at(stream_label);
276 return Status::OK;
277}
278
279void WidevineKeySource::FetchKeysTask() {
280 // Wait until key production is signaled.
281 start_key_production_.WaitForNotification();
282 if (!key_pool_ || key_pool_->Stopped())
283 return;
284
285 Status status =
286 FetchKeysInternal(kEnableKeyRotation, first_crypto_period_index_, false);
287 while (status.ok()) {
288 first_crypto_period_index_ += crypto_period_count_;
289 status = FetchKeysInternal(kEnableKeyRotation, first_crypto_period_index_,
290 false);
291 }
292 common_encryption_request_status_ = status;
293 key_pool_->Stop();
294}
295
296Status WidevineKeySource::FetchKeysInternal(bool enable_key_rotation,
297 uint32_t first_crypto_period_index,
298 bool widevine_classic) {
299 CommonEncryptionRequest request;
300 FillRequest(enable_key_rotation, first_crypto_period_index, &request);
301
302 std::string message;
303 Status status = GenerateKeyMessage(request, &message);
304 if (!status.ok())
305 return status;
306 VLOG(1) << "Message: " << message;
307
308 std::string raw_response;
309 int64_t sleep_duration = kFirstRetryDelayMilliseconds;
310
311 // Perform client side retries if seeing server transient error to workaround
312 // server limitation.
313 for (int i = 0; i < kNumTransientErrorRetries; ++i) {
314 status = key_fetcher_->FetchKeys(server_url_, message, &raw_response);
315 if (status.ok()) {
316 VLOG(1) << "Retry [" << i << "] Response:" << raw_response;
317
318 bool transient_error = false;
319 if (ExtractEncryptionKey(enable_key_rotation, widevine_classic,
320 raw_response, &transient_error))
321 return Status::OK;
322
323 if (!transient_error) {
324 return Status(
325 error::SERVER_ERROR,
326 "Failed to extract encryption key from '" + raw_response + "'.");
327 }
328 } else if (status.error_code() != error::TIME_OUT) {
329 return status;
330 }
331
332 // Exponential backoff.
333 if (i != kNumTransientErrorRetries - 1) {
334 std::this_thread::sleep_for(std::chrono::milliseconds(sleep_duration));
335 sleep_duration *= 2;
336 }
337 }
338 return Status(error::SERVER_ERROR,
339 "Failed to recover from server internal error.");
340}
341
342void WidevineKeySource::FillRequest(bool enable_key_rotation,
343 uint32_t first_crypto_period_index,
344 CommonEncryptionRequest* request) {
345 DCHECK(common_encryption_request_);
346 DCHECK(request);
347 *request = *common_encryption_request_;
348
349 request->add_tracks()->set_type("SD");
350 request->add_tracks()->set_type("HD");
351 request->add_tracks()->set_type("UHD1");
352 request->add_tracks()->set_type("UHD2");
353 request->add_tracks()->set_type("AUDIO");
354
355 request->add_drm_types(ModularDrmType::WIDEVINE);
356
357 if (enable_key_rotation) {
358 request->set_first_crypto_period_index(first_crypto_period_index);
359 request->set_crypto_period_count(crypto_period_count_);
360 request->set_crypto_period_seconds(crypto_period_duration_in_seconds_);
361 }
362
363 if (!group_id_.empty())
364 request->set_group_id(group_id_.data(), group_id_.size());
365
366 std::string video_feature = absl::GetFlag(FLAGS_video_feature);
367 if (!video_feature.empty())
368 request->set_video_feature(video_feature);
369}
370
371Status WidevineKeySource::GenerateKeyMessage(
372 const CommonEncryptionRequest& request,
373 std::string* message) {
374 DCHECK(message);
375
376 SignedModularDrmRequest signed_request;
377 signed_request.set_request(MessageToJsonString(request));
378
379 // Sign the request.
380 if (signer_) {
381 std::string signature;
382 if (!signer_->GenerateSignature(signed_request.request(), &signature))
383 return Status(error::INTERNAL_ERROR, "Signature generation failed.");
384
385 signed_request.set_signature(signature);
386 signed_request.set_signer(signer_->signer_name());
387 }
388
389 *message = MessageToJsonString(signed_request);
390 return Status::OK;
391}
392
393bool WidevineKeySource::ExtractEncryptionKey(bool enable_key_rotation,
394 bool widevine_classic,
395 const std::string& response,
396 bool* transient_error) {
397 DCHECK(transient_error);
398 *transient_error = false;
399
400 SignedModularDrmResponse signed_response_proto;
401 if (!JsonStringToMessage(response, &signed_response_proto)) {
402 LOG(ERROR) << "Failed to convert JSON to proto: " << response;
403 return false;
404 }
405
406 CommonEncryptionResponse response_proto;
407 if (!JsonStringToMessage(signed_response_proto.response(), &response_proto)) {
408 LOG(ERROR) << "Failed to convert JSON to proto: "
409 << signed_response_proto.response();
410 return false;
411 }
412
413 if (response_proto.status() != CommonEncryptionResponse::OK) {
414 LOG(ERROR) << "Received non-OK license response: " << response;
415 // Server may return INTERNAL_ERROR intermittently, which is a transient
416 // error and the next client request may succeed without problem.
417 *transient_error =
418 (response_proto.status() == CommonEncryptionResponse::INTERNAL_ERROR);
419 return false;
420 }
421
422 RCHECK(enable_key_rotation
423 ? response_proto.tracks_size() >= crypto_period_count_
424 : response_proto.tracks_size() >= 1);
425
426 uint32_t current_crypto_period_index = first_crypto_period_index_;
427
428 std::vector<std::vector<uint8_t>> key_ids;
429 for (const auto& track : response_proto.tracks()) {
430 if (!widevine_classic)
431 key_ids.emplace_back(track.key_id().begin(), track.key_id().end());
432 }
433
434 EncryptionKeyMap encryption_key_map;
435 for (const auto& track : response_proto.tracks()) {
436 VLOG(2) << "track " << track.ShortDebugString();
437
438 if (enable_key_rotation) {
439 if (track.crypto_period_index() != current_crypto_period_index) {
440 if (track.crypto_period_index() != current_crypto_period_index + 1) {
441 LOG(ERROR) << "Expecting crypto period index "
442 << current_crypto_period_index << " or "
443 << current_crypto_period_index + 1 << "; Seen "
444 << track.crypto_period_index();
445 return false;
446 }
447 if (!PushToKeyPool(&encryption_key_map))
448 return false;
449 ++current_crypto_period_index;
450 }
451 }
452
453 const std::string& stream_label = track.type();
454 RCHECK(encryption_key_map.find(stream_label) == encryption_key_map.end());
455
456 std::unique_ptr<EncryptionKey> encryption_key(new EncryptionKey());
457 encryption_key->key.assign(track.key().begin(), track.key().end());
458
459 // Get key ID and PSSH data for CENC content only.
460 if (!widevine_classic) {
461 encryption_key->key_id.assign(track.key_id().begin(),
462 track.key_id().end());
463 encryption_key->iv.assign(track.iv().begin(), track.iv().end());
464 encryption_key->key_ids = key_ids;
465
466 if (generate_widevine_protection_system_) {
467 if (track.pssh_size() != 1) {
468 LOG(ERROR) << "Expecting one and only one pssh, seeing "
469 << track.pssh_size();
470 return false;
471 }
472 encryption_key->key_system_info.push_back(
473 ProtectionSystemInfoFromPsshProto(track.pssh(0)));
474 }
475 }
476 encryption_key_map[stream_label] = std::move(encryption_key);
477 }
478
479 DCHECK(!encryption_key_map.empty());
480 if (!enable_key_rotation) {
481 // Merge with previously requested keys.
482 for (auto& pair : encryption_key_map)
483 encryption_key_map_[pair.first] = std::move(pair.second);
484 return true;
485 }
486
487 return PushToKeyPool(&encryption_key_map);
488}
489
490bool WidevineKeySource::PushToKeyPool(EncryptionKeyMap* encryption_key_map) {
491 DCHECK(key_pool_);
492 DCHECK(encryption_key_map);
493 auto encryption_key_map_shared = std::make_shared<EncryptionKeyMap>();
494 encryption_key_map_shared->swap(*encryption_key_map);
495 Status status = key_pool_->Push(encryption_key_map_shared, kInfiniteTimeout);
496 if (!status.ok()) {
497 DCHECK_EQ(error::STOPPED, status.error_code());
498 return false;
499 }
500 return true;
501}
502
503} // namespace media
504} // namespace shaka
static std::unique_ptr< PsshBoxBuilder > ParseFromBox(const uint8_t *data, size_t data_size)
Status GetCryptoPeriodKey(uint32_t crypto_period_index, int32_t crypto_period_duration_in_seconds, const std::string &stream_label, EncryptionKey *key) override
void set_signer(std::unique_ptr< RequestSigner > signer)
void set_key_fetcher(std::unique_ptr< KeyFetcher > key_fetcher)
WidevineKeySource(const std::string &server_url, ProtectionSystem protection_systems, FourCC protection_scheme)
Status GetKey(const std::string &stream_label, EncryptionKey *key) override
Status FetchKeys(EmeInitDataType init_data_type, const std::vector< uint8_t > &init_data) override
All the methods that are virtual are virtual for mocking.
static bool ParseBoxes(const uint8_t *data, size_t data_size, std::vector< ProtectionSystemSpecificInfo > *pssh_boxes)