Shaka Packager SDK
Loading...
Searching...
No Matches
udp_file.cc
1// Copyright 2014 Google LLC. All rights reserved.
2//
3// Use of this source code is governed by a BSD-style
4// license that can be found in the LICENSE file or at
5// https://developers.google.com/open-source/licenses/bsd
6
7#include <packager/file/udp_file.h>
8
9#if defined(OS_WIN)
10#include <ws2tcpip.h>
11#define close closesocket
12#define EINTR_CODE WSAEINTR
13#else
14#include <arpa/inet.h>
15#include <errno.h>
16#include <netinet/in.h>
17#include <string.h>
18#include <sys/socket.h>
19#include <sys/time.h>
20#include <unistd.h>
21#define INVALID_SOCKET -1
22#define EINTR_CODE EINTR
23// IP_MULTICAST_ALL has been supported since kernel version 2.6.31 but we may be
24// building on a machine that is older than that.
25#ifndef IP_MULTICAST_ALL
26#define IP_MULTICAST_ALL 49
27#endif
28#endif // defined(OS_WIN)
29
30#include <limits>
31
32#include <absl/log/check.h>
33#include <absl/log/log.h>
34
35#include <packager/file/udp_options.h>
36#include <packager/macros/classes.h>
37#include <packager/macros/compiler.h>
38#include <packager/macros/logging.h>
39
40namespace shaka {
41
42namespace {
43
44bool IsIpv4MulticastAddress(const struct in_addr& addr) {
45 return (ntohl(addr.s_addr) & 0xf0000000) == 0xe0000000;
46}
47
48int GetSocketErrorCode() {
49#if defined(OS_WIN)
50 return WSAGetLastError();
51#else
52 return errno;
53#endif
54}
55
56} // anonymous namespace
57
58UdpFile::UdpFile(const char* file_name)
59 : File(file_name), socket_(INVALID_SOCKET) {}
60
61UdpFile::~UdpFile() {}
62
63bool UdpFile::Close() {
64 if (socket_ != INVALID_SOCKET) {
65 close(socket_);
66 socket_ = INVALID_SOCKET;
67 }
68 delete this;
69#if defined(OS_WIN)
70 if (wsa_started_)
71 WSACleanup();
72#endif
73 return true;
74}
75
76int64_t UdpFile::Read(void* buffer, uint64_t length) {
77 DCHECK(buffer);
78 DCHECK_GE(length, 65535u)
79 << "Buffer may be too small to read entire datagram.";
80
81 if (socket_ == INVALID_SOCKET)
82 return -1;
83
84 int64_t result;
85 do {
86 result = recvfrom(socket_, reinterpret_cast<char*>(buffer),
87 static_cast<int>(length), 0, NULL, 0);
88 } while (result == -1 && GetSocketErrorCode() == EINTR_CODE);
89
90 return result;
91}
92
93int64_t UdpFile::Write(const void* buffer, uint64_t length) {
94 UNUSED(buffer);
95 UNUSED(length);
96 NOTIMPLEMENTED() << "UdpFile is unwritable!";
97 return -1;
98}
99
100void UdpFile::CloseForWriting() {
101#if defined(OS_WIN)
102 shutdown(socket_, SD_SEND);
103#else
104 shutdown(socket_, SHUT_WR);
105#endif
106}
107
108int64_t UdpFile::Size() {
109 if (socket_ == INVALID_SOCKET)
110 return -1;
111
112 return std::numeric_limits<int64_t>::max();
113}
114
115bool UdpFile::Flush() {
116 NOTIMPLEMENTED() << "UdpFile is unflushable!";
117 return false;
118}
119
120bool UdpFile::Seek(uint64_t position) {
121 UNUSED(position);
122 NOTIMPLEMENTED() << "UdpFile is unseekable!";
123 return false;
124}
125
126bool UdpFile::Tell(uint64_t* position) {
127 UNUSED(position);
128 NOTIMPLEMENTED() << "UdpFile is unseekable!";
129 return false;
130}
131
132class ScopedSocket {
133 public:
134 explicit ScopedSocket(SOCKET sock_fd) : sock_fd_(sock_fd) {}
135
136 ~ScopedSocket() {
137 if (sock_fd_ != INVALID_SOCKET)
138 close(sock_fd_);
139 }
140
141 SOCKET get() { return sock_fd_; }
142
143 SOCKET release() {
144 SOCKET socket = sock_fd_;
145 sock_fd_ = INVALID_SOCKET;
146 return socket;
147 }
148
149 private:
150 SOCKET sock_fd_;
151
152 DISALLOW_COPY_AND_ASSIGN(ScopedSocket);
153};
154
155bool UdpFile::Open() {
156#if defined(OS_WIN)
157 WSADATA wsa_data;
158 int wsa_error = WSAStartup(MAKEWORD(2, 2), &wsa_data);
159 if (wsa_error != 0) {
160 LOG(ERROR) << "Winsock start up failed with error " << wsa_error;
161 return false;
162 }
163 wsa_started_ = true;
164#endif // defined(OS_WIN)
165
166 DCHECK_EQ(INVALID_SOCKET, socket_);
167
168 std::unique_ptr<UdpOptions> options =
169 UdpOptions::ParseFromString(file_name());
170 if (!options)
171 return false;
172
173 ScopedSocket new_socket(socket(AF_INET, SOCK_DGRAM, 0));
174 if (new_socket.get() == INVALID_SOCKET) {
175 LOG(ERROR) << "Could not allocate socket, error = " << GetSocketErrorCode();
176 return false;
177 }
178
179 struct in_addr local_in_addr = {0};
180 if (inet_pton(AF_INET, options->address().c_str(), &local_in_addr) != 1) {
181 LOG(ERROR) << "Malformed IPv4 address " << options->address();
182 return false;
183 }
184
185 // TODO(kqyang): Support IPv6.
186 struct sockaddr_in local_sock_addr;
187 memset(&local_sock_addr, 0, sizeof(local_sock_addr));
188 local_sock_addr.sin_family = AF_INET;
189 local_sock_addr.sin_port = htons(options->port());
190
191 const bool is_multicast = IsIpv4MulticastAddress(local_in_addr);
192 if (is_multicast) {
193 local_sock_addr.sin_addr.s_addr = htonl(INADDR_ANY);
194 } else {
195 local_sock_addr.sin_addr = local_in_addr;
196 }
197
198 if (options->reuse()) {
199 const int optval = 1;
200 if (setsockopt(new_socket.get(), SOL_SOCKET, SO_REUSEADDR,
201 reinterpret_cast<const char*>(&optval),
202 sizeof(optval)) < 0) {
203 LOG(ERROR) << "Could not apply the SO_REUSEADDR property to the UDP "
204 "socket, error = "
205 << GetSocketErrorCode();
206 return false;
207 }
208 }
209
210 if (bind(new_socket.get(),
211 reinterpret_cast<struct sockaddr*>(&local_sock_addr),
212 sizeof(local_sock_addr)) < 0) {
213 LOG(ERROR) << "Could not bind UDP socket, error = " << GetSocketErrorCode();
214 return false;
215 }
216
217 if (is_multicast) {
218 if (options->is_source_specific_multicast()) {
219 struct ip_mreq_source source_multicast_group;
220
221 source_multicast_group.imr_multiaddr = local_in_addr;
222 if (inet_pton(AF_INET,
223 options->interface_address().c_str(),
224 &source_multicast_group.imr_interface) != 1) {
225 LOG(ERROR) << "Malformed IPv4 interface address "
226 << options->interface_address();
227 return false;
228 }
229 if (inet_pton(AF_INET,
230 options->source_address().c_str(),
231 &source_multicast_group.imr_sourceaddr) != 1) {
232 LOG(ERROR) << "Malformed IPv4 source specific multicast address "
233 << options->source_address();
234 return false;
235 }
236
237 if (setsockopt(new_socket.get(),
238 IPPROTO_IP,
239 IP_ADD_SOURCE_MEMBERSHIP,
240 reinterpret_cast<const char*>(&source_multicast_group),
241 sizeof(source_multicast_group)) < 0) {
242 LOG(ERROR) << "Failed to join multicast group, error = "
243 << GetSocketErrorCode();
244 return false;
245 }
246 } else {
247 // this is a v2 join without a specific source.
248 struct ip_mreq multicast_group;
249
250 multicast_group.imr_multiaddr = local_in_addr;
251
252 if (inet_pton(AF_INET, options->interface_address().c_str(),
253 &multicast_group.imr_interface) != 1) {
254 LOG(ERROR) << "Malformed IPv4 interface address "
255 << options->interface_address();
256 return false;
257 }
258
259 if (setsockopt(new_socket.get(), IPPROTO_IP, IP_ADD_MEMBERSHIP,
260 reinterpret_cast<const char*>(&multicast_group),
261 sizeof(multicast_group)) < 0) {
262 LOG(ERROR) << "Failed to join multicast group, error = "
263 << GetSocketErrorCode();
264 return false;
265 }
266 }
267
268#if defined(__linux__)
269 // Disable IP_MULTICAST_ALL to avoid interference caused when two sockets
270 // are bound to the same port but joined to different multicast groups.
271 const int optval_zero = 0;
272 if (setsockopt(new_socket.get(), IPPROTO_IP, IP_MULTICAST_ALL,
273 reinterpret_cast<const char*>(&optval_zero),
274 sizeof(optval_zero)) < 0 &&
275 GetSocketErrorCode() != ENOPROTOOPT) {
276 LOG(ERROR) << "Failed to disable IP_MULTICAST_ALL option, error = "
277 << GetSocketErrorCode();
278 return false;
279 }
280#endif // #if defined(__linux__)
281 }
282
283 // Set timeout if needed.
284 if (options->timeout_us() != 0) {
285 struct timeval tv;
286 tv.tv_sec = options->timeout_us() / 1000000;
287 tv.tv_usec = options->timeout_us() % 1000000;
288 if (setsockopt(new_socket.get(), SOL_SOCKET, SO_RCVTIMEO,
289 reinterpret_cast<const char*>(&tv), sizeof(tv)) < 0) {
290 LOG(ERROR) << "Failed to set socket timeout, error = "
291 << GetSocketErrorCode();
292 return false;
293 }
294 }
295
296 if (options->buffer_size() > 0) {
297 const int receive_buffer_size = options->buffer_size();
298 if (setsockopt(new_socket.get(), SOL_SOCKET, SO_RCVBUF,
299 reinterpret_cast<const char*>(&receive_buffer_size),
300 sizeof(receive_buffer_size)) < 0) {
301 LOG(ERROR) << "Failed to set the maximum receive buffer size, error = "
302 << GetSocketErrorCode();
303 return false;
304 }
305 }
306
307 socket_ = new_socket.release();
308 return true;
309}
310
311} // namespace shaka
UdpFile(const char *address_and_port)
Definition udp_file.cc:58
static std::unique_ptr< UdpOptions > ParseFromString(std::string_view udp_url)
All the methods that are virtual are virtual for mocking.