Mercurial > code
view C++/modules/Socket/Sockets.cpp @ 444:fc055d2a4a2c
Socket: fix kqueue update
author | David Demelier <markand@malikania.fr> |
---|---|
date | Fri, 23 Oct 2015 10:11:34 +0200 |
parents | 9c85d9158990 |
children | 828d3dc89f2d |
line wrap: on
line source
/* * Sockets.cpp -- portable C++ socket wrappers * * Copyright (c) 2013-2015 David Demelier <markand@malikania.fr> * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above * copyright notice and this permission notice appear in all copies. * * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */ #include <algorithm> #include <atomic> #include <cstring> #include <mutex> #include "Sockets.h" namespace net { /* * Portable constants * ------------------------------------------------------------------ */ /* {{{ Constants */ #if defined(_WIN32) const Handle Invalid{INVALID_SOCKET}; const int Failure{SOCKET_ERROR}; #else const Handle Invalid{-1}; const int Failure{-1}; #endif /* }}} */ /* * Portable functions * ------------------------------------------------------------------ */ /* {{{ Functions */ #if defined(_WIN32) namespace { static std::mutex s_mutex; static std::atomic<bool> s_initialized{false}; } // !namespace #endif // !_WIN32 void init() noexcept { #if defined(_WIN32) std::lock_guard<std::mutex> lock(s_mutex); if (!s_initialized) { s_initialized = true; WSADATA wsa; WSAStartup(MAKEWORD(2, 2), &wsa); /* * If SOCKET_WSA_NO_INIT is not set then the user * must also call finish himself. */ #if !defined(SOCKET_NO_AUTO_INIT) atexit(finish); #endif } #endif } void finish() noexcept { #if defined(_WIN32) WSACleanup(); #endif } std::string error(int errn) { #if defined(_WIN32) LPSTR str = nullptr; std::string errmsg = "Unknown error"; FormatMessageA( FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM, NULL, errn, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&str, 0, NULL); if (str) { errmsg = std::string(str); LocalFree(str); } return errmsg; #else return strerror(errn); #endif } std::string error() { #if defined(_WIN32) return error(WSAGetLastError()); #else return error(errno); #endif } /* }}} */ /* * Error class * ------------------------------------------------------------------ */ /* {{{ Error */ Error::Error(Code code, std::string function) : m_code{code} , m_function{std::move(function)} , m_error{error()} { } Error::Error(Code code, std::string function, int n) : m_code{code} , m_function{std::move(function)} , m_error{error(n)} { } Error::Error(Code code, std::string function, std::string error) : m_code{code} , m_function{std::move(function)} , m_error{std::move(error)} { } /* }}} */ /* * Predefine addressed to be used * ------------------------------------------------------------------ */ /* {{{ Addresses */ Ip::Ip(int domain) noexcept : m_domain{domain} { if (m_domain == AF_INET6) { std::memset(&m_sin6, 0, sizeof (sockaddr_in6)); } else { std::memset(&m_sin, 0, sizeof (sockaddr_in)); } } Ip::Ip(int domain, const std::string &host, int port) : m_domain{domain} { if (host == "*") { if (m_domain == AF_INET6) { std::memset(&m_sin6, 0, sizeof (sockaddr_in6)); m_length = sizeof (sockaddr_in6); m_sin6.sin6_addr = in6addr_any; m_sin6.sin6_family = AF_INET6; m_sin6.sin6_port = htons(port); } else { std::memset(&m_sin, 0, sizeof (sockaddr_in)); m_length = sizeof (sockaddr_in); m_sin.sin_addr.s_addr = INADDR_ANY; m_sin.sin_family = AF_INET; m_sin.sin_port = htons(port); } } else { addrinfo hints, *res; std::memset(&hints, 0, sizeof (addrinfo)); hints.ai_family = domain; auto error = getaddrinfo(host.c_str(), std::to_string(port).c_str(), &hints, &res); if (error != 0) { throw Error{Error::System, "getaddrinfo", gai_strerror(error)}; } if (m_domain == AF_INET6) { std::memcpy(&m_sin6, res->ai_addr, res->ai_addrlen); } else { std::memcpy(&m_sin, res->ai_addr, res->ai_addrlen); } m_length = res->ai_addrlen; freeaddrinfo(res); } } Ip::Ip(const struct sockaddr_storage *ss, socklen_t length) : m_length{length} { if (ss->ss_family == AF_INET6) { std::memcpy(&m_sin6, ss, length); } else if (ss->ss_family == AF_INET) { std::memcpy(&m_sin, ss, length); } else { throw std::invalid_argument{"invalid domain for Ip constructor"}; } } #if !defined(_WIN32) Local::Local(std::string path, bool rm) : m_path{std::move(path)} { /* Silently remove the file even if it fails */ if (rm) { ::remove(m_path.c_str()); } /* Copy the path */ std::memset(m_sun.sun_path, 0, sizeof (m_sun.sun_path)); std::strncpy(m_sun.sun_path, m_path.c_str(), sizeof (m_sun.sun_path) - 1); /* Set the parameters */ m_sun.sun_family = AF_UNIX; } Local::Local(const sockaddr_storage &ss, socklen_t length) { if (ss.ss_family == AF_UNIX) { std::memcpy(&m_sun, &ss, length); m_path = reinterpret_cast<const sockaddr_un &>(m_sun).sun_path; } else { throw std::invalid_argument{"invalid domain for local constructor"}; } } #endif // !_WIN32 /* }}} */ const int FlagRead{1 << 0}; const int FlagWrite{1 << 1}; std::vector<ListenerStatus> Select::wait(const ListenerTable &table, int ms) { timeval maxwait, *towait; fd_set readset; fd_set writeset; FD_ZERO(&readset); FD_ZERO(&writeset); Handle max = 0; for (const auto &pair : table) { if (pair.second & FlagRead) { FD_SET(pair.first, &readset); } if (pair.second & FlagWrite) { FD_SET(pair.first, &writeset); } if (pair.first > max) { max = pair.first; } } maxwait.tv_sec = 0; maxwait.tv_usec = ms * 1000; // Set to nullptr for infinite timeout. towait = (ms < 0) ? nullptr : &maxwait; auto error = ::select(max + 1, &readset, &writeset, nullptr, towait); if (error == Failure) { throw Error{Error::System, "select"}; } if (error == 0) { throw Error{Error::Timeout, "select", "Timeout while listening"}; } std::vector<ListenerStatus> sockets; for (const auto &pair : table) { if (FD_ISSET(pair.first, &readset)) { sockets.push_back(ListenerStatus{pair.first, FlagRead}); } if (FD_ISSET(pair.first, &writeset)) { sockets.push_back(ListenerStatus{pair.first, FlagWrite}); } } return sockets; } /* -------------------------------------------------------- * Poll implementation * -------------------------------------------------------- */ #if defined(SOCKET_HAVE_POLL) #if defined(_WIN32) # define poll WSAPoll #endif short Poll::topoll(int flags) const noexcept { short result(0); if (flags & FlagRead) { result |= POLLIN; } if (flags & FlagWrite) { result |= POLLOUT; } return result; } int Poll::toflags(short &event) const noexcept { int flags = 0; /* * Poll implementations mark the socket differently regarding * the disconnection of a socket. * * At least, even if POLLHUP or POLLIN is set, recv() always * return 0 so we mark the socket as readable. */ if ((event & POLLIN) || (event & POLLHUP)) { flags |= FlagRead; } if (event & POLLOUT) { flags |= FlagWrite; } /* Reset event for safety */ event = 0; return flags; } void Poll::set(const ListenerTable &, Handle h, int flags, bool add) { if (add) { m_fds.push_back(pollfd{h, topoll(flags), 0}); } else { auto it = std::find_if(m_fds.begin(), m_fds.end(), [&] (const struct pollfd &pfd) { return pfd.fd == h; }); it->events |= topoll(flags); } } void Poll::unset(const ListenerTable &, Handle h, int flags, bool remove) { auto it = std::find_if(m_fds.begin(), m_fds.end(), [&] (const struct pollfd &pfd) { return pfd.fd == h; }); if (remove) { m_fds.erase(it); } else { it->events &= ~(topoll(flags)); } } std::vector<ListenerStatus> Poll::wait(const ListenerTable &, int ms) { auto result = poll(m_fds.data(), m_fds.size(), ms); if (result == 0) { throw Error{Error::Timeout, "select", "Timeout while listening"}; } if (result < 0) { throw Error{Error::System, "poll"}; } std::vector<ListenerStatus> sockets; for (auto &fd : m_fds) { if (fd.revents != 0) { sockets.push_back(ListenerStatus{fd.fd, toflags(fd.revents)}); } } return sockets; } #endif // !SOCKET_HAVE_POLL /* -------------------------------------------------------- * Epoll implementation * -------------------------------------------------------- */ #if defined(SOCKET_HAVE_EPOLL) uint32_t Epoll::toepoll(int flags) const noexcept { uint32_t events = 0; if (flags & FlagRead) { events |= EPOLLIN; } if (flags & FlagWrite) { events |= EPOLLOUT; } return events; } int Epoll::toflags(uint32_t events) const noexcept { int flags = 0; if ((events & EPOLLIN) || (events & EPOLLHUP)) { flags |= FlagRead; } if (events & EPOLLOUT) { flags |= FlagWrite; } return flags; } void Epoll::update(Handle h, int op, int flags) { struct epoll_event ev; std::memset(&ev, 0, sizeof (struct epoll_event)); ev.events = flags; ev.data.fd = h; if (epoll_ctl(m_handle, op, h, &ev) < 0) { throw Error{Error::System, "epoll_ctl"}; } } Epoll::Epoll() : m_handle(epoll_create1(0)) { if (m_handle < 0) { throw Error{Error::System, "epoll_create"}; } } Epoll::~Epoll() { close(m_handle); } /* * Add a new epoll_event or just update it. */ void Epoll::set(const ListenerTable &, Handle h, int flags, bool add) { update(h, add ? EPOLL_CTL_ADD : EPOLL_CTL_MOD, toepoll(flags)); if (add) { m_events.resize(m_events.size() + 1); } } /* * Unset is a bit complicated case because SocketListener tells us which * flag to remove but to update epoll descriptor we need to pass * the effective flags that we want to be applied. * * So we put the same flags that are currently effective and remove the * requested one. */ void Epoll::unset(const ListenerTable &table, Handle sc, int flags, bool remove) { if (remove) { update(sc, EPOLL_CTL_DEL, 0); m_events.resize(m_events.size() - 1); } else { update(sc, EPOLL_CTL_MOD, table.at(sc) & ~(toepoll(flags))); } } std::vector<ListenerStatus> Epoll::wait(const ListenerTable &, int ms) { int ret = epoll_wait(m_handle, m_events.data(), m_events.size(), ms); std::vector<ListenerStatus> result; if (ret == 0) { throw Error{Error::Timeout, "epoll_wait"}; } if (ret < 0) { throw Error{Error::System, "epoll_wait"}; } for (int i = 0; i < ret; ++i) { result.push_back(ListenerStatus{m_events[i].data.fd, toflags(m_events[i].events)}); } return result; } #endif // !SOCKET_HAVE_EPOLL /* -------------------------------------------------------- * Kqueue implementation * -------------------------------------------------------- */ #if defined(SOCKET_HAVE_KQUEUE) Kqueue::Kqueue() : m_handle(kqueue()) { if (m_handle < 0) { throw Error{Error::System, "kqueue"}; } } Kqueue::~Kqueue() { close(m_handle); } void Kqueue::update(Handle h, int filter, int flags) { struct kevent ev; EV_SET(&ev, h, filter, flags, 0, 0, nullptr); if (kevent(m_handle, &ev, 1, nullptr, 0, nullptr) < 0) { throw Error{Error::System, "kevent"}; } } void Kqueue::set(const ListenerTable &, Handle h, int flags, bool add) { if (flags & FlagRead) { update(h, EVFILT_READ, EV_ADD | EV_ENABLE); } if (flags & FlagWrite) { update(h, EVFILT_WRITE, EV_ADD | EV_ENABLE); } if (add) { m_result.resize(m_result.size() + 1); } } void Kqueue::unset(const ListenerTable &, Handle h, int flags, bool remove) { if (flags & FlagRead) { update(h, EVFILT_READ, EV_DELETE); } if (flags & FlagWrite) { update(h, EVFILT_WRITE, EV_DELETE); } if (remove) { m_result.resize(m_result.size() - 1); } } std::vector<ListenerStatus> Kqueue::wait(const ListenerTable &, int ms) { std::vector<ListenerStatus> sockets; timespec ts = { 0, 0 }; timespec *pts = (ms <= 0) ? nullptr : &ts; ts.tv_sec = ms / 1000; ts.tv_nsec = (ms % 1000) * 1000000; int nevents = kevent(m_handle, nullptr, 0, &m_result[0], m_result.capacity(), pts); if (nevents == 0) { throw Error{Error::Timeout, "kevent"}; } if (nevents < 0) { throw Error{Error::System, "kevent"}; } for (int i = 0; i < nevents; ++i) { sockets.push_back(ListenerStatus{ static_cast<Handle>(m_result[i].ident), m_result[i].filter == EVFILT_READ ? FlagRead : FlagWrite }); } return sockets; } #endif // !SOCKET_HAVE_KQUEUE #if !defined(SOCKET_NO_SSL) namespace ssl { namespace { std::mutex mutex; std::atomic<bool> initialized{false}; } // !namespace void finish() noexcept { ERR_free_strings(); } void init() noexcept { std::lock_guard<std::mutex> lock{mutex}; if (!initialized) { initialized = true; SSL_library_init(); SSL_load_error_strings(); #if !defined(SOCKET_NO_AUTO_SSL_INIT) atexit(finish); #endif // SOCKET_NO_AUTO_SSL_INIT } } } // !ssl #endif // SOCKET_NO_SSL } // !net