Mercurial > code
changeset 258:4ad3c85ab73e
Sockets:
* set(), get() now take template to determine the size
* recv(), recvfrom() can take a template char array to determine the size
SocketListener:
* Additional preferred poll method now supported
* Support for both reading and writing polling
author | David Demelier <markand@malikania.fr> |
---|---|
date | Sun, 05 Oct 2014 11:00:16 +0200 |
parents | 60f71f245c5b |
children | 0b3fcc5ed8eb |
files | C++/Socket.cpp C++/Socket.h C++/SocketListener.cpp C++/SocketListener.h C++/Tests/Sockets/CMakeLists.txt C++/Tests/Sockets/main.cpp CMakeLists.txt |
diffstat | 7 files changed, 1313 insertions(+), 116 deletions(-) [+] |
line wrap: on
line diff
--- a/C++/Socket.cpp Fri Oct 03 16:26:26 2014 +0200 +++ b/C++/Socket.cpp Sun Oct 05 11:00:16 2014 +0200 @@ -22,19 +22,28 @@ #include "SocketAddress.h" /* -------------------------------------------------------- - * SocketError implementation + * Socket exceptions * -------------------------------------------------------- */ -SocketError::SocketError(std::string error) +namespace error { + +const char *Timeout::what() const noexcept { - m_error = std::move(error); + return "Timeout exception"; } -const char *SocketError::what() const noexcept +const char *InProgress::what() const noexcept { - return m_error.c_str(); + return "Operation in progress"; } +const char *WouldBlock::what() const noexcept +{ + return "Operation would block"; +} + +} // !error + /* -------------------------------------------------------- * Socket implementation * -------------------------------------------------------- */ @@ -47,13 +56,20 @@ #endif } +void Socket::finish() +{ +#if defined(_WIN32) + WSACleanup(); +#endif +} + /* -------------------------------------------------------- * System dependent code * -------------------------------------------------------- */ #if defined(_WIN32) -std::string Socket::syserror() +std::string Socket::syserror(int errn) { LPSTR str = nullptr; std::string errmsg = "Unknown error"; @@ -61,7 +77,7 @@ FormatMessageA( FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM, NULL, - WSAGetLastError(), + errn, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&str, 0, NULL); @@ -78,17 +94,19 @@ #include <cerrno> -std::string Socket::syserror() +std::string Socket::syserror(int errn) { - return strerror(errno); + return strerror(errn); } #endif -void Socket::finish() +std::string Socket::syserror() { #if defined(_WIN32) - WSACleanup(); + return syserror(WSAGetLastError()); +#else + return syserror(errno); #endif } @@ -115,7 +133,7 @@ auto addrlen = addr.length(); if (::bind(s.handle(), (sockaddr *)&sa, addrlen) == SOCKET_ERROR) - throw SocketError(Socket::syserror()); + throw error::Failure(Socket::syserror()); } void Standard::connect(Socket &s, const SocketAddress &addr) @@ -123,13 +141,27 @@ auto &sa = addr.address(); auto addrlen = addr.length(); - if (::connect(s.handle(), (sockaddr *)&sa, addrlen) == SOCKET_ERROR) - throw SocketError(Socket::syserror()); + if (::connect(s.handle(), (sockaddr *)&sa, addrlen) == SOCKET_ERROR) { + /* + * Determine if the error comes from a non-blocking connect that cannot be + * accomplished yet. + */ +#if defined(_WIN32) + if (WSAGetLastError() == WSAEWOULDBLOCK) + throw error::InProgress(); + + throw error::Failure(Socket::syserror()); +#else + if (errno == EINPROGRESS) + throw error::InProgress(); + + throw error::Failure(Socket::syserror()); +#endif + } } Socket Standard::accept(Socket &s, SocketAddress &info) { - Socket c; Socket::Handle handle; // Store the information @@ -140,7 +172,7 @@ handle = ::accept(s.handle(), (sockaddr *)&address, &addrlen); if (handle == INVALID_SOCKET) - throw SocketError(Socket::syserror()); + throw error::Failure(Socket::syserror()); // Usually accept works only with SOCK_STREAM info = SocketAddress(address, addrlen); @@ -151,7 +183,7 @@ void Standard::listen(Socket &s, int max) { if (::listen(s.handle(), max) == SOCKET_ERROR) - throw SocketError(Socket::syserror()); + throw error::Failure(Socket::syserror()); } unsigned Standard::recv(Socket &s, void *data, unsigned dataLen) @@ -159,8 +191,19 @@ int nbread; nbread = ::recv(s.handle(), (Socket::Arg)data, dataLen, 0); - if (nbread == SOCKET_ERROR) - throw SocketError(Socket::syserror()); + if (nbread == SOCKET_ERROR) { +#if defined(_WIN32) + if (WSAGetLastError() == WSAEWOULDBLOCK) + throw error::WouldBlock(); + + throw error::Failure(Socket::syserror()); +#else + if (errno == EAGAIN || errno == EWOULDBLOCK) + throw error::WouldBlock(); + + throw error::Failure(Socket::syserror()); +#endif + } return (unsigned)nbread; } @@ -170,8 +213,19 @@ int nbsent; nbsent = ::send(s.handle(), (Socket::ConstArg)data, dataLen, 0); - if (nbsent == SOCKET_ERROR) - throw SocketError(Socket::syserror()); + if (nbsent == SOCKET_ERROR) { +#if defined(_WIN32) + if (WSAGetLastError() == WSAEWOULDBLOCK) + throw error::WouldBlock(); + + throw error::Failure(Socket::syserror()); +#else + if (errno == EAGAIN || errno == EWOULDBLOCK) + throw error::WouldBlock(); + + throw error::Failure(Socket::syserror()); +#endif + } return (unsigned)nbsent; } @@ -187,8 +241,19 @@ addrlen = sizeof (struct sockaddr_storage); nbread = ::recvfrom(s.handle(), (Socket::Arg)data, dataLen, 0, (sockaddr *)&address, &addrlen); - if (nbread == SOCKET_ERROR) - throw SocketError(Socket::syserror()); + if (nbread == SOCKET_ERROR) { +#if defined(_WIN32) + if (WSAGetLastError() == WSAEWOULDBLOCK) + throw error::WouldBlock(); + + throw error::Failure(Socket::syserror()); +#else + if (errno == EAGAIN || errno == EWOULDBLOCK) + throw error::WouldBlock(); + + throw error::Failure(Socket::syserror()); +#endif + } return (unsigned)nbread; } @@ -198,8 +263,19 @@ int nbsent; nbsent = ::sendto(s.handle(), (Socket::ConstArg)data, dataLen, 0, (const sockaddr *)&info.address(), info.length()); - if (nbsent == SOCKET_ERROR) - throw SocketError(Socket::syserror()); + if (nbsent == SOCKET_ERROR) { +#if defined(_WIN32) + if (WSAGetLastError() == WSAEWOULDBLOCK) + throw error::WouldBlock(); + + throw error::Failure(Socket::syserror()); +#else + if (errno == EAGAIN || errno == EWOULDBLOCK) + throw error::WouldBlock(); + + throw error::Failure(Socket::syserror()); +#endif + } return (unsigned)nbsent; } @@ -219,7 +295,7 @@ m_handle = socket(domain, type, protocol); if (m_handle == INVALID_SOCKET) - throw SocketError(syserror()); + throw error::Failure(syserror()); } Socket::Socket(Handle handle) @@ -233,12 +309,6 @@ return m_handle; } -void Socket::set(int level, int name, const void *arg, unsigned argLen) -{ - if (setsockopt(m_handle, level, name, (Socket::ConstArg)arg, argLen) == SOCKET_ERROR) - throw SocketError(syserror()); -} - void Socket::blockMode(bool block) { #if defined(O_NONBLOCK) && !defined(_WIN32) @@ -247,18 +317,18 @@ if ((flags = fcntl(m_handle, F_GETFL, 0)) == -1) flags = 0; - if (!block) + if (block) flags &= ~(O_NONBLOCK); else flags |= O_NONBLOCK; if (fcntl(m_handle, F_SETFL, flags) == -1) - throw SocketError(Socket::syserror()); + throw error::Failure(Socket::syserror()); #else unsigned long flags = (block) ? 0 : 1; if (ioctlsocket(m_handle, FIONBIO, &flags) == SOCKET_ERROR) - throw SocketError(Socket::syserror()); + throw error::Failure(Socket::syserror()); #endif }
--- a/C++/Socket.h Fri Oct 03 16:26:26 2014 +0200 +++ b/C++/Socket.h Sun Oct 05 11:00:16 2014 +0200 @@ -51,22 +51,65 @@ class Socket; class SocketAddress; +namespace error { + /** - * @class SocketError - * @brief socket error reporting + * @class Timeout + * @brief Describe a timeout expiration + * + * Usually thrown on timeout in SocketListener::select. + */ +class Timeout final : public std::exception { +public: + const char *what() const noexcept override; +}; + +/** + * @class InProgress + * @brief Operation cannot be accomplished now * - * This class is mainly used in all socket operations that may fail. + * Usually thrown in a non-blocking connect call. + */ +class InProgress final : public std::exception { +public: + const char *what() const noexcept override; +}; + +/** + * @class WouldBlock + * @brief The operation would block + * + * Usually thrown in a non-blocking connect send or receive. */ -class SocketError final : public std::exception { +class WouldBlock final : public std::exception { +public: + const char *what() const noexcept override; +}; + +/** + * @class Failure + * @brief General socket failure + * + * An operation failed. + */ +class Failure final : public std::exception { private: - std::string m_error; + std::string m_message; public: - SocketError(std::string error); + inline Failure(std::string message) + : m_message(std::move(message)) + { + } - const char *what() const noexcept override; + inline const char *what() const noexcept override + { + return m_message.c_str(); + } }; +} // !error + /** * @class SocketInterface * @brief Interface to implement @@ -96,7 +139,8 @@ * * @param s the socket * @param addr the address - * @throw SocketError on error + * @throw error::Failure on error + * @throw error::InProgress if the socket is marked non-blocking and connection cannot be established yet */ virtual void connect(Socket &s, const SocketAddress &address) = 0; @@ -106,7 +150,7 @@ * @param s the socket * @param info the optional client info * @return a client ready to use - * @throw SocketError on error + * @throw error::Failure on error */ virtual Socket accept(Socket &s, SocketAddress &info) = 0; @@ -115,7 +159,7 @@ * * @param s the socket * @param max the max number of clients - * @throw SocketError on error + * @throw error::Failure on error */ virtual void listen(Socket &s, int max) = 0; @@ -126,7 +170,8 @@ * @param data the destination pointer * @param dataLen max length to receive * @return the number of bytes received - * @throw SocketError on error + * @throw error::Failure on error + * @throw error::WouldBlock if the socket is marked non-blocking and the operation would block */ virtual unsigned recv(Socket &s, void *data, unsigned len) = 0; @@ -139,7 +184,8 @@ * @param dataLen max length to receive * @param info the client info * @return the number of bytes received - * @throw SocketError on error + * @throw error::Failure on error + * @throw error::WouldBlock if the socket is marked non-blocking and the operation would block */ virtual unsigned recvfrom(Socket &s, void *data, unsigned len, SocketAddress &info) = 0; @@ -150,7 +196,8 @@ * @param data the data to send * @param dataLen the data length * @return the number of bytes sent - * @throw SocketError on error + * @throw error::Failure on error + * @throw error::WouldBlock if the socket is marked non-blocking and the operation would block */ virtual unsigned send(Socket &s, const void *data, unsigned len) = 0; @@ -162,7 +209,8 @@ * @param dataLen the data length * @param address the address * @return the number of bytes sent - * @throw SocketError on error + * @throw error::Failure on error + * @throw error::WouldBlock if the socket is marked non-blocking and the operation would block */ virtual unsigned sendto(Socket &s, const void *data, unsigned len, const SocketAddress &info) = 0; }; @@ -193,7 +241,7 @@ protected: Iface m_interface; //!< the interface - Handle m_handle { 0 }; //!< the socket shared pointer + Handle m_handle; //!< the socket shared pointer public: /** @@ -202,13 +250,22 @@ static void init(); /** - * Get the last socket system error. + * Get the last socket system error. The error is set from errno or from + * WSAGetLastError on Windows. * * @return a string message */ static std::string syserror(); /** + * Get the last system error. + * + * @param errn the error number (errno or WSAGetLastError) + * @return the error + */ + static std::string syserror(int errn); + + /** * To be called before exiting. */ static void finish(); @@ -224,7 +281,7 @@ * @param domain the domain * @param type the type * @param protocol the protocol - * @throw SocketError on error + * @throw error::Failure on error */ Socket(int domain, int type, int protocol); @@ -253,15 +310,41 @@ * @param level the setting level * @param name the name * @param arg the value - * @param argLen the argument length - * @throw SocketError on error + * @throw error::Failure on error */ - void set(int level, int name, const void *arg, unsigned argLen); + template <typename Argument> + void set(int level, int name, const Argument &arg) + { + if (setsockopt(m_handle, level, name, (Socket::ConstArg)&arg, sizeof (arg)) == SOCKET_ERROR) + throw error::Failure(syserror()); + } + + /** + * Get an option for the socket. + * + * @param level the setting level + * @param name the name + * @param def the default value + * @note since getsockopt return the real size of the argument, it is recommended to set integer and such to 0 + * @throw error::Failure on error + */ + template <typename Argument> + Argument get(int level, int name, const Argument &def = Argument()) + { + Argument arg(def); + socklen_t size = sizeof (arg); + + if (getsockopt(m_handle, level, name, (Socket::Arg)&level, &size) == SOCKET_ERROR) + throw error::Failure(syserror()); + + return arg; + } /** * Enable or disable blocking mode. * * @param block the mode + * @throw error::Failure on error */ void blockMode(bool block = true); @@ -293,7 +376,7 @@ * Accept a client without getting its info. * * @return a client ready to use - * @throw SocketError on error + * @throw error::Failure on error */ Socket accept(); @@ -322,13 +405,26 @@ } /** + * Overload for char array. + * + * @param data the destination buffer + * @throw error::Failure on error + * @throw error::WouldBlock if the socket is marked non-blocking and the operation would block + */ + template <size_t Size> + inline unsigned recv(char (&data)[Size]) + { + return recv(data, sizeof (data)); + } + + /** * Receive from a connection-less socket without getting * client information. * * @param data the destination pointer * @param dataLen max length to receive * @return the number of bytes received - * @throw SocketError on error + * @throw error::Failure on error */ unsigned recvfrom(void *data, unsigned dataLen); @@ -341,6 +437,32 @@ } /** + * Overload for char array. + * + * @param data the destination buffer + * @throw error::Failure on error + * @throw error::WouldBlock if the socket is marked non-blocking and the operation would block + */ + template <size_t Size> + inline unsigned recvfrom(char (&data)[Size]) + { + return recvfrom(data, sizeof (data)); + } + + /** + * Overload for char array. + * + * @param data the destination buffer + * @throw error::Failure on error + * @throw error::WouldBlock if the socket is marked non-blocking and the operation would block + */ + template <size_t Size> + inline unsigned recvfrom(char (&data)[Size], SocketAddress &info) + { + return recvfrom(data, sizeof (data), info); + } + + /** * @copydoc SocketInterface::send */ inline unsigned send(const void *data, unsigned dataLen)
--- a/C++/SocketListener.cpp Fri Oct 03 16:26:26 2014 +0200 +++ b/C++/SocketListener.cpp Sun Oct 05 11:00:16 2014 +0200 @@ -17,67 +17,256 @@ */ #include <algorithm> +#include <map> +#include <set> +#include <tuple> +#include <vector> #include "SocketListener.h" -const char *SocketTimeout::what() const noexcept +/* -------------------------------------------------------- + * Select implementation + * -------------------------------------------------------- */ + +/** + * @class SelectMethod + * @brief Implements select(2) + * + * This class is the fallback of any other method, it is not preferred at all for many reasons. + */ +class SelectMethod final : public SocketListener::Interface { +private: + std::set<Socket> m_rsockets; + std::set<Socket> m_wsockets; + std::map<Socket::Handle, std::tuple<Socket, SocketDirection>> m_lookup; + + fd_set m_readset; + fd_set m_writeset; + + Socket::Handle m_max { 0 }; + +public: + SelectMethod(); + void add(Socket &&s, SocketDirection direction) override; + void remove(const Socket &s, SocketDirection direction) override; + void list(const SocketListener::MapFunc &func) override; + void clear() override; + unsigned size() const override; + SocketStatus select(int ms) override; +}; + +SelectMethod::SelectMethod() { - return "Timeout occured"; -} - -SocketListener::SocketListener(int count) -{ - m_sockets.reserve(count); + FD_ZERO(&m_readset); + FD_ZERO(&m_writeset); } -void SocketListener::add(Socket s) +void SelectMethod::add(Socket &&s, SocketDirection direction) { - m_sockets.push_back(std::move(s)); + if (m_lookup.count(s.handle()) > 0) + std::get<1>(m_lookup[s.handle()]) |= direction; + else + m_lookup[s.handle()] = std::make_tuple(s, direction); + + if ((direction & SocketDirection::Read) == SocketDirection::Read) { + m_rsockets.insert(s); + FD_SET(s.handle(), &m_readset); + } + if ((direction & SocketDirection::Write) == SocketDirection::Write) { + m_wsockets.insert(s); + FD_SET(s.handle(), &m_writeset); + } + + if (s.handle() > m_max) + m_max = s.handle(); } -void SocketListener::remove(const Socket &s) +void SelectMethod::remove(const Socket &s, SocketDirection direction) { - m_sockets.erase(std::remove(m_sockets.begin(), m_sockets.end(), s), m_sockets.end()); + std::get<1>(m_lookup[s.handle()]) &= ~(direction); + + if (static_cast<int>(std::get<1>(m_lookup[s.handle()])) == 0) + m_lookup.erase(s.handle()); + + if ((direction & SocketDirection::Read) == SocketDirection::Read) { + m_rsockets.erase(s.handle()); + FD_CLR(s.handle(), &m_readset); + } + if ((direction & SocketDirection::Write) == SocketDirection::Write) { + m_wsockets.erase(s.handle()); + FD_CLR(s.handle(), &m_writeset); + } + + // Refind the max file descriptor + if (m_lookup.size() > 0) { + m_max = std::get<0>(std::max_element(m_lookup.begin(), m_lookup.end())->second).handle(); + } else + m_max = 0; } -void SocketListener::clear() +void SelectMethod::list(const SocketListener::MapFunc &func) { - m_sockets.clear(); + for (auto &s : m_lookup) + func(std::get<0>(s.second), std::get<1>(s.second)); } -unsigned SocketListener::size() +void SelectMethod::clear() { - return m_sockets.size(); + m_rsockets.clear(); + m_wsockets.clear(); + m_lookup.clear(); + + FD_ZERO(&m_readset); + FD_ZERO(&m_writeset); + + m_max = 0; +} + +unsigned SelectMethod::size() const +{ + return m_rsockets.size() + m_wsockets.size(); } -Socket &SocketListener::select(int s, int us) +SocketStatus SelectMethod::select(int ms) { - fd_set fds; timeval maxwait, *towait; - auto fdmax = m_sockets.front().handle(); + + maxwait.tv_sec = 0; + maxwait.tv_usec = ms * 1000; + + // Set to nullptr for infinite timeout. + towait = (ms <= 0) ? nullptr : &maxwait; + + auto error = ::select(m_max + 1, &m_readset, &m_writeset, NULL, towait); + if (error == SOCKET_ERROR) + throw error::Failure(Socket::syserror()); + if (error == 0) + throw error::Timeout(); + + for (auto &c : m_lookup) + if (FD_ISSET(c.first, &m_readset)) + return { std::get<0>(c.second), SocketDirection::Read }; + for (auto &c : m_lookup) + if (FD_ISSET(c.first, &m_writeset)) + return { std::get<0>(c.second), SocketDirection::Write }; + + throw error::Failure("No socket found"); +} + +/* -------------------------------------------------------- + * Poll implementation + * -------------------------------------------------------- */ - FD_ZERO(&fds); - for (auto &c : m_sockets) { - FD_SET(c.handle(), &fds); - if ((int)c.handle() > fdmax) - fdmax = c.handle(); +#if defined(_WIN32) +# include <Winsock2.h> +# define poll WSAPoll +#else +# include <poll.h> +#endif + +namespace { + +class PollMethod final : public SocketListener::Interface { +private: + std::vector<pollfd> m_fds; + std::map<Socket::Handle, Socket> m_lookup; + + inline short topoll(SocketDirection direction) + { + short result(0); + + if ((direction & SocketDirection::Read) == SocketDirection::Read) + result |= POLLIN; + if ((direction & SocketDirection::Write) == SocketDirection::Write) + result |= POLLOUT; + + return result; + } + + inline SocketDirection todirection(short event) + { + SocketDirection direction = static_cast<SocketDirection>(0); + + if (event & POLLIN) + direction |= SocketDirection::Read; + if (event & POLLOUT) + direction |= SocketDirection::Write; + + return direction; } - maxwait.tv_sec = s; - maxwait.tv_usec = us; +public: + void add(Socket &&s, SocketDirection direction) override; + void remove(const Socket &s, SocketDirection direction) override; + void list(const SocketListener::MapFunc &func) override; + void clear() override; + unsigned size() const override; + SocketStatus select(int ms) override; +}; + +void PollMethod::add(Socket &&s, SocketDirection direction) +{ + m_lookup[s.handle()] = s; + m_fds.push_back({ s.handle(), topoll(direction), 0 }); +} - // Set to NULL for infinite timeout. - towait = (s == 0 && us == 0) ? nullptr : &maxwait; +void PollMethod::remove(const Socket &s, SocketDirection direction) +{ + for (auto i = m_fds.begin(); i != m_fds.end();) { + if (i->fd == s) { + i->events &= ~(topoll(direction)); + + if (i->events == 0) { + m_lookup.erase(i->fd); + i = m_fds.erase(i); + } else { + ++i; + } + } else + ++i; + } +} + +void PollMethod::list(const SocketListener::MapFunc &func) +{ + for (auto &fd : m_fds) + func(m_lookup[fd.fd], todirection(fd.events)); +} - auto error = ::select(fdmax + 1, &fds, NULL, NULL, towait); - if (error == SOCKET_ERROR) - throw SocketError(Socket::syserror()); - if (error == 0) - throw SocketTimeout(); +void PollMethod::clear() +{ + m_fds.clear(); + m_lookup.clear(); +} + +unsigned PollMethod::size() const +{ + return static_cast<unsigned>(m_fds.size()); +} + +SocketStatus PollMethod::select(int ms) +{ + auto result = poll(m_fds.data(), m_fds.size(), ms); + if (result == 0) + throw error::Timeout(); + if (result < 0) + throw error::Failure(Socket::syserror()); - for (Socket &c : m_sockets) - if (FD_ISSET(c.handle(), &fds)) - return c; + for (auto &fd : m_fds) + if (fd.revents != 0) + return { m_lookup[fd.fd], todirection(fd.revents) }; +} + +} // !namespace - throw SocketError("No socket found"); +/* -------------------------------------------------------- + * Socket listener + * -------------------------------------------------------- */ + +SocketListener::SocketListener(SocketMethod method) +{ + if (method == SocketMethod::Poll) + m_interface = std::make_unique<PollMethod>(); + else + m_interface = std::make_unique<SelectMethod>(); }
--- a/C++/SocketListener.h Fri Oct 03 16:26:26 2014 +0200 +++ b/C++/SocketListener.h Sun Oct 05 11:00:16 2014 +0200 @@ -19,17 +19,85 @@ #ifndef _SOCKET_LISTENER_H_ #define _SOCKET_LISTENER_H_ -#include <vector> +#include <chrono> +#include <functional> #include "Socket.h" /** - * @class SocketTimeout - * @brief thrown when a timeout occured + * @enum SocketDirection + * @brief The SocketDirection enum + * + * Bitmask that can be set to both reading and writing. */ -class SocketTimeout final : public std::exception { -public: - const char *what() const noexcept override; +enum class SocketDirection { + Read = (1 << 0), //!< only for receive + Write = (1 << 1) //!< only for sending +}; + +inline SocketDirection operator&(SocketDirection x, SocketDirection y) +{ + return static_cast<SocketDirection>(static_cast<int>(x) & static_cast<int>(y)); +} + +inline SocketDirection operator|(SocketDirection x, SocketDirection y) +{ + return static_cast<SocketDirection>(static_cast<int>(x) | static_cast<int>(y)); +} + +inline SocketDirection operator^(SocketDirection x, SocketDirection y) +{ + return static_cast<SocketDirection>(static_cast<int>(x) ^ static_cast<int>(y)); +} + +inline SocketDirection operator~(SocketDirection x) +{ + return static_cast<SocketDirection>(~static_cast<int>(x)); +} + +inline SocketDirection &operator&=(SocketDirection &x, SocketDirection y) +{ + x = x & y; + + return x; +} + +inline SocketDirection &operator|=(SocketDirection &x, SocketDirection y) +{ + x = x | y; + + return x; +} + +inline SocketDirection &operator^=(SocketDirection &x, SocketDirection y) +{ + x = x ^ y; + + return x; +} + +/** + * @enum SocketMethod + * @brief The SocketMethod enum + * + * Select the method of polling. It is only a preferred method, for example if you + * request for poll but it is not available, select will be used. + */ +enum class SocketMethod { + Select, //!< select(2) method, fallback + Poll //!< poll(2), everywhere possible +}; + +/** + * @struct SocketStatus + * @brief The SocketStatus struct + * + * Result of a select call, returns the first ready socket found with its + * direction. + */ +struct SocketStatus { + Socket socket; //!< which socket is ready + SocketDirection direction; //!< the direction }; /** @@ -39,55 +107,164 @@ * Convenient wrapper around the select() system call. */ class SocketListener final { -private: - std::vector<Socket> m_sockets; +public: + /** + * @brief Function for listing all sockets + */ + using MapFunc = std::function<void (Socket &, SocketDirection)>; + + /** + * @class Interface + * @brief Implement the polling method + */ + class Interface { + public: + /** + * Default destructor. + */ + virtual ~Interface() = default; + + /** + * List all sockets in the interface. + * + * @param func the function + */ + virtual void list(const MapFunc &func) = 0; + + /** + * Add a socket with a specified direction. + * + * @param s the socket + * @param direction the direction + */ + virtual void add(Socket &&s, SocketDirection direction) = 0; + + /** + * Remove a socket with a specified direction. + * + * @param s the socket + * @param direction the direction + */ + virtual void remove(const Socket &s, SocketDirection direction) = 0; + + /** + * Remove all sockets. + */ + virtual void clear() = 0; + + /** + * Get the total number of sockets in the listener. + */ + virtual unsigned size() const = 0; + + /** + * Select a socket. + * + * @param ms the number of milliseconds to wait, -1 means forever + * @return the socket status + * @throw error::Failure on failure + * @throw error::Timeout on timeout + */ + virtual SocketStatus select(int ms) = 0; + }; + + std::unique_ptr<Interface> m_interface; + +#if defined(_WIN32) +# if _WIN32_WINNT >= 0x0600 + static constexpr const SocketMethod PreferredMethod = SocketMethod::Poll; +# else + static constexpr const SocketMethod PreferredMethod = SocketMethod::Select; +# endif +#else + static constexpr const SocketMethod PreferredMethod = SocketMethod::Poll; +#endif public: /** - * Create a socket listener with a specific number of sockets to reserve. + * Create a socket listener. * - * @param count the number of socket to reserve (default: 0) + * @param method the preferred method */ - SocketListener(int count = 0); + SocketListener(SocketMethod method = SocketMethod::Poll); /** * Add a socket to listen to. * * @param s the socket + * @param direction the direction */ - void add(Socket s); + inline void add(Socket s, SocketDirection direction) + { + m_interface->add(std::move(s), direction); + } /** * Remove a socket from the list. * * @param s the socket + * @param direction the direction */ - void remove(const Socket &s); + inline void remove(const Socket &s, SocketDirection direction) + { + m_interface->remove(s, direction); + } /** * Remove every sockets in the listener. */ - void clear(); + inline void clear() + { + m_interface->clear(); + } /** * Get the number of clients in listener. * - * @return the number of clients in the listener. + * @return the total number of sockets in the listener */ - unsigned size(); + inline unsigned size() const + { + return m_interface->size(); + } /** - * Wait for an event in the socket list. If both s and us are set to 0 then - * it waits indefinitely. + * Select a socket. Waits for a specific amount of time specified as the duration. * - * @param s the timeout in seconds - * @param us the timeout in milliseconds - * @see take * @return the socket ready * @throw SocketError on error * @throw SocketTimeout on timeout */ - Socket &select(int s = 0, int us = 0); + template <typename Rep, typename Ratio> + inline SocketStatus select(const std::chrono::duration<Rep, Ratio> &duration) + { + auto cvt = std::chrono::duration_cast<std::chrono::milliseconds>(duration); + + return m_interface->select(cvt.count()); + } + + /** + * Overload that waits indefinitely. + * + * @return the socket ready + * @throw SocketError on error + * @throw SocketTimeout on timeout + */ + inline SocketStatus select() + { + return m_interface->select(-1); + } + + /** + * List every socket in the listener. + * + * @param func the function to call + */ + template <typename Func> + inline void list(Func func) + { + m_interface->list(func); + } }; #endif // !_SOCKET_LISTENER_H_
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/C++/Tests/Sockets/CMakeLists.txt Sun Oct 05 11:00:16 2014 +0200 @@ -0,0 +1,30 @@ +# +# CMakeLists.txt -- tests for sockets +# +# Copyright (c) 2013, 2014 David Demelier <markand@malikania.fr> +# +# Permission to use, copy, modify, and/or distribute this software for any +# purpose with or without fee is hereby granted, provided that the above +# copyright notice and this permission notice appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF +# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +# + +set( + SOURCES + ${code_SOURCE_DIR}/C++/Socket.cpp + ${code_SOURCE_DIR}/C++/Socket.h + ${code_SOURCE_DIR}/C++/SocketAddress.cpp + ${code_SOURCE_DIR}/C++/SocketAddress.h + ${code_SOURCE_DIR}/C++/SocketListener.cpp + ${code_SOURCE_DIR}/C++/SocketListener.h + main.cpp +) + +define_test(socket "${SOURCES}")
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/C++/Tests/Sockets/main.cpp Sun Oct 05 11:00:16 2014 +0200 @@ -0,0 +1,605 @@ +#include <chrono> +#include <sstream> +#include <string> +#include <thread> + +#include <gtest/gtest.h> + +#include <Socket.h> +#include <SocketListener.h> +#include <SocketAddress.h> + +using namespace std::literals::chrono_literals; + +using namespace address; + +/* -------------------------------------------------------- + * Miscellaneous + * -------------------------------------------------------- */ + +TEST(Misc, set) +{ + try { + Socket s(AF_INET6, SOCK_STREAM, 0); + + s.set(IPPROTO_IPV6, IPV6_V6ONLY, false); + ASSERT_FALSE(s.get<bool>(IPPROTO_IPV6, IPV6_V6ONLY)); + + s.set(IPPROTO_IPV6, IPV6_V6ONLY, true); + ASSERT_TRUE(s.get<bool>(IPPROTO_IPV6, IPV6_V6ONLY)); + } catch (const std::exception &ex) { + } +} + +/* -------------------------------------------------------- + * Select tests + * -------------------------------------------------------- */ + +TEST(ListenerMethodSelect, pollAdd) +{ + try { + Socket s(AF_INET, SOCK_STREAM, 0); + Socket s2(AF_INET, SOCK_STREAM, 0); + SocketListener listener(SocketMethod::Select); + + listener.add(s, SocketDirection::Read); + listener.add(s2, SocketDirection::Read); + + ASSERT_EQ(2, listener.size()); + } catch (const std::exception &ex) { + } +} + +TEST(ListenerMethodSelect, pollRemove) +{ + try { + Socket s(AF_INET, SOCK_STREAM, 0); + Socket s2(AF_INET, SOCK_STREAM, 0); + SocketListener listener(SocketMethod::Select); + + listener.add(s, SocketDirection::Read); + listener.add(s2, SocketDirection::Read); + listener.remove(s, SocketDirection::Read); + listener.remove(s2, SocketDirection::Read); + + ASSERT_EQ(0, listener.size()); + } catch (const std::exception &ex) { + } +} + +/* + * Add two sockets for both reading and writing, them remove only reading and then + * move only writing. + */ +TEST(ListenerMethodSelect, pollInOut) +{ + try { + Socket s(AF_INET, SOCK_STREAM, 0); + Socket s2(AF_INET, SOCK_STREAM, 0); + SocketListener listener(SocketMethod::Select); + + listener.add(s, SocketDirection::Read | SocketDirection::Write); + listener.add(s2, SocketDirection::Read | SocketDirection::Write); + + listener.list([&] (Socket &si, SocketDirection dir) { + ASSERT_TRUE(si == s || si == s2); + ASSERT_EQ(0x03, static_cast<int>(dir)); + }); + + listener.remove(s, SocketDirection::Write); + listener.remove(s2, SocketDirection::Write); + + ASSERT_EQ(2, listener.size()); + + listener.list([&] (Socket &si, SocketDirection dir) { + ASSERT_TRUE(si == s || si == s2); + ASSERT_EQ(SocketDirection::Read, dir); + }); + + listener.remove(s, SocketDirection::Read); + listener.remove(s2, SocketDirection::Read); + + ASSERT_EQ(0, listener.size()); + } catch (const std::exception &ex) { + } +} + +/* -------------------------------------------------------- + * Poll tests + * -------------------------------------------------------- */ + +TEST(ListenerMethodPoll, pollAdd) +{ + try { + Socket s(AF_INET, SOCK_STREAM, 0); + Socket s2(AF_INET, SOCK_STREAM, 0); + SocketListener listener(SocketMethod::Poll); + + listener.add(s, SocketDirection::Read); + listener.add(s2, SocketDirection::Read); + + ASSERT_EQ(2, listener.size()); + } catch (const std::exception &ex) { + } +} + +TEST(ListenerMethodPoll, pollRemove) +{ + try { + Socket s(AF_INET, SOCK_STREAM, 0); + Socket s2(AF_INET, SOCK_STREAM, 0); + SocketListener listener(SocketMethod::Poll); + + listener.add(s, SocketDirection::Read); + listener.add(s2, SocketDirection::Read); + listener.remove(s, SocketDirection::Read); + listener.remove(s2, SocketDirection::Read); + + ASSERT_EQ(0, listener.size()); + } catch (const std::exception &ex) { + } +} + +/* + * Add two sockets for both reading and writing, them remove only reading and then + * move only writing. + */ +TEST(ListenerMethodPoll, pollInOut) +{ + try { + Socket s(AF_INET, SOCK_STREAM, 0); + Socket s2(AF_INET, SOCK_STREAM, 0); + SocketListener listener(SocketMethod::Poll); + + listener.add(s, SocketDirection::Read | SocketDirection::Write); + listener.add(s2, SocketDirection::Read | SocketDirection::Write); + + listener.list([&] (Socket &si, SocketDirection dir) { + ASSERT_TRUE(si == s || si == s2); + ASSERT_EQ(0x03, static_cast<int>(dir)); + }); + + listener.remove(s, SocketDirection::Write); + listener.remove(s2, SocketDirection::Write); + + ASSERT_EQ(2, listener.size()); + + listener.list([&] (Socket &si, SocketDirection dir) { + ASSERT_TRUE(si == s || si == s2); + ASSERT_EQ(SocketDirection::Read, dir); + }); + + listener.remove(s, SocketDirection::Read); + listener.remove(s2, SocketDirection::Read); + + ASSERT_EQ(0, listener.size()); + } catch (const std::exception &ex) { + } +} + +/* -------------------------------------------------------- + * Socket listener class + * -------------------------------------------------------- */ + +TEST(Listener, connection) +{ + std::thread client([] () { + Socket client; + + std::this_thread::sleep_for(3s); + + try { + client = Socket(AF_INET, SOCK_STREAM, 0); + client.connect(Internet("localhost", 10000, AF_INET)); + } catch (const std::exception &ex) { + } + + client.close(); + }); + + Socket s; + SocketListener listener; + + try { + s = Socket(AF_INET, SOCK_STREAM, 0); + + s.bind(Internet("localhost", 10000, AF_INET)); + s.listen(8); + + listener.add(s, SocketDirection::Read); + + auto client = listener.select(10s); + + ASSERT_TRUE(client.direction == SocketDirection::Read); + ASSERT_TRUE(client.socket == s); + } catch (const std::exception &ex) { + FAIL() << ex.what(); + } + + s.close(); + client.join(); +} + +TEST(Listener, connectionAndRead) +{ + std::thread thread([] () { + Socket client; + + std::this_thread::sleep_for(3s); + + try { + client = Socket(AF_INET, SOCK_STREAM, 0); + client.connect(Internet("localhost", 10000, AF_INET)); + client.send("hello world"); + } catch (const std::exception &ex) { + } + + client.close(); + }); + + Socket s; + SocketListener listener; + + try { + s = Socket(AF_INET, SOCK_STREAM, 0); + + s.bind(Internet("localhost", 10000, AF_INET)); + s.listen(8); + + // Read for master + listener.add(s, SocketDirection::Read); + + auto result = listener.select(10s); + + ASSERT_TRUE(result.direction == SocketDirection::Read); + ASSERT_TRUE(result.socket == s); + + // Wait for client + auto client = s.accept(); + listener.add(client, SocketDirection::Read); + + result = listener.select(10s); + + ASSERT_TRUE(result.direction == SocketDirection::Read); + ASSERT_TRUE(result.socket == client); + + char data[512]; + auto nb = client.recv(data, sizeof (data) - 1); + + data[nb] = '\0'; + + client.close(); + ASSERT_STREQ("hello world", data); + } catch (const std::exception &ex) { + FAIL() << ex.what(); + } + + s.close(); + thread.join(); +} + +TEST(Listener, bigData) +{ + auto producer = [] () { + std::string data; + + data.reserve(9000000); + for (int i = 0; i < 9000000; ++i) + data.push_back('a'); + + data.push_back('\r'); + data.push_back('\n'); + + std::this_thread::sleep_for(3s); + + Socket client; + SocketListener listener; + + try { + client = Socket(AF_INET, SOCK_STREAM, 0); + + client.connect(Internet("localhost", 10000, AF_INET)); + client.blockMode(false); + listener.add(client, SocketDirection::Write); + + while (data.size() > 0) { + auto s = listener.select(30s).socket; + auto nb = s.send(data.data(), data.size()); + data.erase(0, nb); + } + } catch (const std::exception &ex) { + FAIL() << ex.what(); + } + + client.close(); + }; + + auto consumer = [] () { + std::ostringstream out; + + Socket server; + SocketListener listener; + bool finished(false); + + try { + server = Socket(AF_INET, SOCK_STREAM, 0); + + server.bind(Internet("*", 10000, AF_INET)); + server.listen(10); + listener.add(server, SocketDirection::Read); + + while (!finished) { + auto s = listener.select(60s).socket; + + if (s == server) { + listener.add(s.accept(), SocketDirection::Read); + } else { + char data[512]; + auto nb = s.recv(data, sizeof (data) - 1); + + if (nb == 0) + finished = true; + else { + data[nb] = '\0'; + out << data; + } + } + } + + ASSERT_EQ(9000002, out.str().size()); + } catch (const std::exception &ex) { + FAIL() << ex.what(); + } + + server.close(); + }; + + std::thread tconsumer(consumer); + std::thread tproducer(producer); + + tconsumer.join(); + tproducer.join(); +} + +/* -------------------------------------------------------- + * Basic TCP tests + * -------------------------------------------------------- */ + +TEST(BasicTcp, sendipv4) { + std::thread client([] () { + Socket s; + + std::this_thread::sleep_for(3s); + try { + s = Socket(AF_INET, SOCK_STREAM, 0); + + s.connect(Internet("localhost", 10000, AF_INET)); + s.send("hello"); + } catch (const std::exception &ex) { + } + + s.close(); + }); + + Socket server; + + try { + server = Socket(AF_INET, SOCK_STREAM, 0); + + server.bind(Internet("*", 10000, SOCK_STREAM)); + server.listen(8); + + auto client = server.accept(); + + char data[512]; + auto nb = client.recv(data, sizeof (data) - 1); + + data[nb] = '\0'; + + ASSERT_STREQ("hello", data); + } catch (const std::exception &ex) { + } + + server.close(); + client.join(); +} + +TEST(BasicTcp, sendipv6) { + std::thread client([] () { + Socket s; + + std::this_thread::sleep_for(3s); + try { + s = Socket(AF_INET6, SOCK_STREAM, 0); + + s.connect(Internet("localhost", 10000, AF_INET6)); + s.send("hello"); + } catch (const std::exception &ex) { + } + + s.close(); + }); + + Socket server; + + try { + server = Socket(AF_INET6, SOCK_STREAM, 0); + + server.bind(Internet("*", 10000, SOCK_STREAM)); + server.listen(8); + + auto client = server.accept(); + + char data[512]; + auto nb = client.recv(data, sizeof (data) - 1); + + data[nb] = '\0'; + + ASSERT_STREQ("hello", data); + } catch (const std::exception &ex) { + } + + server.close(); + client.join(); +} + +#if !defined(_WIN32) + +TEST(BasicTcp, sendunix) { + std::thread client([] () { + Socket s; + + std::this_thread::sleep_for(3s); + try { + s = Socket(AF_UNIX, SOCK_STREAM, 0); + + s.connect(Unix("/tmp/gtest-send-tcp-unix.sock")); + s.send("hello"); + } catch (const std::exception &ex) { + } + + s.close(); + }); + + Socket server; + + try { + server = Socket(AF_UNIX, SOCK_STREAM, 0); + + server.bind(Unix("/tmp/gtest-send-tcp-unix.sock", true)); + server.listen(8); + + auto client = server.accept(); + + char data[512]; + auto nb = client.recv(data, sizeof (data) - 1); + + data[nb] = '\0'; + + ASSERT_STREQ("hello", data); + } catch (const std::exception &ex) { + } + + server.close(); + client.join(); +} + +#endif + +TEST(BasicUdp, sendipv4) { + std::thread client([] () { + Socket s; + + std::this_thread::sleep_for(3s); + try { + s = Socket(AF_INET, SOCK_DGRAM, 0); + + s.sendto("hello", Internet("localhost", 10000, AF_INET)); + } catch (const std::exception &ex) { + } + + s.close(); + }); + + Socket server; + + try { + server = Socket(AF_INET, SOCK_DGRAM, 0); + + server.bind(Internet("*", 10000, SOCK_DGRAM)); + + char data[512]; + auto nb = server.recvfrom(data, sizeof (data) - 1); + + data[nb] = '\0'; + + ASSERT_STREQ("hello", data); + } catch (const std::exception &ex) { + } + + server.close(); + client.join(); +} + +TEST(BasicUdp, sendipv6) { + std::thread client([] () { + Socket s; + + std::this_thread::sleep_for(3s); + try { + s = Socket(AF_INET6, SOCK_DGRAM, 0); + + s.sendto("hello", Internet("localhost", 10000, AF_INET6)); + } catch (const std::exception &ex) { + } + + s.close(); + }); + + Socket server; + + try { + server = Socket(AF_INET6, SOCK_DGRAM, 0); + + server.bind(Internet("*", 10000, SOCK_DGRAM)); + + char data[512]; + auto nb = server.recvfrom(data, sizeof (data) - 1); + + data[nb] = '\0'; + + ASSERT_STREQ("hello", data); + } catch (const std::exception &ex) { + } + + server.close(); + client.join(); +} + +#if !defined(_WIN32) + +TEST(BasicUdp, sendunix) { + std::thread client([] () { + Socket s; + + std::this_thread::sleep_for(3s); + try { + s = Socket(AF_UNIX, SOCK_DGRAM, 0); + + s.sendto("hello", Unix("/tmp/gtest-send-udp-unix.sock")); + } catch (const std::exception &ex) { + } + + s.close(); + }); + + Socket server; + + try { + server = Socket(AF_UNIX, SOCK_DGRAM, 0); + + server.bind(Unix("/tmp/gtest-send-udp-unix.sock", true)); + + char data[512]; + auto nb = server.recvfrom(data, sizeof (data) - 1); + + data[nb] = '\0'; + + ASSERT_STREQ("hello", data); + } catch (const std::exception &ex) { + } + + server.close(); + client.join(); +} + +#endif + +int main(int argc, char **argv) +{ + Socket::init(); + testing::InitGoogleTest(&argc, argv); + Socket::finish(); + + return RUN_ALL_TESTS(); +}
--- a/CMakeLists.txt Fri Oct 03 16:26:26 2014 +0200 +++ b/CMakeLists.txt Sun Oct 05 11:00:16 2014 +0200 @@ -55,7 +55,7 @@ option(WITH_OPTIONPARSER "Enable option parser tests" On) option(WITH_PACK "Enable pack functions" On) option(WITH_PARSER "Enable parser tests" On) -option(WITH_SOCKET "Enable sockets tests" On) +option(WITH_SOCKETS "Enable sockets tests" On) option(WITH_TREENODE "Enable treenode tests" On) option(WITH_UTF8 "Enable Utf8 functions tests" On) option(WITH_XMLPARSER "Enable XML tests" On) @@ -96,6 +96,10 @@ add_subdirectory(C++/Tests/Parser) endif () +if (WITH_SOCKETS) + add_subdirectory(C++/Tests/Sockets) +endif () + if (WITH_TREENODE) add_subdirectory(C++/Tests/TreeNode) endif ()