Coverage Report

Created: 2025-02-21 14:37

/root/bitcoin/src/util/sock.cpp
Line
Count
Source (jump to first uncovered line)
1
// Copyright (c) 2020-2022 The Bitcoin Core developers
2
// Distributed under the MIT software license, see the accompanying
3
// file COPYING or http://www.opensource.org/licenses/mit-license.php.
4
5
#include <common/system.h>
6
#include <compat/compat.h>
7
#include <logging.h>
8
#include <tinyformat.h>
9
#include <util/sock.h>
10
#include <util/syserror.h>
11
#include <util/threadinterrupt.h>
12
#include <util/time.h>
13
14
#include <memory>
15
#include <stdexcept>
16
#include <string>
17
18
#ifdef USE_POLL
19
#include <poll.h>
20
#endif
21
22
static inline bool IOErrorIsPermanent(int err)
23
0
{
24
0
    return err != WSAEAGAIN && err != WSAEINTR && err != WSAEWOULDBLOCK && err != WSAEINPROGRESS;
25
0
}
26
27
0
Sock::Sock(SOCKET s) : m_socket(s) {}
28
29
Sock::Sock(Sock&& other)
30
0
{
31
0
    m_socket = other.m_socket;
32
0
    other.m_socket = INVALID_SOCKET;
33
0
}
34
35
0
Sock::~Sock() { Close(); }
36
37
Sock& Sock::operator=(Sock&& other)
38
0
{
39
0
    Close();
40
0
    m_socket = other.m_socket;
41
0
    other.m_socket = INVALID_SOCKET;
42
0
    return *this;
43
0
}
44
45
ssize_t Sock::Send(const void* data, size_t len, int flags) const
46
0
{
47
0
    return send(m_socket, static_cast<const char*>(data), len, flags);
48
0
}
49
50
ssize_t Sock::Recv(void* buf, size_t len, int flags) const
51
0
{
52
0
    return recv(m_socket, static_cast<char*>(buf), len, flags);
53
0
}
54
55
int Sock::Connect(const sockaddr* addr, socklen_t addr_len) const
56
0
{
57
0
    return connect(m_socket, addr, addr_len);
58
0
}
59
60
int Sock::Bind(const sockaddr* addr, socklen_t addr_len) const
61
0
{
62
0
    return bind(m_socket, addr, addr_len);
63
0
}
64
65
int Sock::Listen(int backlog) const
66
0
{
67
0
    return listen(m_socket, backlog);
68
0
}
69
70
std::unique_ptr<Sock> Sock::Accept(sockaddr* addr, socklen_t* addr_len) const
71
0
{
72
#ifdef WIN32
73
    static constexpr auto ERR = INVALID_SOCKET;
74
#else
75
0
    static constexpr auto ERR = SOCKET_ERROR;
76
0
#endif
77
78
0
    std::unique_ptr<Sock> sock;
79
80
0
    const auto socket = accept(m_socket, addr, addr_len);
81
0
    if (socket != ERR) {
82
0
        try {
83
0
            sock = std::make_unique<Sock>(socket);
84
0
        } catch (const std::exception&) {
85
#ifdef WIN32
86
            closesocket(socket);
87
#else
88
0
            close(socket);
89
0
#endif
90
0
        }
91
0
    }
92
93
0
    return sock;
94
0
}
95
96
int Sock::GetSockOpt(int level, int opt_name, void* opt_val, socklen_t* opt_len) const
97
0
{
98
0
    return getsockopt(m_socket, level, opt_name, static_cast<char*>(opt_val), opt_len);
99
0
}
100
101
int Sock::SetSockOpt(int level, int opt_name, const void* opt_val, socklen_t opt_len) const
102
0
{
103
0
    return setsockopt(m_socket, level, opt_name, static_cast<const char*>(opt_val), opt_len);
104
0
}
105
106
int Sock::GetSockName(sockaddr* name, socklen_t* name_len) const
107
0
{
108
0
    return getsockname(m_socket, name, name_len);
109
0
}
110
111
bool Sock::SetNonBlocking() const
112
0
{
113
#ifdef WIN32
114
    u_long on{1};
115
    if (ioctlsocket(m_socket, FIONBIO, &on) == SOCKET_ERROR) {
116
        return false;
117
    }
118
#else
119
0
    const int flags{fcntl(m_socket, F_GETFL, 0)};
120
0
    if (flags == SOCKET_ERROR) {
121
0
        return false;
122
0
    }
123
0
    if (fcntl(m_socket, F_SETFL, flags | O_NONBLOCK) == SOCKET_ERROR) {
124
0
        return false;
125
0
    }
126
0
#endif
127
0
    return true;
128
0
}
129
130
bool Sock::IsSelectable() const
131
0
{
132
0
#if defined(USE_POLL) || defined(WIN32)
133
0
    return true;
134
#else
135
    return m_socket < FD_SETSIZE;
136
#endif
137
0
}
138
139
bool Sock::Wait(std::chrono::milliseconds timeout, Event requested, Event* occurred) const
140
0
{
141
    // We need a `shared_ptr` owning `this` for `WaitMany()`, but don't want
142
    // `this` to be destroyed when the `shared_ptr` goes out of scope at the
143
    // end of this function. Create it with a custom noop deleter.
144
0
    std::shared_ptr<const Sock> shared{this, [](const Sock*) {}};
145
146
0
    EventsPerSock events_per_sock{std::make_pair(shared, Events{requested})};
147
148
0
    if (!WaitMany(timeout, events_per_sock)) {
149
0
        return false;
150
0
    }
151
152
0
    if (occurred != nullptr) {
153
0
        *occurred = events_per_sock.begin()->second.occurred;
154
0
    }
155
156
0
    return true;
157
0
}
158
159
bool Sock::WaitMany(std::chrono::milliseconds timeout, EventsPerSock& events_per_sock) const
160
0
{
161
0
#ifdef USE_POLL
162
0
    std::vector<pollfd> pfds;
163
0
    for (const auto& [sock, events] : events_per_sock) {
164
0
        pfds.emplace_back();
165
0
        auto& pfd = pfds.back();
166
0
        pfd.fd = sock->m_socket;
167
0
        if (events.requested & RECV) {
168
0
            pfd.events |= POLLIN;
169
0
        }
170
0
        if (events.requested & SEND) {
171
0
            pfd.events |= POLLOUT;
172
0
        }
173
0
    }
174
175
0
    if (poll(pfds.data(), pfds.size(), count_milliseconds(timeout)) == SOCKET_ERROR) {
176
0
        return false;
177
0
    }
178
179
0
    assert(pfds.size() == events_per_sock.size());
180
0
    size_t i{0};
181
0
    for (auto& [sock, events] : events_per_sock) {
182
0
        assert(sock->m_socket == static_cast<SOCKET>(pfds[i].fd));
183
0
        events.occurred = 0;
184
0
        if (pfds[i].revents & POLLIN) {
185
0
            events.occurred |= RECV;
186
0
        }
187
0
        if (pfds[i].revents & POLLOUT) {
188
0
            events.occurred |= SEND;
189
0
        }
190
0
        if (pfds[i].revents & (POLLERR | POLLHUP)) {
191
0
            events.occurred |= ERR;
192
0
        }
193
0
        ++i;
194
0
    }
195
196
0
    return true;
197
#else
198
    fd_set recv;
199
    fd_set send;
200
    fd_set err;
201
    FD_ZERO(&recv);
202
    FD_ZERO(&send);
203
    FD_ZERO(&err);
204
    SOCKET socket_max{0};
205
206
    for (const auto& [sock, events] : events_per_sock) {
207
        if (!sock->IsSelectable()) {
208
            return false;
209
        }
210
        const auto& s = sock->m_socket;
211
        if (events.requested & RECV) {
212
            FD_SET(s, &recv);
213
        }
214
        if (events.requested & SEND) {
215
            FD_SET(s, &send);
216
        }
217
        FD_SET(s, &err);
218
        socket_max = std::max(socket_max, s);
219
    }
220
221
    timeval tv = MillisToTimeval(timeout);
222
223
    if (select(socket_max + 1, &recv, &send, &err, &tv) == SOCKET_ERROR) {
224
        return false;
225
    }
226
227
    for (auto& [sock, events] : events_per_sock) {
228
        const auto& s = sock->m_socket;
229
        events.occurred = 0;
230
        if (FD_ISSET(s, &recv)) {
231
            events.occurred |= RECV;
232
        }
233
        if (FD_ISSET(s, &send)) {
234
            events.occurred |= SEND;
235
        }
236
        if (FD_ISSET(s, &err)) {
237
            events.occurred |= ERR;
238
        }
239
    }
240
241
    return true;
242
#endif /* USE_POLL */
243
0
}
244
245
void Sock::SendComplete(Span<const unsigned char> data,
246
                        std::chrono::milliseconds timeout,
247
                        CThreadInterrupt& interrupt) const
