Shaka Packager SDK
Loading...
Searching...
No Matches
widevine_key_source.cc
1// Copyright 2014 Google LLC. All rights reserved.
2//
3// Use of this source code is governed by a BSD-style
4// license that can be found in the LICENSE file or at
5// https://developers.google.com/open-source/licenses/bsd
6
7#include <packager/media/base/widevine_key_source.h>
8
9#include <functional>
10#include <iterator>
11
12#include <absl/base/internal/endian.h>
13#include <absl/flags/flag.h>
14#include <absl/log/check.h>
15#include <absl/strings/escaping.h>
16
17#include <packager/macros/logging.h>
18#include <packager/media/base/http_key_fetcher.h>
19#include <packager/media/base/producer_consumer_queue.h>
20#include <packager/media/base/protection_system_ids.h>
21#include <packager/media/base/protection_system_specific_info.h>
22#include <packager/media/base/proto_json_util.h>
23#include <packager/media/base/pssh_generator_util.h>
24#include <packager/media/base/rcheck.h>
25#include <packager/media/base/request_signer.h>
26#include <packager/media/base/widevine_common_encryption.pb.h>
27
28ABSL_FLAG(std::string,
29 video_feature,
30 "",
31 "Specify the optional video feature, e.g. HDR.");
32
33namespace shaka {
34namespace media {
35namespace {
36
37const bool kEnableKeyRotation = true;
38
39// Number of times to retry requesting keys in case of a transient error from
40// the server.
41const int kNumTransientErrorRetries = 5;
42const int kFirstRetryDelayMilliseconds = 1000;
43
44// Default crypto period count, which is the number of keys to fetch on every
45// key rotation enabled request.
46const int kDefaultCryptoPeriodCount = 10;
47const int kGetKeyTimeoutInSeconds = 5 * 60; // 5 minutes.
48const int kKeyFetchTimeoutInSeconds = 60; // 1 minute.
49
50CommonEncryptionRequest::ProtectionScheme ToCommonEncryptionProtectionScheme(
51 FourCC protection_scheme) {
52 switch (protection_scheme) {
53 case FOURCC_cenc:
54 return CommonEncryptionRequest::CENC;
55 case FOURCC_cbcs:
56 case kAppleSampleAesProtectionScheme:
57 // Treat sample aes as a variant of cbcs.
58 return CommonEncryptionRequest::CBCS;
59 case FOURCC_cbc1:
60 return CommonEncryptionRequest::CBC1;
61 case FOURCC_cens:
62 return CommonEncryptionRequest::CENS;
63 default:
64 LOG(WARNING) << "Ignore unrecognized protection scheme "
65 << FourCCToString(protection_scheme);
66 return CommonEncryptionRequest::UNSPECIFIED;
67 }
68}
69
70ProtectionSystemSpecificInfo ProtectionSystemInfoFromPsshProto(
71 const CommonEncryptionResponse::Track::Pssh& pssh_proto) {
72 PsshBoxBuilder pssh_builder;
73 pssh_builder.set_system_id(kWidevineSystemId, std::size(kWidevineSystemId));
74
75 if (pssh_proto.has_boxes()) {
76 return {pssh_builder.system_id(),
77 std::vector<uint8_t>(pssh_proto.boxes().begin(),
78 pssh_proto.boxes().end())};
79 } else {
80 pssh_builder.set_pssh_box_version(0);
81 const std::vector<uint8_t> pssh_data(pssh_proto.data().begin(),
82 pssh_proto.data().end());
83 pssh_builder.set_pssh_data(pssh_data);
84 return {pssh_builder.system_id(), pssh_builder.CreateBox()};
85 }
86}
87
88} // namespace
89
90WidevineKeySource::WidevineKeySource(const std::string& server_url,
91 ProtectionSystem protection_systems,
92 FourCC protection_scheme)
93 // Widevine PSSH is fetched from Widevine license server.
94 : generate_widevine_protection_system_(
95 // Generate Widevine protection system if there are no other
96 // protection system specified.
97 protection_systems == ProtectionSystem::kNone ||
98 has_flag(protection_systems, ProtectionSystem::kWidevine)),
99 key_fetcher_(new HttpKeyFetcher(kKeyFetchTimeoutInSeconds)),
100 server_url_(server_url),
101 crypto_period_count_(kDefaultCryptoPeriodCount),
102 protection_scheme_(protection_scheme),
103 key_production_thread_(
104 std::bind(&WidevineKeySource::FetchKeysTask, this)) {}
105
106WidevineKeySource::~WidevineKeySource() {
107 if (key_pool_)
108 key_pool_->Stop();
109 // Signal the production thread to start key production if it is not
110 // signaled yet so the thread can be joined.
111 if (!start_key_production_.HasBeenNotified())
112 start_key_production_.Notify();
113 key_production_thread_.join();
114}
115
116Status WidevineKeySource::FetchKeys(const std::vector<uint8_t>& content_id,
117 const std::string& policy) {
118 absl::MutexLock scoped_lock(&mutex_);
119 common_encryption_request_.reset(new CommonEncryptionRequest);
120 common_encryption_request_->set_content_id(content_id.data(),
121 content_id.size());
122 common_encryption_request_->set_policy(policy);
123 common_encryption_request_->set_protection_scheme(
124 ToCommonEncryptionProtectionScheme(protection_scheme_));
125 if (enable_entitlement_license_)
126 common_encryption_request_->set_enable_entitlement_license(true);
127
128 return FetchKeysInternal(!kEnableKeyRotation, 0, false);
129}
130
131Status WidevineKeySource::FetchKeys(EmeInitDataType init_data_type,
132 const std::vector<uint8_t>& init_data) {
133 std::vector<uint8_t> pssh_data;
134 uint32_t asset_id = 0;
135 switch (init_data_type) {
136 case EmeInitDataType::CENC: {
137 const std::vector<uint8_t> widevine_system_id(
138 kWidevineSystemId, kWidevineSystemId + std::size(kWidevineSystemId));
139 std::vector<ProtectionSystemSpecificInfo> protection_systems_info;
141 init_data.data(), init_data.size(), &protection_systems_info)) {
142 return Status(error::PARSER_FAILURE, "Error parsing the PSSH boxes.");
143 }
144 for (const auto& info : protection_systems_info) {
145 std::unique_ptr<PsshBoxBuilder> pssh_builder =
146 PsshBoxBuilder::ParseFromBox(info.psshs.data(), info.psshs.size());
147 if (!pssh_builder)
148 return Status(error::PARSER_FAILURE, "Error parsing the PSSH box.");
149 // Use Widevine PSSH if available otherwise construct a Widevine PSSH
150 // from the first available key ids.
151 if (info.system_id == widevine_system_id) {
152 pssh_data = pssh_builder->pssh_data();
153 break;
154 } else if (pssh_data.empty() && !pssh_builder->key_ids().empty()) {
155 pssh_data =
156 GenerateWidevinePsshDataFromKeyIds(pssh_builder->key_ids());
157 // Continue to see if there is any Widevine PSSH. The KeyId generated
158 // PSSH is only used if a Widevine PSSH could not be found.
159 continue;
160 }
161 }
162 if (pssh_data.empty())
163 return Status(error::INVALID_ARGUMENT, "No supported PSSHs found.");
164 break;
165 }
166 case EmeInitDataType::WEBM: {
167 pssh_data = GenerateWidevinePsshDataFromKeyIds({init_data});
168 break;
169 }
170 case EmeInitDataType::WIDEVINE_CLASSIC:
171 if (init_data.size() < sizeof(asset_id))
172 return Status(error::INVALID_ARGUMENT, "Invalid asset id.");
173 asset_id = absl::big_endian::Load32(init_data.data());
174 break;
175 default:
176 LOG(ERROR) << "Init data type " << static_cast<int>(init_data_type)
177 << " not supported.";
178 return Status(error::INVALID_ARGUMENT, "Unsupported init data type.");
179 }
180 const bool widevine_classic =
181 init_data_type == EmeInitDataType::WIDEVINE_CLASSIC;
182 absl::MutexLock scoped_lock(&mutex_);
183 common_encryption_request_.reset(new CommonEncryptionRequest);
184 if (widevine_classic) {
185 common_encryption_request_->set_asset_id(asset_id);
186 } else {
187 common_encryption_request_->set_pssh_data(pssh_data.data(),
188 pssh_data.size());
189 }
190 return FetchKeysInternal(!kEnableKeyRotation, 0, widevine_classic);
191}
192
193Status WidevineKeySource::GetKey(const std::string& stream_label,
194 EncryptionKey* key) {
195 DCHECK(key);
196 if (encryption_key_map_.find(stream_label) == encryption_key_map_.end()) {
197 return Status(error::INTERNAL_ERROR,
198 "Cannot find key for '" + stream_label + "'.");
199 }
200 *key = *encryption_key_map_[stream_label];
201 return Status::OK;
202}
203
204Status WidevineKeySource::GetKey(const std::vector<uint8_t>& key_id,
205 EncryptionKey* key) {
206 DCHECK(key);
207 for (const auto& pair : encryption_key_map_) {
208 if (pair.second->key_id == key_id) {
209 *key = *pair.second;
210 return Status::OK;
211 }
212 }
213 return Status(error::INTERNAL_ERROR,
214 "Cannot find key with specified key ID");
215}
216
218 uint32_t crypto_period_index,
219 int32_t crypto_period_duration_in_seconds,
220 const std::string& stream_label,
221 EncryptionKey* key) {
222 // TODO(kqyang): This is not elegant. Consider refactoring later.
223 {
224 absl::MutexLock scoped_lock(&mutex_);
225 if (!key_production_started_) {
226 crypto_period_duration_in_seconds_ = crypto_period_duration_in_seconds;
227 // Another client may have a slightly smaller starting crypto period
228 // index. Set the initial value to account for that.
229 first_crypto_period_index_ =
230 crypto_period_index ? crypto_period_index - 1 : 0;
231 DCHECK(!key_pool_);
232 const size_t queue_size = crypto_period_count_ * 10;
233 key_pool_.reset(
234 new EncryptionKeyQueue(queue_size, first_crypto_period_index_));
235 start_key_production_.Notify();
236 key_production_started_ = true;
237 } else if (crypto_period_duration_in_seconds_ !=
238 crypto_period_duration_in_seconds) {
239 return Status(error::INVALID_ARGUMENT,
240 "Crypto period duration should not change.");
241 }
242 }
243 return GetKeyInternal(crypto_period_index, stream_label, key);
244}
245
246void WidevineKeySource::set_signer(std::unique_ptr<RequestSigner> signer) {
247 signer_ = std::move(signer);
248}
249
251 std::unique_ptr<KeyFetcher> key_fetcher) {
252 key_fetcher_ = std::move(key_fetcher);
253}
254
255Status WidevineKeySource::GetKeyInternal(uint32_t crypto_period_index,
256 const std::string& stream_label,
257 EncryptionKey* key) {
258 DCHECK(key_pool_);
259 DCHECK(key);
260
261 std::shared_ptr<EncryptionKeyMap> encryption_key_map;
262 Status status = key_pool_->Peek(crypto_period_index, &encryption_key_map,
263 kGetKeyTimeoutInSeconds * 1000);
264 if (!status.ok()) {
265 if (status.error_code() == error::STOPPED) {
266 CHECK(!common_encryption_request_status_.ok());
267 return common_encryption_request_status_;
268 }
269 return status;
270 }
271
272 if (encryption_key_map->find(stream_label) == encryption_key_map->end()) {
273 return Status(error::INTERNAL_ERROR,
274 "Cannot find key for '" + stream_label + "'.");
275 }
276 *key = *encryption_key_map->at(stream_label);
277 return Status::OK;
278}
279
280void WidevineKeySource::FetchKeysTask() {
281 // Wait until key production is signaled.
282 start_key_production_.WaitForNotification();
283 if (!key_pool_ || key_pool_->Stopped())
284 return;
285
286 Status status = FetchKeysInternal(kEnableKeyRotation,
287 first_crypto_period_index_,
288 false);
289 while (status.ok()) {
290 first_crypto_period_index_ += crypto_period_count_;
291 status = FetchKeysInternal(kEnableKeyRotation,
292 first_crypto_period_index_,
293 false);
294 }
295 common_encryption_request_status_ = status;
296 key_pool_->Stop();
297}
298
299Status WidevineKeySource::FetchKeysInternal(bool enable_key_rotation,
300 uint32_t first_crypto_period_index,
301 bool widevine_classic) {
302 CommonEncryptionRequest request;
303 FillRequest(enable_key_rotation, first_crypto_period_index, &request);
304
305 std::string message;
306 Status status = GenerateKeyMessage(request, &message);
307 if (!status.ok())
308 return status;
309 VLOG(1) << "Message: " << message;
310
311 std::string raw_response;
312 int64_t sleep_duration = kFirstRetryDelayMilliseconds;
313
314 // Perform client side retries if seeing server transient error to workaround
315 // server limitation.
316 for (int i = 0; i < kNumTransientErrorRetries; ++i) {
317 status = key_fetcher_->FetchKeys(server_url_, message, &raw_response);
318 if (status.ok()) {
319 VLOG(1) << "Retry [" << i << "] Response:" << raw_response;
320
321 bool transient_error = false;
322 if (ExtractEncryptionKey(enable_key_rotation, widevine_classic,
323 raw_response, &transient_error))
324 return Status::OK;
325
326 if (!transient_error) {
327 return Status(
328 error::SERVER_ERROR,
329 "Failed to extract encryption key from '" + raw_response + "'.");
330 }
331 } else if (status.error_code() != error::TIME_OUT) {
332 return status;
333 }
334
335 // Exponential backoff.
336 if (i != kNumTransientErrorRetries - 1) {
337 std::this_thread::sleep_for(std::chrono::milliseconds(sleep_duration));
338 sleep_duration *= 2;
339 }
340 }
341 return Status(error::SERVER_ERROR,
342 "Failed to recover from server internal error.");
343}
344
345void WidevineKeySource::FillRequest(bool enable_key_rotation,
346 uint32_t first_crypto_period_index,
347 CommonEncryptionRequest* request) {
348 DCHECK(common_encryption_request_);
349 DCHECK(request);
350 *request = *common_encryption_request_;
351
352 request->add_tracks()->set_type("SD");
353 request->add_tracks()->set_type("HD");
354 request->add_tracks()->set_type("UHD1");
355 request->add_tracks()->set_type("UHD2");
356 request->add_tracks()->set_type("AUDIO");
357
358 request->add_drm_types(ModularDrmType::WIDEVINE);
359
360 if (enable_key_rotation) {
361 request->set_first_crypto_period_index(first_crypto_period_index);
362 request->set_crypto_period_count(crypto_period_count_);
363 request->set_crypto_period_seconds(crypto_period_duration_in_seconds_);
364 }
365
366 if (!group_id_.empty())
367 request->set_group_id(group_id_.data(), group_id_.size());
368
369 std::string video_feature = absl::GetFlag(FLAGS_video_feature);
370 if (!video_feature.empty())
371 request->set_video_feature(video_feature);
372}
373
374Status WidevineKeySource::GenerateKeyMessage(
375 const CommonEncryptionRequest& request,
376 std::string* message) {
377 DCHECK(message);
378
379 SignedModularDrmRequest signed_request;
380 signed_request.set_request(MessageToJsonString(request));
381
382 // Sign the request.
383 if (signer_) {
384 std::string signature;
385 if (!signer_->GenerateSignature(signed_request.request(), &signature))
386 return Status(error::INTERNAL_ERROR, "Signature generation failed.");
387
388 signed_request.set_signature(signature);
389 signed_request.set_signer(signer_->signer_name());
390 }
391
392 *message = MessageToJsonString(signed_request);
393 return Status::OK;
394}
395
396bool WidevineKeySource::ExtractEncryptionKey(
397 bool enable_key_rotation,
398 bool widevine_classic,
399 const std::string& response,
400 bool* transient_error) {
401 DCHECK(transient_error);
402 *transient_error = false;
403
404 SignedModularDrmResponse signed_response_proto;
405 if (!JsonStringToMessage(response, &signed_response_proto)) {
406 LOG(ERROR) << "Failed to convert JSON to proto: " << response;
407 return false;
408 }
409
410 CommonEncryptionResponse response_proto;
411 if (!JsonStringToMessage(signed_response_proto.response(), &response_proto)) {
412 LOG(ERROR) << "Failed to convert JSON to proto: "
413 << signed_response_proto.response();
414 return false;
415 }
416
417 if (response_proto.status() != CommonEncryptionResponse::OK) {
418 LOG(ERROR) << "Received non-OK license response: " << response;
419 // Server may return INTERNAL_ERROR intermittently, which is a transient
420 // error and the next client request may succeed without problem.
421 *transient_error =
422 (response_proto.status() == CommonEncryptionResponse::INTERNAL_ERROR);
423 return false;
424 }
425
426 RCHECK(enable_key_rotation
427 ? response_proto.tracks_size() >= crypto_period_count_
428 : response_proto.tracks_size() >= 1);
429
430 uint32_t current_crypto_period_index = first_crypto_period_index_;
431
432 std::vector<std::vector<uint8_t>> key_ids;
433 for (const auto& track : response_proto.tracks()) {
434 if (!widevine_classic)
435 key_ids.emplace_back(track.key_id().begin(), track.key_id().end());
436 }
437
438 EncryptionKeyMap encryption_key_map;
439 for (const auto& track : response_proto.tracks()) {
440 VLOG(2) << "track " << track.ShortDebugString();
441
442 if (enable_key_rotation) {
443 if (track.crypto_period_index() != current_crypto_period_index) {
444 if (track.crypto_period_index() != current_crypto_period_index + 1) {
445 LOG(ERROR) << "Expecting crypto period index "
446 << current_crypto_period_index << " or "
447 << current_crypto_period_index + 1 << "; Seen "
448 << track.crypto_period_index();
449 return false;
450 }
451 if (!PushToKeyPool(&encryption_key_map))
452 return false;
453 ++current_crypto_period_index;
454 }
455 }
456
457 const std::string& stream_label = track.type();
458 RCHECK(encryption_key_map.find(stream_label) == encryption_key_map.end());
459
460 std::unique_ptr<EncryptionKey> encryption_key(new EncryptionKey());
461 encryption_key->key.assign(track.key().begin(), track.key().end());
462
463 // Get key ID and PSSH data for CENC content only.
464 if (!widevine_classic) {
465 encryption_key->key_id.assign(track.key_id().begin(),
466 track.key_id().end());
467 encryption_key->iv.assign(track.iv().begin(), track.iv().end());
468 encryption_key->key_ids = key_ids;
469
470 if (generate_widevine_protection_system_) {
471 if (track.pssh_size() != 1) {
472 LOG(ERROR) << "Expecting one and only one pssh, seeing "
473 << track.pssh_size();
474 return false;
475 }
476 encryption_key->key_system_info.push_back(
477 ProtectionSystemInfoFromPsshProto(track.pssh(0)));
478 }
479 }
480 encryption_key_map[stream_label] = std::move(encryption_key);
481 }
482
483 DCHECK(!encryption_key_map.empty());
484 if (!enable_key_rotation) {
485 // Merge with previously requested keys.
486 for (auto& pair : encryption_key_map)
487 encryption_key_map_[pair.first] = std::move(pair.second);
488 return true;
489 }
490
491 return PushToKeyPool(&encryption_key_map);
492}
493
494bool WidevineKeySource::PushToKeyPool(
495 EncryptionKeyMap* encryption_key_map) {
496 DCHECK(key_pool_);
497 DCHECK(encryption_key_map);
498 auto encryption_key_map_shared = std::make_shared<EncryptionKeyMap>();
499 encryption_key_map_shared->swap(*encryption_key_map);
500 Status status = key_pool_->Push(encryption_key_map_shared, kInfiniteTimeout);
501 if (!status.ok()) {
502 DCHECK_EQ(error::STOPPED, status.error_code());
503 return false;
504 }
505 return true;
506}
507
508} // namespace media
509} // namespace shaka
static std::unique_ptr< PsshBoxBuilder > ParseFromBox(const uint8_t *data, size_t data_size)
Status GetCryptoPeriodKey(uint32_t crypto_period_index, int32_t crypto_period_duration_in_seconds, const std::string &stream_label, EncryptionKey *key) override
void set_signer(std::unique_ptr< RequestSigner > signer)
void set_key_fetcher(std::unique_ptr< KeyFetcher > key_fetcher)
WidevineKeySource(const std::string &server_url, ProtectionSystem protection_systems, FourCC protection_scheme)
Status GetKey(const std::string &stream_label, EncryptionKey *key) override
Status FetchKeys(EmeInitDataType init_data_type, const std::vector< uint8_t > &init_data) override
All the methods that are virtual are virtual for mocking.
static bool ParseBoxes(const uint8_t *data, size_t data_size, std::vector< ProtectionSystemSpecificInfo > *pssh_boxes)