Shaka Packager SDK
media_handler_test_base.h
1 // Copyright 2022 Google LLC. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file or at
5 // https://developers.google.com/open-source/licenses/bsd
6 
7 #ifndef PACKAGER_MEDIA_BASE_MEDIA_HANDLER_TEST_BASE_H_
8 #define PACKAGER_MEDIA_BASE_MEDIA_HANDLER_TEST_BASE_H_
9 
10 #include <absl/strings/escaping.h>
11 #include <absl/strings/numbers.h>
12 #include <gmock/gmock.h>
13 #include <gtest/gtest.h>
14 
15 #include <packager/media/base/media_handler.h>
16 #include <packager/media/base/video_stream_info.h>
17 #include <packager/utils/bytes_to_string_view.h>
18 
19 namespace shaka {
20 namespace media {
21 
22 std::string BoolToString(bool value);
23 std::string ToPrettyString(const std::string& str);
24 
25 bool TryMatchStreamDataType(const StreamDataType& actual,
26  const StreamDataType& expected,
27  ::testing::MatchResultListener* listener);
28 
29 bool TryMatchStreamType(const StreamType& actual,
30  const StreamType& expected,
31  ::testing::MatchResultListener* listener);
32 
33 template <typename T, typename M>
34 bool TryMatch(const T& value,
35  const M& matcher,
36  ::testing::MatchResultListener* listener,
37  const char* value_name) {
38  if (!ExplainMatchResult(matcher, value, listener)) {
39  // Need a space at the start of the string in the case that
40  // it gets combined with another string.
41  *listener << " Mismatch on " << value_name;
42  return false;
43  }
44 
45  return true;
46 }
47 
48 MATCHER_P(IsPsshInfoWithSystemId,
49  system_id,
50  std::string(negation ? "doesn't " : "") + " have system ID " +
51  testing::PrintToString(system_id)) {
52  *result_listener << "which is (" << testing::PrintToString(arg.system_id)
53  << ")";
54  return arg.system_id == system_id;
55 }
56 
57 MATCHER_P4(IsStreamInfo, stream_index, time_scale, encrypted, language, "") {
58  if (!TryMatchStreamDataType(arg->stream_data_type,
59  StreamDataType::kStreamInfo, result_listener)) {
60  return false;
61  }
62 
63  const std::string is_encrypted_string =
64  BoolToString(arg->stream_info->is_encrypted());
65 
66  *result_listener << "which is (" << arg->stream_index << ", "
67  << arg->stream_info->time_scale() << ", "
68  << is_encrypted_string << ", "
69  << arg->stream_info->language() << ")";
70 
71  return TryMatch(arg->stream_index, stream_index, result_listener,
72  "stream_index") &&
73  TryMatch(arg->stream_info->time_scale(), time_scale, result_listener,
74  "time_scale") &&
75  TryMatch(arg->stream_info->is_encrypted(), encrypted, result_listener,
76  "is_encrypted") &&
77  TryMatch(arg->stream_info->language(), language, result_listener,
78  "language");
79 }
80 
81 MATCHER_P3(IsVideoStream, stream_index, trick_play_factor, playback_rate, "") {
82  if (!TryMatchStreamDataType(arg->stream_data_type,
83  StreamDataType::kStreamInfo, result_listener)) {
84  return false;
85  }
86 
87  if (!TryMatchStreamType(arg->stream_info->stream_type(), kStreamVideo,
88  result_listener)) {
89  return false;
90  }
91 
92  const VideoStreamInfo* info =
93  static_cast<const VideoStreamInfo*>(arg->stream_info.get());
94 
95  *result_listener << "which is (" << arg->stream_index << ", "
96  << info->trick_play_factor() << ", " << info->playback_rate()
97  << ")";
98 
99  return TryMatch(arg->stream_index, stream_index, result_listener,
100  "stream_index") &&
101  TryMatch(info->trick_play_factor(), trick_play_factor, result_listener,
102  "trick_play_factor") &&
103  TryMatch(info->playback_rate(), playback_rate, result_listener,
104  "playback_rate");
105 }
106 
107 MATCHER_P5(IsSegmentInfo,
108  stream_index,
109  start_timestamp,
110  duration,
111  subsegment,
112  encrypted,
113  "") {
114  if (!TryMatchStreamDataType(arg->stream_data_type,
115  StreamDataType::kSegmentInfo, result_listener)) {
116  return false;
117  }
118 
119  const std::string is_subsegment_string =
120  BoolToString(arg->segment_info->is_subsegment);
121  const std::string is_encrypted_string =
122  BoolToString(arg->segment_info->is_encrypted);
123 
124  *result_listener << "which is (" << arg->stream_index << ", "
125  << arg->segment_info->start_timestamp << ", "
126  << arg->segment_info->duration << ", "
127  << is_subsegment_string << ", " << is_encrypted_string
128  << ")";
129 
130  return TryMatch(arg->stream_index, stream_index, result_listener,
131  "stream_index") &&
132  TryMatch(arg->segment_info->start_timestamp, start_timestamp,
133  result_listener, "start_timestamp") &&
134  TryMatch(arg->segment_info->duration, duration, result_listener,
135  "duration") &&
136  TryMatch(arg->segment_info->is_subsegment, subsegment, result_listener,
137  "is_subsegment") &&
138  TryMatch(arg->segment_info->is_encrypted, encrypted, result_listener,
139  "is_encrypted");
140 }
141 
142 MATCHER_P6(MatchEncryptionConfig,
143  protection_scheme,
144  crypt_byte_block,
145  skip_byte_block,
146  per_sample_iv_size,
147  constant_iv,
148  key_id,
149  "") {
150  const std::string constant_iv_hex = absl::BytesToHexString(
151  std::string(std::begin(arg.constant_iv), std::end(arg.constant_iv)));
152  const std::string key_id_hex = absl::BytesToHexString(
153  std::string(std::begin(arg.key_id), std::end(arg.key_id)));
154  const std::string protection_scheme_as_string =
155  FourCCToString(arg.protection_scheme);
156  // Convert to integers so that they will print as a number and not a uint8_t
157  // (char).
158  const int crypt_byte_as_int = static_cast<int>(arg.crypt_byte_block);
159  const int skip_byte_as_int = static_cast<int>(arg.skip_byte_block);
160 
161  *result_listener << "which is (" << protection_scheme_as_string << ", "
162  << crypt_byte_as_int << ", " << skip_byte_as_int << ", "
163  << arg.per_sample_iv_size << ", " << constant_iv_hex << ", "
164  << key_id_hex << ")";
165 
166  return TryMatch(arg.protection_scheme, protection_scheme, result_listener,
167  "protection_scheme") &&
168  TryMatch(arg.crypt_byte_block, crypt_byte_block, result_listener,
169  "crypt_byte_block") &&
170  TryMatch(arg.skip_byte_block, skip_byte_block, result_listener,
171  "skip_byte_block") &&
172  TryMatch(arg.per_sample_iv_size, per_sample_iv_size, result_listener,
173  "per_sample_iv_size") &&
174  TryMatch(arg.constant_iv, constant_iv, result_listener,
175  "constant_iv") &&
176  TryMatch(arg.key_id, key_id, result_listener, "key_id");
177 }
178 
179 MATCHER_P5(IsMediaSample,
180  stream_index,
181  timestamp,
182  duration,
183  encrypted,
184  keyframe,
185  "") {
186  if (!TryMatchStreamDataType(arg->stream_data_type,
187  StreamDataType::kMediaSample, result_listener)) {
188  return false;
189  }
190 
191  const std::string is_encrypted_string =
192  BoolToString(arg->media_sample->is_encrypted());
193  const std::string is_key_frame_string =
194  BoolToString(arg->media_sample->is_key_frame());
195 
196  *result_listener << "which is (" << arg->stream_index << ", "
197  << arg->media_sample->dts() << ", "
198  << arg->media_sample->duration() << ", "
199  << is_encrypted_string << ", " << is_key_frame_string << ")";
200 
201  return TryMatch(arg->stream_index, stream_index, result_listener,
202  "stream_index") &&
203  TryMatch(arg->media_sample->dts(), timestamp, result_listener,
204  "dts") &&
205  TryMatch(arg->media_sample->duration(), duration, result_listener,
206  "duration") &&
207  TryMatch(arg->media_sample->is_encrypted(), encrypted, result_listener,
208  "is_encrypted") &&
209  TryMatch(arg->media_sample->is_key_frame(), keyframe, result_listener,
210  "is_key_frame");
211 }
212 
213 MATCHER_P4(IsTextSample, stream_index, id, start_time, end_time, "") {
214  if (!TryMatchStreamDataType(arg->stream_data_type,
215  StreamDataType::kTextSample, result_listener)) {
216  return false;
217  }
218 
219  *result_listener << "which is (" << arg->stream_index << ", "
220  << ToPrettyString(arg->text_sample->id()) << ", "
221  << arg->text_sample->start_time() << ", "
222  << arg->text_sample->EndTime() << ")";
223 
224  return TryMatch(arg->stream_index, stream_index, result_listener,
225  "stream_index") &&
226  TryMatch(arg->text_sample->id(), id, result_listener, "id") &&
227  TryMatch(arg->text_sample->start_time(), start_time, result_listener,
228  "start_time") &&
229  TryMatch(arg->text_sample->EndTime(), end_time, result_listener,
230  "EndTime");
231 }
232 
233 MATCHER_P2(IsCueEvent, stream_index, time_in_seconds, "") {
234  if (!TryMatchStreamDataType(arg->stream_data_type, StreamDataType::kCueEvent,
235  result_listener)) {
236  return false;
237  }
238 
239  *result_listener << "which is (" << arg->stream_index << ", "
240  << arg->cue_event->time_in_seconds << ")";
241 
242  return TryMatch(arg->stream_index, stream_index, result_listener,
243  "stream_index") &&
244  TryMatch(arg->cue_event->time_in_seconds, time_in_seconds,
245  result_listener, "time_in_seconds");
246 }
247 
249  public:
253 
254  private:
255  bool ValidateOutputStreamIndex(size_t index) const override;
256  Status InitializeInternal() override;
257  Status Process(std::unique_ptr<StreamData> stream_data) override;
258 };
259 
261  public:
262  MOCK_METHOD1(OnProcess, void(const StreamData*));
263  MOCK_METHOD1(OnFlush, void(size_t index));
264 
265  private:
266  Status InitializeInternal() override;
267  Status Process(std::unique_ptr<StreamData> stream_data) override;
268  Status OnFlushRequest(size_t index) override;
269 };
270 
272  public:
273  const std::vector<std::unique_ptr<StreamData>>& Cache() const {
274  return stream_data_vector_;
275  }
276 
277  // TODO(vaage) : Remove the use of clear in our tests as it can make flow
278  // of the test harder to understand.
279  void Clear() { stream_data_vector_.clear(); }
280 
281  private:
282  Status InitializeInternal() override;
283  Status Process(std::unique_ptr<StreamData> stream_data) override;
284  Status OnFlushRequest(size_t input_stream_index) override;
285  bool ValidateOutputStreamIndex(size_t stream_index) const override;
286 
287  std::vector<std::unique_ptr<StreamData>> stream_data_vector_;
288 };
289 
290 class MediaHandlerTestBase : public ::testing::Test {
291  public:
292  MediaHandlerTestBase() = default;
293 
294  protected:
295  bool IsVideoCodec(Codec codec) const;
296 
297  std::unique_ptr<StreamInfo> GetVideoStreamInfo(int32_t time_scale) const;
298 
299  std::unique_ptr<StreamInfo> GetVideoStreamInfo(int32_t time_scale,
300  uint32_t width,
301  uint32_t height) const;
302 
303  std::unique_ptr<StreamInfo> GetVideoStreamInfo(int32_t time_scale,
304  Codec codec) const;
305 
306  std::unique_ptr<StreamInfo> GetVideoStreamInfo(int32_t time_scale,
307  Codec codec,
308  uint32_t width,
309  uint32_t height) const;
310 
311  std::unique_ptr<StreamInfo> GetAudioStreamInfo(int32_t time_scale) const;
312 
313  std::unique_ptr<StreamInfo> GetAudioStreamInfo(int32_t time_scale,
314  Codec codec) const;
315 
316  std::shared_ptr<MediaSample> GetMediaSample(int64_t timestamp,
317  int64_t duration,
318  bool is_keyframe) const;
319 
320  std::shared_ptr<MediaSample> GetMediaSample(int64_t timestamp,
321  int64_t duration,
322  bool is_keyframe,
323  const uint8_t* data,
324  size_t data_length) const;
325 
326  std::unique_ptr<SegmentInfo> GetSegmentInfo(int64_t start_timestamp,
327  int64_t duration,
328  bool is_subsegment,
329  int64_t segment_number) const;
330 
331  std::unique_ptr<StreamInfo> GetTextStreamInfo(int32_t timescale) const;
332 
333  std::unique_ptr<TextSample> GetTextSample(const std::string& id,
334  int64_t start,
335  int64_t end,
336  const std::string& payload) const;
337 
338  std::unique_ptr<CueEvent> GetCueEvent(double time_in_seconds) const;
339 
340  // Connect and initialize all handlers.
341  Status SetUpAndInitializeGraph(std::shared_ptr<MediaHandler> handler,
342  size_t input_count,
343  size_t output_count);
344 
345  // Get the input handler at |index|. The values of |index| will match the
346  // call to |AddInput|.
347  FakeInputMediaHandler* Input(size_t index);
348 
349  // Get the output handler at |index|. The values of |index| will match the
350  // call to |AddOutput|.
351  MockOutputMediaHandler* Output(size_t index);
352 
353  private:
355  MediaHandlerTestBase& operator=(const MediaHandlerTestBase&) = delete;
356 
357  std::shared_ptr<MediaHandler> handler_;
358 
359  std::vector<std::shared_ptr<FakeInputMediaHandler>> inputs_;
360  std::vector<std::shared_ptr<MockOutputMediaHandler>> outputs_;
361 };
362 
364  public:
366 
367  protected:
369  void SetUpGraph(size_t num_inputs,
370  size_t num_outputs,
371  std::shared_ptr<MediaHandler> handler);
372 
374  const std::vector<std::unique_ptr<StreamData>>& GetOutputStreamDataVector()
375  const;
376 
379 
381  std::shared_ptr<MediaHandler> some_handler() { return some_handler_; }
382 
384  std::shared_ptr<CachingMediaHandler> next_handler() { return next_handler_; }
385 
386  private:
389  delete;
390 
391  // Downstream handler used in testing graph.
392  std::shared_ptr<CachingMediaHandler> next_handler_;
393  // Some random handler which can be used for testing.
394  std::shared_ptr<MediaHandler> some_handler_;
395 };
396 
397 } // namespace media
398 } // namespace shaka
399 
400 #endif // PACKAGER_MEDIA_BASE_MEDIA_HANDLER_TEST_BASE_H_
std::shared_ptr< MediaHandler > some_handler()
void ClearOutputStreamDataVector()
Clear the output stream data vector.
const std::vector< std::unique_ptr< StreamData > > & GetOutputStreamDataVector() const
std::shared_ptr< CachingMediaHandler > next_handler()
void SetUpGraph(size_t num_inputs, size_t num_outputs, std::shared_ptr< MediaHandler > handler)
Setup a graph using |handler| with |num_inputs| and |num_outputs|.
Status FlushAllDownstreams()
Flush all connected downstream handlers.
Status FlushDownstream(size_t output_stream_index)
Flush the downstream connected at the specified output stream index.
Status Dispatch(std::unique_ptr< StreamData > stream_data) const
All the methods that are virtual are virtual for mocking.
Definition: crypto_flags.cc:66