248
0
{
249
0
    const auto deadline = GetTime<std::chrono::milliseconds>() + timeout;
250
0
    size_t sent{0};
251
252
0
    for (;;) {
253
0
        const ssize_t ret{Send(data.data() + sent, data.size() - sent, MSG_NOSIGNAL)};
254
255
0
        if (ret > 0) {
256
0
            sent += static_cast<size_t>(ret);
257
0
            if (sent == data.size()) {
258
0
                break;
259
0
            }
260
0
        } else {
261
0
            const int err{WSAGetLastError()};
262
0
            if (IOErrorIsPermanent(err)) {
263
0
                throw std::runtime_error(strprintf("send(): %s", NetworkErrorString(err)));
264
0
            }
265
0
        }
266
267
0
        const auto now = GetTime<std::chrono::milliseconds>();
268
269
0
        if (now >= deadline) {
270
0
            throw std::runtime_error(strprintf(
271
0
                "Send timeout (sent only %u of %u bytes before that)", sent, data.size()));
272
0
        }
273
274
0
        if (interrupt) {
275
0
            throw std::runtime_error(strprintf(
276
0
                "Send interrupted (sent only %u of %u bytes before that)", sent, data.size()));
277
0
        }
278
279
        // Wait for a short while (or the socket to become ready for sending) before retrying
280
        // if nothing was sent.
281
0
        const auto wait_time = std::min(deadline - now, std::chrono::milliseconds{MAX_WAIT_FOR_IO});
282
0
        (void)Wait(wait_time, SEND);
283
0
    }
284
0
}
285
286
void Sock::SendComplete(Span<const char> data,
287
                        std::chrono::milliseconds timeout,
288
                        CThreadInterrupt& interrupt) const
