Shaka Packager SDK
udp_file.cc
1 // Copyright 2014 Google LLC. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file or at
5 // https://developers.google.com/open-source/licenses/bsd
6 
7 #include <packager/file/udp_file.h>
8 
9 #if defined(OS_WIN)
10 #include <ws2tcpip.h>
11 #define close closesocket
12 #define EINTR_CODE WSAEINTR
13 #else
14 #include <arpa/inet.h>
15 #include <errno.h>
16 #include <netinet/in.h>
17 #include <string.h>
18 #include <sys/socket.h>
19 #include <sys/time.h>
20 #include <unistd.h>
21 #define INVALID_SOCKET -1
22 #define EINTR_CODE EINTR
23 // IP_MULTICAST_ALL has been supported since kernel version 2.6.31 but we may be
24 // building on a machine that is older than that.
25 #ifndef IP_MULTICAST_ALL
26 #define IP_MULTICAST_ALL 49
27 #endif
28 #endif // defined(OS_WIN)
29 
30 #include <limits>
31 
32 #include <absl/log/check.h>
33 #include <absl/log/log.h>
34 
35 #include <packager/file/udp_options.h>
36 #include <packager/macros/classes.h>
37 #include <packager/macros/compiler.h>
38 #include <packager/macros/logging.h>
39 
40 namespace shaka {
41 
42 namespace {
43 
44 bool IsIpv4MulticastAddress(const struct in_addr& addr) {
45  return (ntohl(addr.s_addr) & 0xf0000000) == 0xe0000000;
46 }
47 
48 int GetSocketErrorCode() {
49 #if defined(OS_WIN)
50  return WSAGetLastError();
51 #else
52  return errno;
53 #endif
54 }
55 
56 } // anonymous namespace
57 
58 UdpFile::UdpFile(const char* file_name)
59  : File(file_name), socket_(INVALID_SOCKET) {}
60 
61 UdpFile::~UdpFile() {}
62 
63 bool UdpFile::Close() {
64  if (socket_ != INVALID_SOCKET) {
65  close(socket_);
66  socket_ = INVALID_SOCKET;
67  }
68  delete this;
69 #if defined(OS_WIN)
70  if (wsa_started_)
71  WSACleanup();
72 #endif
73  return true;
74 }
75 
76 int64_t UdpFile::Read(void* buffer, uint64_t length) {
77  DCHECK(buffer);
78  DCHECK_GE(length, 65535u)
79  << "Buffer may be too small to read entire datagram.";
80 
81  if (socket_ == INVALID_SOCKET)
82  return -1;
83 
84  int64_t result;
85  do {
86  result = recvfrom(socket_, reinterpret_cast<char*>(buffer),
87  static_cast<int>(length), 0, NULL, 0);
88  } while (result == -1 && GetSocketErrorCode() == EINTR_CODE);
89 
90  return result;
91 }
92 
93 int64_t UdpFile::Write(const void* buffer, uint64_t length) {
94  UNUSED(buffer);
95  UNUSED(length);
96  NOTIMPLEMENTED() << "UdpFile is unwritable!";
97  return -1;
98 }
99 
100 void UdpFile::CloseForWriting() {
101 #if defined(OS_WIN)
102  shutdown(socket_, SD_SEND);
103 #else
104  shutdown(socket_, SHUT_WR);
105 #endif
106 }
107 
108 int64_t UdpFile::Size() {
109  if (socket_ == INVALID_SOCKET)
110  return -1;
111 
112  return std::numeric_limits<int64_t>::max();
113 }
114 
115 bool UdpFile::Flush() {
116  NOTIMPLEMENTED() << "UdpFile is unflushable!";
117  return false;
118 }
119 
120 bool UdpFile::Seek(uint64_t position) {
121  UNUSED(position);
122  NOTIMPLEMENTED() << "UdpFile is unseekable!";
123  return false;
124 }
125 
126 bool UdpFile::Tell(uint64_t* position) {
127  UNUSED(position);
128  NOTIMPLEMENTED() << "UdpFile is unseekable!";
129  return false;
130 }
131 
132 class ScopedSocket {
133  public:
134  explicit ScopedSocket(SOCKET sock_fd) : sock_fd_(sock_fd) {}
135 
136  ~ScopedSocket() {
137  if (sock_fd_ != INVALID_SOCKET)
138  close(sock_fd_);
139  }
140 
141  SOCKET get() { return sock_fd_; }
142 
143  SOCKET release() {
144  SOCKET socket = sock_fd_;
145  sock_fd_ = INVALID_SOCKET;
146  return socket;
147  }
148 
149  private:
150  SOCKET sock_fd_;
151 
152  DISALLOW_COPY_AND_ASSIGN(ScopedSocket);
153 };
154 
155 bool UdpFile::Open() {
156 #if defined(OS_WIN)
157  WSADATA wsa_data;
158  int wsa_error = WSAStartup(MAKEWORD(2, 2), &wsa_data);
159  if (wsa_error != 0) {
160  LOG(ERROR) << "Winsock start up failed with error " << wsa_error;
161  return false;
162  }
163  wsa_started_ = true;
164 #endif // defined(OS_WIN)
165 
166  DCHECK_EQ(INVALID_SOCKET, socket_);
167 
168  std::unique_ptr<UdpOptions> options =
169  UdpOptions::ParseFromString(file_name());
170  if (!options)
171  return false;
172 
173  ScopedSocket new_socket(socket(AF_INET, SOCK_DGRAM, 0));
174  if (new_socket.get() == INVALID_SOCKET) {
175  LOG(ERROR) << "Could not allocate socket, error = " << GetSocketErrorCode();
176  return false;
177  }
178 
179  struct in_addr local_in_addr = {0};
180  if (inet_pton(AF_INET, options->address().c_str(), &local_in_addr) != 1) {
181  LOG(ERROR) << "Malformed IPv4 address " << options->address();
182  return false;
183  }
184 
185  // TODO(kqyang): Support IPv6.
186  struct sockaddr_in local_sock_addr;
187  memset(&local_sock_addr, 0, sizeof(local_sock_addr));
188  local_sock_addr.sin_family = AF_INET;
189  local_sock_addr.sin_port = htons(options->port());
190 
191  const bool is_multicast = IsIpv4MulticastAddress(local_in_addr);
192  if (is_multicast) {
193  local_sock_addr.sin_addr.s_addr = htonl(INADDR_ANY);
194  } else {
195  local_sock_addr.sin_addr = local_in_addr;
196  }
197 
198  if (options->reuse()) {
199  const int optval = 1;
200  if (setsockopt(new_socket.get(), SOL_SOCKET, SO_REUSEADDR,
201  reinterpret_cast<const char*>(&optval),
202  sizeof(optval)) < 0) {
203  LOG(ERROR) << "Could not apply the SO_REUSEADDR property to the UDP "
204  "socket, error = "
205  << GetSocketErrorCode();
206  return false;
207  }
208  }
209 
210  if (bind(new_socket.get(),
211  reinterpret_cast<struct sockaddr*>(&local_sock_addr),
212  sizeof(local_sock_addr)) < 0) {
213  LOG(ERROR) << "Could not bind UDP socket, error = " << GetSocketErrorCode();
214  return false;
215  }
216 
217  if (is_multicast) {
218  if (options->is_source_specific_multicast()) {
219  struct ip_mreq_source source_multicast_group;
220 
221  source_multicast_group.imr_multiaddr = local_in_addr;
222  if (inet_pton(AF_INET,
223  options->interface_address().c_str(),
224  &source_multicast_group.imr_interface) != 1) {
225  LOG(ERROR) << "Malformed IPv4 interface address "
226  << options->interface_address();
227  return false;
228  }
229  if (inet_pton(AF_INET,
230  options->source_address().c_str(),
231  &source_multicast_group.imr_sourceaddr) != 1) {
232  LOG(ERROR) << "Malformed IPv4 source specific multicast address "
233  << options->source_address();
234  return false;
235  }
236 
237  if (setsockopt(new_socket.get(),
238  IPPROTO_IP,
239  IP_ADD_SOURCE_MEMBERSHIP,
240  reinterpret_cast<const char*>(&source_multicast_group),
241  sizeof(source_multicast_group)) < 0) {
242  LOG(ERROR) << "Failed to join multicast group, error = "
243  << GetSocketErrorCode();
244  return false;
245  }
246  } else {
247  // this is a v2 join without a specific source.
248  struct ip_mreq multicast_group;
249 
250  multicast_group.imr_multiaddr = local_in_addr;
251 
252  if (inet_pton(AF_INET, options->interface_address().c_str(),
253  &multicast_group.imr_interface) != 1) {
254  LOG(ERROR) << "Malformed IPv4 interface address "
255  << options->interface_address();
256  return false;
257  }
258 
259  if (setsockopt(new_socket.get(), IPPROTO_IP, IP_ADD_MEMBERSHIP,
260  reinterpret_cast<const char*>(&multicast_group),
261  sizeof(multicast_group)) < 0) {
262  LOG(ERROR) << "Failed to join multicast group, error = "
263  << GetSocketErrorCode();
264  return false;
265  }
266  }
267 
268 #if defined(__linux__)
269  // Disable IP_MULTICAST_ALL to avoid interference caused when two sockets
270  // are bound to the same port but joined to different multicast groups.
271  const int optval_zero = 0;
272  if (setsockopt(new_socket.get(), IPPROTO_IP, IP_MULTICAST_ALL,
273  reinterpret_cast<const char*>(&optval_zero),
274  sizeof(optval_zero)) < 0 &&
275  GetSocketErrorCode() != ENOPROTOOPT) {
276  LOG(ERROR) << "Failed to disable IP_MULTICAST_ALL option, error = "
277  << GetSocketErrorCode();
278  return false;
279  }
280 #endif // #if defined(__linux__)
281  }
282 
283  // Set timeout if needed.
284  if (options->timeout_us() != 0) {
285  struct timeval tv;
286  tv.tv_sec = options->timeout_us() / 1000000;
287  tv.tv_usec = options->timeout_us() % 1000000;
288  if (setsockopt(new_socket.get(), SOL_SOCKET, SO_RCVTIMEO,
289  reinterpret_cast<const char*>(&tv), sizeof(tv)) < 0) {
290  LOG(ERROR) << "Failed to set socket timeout, error = "
291  << GetSocketErrorCode();
292  return false;
293  }
294  }
295 
296  if (options->buffer_size() > 0) {
297  const int receive_buffer_size = options->buffer_size();
298  if (setsockopt(new_socket.get(), SOL_SOCKET, SO_RCVBUF,
299  reinterpret_cast<const char*>(&receive_buffer_size),
300  sizeof(receive_buffer_size)) < 0) {
301  LOG(ERROR) << "Failed to set the maximum receive buffer size, error = "
302  << GetSocketErrorCode();
303  return false;
304  }
305  }
306 
307  socket_ = new_socket.release();
308  return true;
309 }
310 
311 } // namespace shaka
UdpFile(const char *address_and_port)
Definition: udp_file.cc:58
static std::unique_ptr< UdpOptions > ParseFromString(std::string_view udp_url)
Definition: udp_options.cc:84
All the methods that are virtual are virtual for mocking.
Definition: crypto_flags.cc:66