289
0
{
290
0
    SendComplete(MakeUCharSpan(data), timeout, interrupt);
291
0
}
292
293
std::string Sock::RecvUntilTerminator(uint8_t terminator,
294
                                      std::chrono::milliseconds timeout,
295
                                      CThreadInterrupt& interrupt,
296
                                      size_t max_data) const
297
0
{
298
0
    const auto deadline = GetTime<std::chrono::milliseconds>() + timeout;
299
0
    std::string data;
300
0
    bool terminator_found{false};
301
302
    // We must not consume any bytes past the terminator from the socket.
303
    // One option is to read one byte at a time and check if we have read a terminator.
304
    // However that is very slow. Instead, we peek at what is in the socket and only read
305
    // as many bytes as possible without crossing the terminator.
306
    // Reading 64 MiB of random data with 262526 terminator chars takes 37 seconds to read
307
    // one byte at a time VS 0.71 seconds with the "peek" solution below. Reading one byte
308
    // at a time is about 50 times slower.
309
310
0
    for (;;) {
311
0
        if (data.size() >= max_data) {
312
0
            throw std::runtime_error(
313
0
                strprintf("Received too many bytes without a terminator (%u)", data.size()));
314
0
        }
315
316
0
        char buf[512];
317
318
0
        const ssize_t peek_ret{Recv(buf, std::min(sizeof(buf), max_data - data.size()), MSG_PEEK)};
319
320
0
        switch (peek_ret) {
321
0
        case -1: {
322
0
            const int err{WSAGetLastError()};
323
0
            if (IOErrorIsPermanent(err)) {
324
0
                throw std::runtime_error(strprintf("recv(): %s", NetworkErrorString(err)));
325
0
            }
326
0
            break;
327
0
        }
328
0
        case 0:
329
0
            throw std::runtime_error("Connection unexpectedly closed by peer");
330
0
        default:
331
0
            auto end = buf + peek_ret;
332
0
            auto terminator_pos = std::find(buf, end, terminator);
333
0
            terminator_found = terminator_pos != end;
334
335
0
            const size_t try_len{terminator_found ? terminator_pos - buf + 1 :
336
0
                                                    static_cast<size_t>(peek_ret)};
337
338
0
            const ssize_t read_ret{Recv(buf, try_len, 0)};
339
340
0
            if (read_ret < 0 || static_cast<size_t>(read_ret) != try_len) {
341
0
                throw std::runtime_error(
342
0
                    strprintf("recv() returned %u bytes on attempt to read %u bytes but previous "
343
0
                              "peek claimed %u bytes are available",
344
0
                              read_ret, try_len, peek_ret));
345
0
            }
346
347
            // Don't include the terminator in the output.
348
0
            const size_t append_len{terminator_found ? try_len - 1 : try_len};
349
350
0
            data.append(buf, buf + append_len);
351
352
0
            if (terminator_found) {
353
0
                return data;
354
0
            }
355
0
        }
356
357
0
        const auto now = GetTime<std::chrono::milliseconds>();
358
359
0
        if (now >= deadline) {
360
0
            throw std::runtime_error(strprintf(
361
0
                "Receive timeout (received %u bytes without terminator before that)", data.size()));
362
0
        }
363
364
0
        if (interrupt) {
365
0
            throw std::runtime_error(strprintf(
366
0
                "Receive interrupted (received %u bytes without terminator before that)",
367
0
                data.size()));
368
0
        }
369
370
        // Wait for a short while (or the socket to become ready for reading) before retrying.
371
0
        const auto wait_time = std::min(deadline - now, std::chrono::milliseconds{MAX_WAIT_FOR_IO});
372
0
        (void)Wait(wait_time, RECV);
373
0
    }
374
0
}
375
376
bool Sock::IsConnected(std::string& errmsg) const
377
0
{
378
0
    if (m_socket == INVALID_SOCKET) {
379
0
        errmsg = "not connected";
380
0
        return false;
381
0
    }
382
383
0
    char c;
384
0
    switch (Recv(&c, sizeof(c), MSG_PEEK)) {
385
0
    case -1: {
386
0
        const int err = WSAGetLastError();
387
0
        if (IOErrorIsPermanent(err)) {
388
0
            errmsg = NetworkErrorString(err);
389
0
            return false;
390
0
        }
391
0
        return true;
392
0
    }
393
0
    case 0:
394
0
        errmsg = "closed";
395
0
        return false;
396
0
    default:
397
0
        return true;
398
0
    }
399
0
}
400
401
void Sock::Close()
402
0
{
403
0
    if (m_socket == INVALID_SOCKET) {
404
0
        return;
405
0
    }
406
#ifdef WIN32
407
    int ret = closesocket(m_socket);
408
#else
409
0
    int ret = close(m_socket);
410
0
#endif
411
0
    if (ret) {
412
0
        LogPrintf("Error closing socket %d: %s\n", m_socket, NetworkErrorString(WSAGetLastError()));
413
0
    }
414
0
    m_socket = INVALID_SOCKET;
415
0
}
416
417
bool Sock::operator==(SOCKET s) const
418
0
{
419
0
    return m_socket == s;
420
0
};
421
422
std::string NetworkErrorString(int err)
423
0
{
424
#if defined(WIN32)
425
    return Win32ErrorString(err);
426
#else
427
    // On BSD sockets implementations, NetworkErrorString is the same as SysErrorString.
428
0
    return SysErrorString(err);
429
0
#endif
430
0
}