Mercurial > code
changeset 315:c9356cb38c86
Split sockets into SocketTcp and SocketUdp
author | David Demelier <markand@malikania.fr> |
---|---|
date | Mon, 02 Mar 2015 14:00:48 +0100 |
parents | 4c3019385769 |
children | 4c0af1143fc4 |
files | C++/Socket.cpp C++/Socket.h C++/SocketAddress.cpp C++/SocketListener.cpp C++/SocketListener.h C++/SocketSsl.cpp C++/SocketSsl.h C++/SocketTcp.cpp C++/SocketTcp.h C++/SocketUdp.cpp C++/SocketUdp.h C++/Tests/Sockets/CMakeLists.txt C++/Tests/Sockets/main.cpp |
diffstat | 13 files changed, 1191 insertions(+), 1329 deletions(-) [+] |
line wrap: on
line diff
--- a/C++/Socket.cpp Wed Feb 25 13:53:41 2015 +0100 +++ b/C++/Socket.cpp Mon Mar 02 14:00:48 2015 +0100 @@ -20,112 +20,6 @@ #include "Socket.h" #include "SocketAddress.h" -#include "SocketListener.h" - -using namespace direction; - -/* -------------------------------------------------------- - * Socket exceptions - * -------------------------------------------------------- */ - -namespace error { - -/* -------------------------------------------------------- - * Error - * -------------------------------------------------------- */ - -Error::Error(std::string function, std::string error, int code) - : m_function(std::move(function)) - , m_error(std::move(error)) - , m_code(std::move(code)) -{ - m_shortcut = m_function + ": " + m_error; -} - -const std::string &Error::function() const noexcept -{ - return m_function; -} - -const std::string &Error::error() const noexcept -{ - return m_error; -} - -int Error::code() const noexcept -{ - return m_code; -} - -const char *Error::what() const noexcept -{ - return m_shortcut.c_str(); -} - -/* -------------------------------------------------------- - * InProgress - * -------------------------------------------------------- */ - -InProgress::InProgress(std::string func, std::string reason, int code, int direction) - : Error(std::move(func), std::move(reason), code) - , m_direction(direction) -{ -} - -int InProgress::direction() const noexcept -{ - return m_direction; -} - -/* -------------------------------------------------------- - * WouldBlock - * -------------------------------------------------------- */ - -WouldBlock::WouldBlock(std::string func, std::string reason, int code, int direction) - : Error(std::move(func), std::move(reason), code) - , m_direction(direction) -{ -} - -int WouldBlock::direction() const noexcept -{ - return m_direction; -} - -/* -------------------------------------------------------- - * Timeout - * -------------------------------------------------------- */ - -Timeout::Timeout(std::string func) - : m_error(func + ": Timeout exception") -{ -} - -const char *Timeout::what() const noexcept -{ - return m_error.c_str(); -} - -} // !error - -/* -------------------------------------------------------- - * Socket implementation - * -------------------------------------------------------- */ - -void Socket::init() -{ -#if defined(_WIN32) - WSADATA wsa; - WSAStartup(MAKEWORD(2, 2), &wsa); -#endif -} - -void Socket::finish() -{ -#if defined(_WIN32) - WSACleanup(); -#endif -} /* -------------------------------------------------------- * System dependent code @@ -175,47 +69,86 @@ } /* -------------------------------------------------------- - * SocketStandard clear implementation + * Socket class * -------------------------------------------------------- */ -void SocketStandard::bind(Socket &s, const SocketAddress &addr) +Socket::Socket(int domain, int type, int protocol) +{ +#if defined(_WIN32) && !defined(SOCKET_NO_WSA_INIT) + if (!s_initialized) + initialize(); +#endif + + m_handle = ::socket(domain, type, protocol); + + if (m_handle == Invalid) + throw SocketError(); +} + +void Socket::bind(const SocketAddress &address) { - auto &sa = addr.address(); - auto addrlen = addr.length(); + const auto &sa = address.address(); + const auto addrlen = address.length(); + + if (::bind(m_handle, reinterpret_cast<const sockaddr *>(&sa), addrlen) == Error) + throw SocketError("bind", Socket::syserror(), errno); +} - if (::bind(s.handle(), (sockaddr *)&sa, addrlen) == SOCKET_ERROR) - throw error::Error("bind", Socket::syserror(), errno); +void Socket::close() +{ +#if defined(_WIN32) + ::closesocket(m_handle); +#else + ::close(m_handle); +#endif } -void SocketStandard::connect(Socket &s, const SocketAddress &addr) +void Socket::setBlockMode(bool block) { - if (m_connected) - return; +#if defined(O_NONBLOCK) && !defined(_WIN32) + int flags; - auto &sa = addr.address(); - auto addrlen = addr.length(); + if ((flags = fcntl(m_handle, F_GETFL, 0)) == -1) + flags = 0; - if (::connect(s.handle(), (sockaddr *)&sa, addrlen) == SOCKET_ERROR) { - /* - * Determine if the error comes from a non-blocking connect that cannot be - * accomplished yet. - */ -#if defined(_WIN32) - if (WSAGetLastError() == WSAEWOULDBLOCK) - throw error::InProgress("connect", Socket::syserror(WSAEWOULDBLOCK), WSAEWOULDBLOCK, Write); + if (block) + flags &= ~(O_NONBLOCK); + else + flags |= O_NONBLOCK; - throw error::Error("connect", Socket::syserror(WSAEWOULDBLOCK), WSAGetLastError()); + if (fcntl(m_handle, F_SETFL, flags) == Error) + throw SocketError("setBlockMode", Socket::syserror(), errno); #else - if (errno == EINPROGRESS) - throw error::InProgress("connect", Socket::syserror(EINPROGRESS), EINPROGRESS, Write); + unsigned long flags = (block) ? 0 : 1; - throw error::Error("connect", Socket::syserror(), errno); + if (ioctlsocket(m_handle, FIONBIO, &flags) == Error) + throw SocketError("setBlockMode", Socket::syserror(), WSAGetLastError()); #endif - } - - m_connected = true; } +bool operator==(const Socket &s1, const Socket &s2) +{ + return s1.handle() == s2.handle(); +} + +bool operator<(const Socket &s1, const Socket &s2) +{ + return s1.handle() < s2.handle(); +} + + + + + + + + + + + + + +#if 0 void SocketStandard::tryConnect(Socket &s, const SocketAddress &address, int timeout) { if (m_connected) @@ -239,37 +172,6 @@ m_connected = true; } -Socket SocketStandard::accept(Socket &s, SocketAddress &info) -{ - Socket::Handle handle; - - // Store the information - sockaddr_storage address; - socklen_t addrlen; - - addrlen = sizeof (sockaddr_storage); - handle = ::accept(s.handle(), (sockaddr *)&address, &addrlen); - - if (handle == INVALID_SOCKET) { -#if defined(_WIN32) - if (WSAGetLastError() == WSAEWOULDBLOCK) - throw error::WouldBlock("accept", Socket::syserror(WSAEWOULDBLOCK), WSAEWOULDBLOCK, Read); - - throw error::Error("accept", Socket::syserror(), WSAGetLastError()); -#else - if (errno == EAGAIN || errno == EWOULDBLOCK) - throw error::WouldBlock("accept", Socket::syserror(EWOULDBLOCK), EWOULDBLOCK, Read); - - throw error::Error("accept", Socket::syserror(), errno); -#endif - } - - // Usually accept works only with SOCK_STREAM - info = SocketAddress(address, addrlen); - - return Socket(handle, std::make_shared<SocketStandard>()); -} - Socket SocketStandard::tryAccept(Socket &s, SocketAddress &info, int timeout) { SocketListener listener{{s, Read}}; @@ -279,34 +181,6 @@ return accept(s, info); } -void SocketStandard::listen(Socket &s, int max) -{ - if (::listen(s.handle(), max) == SOCKET_ERROR) - throw error::Error("listen", Socket::syserror(), errno); -} - -unsigned SocketStandard::recv(Socket &s, void *data, unsigned dataLen) -{ - int nbread; - - nbread = ::recv(s.handle(), (Socket::Arg)data, dataLen, 0); - if (nbread == SOCKET_ERROR) { -#if defined(_WIN32) - if (WSAGetLastError() == WSAEWOULDBLOCK) - throw error::WouldBlock("recv", Socket::syserror(), WSAEWOULDBLOCK, Read); - - throw error::Error("recv", Socket::syserror(), WSAGetLastError()); -#else - if (errno == EAGAIN || errno == EWOULDBLOCK) - throw error::WouldBlock("recv", Socket::syserror(), errno, Read); - - throw error::Error("recv", Socket::syserror(), errno); -#endif - } - - return (unsigned)nbread; -} - unsigned SocketStandard::tryRecv(Socket &s, void *data, unsigned len, int timeout) { SocketListener listener{{s, Read}}; @@ -316,28 +190,6 @@ return recv(s, data, len); } -unsigned SocketStandard::send(Socket &s, const void *data, unsigned dataLen) -{ - int nbsent; - - nbsent = ::send(s.handle(), (Socket::ConstArg)data, dataLen, 0); - if (nbsent == SOCKET_ERROR) { -#if defined(_WIN32) - if (WSAGetLastError() == WSAEWOULDBLOCK) - throw error::WouldBlock("send", Socket::syserror(), WSAEWOULDBLOCK, Write); - - throw error::Error("send", Socket::syserror(), WSAGetLastError()); -#else - if (errno == EAGAIN || errno == EWOULDBLOCK) - throw error::WouldBlock("send", Socket::syserror(), errno, Write); - - throw error::Error("send", Socket::syserror(), errno); -#endif - } - - return (unsigned)nbsent; -} - unsigned SocketStandard::trySend(Socket &s, const void *data, unsigned len, int timeout) { SocketListener listener{{s, Write}}; @@ -347,36 +199,6 @@ return send(s, data, len); } -unsigned SocketStandard::recvfrom(Socket &s, void *data, unsigned dataLen, SocketAddress &info) -{ - int nbread; - - // Store information - sockaddr_storage address; - socklen_t addrlen; - - addrlen = sizeof (struct sockaddr_storage); - nbread = ::recvfrom(s.handle(), (Socket::Arg)data, dataLen, 0, (sockaddr *)&address, &addrlen); - - info = SocketAddress(address, addrlen); - - if (nbread == SOCKET_ERROR) { -#if defined(_WIN32) - if (WSAGetLastError() == WSAEWOULDBLOCK) - throw error::WouldBlock("recvfrom", Socket::syserror(), WSAEWOULDBLOCK, Read); - - throw error::Error("recvfrom", Socket::syserror(), WSAGetLastError()); -#else - if (errno == EAGAIN || errno == EWOULDBLOCK) - throw error::WouldBlock("recvfrom", Socket::syserror(), errno, Read); - - throw error::Error("recvfrom", Socket::syserror(), errno); -#endif - } - - return (unsigned)nbread; -} - unsigned SocketStandard::tryRecvfrom(Socket &s, void *data, unsigned len, SocketAddress &info, int timeout) { SocketListener listener{{s, Read}}; @@ -386,28 +208,6 @@ return recvfrom(s, data, len, info); } -unsigned SocketStandard::sendto(Socket &s, const void *data, unsigned dataLen, const SocketAddress &info) -{ - int nbsent; - - nbsent = ::sendto(s.handle(), (Socket::ConstArg)data, dataLen, 0, (const sockaddr *)&info.address(), info.length()); - if (nbsent == SOCKET_ERROR) { -#if defined(_WIN32) - if (WSAGetLastError() == WSAEWOULDBLOCK) - throw error::WouldBlock("sendto", Socket::syserror(), WSAEWOULDBLOCK, Write); - - throw error::Error("sendto", Socket::syserror(), errno); -#else - if (errno == EAGAIN || errno == EWOULDBLOCK) - throw error::WouldBlock("sendto", Socket::syserror(), errno, Write); - - throw error::Error("sendto", Socket::syserror(), errno); -#endif - } - - return (unsigned)nbsent; -} - unsigned SocketStandard::trySendto(Socket &s, const void *data, unsigned len, const SocketAddress &info, int timeout) { SocketListener listener{{s, Write}}; @@ -417,42 +217,6 @@ return sendto(s, data, len, info); } -void SocketStandard::close(Socket &s) -{ - (void)closesocket(s.handle()); - - m_connected = false; -} - -/* -------------------------------------------------------- - * Socket code - * -------------------------------------------------------- */ - -Socket::Socket() - : m_interface(std::make_shared<SocketStandard>()) -{ -} - -Socket::Socket(int domain, int type, int protocol) - : Socket() -{ - m_handle = socket(domain, type, protocol); - - if (m_handle == INVALID_SOCKET) - throw error::Error("socket", syserror(), errno); -} - -Socket::Socket(Handle handle, std::shared_ptr<SocketInterface> iface) - : m_interface(std::move(iface)) - , m_handle(handle) -{ -} - -Socket::Handle Socket::handle() const -{ - return m_handle; -} - void Socket::blockMode(bool block) { #if defined(O_NONBLOCK) && !defined(_WIN32) @@ -476,20 +240,6 @@ #endif } -bool operator==(const Socket &s1, const Socket &s2) -{ - return s1.handle() == s2.handle(); -} - -bool operator<(const Socket &s1, const Socket &s2) -{ - return s1.handle() < s2.handle(); -} - -/* - * - */ - Socket Socket::tryAccept(int timeout) { SocketAddress dummy; @@ -517,3 +267,5 @@ return recvfrom(count, dummy); } + +#endif
--- a/C++/Socket.h Wed Feb 25 13:53:41 2015 +0100 +++ b/C++/Socket.h Mon Mar 02 14:00:48 2015 +0100 @@ -16,15 +16,39 @@ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */ -#ifndef _SOCKET_NG_H_ -#define _SOCKET_NG_H_ +#ifndef _SOCKET_NG_3_H_ +#define _SOCKET_NG_3_H_ + +/** + * @file Socket.h + * @brief Portable socket abstraction + * + * User may set the following variables before compiling these files: + * + * SOCKET_NO_WSA_INIT - (bool) Set to false if you don't want Socket class to + * automatically calls WSAStartup() when creating sockets. + * + * Otherwise, you will need to call Socket::init, + * Socket::finish yourself. + * + * SOCKET_NO_SSL_INIT - (bool) Set to false if you don't want OpenSSL to be + * initialized when the first SocketSsl object is created. + * + * SOCKET_HAVE_POLL - (bool) Set to true if poll(2) function is available. + * + * Note: on Windows, this is automatically set if the + * _WIN32_WINNT variable is greater or equal to 0x0600. + */ #include <cstring> #include <exception> -#include <memory> #include <string> #if defined(_WIN32) +# include <atomic> +# include <cstdlib> +# include <mutex> + # include <WinSock2.h> # include <WS2tcpip.h> #else @@ -41,406 +65,42 @@ # include <fcntl.h> # include <netdb.h> # include <unistd.h> - -# define ioctlsocket(s, p, a) ::ioctl(s, p, a) -# define closesocket(s) ::close(s) - -# define gai_strerrorA gai_strerror - -# define INVALID_SOCKET -1 -# define SOCKET_ERROR -1 #endif -class Socket; class SocketAddress; -/** - * Namespace for listener flags. - */ -namespace direction { - -constexpr const int Read = (1 << 0); //!< Wants to read -constexpr const int Write = (1 << 1); //!< Wants to write - -} // !direction - -/** - * Various errors. - */ -namespace error { - -/** - * Base error class, contains - */ -class Error : public std::exception { -private: - std::string m_function; - std::string m_error; - int m_code; - std::string m_shortcut; - -public: - /** - * Construct a full error. - * - * @param function which function - * @param error the error - * @param code the native code - */ - Error(std::string function, std::string error, int code); - - /** - * Get the function which thrown an exception. - * - * @return the function name - */ - const std::string &function() const noexcept; - - /** - * Get the error string. - * - * @return the error - */ - const std::string &error() const noexcept; - - /** - * Get the native code. Use with care because it varies from the system, - * the type of socket and such. - * - * @return the code - */ - int code() const noexcept; - - /** - * Get a brief error message - * - * @return a message - */ - const char *what() const noexcept override; -}; - -/** - * @class InProgress - * @brief Operation cannot be accomplished now - * - * Usually thrown in a non-blocking connect call. - */ -class InProgress final : public Error { -private: - int m_direction; - +class SocketError : public std::exception { public: - /** - * Operation is in progress, contains the direction needed to complete - * the operation. - * - * The direction may be different from the requested operation because - * some sockets (e.g SSL) requires reading even when writing! - * - * @param func the function - * @param reason the reason - * @param code the native code - * @param direction the direction - */ - InProgress(std::string func, std::string reason, int code, int direction); - - /** - * Get the required direction for listening operation requires. - * - * @return the direction required - */ - int direction() const noexcept; -}; - -/** - * @class WouldBlock - * @brief The operation would block - * - * Usually thrown in a non-blocking connect send or receive. - */ -class WouldBlock final : public Error { -private: - int m_direction; - -public: - /** - * Operation would block, contains the direction needed to complete the - * operation. - * - * The direction may be different from the requested operation because - * some sockets (e.g SSL) requires reading even when writing! - * - * @param func the function - * @param reason the reason - * @param code the native code - * @param direction the direction - */ - WouldBlock(std::string func, std::string reason, int code, int direction); - - /** - * Get the required direction for listening operation requires. - * - * @return the direction required - */ - int direction() const noexcept; -}; - -/** - * @class Timeout - * @brief Describe a timeout expiration - * - * Usually thrown on timeout in SocketListener::select. - */ -class Timeout final : public std::exception { -private: - std::string m_error; - -public: - /** - * Timeout exception. - * - * @param func the function name - */ - Timeout(std::string func); - - /** - * The error message. - */ - const char *what() const noexcept override; -}; - -} // !error - -/** - * @class SocketInterface - * @brief Interface to implement - * - * This class implements the socket functions. - */ -class SocketInterface { -public: - /** - * Default destructor. - */ - virtual ~SocketInterface() = default; - - /** - * Bind the socket. - * - * @param s the socket - * @param address the address - * @throw error::Failure on error - */ - virtual void bind(Socket &s, const SocketAddress &address) = 0; - - /** - * Close the socket. - * - * @param s the socket - */ - virtual void close(Socket &s) = 0; + enum { + WouldBlock, ///!< The operation requested would block + InProgress, ///!< The operation is in progress + Timeout, ///!< The action did timeout + System ///!< There is a system error + }; - /** - * Try to connect to the specific address - * - * @param s the socket - * @param address the address - * @throw error::Failure on error - * @throw error::InProgress if the socket is marked non-blocking and connection cannot be established yet - */ - virtual void connect(Socket &s, const SocketAddress &address) = 0; - - /** - * Try to connect without blocking. - * - * @param s the socket - * @param address the address - * @param timeout the timeout in milliseconds - * @warning This function won't block only if the socket is marked non blocking - */ - virtual void tryConnect(Socket &s, const SocketAddress &address, int timeout) = 0; - - /** - * Accept a client. - * - * @param s the socket - * @param info the client information - * @return a client ready to use - * @throw error::Failure on error - */ - virtual Socket accept(Socket &s, SocketAddress &info) = 0; - - /** - * Try to accept a client with a timeout. - * - * @param s the socket - * @param info the client information - * @param timeout the timeout in milliseconds - */ - virtual Socket tryAccept(Socket &s, SocketAddress &info, int timeout) = 0; - - /** - * Listen to a specific number of pending connections. - * - * @param s the socket - * @param max the max number of clients - * @throw error::Failure on error - */ - virtual void listen(Socket &s, int max) = 0; - - /** - * Receive some data. - * - * @param s the socket - * @param data the destination pointer - * @param len max length to receive - * @return the number of bytes received - * @throw error::Failure on error - * @throw error::WouldBlock if the socket is marked non-blocking and the operation would block - */ - virtual unsigned recv(Socket &s, void *data, unsigned len) = 0; - - /** - * Try to receive data without blocking. - * - * @param s the socket - * @param data the destination pointer - * @param len max length to receive - * @param timeout the timeout in milliseconds - * @return the number of bytes received - * @throw error::Failure on error - */ - virtual unsigned tryRecv(Socket &s, void *data, unsigned len, int timeout) = 0; - - /** - * Receive from a connection-less socket. - * - * @param s the socket - * @param data the destination pointer - * @param dataLen max length to receive - * @param info the client information - * @return the number of bytes received - * @throw error::Failure on error - * @throw error::WouldBlock if the socket is marked non-blocking and the operation would block - */ - virtual unsigned recvfrom(Socket &s, void *data, unsigned len, SocketAddress &info) = 0; + template <typename... Args> + inline SocketError(const Args&... args) + { - /** - * Try to receive data from a connection-less socket without blocking. - * - * @param s the socket - * @param data the data - * @param len max length to receive - * @param info the client information - * @param timeout the optional timeout in milliseconds - * @return the number of bytes received - * @throw error::Failure on error - */ - virtual unsigned tryRecvfrom(Socket &s, void *data, unsigned len, SocketAddress &info, int timeout) = 0; - - /** - * Send some data. - * - * @param s the socket - * @param data the data to send - * @param dataLen the data length - * @return the number of bytes sent - * @throw error::Failure on error - * @throw error::WouldBlock if the socket is marked non-blocking and the operation would block - */ - virtual unsigned send(Socket &s, const void *data, unsigned len) = 0; - - /** - * Try to send some data without blocking. - * - * @param s the socket - * @param data the data to send - * @param dataLen the data length - * @param timeout the optional timeout in milliseconds - * @return the number of bytes sent - * @throw error::Failure on error - * @throw error::WouldBlock if the socket is marked non-blocking and the operation would block - * @warning This function won't block only if the socket is marked non blocking - */ - virtual unsigned trySend(Socket &s, const void *data, unsigned len, int timeout) = 0; + } - /** - * Send some data to a connection-less socket. - * - * @param s the socket - * @param data the data to send - * @param dataLen the data length - * @param address the address - * @return the number of bytes sent - * @throw error::Failure on error - * @throw error::WouldBlock if the socket is marked non-blocking and the operation would block - */ - virtual unsigned sendto(Socket &s, const void *data, unsigned len, const SocketAddress &info) = 0; - - /** - * Try to send some data to a connection-less without blocking. - * - * @param s the socket - * @param data the data to send - * @param dataLen the data length - * @param address the address - * @return the number of bytes sent - * @throw error::Failure on error - * @throw error::WouldBlock if the socket is marked non-blocking and the operation would block - * @warning This function won't block only if the socket is marked non blocking - */ - virtual unsigned trySendto(Socket &s, const void *data, unsigned len, const SocketAddress &info, int timeout) = 0; -}; - -/** - * Standard interface for sockets. - * - * This will use standard clear functions: - * - * bind(2) - * connect(2) - * close(2), closesocket on Windows, - * accept(2), - * listen(2), - * recv(2), - * recvfrom(2), - * send(2), - * sendto(2) - */ -class SocketStandard : public SocketInterface { -protected: - bool m_connected{false}; - -public: - void bind(Socket &s, const SocketAddress &address) override; - void close(Socket &s) override; - void connect(Socket &s, const SocketAddress &address) override; - void tryConnect(Socket &s, const SocketAddress &address, int timeout) override; - Socket accept(Socket &s, SocketAddress &info) override; - Socket tryAccept(Socket &s, SocketAddress &info, int timeout) override; - void listen(Socket &s, int max) override; - unsigned recv(Socket &s, void *data, unsigned len) override; - unsigned tryRecv(Socket &s, void *data, unsigned len, int timeout) override; - unsigned recvfrom(Socket &s, void *data, unsigned len, SocketAddress &info) override; - unsigned tryRecvfrom(Socket &s, void *data, unsigned len, SocketAddress &info, int timeout) override; - unsigned send(Socket &s, const void *data, unsigned len) override; - unsigned trySend(Socket &s, const void *data, unsigned len, int timeout) override; - unsigned sendto(Socket &s, const void *data, unsigned len, const SocketAddress &info) override; - unsigned trySendto(Socket &s, const void *data, unsigned len, const SocketAddress &info, int timeout) override; + const char *what() const noexcept { + return "Error failure"; + } }; /** * @class Socket - * @brief socket abstraction - * - * This class is a big wrapper around sockets functions but portable, - * there is some functions that helps for getting error reporting. - * - * This class is implemented as a PIMPL idiom, it is perfectly - * safe to cast the object to any other derivate children. + * @brief Base socket class for socket operations */ class Socket { public: + /* {{{ Portable types */ + + /* + * The following types are defined differently between Unix + * and Windows. + */ #if defined(_WIN32) using Handle = SOCKET; using ConstArg = const char *; @@ -451,19 +111,86 @@ using Arg = void *; #endif - using Iface = std::shared_ptr<SocketInterface>; + /* }}} */ + + /* {{{ Portable constants */ + + /* + * The following constants are defined differently from Unix + * to Windows. + */ +#if defined(_WIN32) + static constexpr const int Invalid = INVALID_SOCKET; + static constexpr const int Error = SOCKET_ERROR; +#else + static constexpr const int Invalid = -1; + static constexpr const int Error = -1; +#endif + + /* }}} */ + + /* {{{ Portable initialization */ + + /* + * Initialization stuff. + * + * The function init and finish are threadsafe. + */ +#if defined(_WIN32) +private: + static std::mutex s_mutex{}; + static std::atomic<bool> s_initialized{false}; + +public: + static inline void finish() noexcept + { + WSACleanup(); + } + + static inline void init() noexcept + { + std::lock_guard<std::mutex> lock(s_mutex); + + if (!s_initialized) { + s_initialized = true; + + WSADATA wsa; + WSAStartup(MAKEWORD(2, 2), &wsa); + + /* + * If SOCKET_WSA_NO_INIT is not set then the user + * must also call finish himself. + */ +#if !defined(SOCKET_WSA_NO_INIT) + std::atexit(finish); +#endif + } + } +#else +public: + /** + * no-op. + */ + static inline void init() noexcept {} + + /** + * no-op. + */ + static inline void finish() noexcept {} +#endif + + /* }}} */ protected: - Iface m_interface; //!< the interface - Handle m_handle{INVALID_SOCKET}; //!< the socket shared pointer + Handle m_handle; + + inline Socket(Handle handle) + : m_handle(handle) + { + } public: /** - * To be called before any socket operation. - */ - static void init(); - - /** * Get the last socket system error. The error is set from errno or from * WSAGetLastError on Windows. * @@ -480,46 +207,20 @@ static std::string syserror(int errn); /** - * To be called before exiting. - */ - static void finish(); - - /** - * Default constructor. - */ - Socket(); - - /** - * Constructor to create a new socket. + * Create a socket handle. * - * @param domain the domain - * @param type the type + * @param domain the domain AF_* + * @param type the type SOCK_* * @param protocol the protocol - * @throw error::Failure on error */ Socket(int domain, int type, int protocol); /** - * Create a socket object with a already initialized socket. - * - * @param handle the handle - * @param interface the interface to use - */ - Socket(Handle handle, std::shared_ptr<SocketInterface> iface); - - /** - * Close the socket. + * Default destructor. */ virtual ~Socket() = default; /** - * Get the socket. - * - * @return the socket - */ - Handle handle() const; - - /** * Set an option for the socket. * * @param level the setting level @@ -528,10 +229,14 @@ * @throw error::Failure on error */ template <typename Argument> - void set(int level, int name, const Argument &arg) + inline void set(int level, int name, const Argument &arg) { +#if defined(_WIN32) if (setsockopt(m_handle, level, name, (Socket::ConstArg)&arg, sizeof (arg)) == SOCKET_ERROR) - throw error::Error("set", syserror(), errno); +#else + if (setsockopt(m_handle, level, name, (Socket::ConstArg)&arg, sizeof (arg)) < 0) +#endif + throw SocketError("set", syserror(), errno); } /** @@ -542,13 +247,17 @@ * @throw error::Failure on error */ template <typename Argument> - Argument get(int level, int name) + inline Argument get(int level, int name) { Argument desired, result{}; socklen_t size = sizeof (result); +#if defined(_WIN32) if (getsockopt(m_handle, level, name, (Socket::Arg)&desired, &size) == SOCKET_ERROR) - throw error::Error("get", syserror(), errno); +#else + if (getsockopt(m_handle, level, name, (Socket::Arg)&desired, &size) < 0) +#endif + throw SocketError("get", syserror(), errno); std::memcpy(&result, &desired, size); @@ -556,234 +265,41 @@ } /** - * Enable or disable blocking mode. - * - * @param block the mode - * @throw error::Failure on error - */ - void blockMode(bool block = true); - - /** - * @copydoc SocketInterface::bind - */ - inline void bind(const SocketAddress &address) - { - m_interface->bind(*this, address); - } - - /** - * @copydoc SocketInterface::close - */ - inline void close() - { - m_interface->close(*this); - } - - /** - * @copydoc SocketInterface::connect - */ - inline void connect(const SocketAddress &address) - { - m_interface->connect(*this, address); - } - - /** - * @copydoc SocketInterface::tryConnect - */ - inline void tryConnect(const SocketAddress &address, int timeout) - { - m_interface->tryConnect(*this, address, timeout); - } - - /** - * @copydoc SocketInterface::tryAccept - */ - inline Socket tryAccept(SocketAddress &address, int timeout) - { - return m_interface->tryAccept(*this, address, timeout); - } - - /** - * Overload without client information. + * Get the native handle. * - * @param timeout the timeout - * @return the socket - * @throw error::Error on errors - * @throw error::Timeout if operation has timeout - */ - Socket tryAccept(int timeout); - - /** - * Accept a client without getting its info. - * - * @return a client ready to use - * @throw error::Failure on error + * @return the handle + * @warning Not portable */ - Socket accept(); - - /** - * @copydoc SocketInterface::accept - */ - inline Socket accept(SocketAddress &info) - { - return m_interface->accept(*this, info); - } - - /** - * @copydoc SocketInterface::listen - */ - inline void listen(int max) + inline Handle handle() const noexcept { - m_interface->listen(*this, max); - } - - /** - * @copydoc SocketInterface::recv - */ - inline unsigned recv(void *data, unsigned dataLen) - { - return m_interface->recv(*this, data, dataLen); - } - - /** - * Overload for strings. - * - * @param count number of bytes to recv - * @return the string - * @throw error::Error on failures - * @throw error::WouldBlock if operation would block - */ - inline std::string recv(unsigned count) - { - std::string result; - - result.resize(count); - auto n = recv(const_cast<char *>(result.data()), count); - result.resize(n); - - return result; + return m_handle; } /** - * @copydoc SocketInterface::tryRecv - */ - inline unsigned tryRecv(void *data, unsigned dataLen, int timeout = 0) - { - return m_interface->tryRecv(*this, data, dataLen, timeout); - } - - /** - * Overload for string. + * Bind to an address. * - * @param count - * @param timeout - * @return the string - * @throw error::Error on failures - * @throw error::Timeout if operation has timeout + * @param address the address + * @throw SocketError on any error */ - inline std::string tryRecv(unsigned count, int timeout = 0) - { - std::string result; - - result.resize(count); - auto n = tryRecv(const_cast<char *>(result.data()), count, timeout); - result.resize(n); - - return result; - } - - /** - * Receive from a connection-less socket without getting - * client information. - * - * @param data the destination pointer - * @param dataLen max length to receive - * @return the number of bytes received - * @throw error::Failure on error - */ - unsigned recvfrom(void *data, unsigned dataLen); - - /** - * @copydoc SocketInterface::recvfrom - */ - inline unsigned recvfrom(void *data, unsigned dataLen, SocketAddress &info) - { - return m_interface->recvfrom(*this, data, dataLen, info); - } - - /** - * Overload for string. - * - * @param count the number of bytes to receive - * @return the string - * @throw error::Error on failures - * @throw error::WouldBlock if operation would block - */ - std::string recvfrom(unsigned count); + void bind(const SocketAddress &address); /** - * Overload with client information. + * Set the blocking mode, if set to false, the socket will be marked + * **non-blocking**. * - * @param count the number of bytes to receive - * @return the string - * @throw error::Error on failures - * @throw error::WouldBlock if operation would block + * @param block set to false to mark **non-blocking** + * @throw SocketError on any error */ - inline std::string recvfrom(unsigned count, SocketAddress &info) - { - std::string result; - - result.resize(count); - auto n = recvfrom(const_cast<char *>(result.data()), count, info); - result.resize(n); - - return result; - } - - /** - * @copydoc SocketInterface::send - */ - inline unsigned send(const void *data, unsigned dataLen) - { - return m_interface->send(*this, data, dataLen); - } + void setBlockMode(bool block); /** - * Send a message as a string. - * - * @param message the message - * @return the number of bytes sent - * @throw SocketError on error - */ - inline unsigned send(const std::string &message) - { - return send(message.data(), message.size()); - } - - /** - * @copydoc SocketInterface::sendto + * Close the socket. */ - inline unsigned sendto(const void *data, unsigned dataLen, const SocketAddress &info) - { - return m_interface->sendto(*this, data, dataLen, info); - } - - /** - * Send a message to a connection-less socket. - * - * @param message the message - * @param address the address - * @return the number of bytes sent - * @throw SocketError on error - */ - inline unsigned sendto(const std::string &message, const SocketAddress &info) - { - return sendto(message.data(), message.size(), info); - } + virtual void close(); }; bool operator==(const Socket &s1, const Socket &s2); bool operator<(const Socket &s, const Socket &s2); -#endif // !_SOCKET_NG_H_ +#endif // !_SOCKET_NG_3_H_
--- a/C++/SocketAddress.cpp Wed Feb 25 13:53:41 2015 +0100 +++ b/C++/SocketAddress.cpp Mon Mar 02 14:00:48 2015 +0100 @@ -56,7 +56,7 @@ auto error = getaddrinfo(host.c_str(), std::to_string(port).c_str(), &hints, &res); if (error != 0) - throw error::Error("getaddrinfo", gai_strerror(error), error); + throw SocketError("getaddrinfo", gai_strerror(error), error); std::memcpy(&m_addr, res->ai_addr, res->ai_addrlen); m_addrlen = res->ai_addrlen;
--- a/C++/SocketListener.cpp Wed Feb 25 13:53:41 2015 +0100 +++ b/C++/SocketListener.cpp Mon Mar 02 14:00:48 2015 +0100 @@ -24,131 +24,131 @@ #include "SocketListener.h" -using namespace direction; - /* -------------------------------------------------------- * Select implementation * -------------------------------------------------------- */ +namespace { + /** * @class SelectMethod * @brief Implements select(2) * * This class is the fallback of any other method, it is not preferred at all for many reasons. */ -class SelectMethod final : public SocketListener::Interface { +class SelectMethod final : public SocketListenerInterface { private: - std::map<Socket::Handle, std::pair<Socket, int>> m_table; + std::map<Socket::Handle, std::pair<std::reference_wrapper<Socket>, int>> m_table; public: - void add(Socket s, int direction) override; - void remove(const Socket &s, int direction) override; - void list(const SocketListener::MapFunc &func) override; - void clear() override; - unsigned size() const override; - SocketStatus select(int ms) override; - std::vector<SocketStatus> selectMultiple(int ms) override; + void set(Socket &s, int direction) override + { + if (m_table.count(s.handle()) > 0) { + m_table.at(s.handle()).second |= direction; + } else { + m_table.insert({s.handle(), {s, direction}}); + } + + } + + void unset(Socket &s, int direction) override + { + if (m_table.count(s.handle()) != 0) { + m_table.at(s.handle()).second &= ~(direction); + + // If no read, no write is requested, remove it + if (m_table.at(s.handle()).second == 0) { + m_table.erase(s.handle()); + } + } + } + + void remove(Socket &sc) override + { + m_table.erase(sc.handle()); + } + + void clear() override + { + m_table.clear(); + } + + SocketStatus select(int ms) override + { + auto result = selectMultiple(ms); + + if (result.size() == 0) { + throw SocketError("select", "No socket found", 0); + } + + return result[0]; + } + + std::vector<SocketStatus> selectMultiple(int ms) override + { + timeval maxwait, *towait; + fd_set readset; + fd_set writeset; + + FD_ZERO(&readset); + FD_ZERO(&writeset); + + Socket::Handle max = 0; + + for (auto &s : m_table) { + if (s.second.second & SocketListener::Read) { + FD_SET(s.first, &readset); + } + if (s.second.second & SocketListener::Write) { + FD_SET(s.first, &writeset); + } + + if (s.first > max) { + max = s.first; + } + } + + maxwait.tv_sec = 0; + maxwait.tv_usec = ms * 1000; + + // Set to nullptr for infinite timeout. + towait = (ms < 0) ? nullptr : &maxwait; + + auto error = ::select(max + 1, &readset, &writeset, nullptr, towait); + if (error == Socket::Error) { +#if defined(_WIN32) + throw SocketError("select", Socket::syserror(), WSAGetLastError()); +#else + throw SocketError("select", Socket::syserror(), errno); +#endif + } + + if (error == 0) { + throw SocketError("select"); + } + + std::vector<SocketStatus> sockets; + + for (auto &c : m_table) { + if (FD_ISSET(c.first, &readset)) { + sockets.push_back({ c.second.first, SocketListener::Read }); + } + if (FD_ISSET(c.first, &writeset)) { + sockets.push_back({ c.second.first, SocketListener::Write }); + } + } + + return sockets; + } }; -void SelectMethod::add(Socket s, int direction) -{ - if (m_table.count(s.handle()) > 0) - m_table[s.handle()].second |= direction; - else - m_table[s.handle()] = { s, direction }; -} - -void SelectMethod::remove(const Socket &s, int direction) -{ - if (m_table.count(s.handle()) != 0) { - m_table[s.handle()].second &= ~(direction); - - // If no read, no write is requested, remove it - if (m_table[s.handle()].second == 0) - m_table.erase(s.handle()); - } -} - -void SelectMethod::list(const SocketListener::MapFunc &func) -{ - for (auto &s : m_table) - func(s.second.first, s.second.second); -} - -void SelectMethod::clear() -{ - m_table.clear(); -} - -unsigned SelectMethod::size() const -{ - return m_table.size(); -} - -SocketStatus SelectMethod::select(int ms) -{ - auto result = selectMultiple(ms); - - if (result.size() == 0) - throw error::Error("select", "No socket found", 0); - - return result[0]; -} - -std::vector<SocketStatus> SelectMethod::selectMultiple(int ms) -{ - timeval maxwait, *towait; - fd_set readset; - fd_set writeset; - - FD_ZERO(&readset); - FD_ZERO(&writeset); - - Socket::Handle max = 0; - - for (auto &s : m_table) { - if (s.second.second & Read) - FD_SET(s.first, &readset); - if (s.second.second & Write) - FD_SET(s.first, &writeset); - - if (s.first > max) - max = s.first; - } - - maxwait.tv_sec = 0; - maxwait.tv_usec = ms * 1000; - - // Set to nullptr for infinite timeout. - towait = (ms < 0) ? nullptr : &maxwait; - - auto error = ::select(max + 1, &readset, &writeset, nullptr, towait); - if (error == SOCKET_ERROR) -#if defined(_WIN32) - throw error::Error("select", Socket::syserror(), WSAGetLastError()); -#else - throw error::Error("select", Socket::syserror(), errno); -#endif - if (error == 0) - throw error::Timeout("select"); - - std::vector<SocketStatus> sockets; - - for (auto &c : m_table) { - if (FD_ISSET(c.first, &readset)) - sockets.push_back({ c.second.first, Read }); - if (FD_ISSET(c.first, &writeset)) - sockets.push_back({ c.second.first, Write }); - } - - return sockets; -} +} // !namespace /* -------------------------------------------------------- * Poll implementation * -------------------------------------------------------- */ -#if defined(SOCKET_LISTENER_HAVE_POLL) +#if defined(SOCKET_HAVE_POLL) #if defined(_WIN32) # include <Winsock2.h> @@ -159,18 +159,18 @@ namespace { -class PollMethod final : public SocketListener::Interface { +class PollMethod final : public SocketListenerInterface { private: std::vector<pollfd> m_fds; - std::map<Socket::Handle, Socket> m_lookup; + std::map<Socket::Handle, std::reference_wrapper<Socket>> m_lookup; inline short topoll(int direction) { short result(0); - if (direction & Read) + if (direction & SocketListener::Read) result |= POLLIN; - if (direction & Write) + if (direction & SocketListener::Write) result |= POLLOUT; return result; @@ -188,121 +188,128 @@ * return 0 so we mark the socket as readable. */ if ((event & POLLIN) || (event & POLLHUP)) - direction |= Read; + direction |= SocketListener::Read; if (event & POLLOUT) - direction |= Write; + direction |= SocketListener::Write; return direction; } public: - void add(Socket s, int direction) override; - void remove(const Socket &s, int direction) override; - void list(const SocketListener::MapFunc &func) override; - void clear() override; - unsigned size() const override; - SocketStatus select(int ms) override; - std::vector<SocketStatus> selectMultiple(int ms) override; -}; + void set(Socket &s, int direction) override + { + auto it = std::find_if(m_fds.begin(), m_fds.end(), [&] (const auto &pfd) { return pfd.fd == s.handle(); }); -void PollMethod::add(Socket s, int direction) -{ - auto it = std::find_if(m_fds.begin(), m_fds.end(), [&] (const auto &pfd) { return pfd.fd == s.handle(); }); + // If found, add the new direction, otherwise add a new socket + if (it != m_fds.end()) + it->events |= topoll(direction); + else { + m_lookup.insert({s.handle(), s}); + m_fds.push_back({ s.handle(), topoll(direction), 0 }); + } + } - // If found, add the new direction, otherwise add a new socket - if (it != m_fds.end()) - it->events |= topoll(direction); - else { - m_lookup[s.handle()] = s; - m_fds.push_back({ s.handle(), topoll(direction), 0 }); - } -} + void unset(Socket &s, int direction) override + { + for (auto i = m_fds.begin(); i != m_fds.end();) { + if (i->fd == s.handle()) { + i->events &= ~(topoll(direction)); -void PollMethod::remove(const Socket &s, int direction) -{ - for (auto i = m_fds.begin(); i != m_fds.end();) { - if (i->fd == s.handle()) { - i->events &= ~(topoll(direction)); - - if (i->events == 0) { - m_lookup.erase(i->fd); - i = m_fds.erase(i); - } else { + if (i->events == 0) { + m_lookup.erase(i->fd); + i = m_fds.erase(i); + } else { + ++i; + } + } else ++i; - } - } else - ++i; + } } -} -void PollMethod::list(const SocketListener::MapFunc &func) -{ - for (auto &fd : m_fds) - func(m_lookup[fd.fd], todirection(fd.events)); -} + void remove(Socket &s) override + { + auto it = std::find_if(m_fds.begin(), m_fds.end(), [&] (const auto &pfd) { return pfd.fd == s.handle(); }); + + if (it != m_fds.end()) { + m_fds.erase(it); + m_lookup.erase(s.handle()); + } + } -void PollMethod::clear() -{ - m_fds.clear(); - m_lookup.clear(); -} - -unsigned PollMethod::size() const -{ - return static_cast<unsigned>(m_fds.size()); -} + void clear() override + { + m_fds.clear(); + m_lookup.clear(); + } -SocketStatus PollMethod::select(int ms) -{ - auto result = poll(m_fds.data(), m_fds.size(), ms); - if (result == 0) - throw error::Timeout("select"); - if (result < 0) + SocketStatus select(int ms) override + { + auto result = poll(m_fds.data(), m_fds.size(), ms); + if (result == 0) + throw SocketError("select"); + if (result < 0) #if defined(_WIN32) - throw error::Error("poll", Socket::syserror(), WSAGetLastError()); + throw SocketError("poll", Socket::syserror(), WSAGetLastError()); #else - throw error::Error("poll", Socket::syserror(), errno); + throw SocketError("poll", Socket::syserror(), errno); #endif - for (auto &fd : m_fds) - if (fd.revents != 0) - return { m_lookup[fd.fd], todirection(fd.revents) }; + for (auto &fd : m_fds) { + if (fd.revents != 0) { + return { m_lookup.at(fd.fd), todirection(fd.revents) }; + } + } - throw error::Error("select", "No socket found", 0); -} + throw SocketError("select", "No socket found", 0); + } -std::vector<SocketStatus> PollMethod::selectMultiple(int ms) -{ - auto result = poll(m_fds.data(), m_fds.size(), ms); - if (result == 0) - throw error::Timeout("select"); - if (result < 0) + std::vector<SocketStatus> selectMultiple(int ms) override + { + auto result = poll(m_fds.data(), m_fds.size(), ms); + if (result == 0) { + throw SocketError("select"); + } + if (result < 0) { #if defined(_WIN32) - throw error::Error("poll", Socket::syserror(), WSAGetLastError()); + throw SocketError("poll", Socket::syserror(), WSAGetLastError()); #else - throw error::Error("poll", Socket::syserror(), errno); + throw SocketError("poll", Socket::syserror(), errno); #endif + } - std::vector<SocketStatus> sockets; - for (auto &fd : m_fds) - if (fd.revents != 0) - sockets.push_back({ m_lookup[fd.fd], todirection(fd.revents) }); + std::vector<SocketStatus> sockets; + for (auto &fd : m_fds) { + if (fd.revents != 0) { + sockets.push_back({ m_lookup.at(fd.fd), todirection(fd.revents) }); + } + } - return sockets; -} + return sockets; + } +}; } // !namespace -#endif // !_SOCKET_LISTENER_HAVE_POLL +#endif // !_SOCKET_HAVE_POLL /* -------------------------------------------------------- - * Socket listener + * SocketListener * -------------------------------------------------------- */ -SocketListener::SocketListener(int method) +const int SocketListener::Read{1 << 0}; +const int SocketListener::Write{1 << 1}; + +SocketListener::SocketListener(std::initializer_list<std::pair<std::reference_wrapper<Socket>, int>> list) + : SocketListener() { -#if defined(SOCKET_LISTENER_HAVE_POLL) - if (method == Poll) + for (const auto &p : list) + set(p.first, p.second); +} + +SocketListener::SocketListener(SocketMethod method) +{ +#if defined(SOCKET_HAVE_POLL) + if (method == SocketMethod::Poll) m_interface = std::make_unique<PollMethod>(); else #endif @@ -311,9 +318,26 @@ (void)method; } -SocketListener::SocketListener(std::initializer_list<std::pair<Socket, int>> list, int method) - : SocketListener(method) +void SocketListener::set(Socket &sc, int flags) { - for (const auto &p : list) - add(p.first, p.second); + if (m_map.count(sc) > 0) { + m_map[sc] |= flags; + m_interface->set(sc, flags); + } else { + m_map.insert({sc, flags}); + m_interface->set(sc, flags); + } } + +void SocketListener::unset(Socket &sc, int flags) noexcept +{ + if (m_map.count(sc) > 0) { + m_map[sc] &= ~(flags); + m_interface->unset(sc, flags); + + // No more flags, remove it + if (m_map[sc] == 0) { + m_map.erase(sc); + } + } +}
--- a/C++/SocketListener.h Wed Feb 25 13:53:41 2015 +0100 +++ b/C++/SocketListener.h Mon Mar 02 14:00:48 2015 +0100 @@ -22,6 +22,8 @@ #include <chrono> #include <functional> #include <initializer_list> +#include <map> +#include <memory> #include <utility> #include <vector> @@ -29,10 +31,10 @@ #if defined(_WIN32) # if _WIN32_WINNT >= 0x0600 -# define SOCKET_LISTENER_HAVE_POLL +# define SOCKET_HAVE_POLL # endif #else -# define SOCKET_LISTENER_HAVE_POLL +# define SOCKET_HAVE_POLL #endif /** @@ -42,7 +44,7 @@ * Select the method of polling. It is only a preferred method, for example if you * request for poll but it is not available, select will be used. */ -enum SocketMethod { +enum class SocketMethod { Select, //!< select(2) method, fallback Poll //!< poll(2), everywhere possible }; @@ -54,9 +56,70 @@ * Result of a select call, returns the first ready socket found with its * direction. */ -struct SocketStatus { - Socket socket; //!< which socket is ready - int direction; //!< the direction +class SocketStatus { +public: + Socket &socket; //!< which socket is ready + int direction; //!< the direction +}; + +/** + * @class SocketListenerInterface + * @brief Implement the polling method + */ +class SocketListenerInterface { +public: + /** + * Default destructor. + */ + virtual ~SocketListenerInterface() = default; + + /** + * Add a socket with a specified direction. + * + * @param s the socket + * @param direction the direction + */ + virtual void set(Socket &sc, int direction) = 0; + + /** + * Remove a socket with a specified direction. + * + * @param s the socket + * @param direction the direction + */ + virtual void unset(Socket &sc, int direction) = 0; + + /** + * Remove completely a socket. + * + * @param sc the socket to remove + */ + virtual void remove(Socket &sc) = 0; + + /** + * Remove all sockets. + */ + virtual void clear() = 0; + + /** + * Select one socket. + * + * @param ms the number of milliseconds to wait, -1 means forever + * @return the socket status + * @throw error::Failure on failure + * @throw error::Timeout on timeout + */ + virtual SocketStatus select(int ms) = 0; + + /** + * Select many sockets. + * + * @param ms the number of milliseconds to wait, -1 means forever + * @return a vector of ready sockets + * @throw error::Failure on failure + * @throw error::Timeout on timeout + */ + virtual std::vector<SocketStatus> selectMultiple(int ms) = 0; }; /** @@ -64,86 +127,28 @@ * @brief Synchronous multiplexing * * Convenient wrapper around the select() system call. + * + * This wrappers takes abstract sockets as non-const reference but it does not + * own them so you must take care that sockets are still alive until the + * SocketListener is destroyed. */ class SocketListener final { public: - /** - * @brief Function for listing all sockets - */ - using MapFunc = std::function<void (Socket &, int)>; - -#if defined(SOCKET_LISTENER_HAVE_POLL) +#if defined(SOCKET_HAVE_POLL) static constexpr const SocketMethod PreferredMethod = SocketMethod::Poll; #else static constexpr const SocketMethod PreferredMethod = SocketMethod::Select; #endif - /** - * @class Interface - * @brief Implement the polling method - */ - class Interface { - public: - /** - * Default destructor. - */ - virtual ~Interface() = default; - - /** - * List all sockets in the interface. - * - * @param func the function - */ - virtual void list(const MapFunc &func) = 0; - - /** - * Add a socket with a specified direction. - * - * @param s the socket - * @param direction the direction - */ - virtual void add(Socket s, int direction) = 0; + static const int Read; + static const int Write; - /** - * Remove a socket with a specified direction. - * - * @param s the socket - * @param direction the direction - */ - virtual void remove(const Socket &s, int direction) = 0; - - /** - * Remove all sockets. - */ - virtual void clear() = 0; - - /** - * Get the total number of sockets in the listener. - */ - virtual unsigned size() const = 0; + using Map = std::map<std::reference_wrapper<Socket>, int>; + using Iface = std::unique_ptr<SocketListenerInterface>; - /** - * Select one socket. - * - * @param ms the number of milliseconds to wait, -1 means forever - * @return the socket status - * @throw error::Failure on failure - * @throw error::Timeout on timeout - */ - virtual SocketStatus select(int ms) = 0; - - /** - * Select many sockets. - * - * @param ms the number of milliseconds to wait, -1 means forever - * @return a vector of ready sockets - * @throw error::Failure on failure - * @throw error::Timeout on timeout - */ - virtual std::vector<SocketStatus> selectMultiple(int ms) = 0; - }; - - std::unique_ptr<Interface> m_interface; +private: + Map m_map; + Iface m_interface; public: /** @@ -166,54 +171,59 @@ * * @param method the preferred method */ - SocketListener(int method = PreferredMethod); + SocketListener(SocketMethod method = PreferredMethod); + + SocketListener(std::initializer_list<std::pair<std::reference_wrapper<Socket>, int>> list); - /** - * Createa listener with some sockets. - * - * @param list the initializer list - * @param method the preferred method - */ - SocketListener(std::initializer_list<std::pair<Socket, int>> list, int method = PreferredMethod); + inline auto begin() noexcept + { + return m_map.begin(); + } - /** - * Add a socket to listen to. - * - * @param s the socket - * @param direction the direction - */ - inline void add(Socket s, int direction) + inline auto begin() const noexcept + { + return m_map.begin(); + } + + inline auto cbegin() const noexcept { - m_interface->add(std::move(s), direction); + return m_map.cbegin(); + } + + inline auto end() noexcept + { + return m_map.end(); } - /** - * Remove a socket from the list. - * - * @param s the socket - * @param direction the direction - */ - inline void remove(const Socket &s, int direction) + inline auto end() const noexcept { - m_interface->remove(s, direction); + return m_map.end(); + } + + inline auto cend() const noexcept + { + return m_map.cend(); } - /** - * Remove every sockets in the listener. - */ - inline void clear() + void set(Socket &sc, int flags); + + void unset(Socket &sc, int flags) noexcept; + + inline void remove(Socket &sc) noexcept { + m_map.erase(sc); + m_interface->remove(sc); + } + + inline void clear() noexcept + { + m_map.clear(); m_interface->clear(); } - /** - * Get the number of clients in listener. - * - * @return the total number of sockets in the listener - */ - inline unsigned size() const + unsigned size() const noexcept { - return m_interface->size(); + return m_map.size(); } /** @@ -272,17 +282,6 @@ { return m_interface->selectMultiple(timeout); } - - /** - * List every socket in the listener. - * - * @param func the function to call - */ - template <typename Func> - inline void list(Func func) - { - m_interface->list(func); - } }; #endif // !_SOCKET_LISTENER_H_
--- a/C++/SocketSsl.cpp Wed Feb 25 13:53:41 2015 +0100 +++ b/C++/SocketSsl.cpp Mon Mar 02 14:00:48 2015 +0100 @@ -202,17 +202,6 @@ throw error::Error("sendto", "SSL socket is not UDP compatible", 0); } -void SocketSsl::init() -{ - SSL_library_init(); - SSL_load_error_strings(); -} - -void SocketSsl::finish() -{ - ERR_free_strings(); -} - SocketSsl::SocketSsl(int family, SocketSslOptions options) : Socket(family, SOCK_STREAM, 0) {
--- a/C++/SocketSsl.h Wed Feb 25 13:53:41 2015 +0100 +++ b/C++/SocketSsl.h Mon Mar 02 14:00:48 2015 +0100 @@ -16,23 +16,24 @@ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */ -#ifndef _SOCKET_SSL_H_ -#define _SOCKET_SSL_H_ +#ifndef _SOCKET_SSL_NG_3_H_ +#define _SOCKET_SSL_NG_3_H_ #include <openssl/err.h> #include <openssl/evp.h> #include <openssl/ssl.h> -#include "Socket.h" +#include "SocketTcp.h" -struct SocketSslOptions { +class SocketSslOptions { +public: enum { SSLv3 = (1 << 0), TLSv1 = (1 << 1), All = (0xf) }; - unsigned short method{All}; + int method{All}; std::string certificate; std::string privateKey; bool verify{false}; @@ -48,42 +49,37 @@ } }; -class SocketSslInterface : public SocketStandard { -private: - using Ssl = std::shared_ptr<SSL>; - using SslContext = std::shared_ptr<SSL_CTX>; - - SslContext m_context; - Ssl m_ssl; - SocketSslOptions m_options; - -public: - SocketSslInterface(SSL_CTX *context, SSL *ssl, SocketSslOptions options = {}); - SocketSslInterface(SocketSslOptions options = {}); - void connect(Socket &s, const SocketAddress &address) override; - void tryConnect(Socket &s, const SocketAddress &address, int timeout) override; - Socket accept(Socket &s, SocketAddress &info) override; - unsigned recv(Socket &s, void *data, unsigned len) override; - unsigned recvfrom(Socket &s, void *data, unsigned len, SocketAddress &info) override; - unsigned tryRecv(Socket &s, void *data, unsigned len, int timeout) override; - unsigned tryRecvfrom(Socket &s, void *data, unsigned len, SocketAddress &info, int timeout) override; - unsigned send(Socket &s, const void *data, unsigned len) override; - unsigned sendto(Socket &s, const void *data, unsigned len, const SocketAddress &info) override; - unsigned trySend(Socket &s, const void *data, unsigned len, int timeout) override; - unsigned trySendto(Socket &s, const void *data, unsigned len, const SocketAddress &info, int timeout) override; -}; - /** * @class SocketSsl * @brief SSL interface for sockets * - * This class derives from Socket and provide SSL support through OpenSSL. - * - * It is perfectly safe to cast SocketSsl to Socket and vice-versa. + * This class derives from SocketAbstractTcp and provide SSL support through OpenSSL. */ -class SocketSsl : public Socket { +class SocketSsl : public SocketAbstractTcp { private: - using Socket::Socket; +#if defined(_WIN32) && !defined(SOCKET_NO_WSA_INIT) + static std::mutex s_mutex{}; + static std::atomic<bool> s_initialized{false}; + + static inline void terminateSsl() + { + ERR_free_strings(); + } + + static inline void initializeSsl() + { + std::lock_guard<std::mutex> lock(s_mutex); + + if (!s_initialized) { + s_initialized = true; + + SSL_library_init(); + SSL_load_error_strings(); + + std::atexit(terminate); + } + } +#endif public: /** @@ -106,4 +102,4 @@ SocketSsl(int family, SocketSslOptions options = {}); }; -#endif // !_SOCKET_SSL_H_ +#endif // !_SOCKET_SSL_NG_3_H_
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/C++/SocketTcp.cpp Mon Mar 02 14:00:48 2015 +0100 @@ -0,0 +1,140 @@ +/* + * SocketTcp.cpp -- portable C++ socket wrappers + * + * Copyright (c) 2013, 2014 David Demelier <markand@malikania.fr> + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +#include "SocketAddress.h" +#include "SocketTcp.h" + +/* -------------------------------------------------------- + * SocketAbstractTcp + * -------------------------------------------------------- */ + +void SocketAbstractTcp::listen(int max) +{ + if (::listen(m_handle, max) == Error) + throw SocketError("listen", Socket::syserror(), errno); +} + +/* -------------------------------------------------------- + * SocketTcp + * -------------------------------------------------------- */ + +SocketTcp SocketTcp::accept() +{ + SocketAddress dummy; + + return accept(dummy); +} + +SocketTcp SocketTcp::accept(SocketAddress &info) +{ + Socket::Handle handle; + + // Store the information + sockaddr_storage address; + socklen_t addrlen; + + addrlen = sizeof (sockaddr_storage); + handle = ::accept(m_handle, reinterpret_cast<sockaddr *>(&address), &addrlen); + + if (handle == Invalid) { +#if defined(_WIN32) + if (WSAGetLastError() == WSAEWOULDBLOCK) + throw SocketError("accept", Socket::syserror(WSAEWOULDBLOCK), WSAEWOULDBLOCK /* TODO: Read */); + + throw SocketError("accept", Socket::syserror(), WSAGetLastError()); +#else + if (errno == EAGAIN || errno == EWOULDBLOCK) + throw SocketError("accept", Socket::syserror(EWOULDBLOCK), EWOULDBLOCK /* TODO: Read */); + + throw SocketError("accept", Socket::syserror(), errno); +#endif + } + + // Usually accept works only with SOCK_STREAM + info = SocketAddress(address, addrlen); + + return SocketTcp(handle); +} + +void SocketTcp::connect(const SocketAddress &address) +{ + auto &sa = address.address(); + auto addrlen = address.length(); + + if (::connect(m_handle, reinterpret_cast<const sockaddr *>(&sa), addrlen) == Error) { + /* + * Determine if the error comes from a non-blocking connect that cannot be + * accomplished yet. + */ +#if defined(_WIN32) + if (WSAGetLastError() == WSAEWOULDBLOCK) + throw SocketError("connect", Socket::syserror(WSAEWOULDBLOCK), WSAEWOULDBLOCK /*, Write */); + + throw SocketError("connect", Socket::syserror(WSAEWOULDBLOCK), WSAGetLastError()); +#else + if (errno == EINPROGRESS) + throw SocketError("connect", Socket::syserror(EINPROGRESS), EINPROGRESS /*, Write */); + + throw SocketError("connect", Socket::syserror(), errno); +#endif + } +} + +unsigned SocketTcp::recv(void *data, unsigned dataLen) +{ + int nbread; + + nbread = ::recv(m_handle, (Socket::Arg)data, dataLen, 0); + if (nbread == Error) { +#if defined(_WIN32) + if (WSAGetLastError() == WSAEWOULDBLOCK) + throw SocketError("recv", Socket::syserror(), WSAEWOULDBLOCK /* TODO: Read */); + + throw SocketError("recv", Socket::syserror(), WSAGetLastError()); +#else + if (errno == EAGAIN || errno == EWOULDBLOCK) + throw SocketError("recv", Socket::syserror(), errno /* TODO: Read */); + + throw SocketError("recv", Socket::syserror(), errno); +#endif + } + + return (unsigned)nbread; +} + +unsigned SocketTcp::send(const void *data, unsigned length) +{ + int nbsent; + + nbsent = ::send(m_handle, (Socket::ConstArg)data, length, 0); + if (nbsent == Error) { +#if defined(_WIN32) + if (WSAGetLastError() == WSAEWOULDBLOCK) + throw SocketError("send", Socket::syserror(), WSAEWOULDBLOCK /* Write */); + + throw SocketError("send", Socket::syserror(), WSAGetLastError()); +#else + if (errno == EAGAIN || errno == EWOULDBLOCK) + throw SocketError("send", Socket::syserror(), errno /*, Write */); + + throw SocketError("send", Socket::syserror(), errno); +#endif + } + + return (unsigned)nbsent; +}
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/C++/SocketTcp.h Mon Mar 02 14:00:48 2015 +0100 @@ -0,0 +1,81 @@ +/* + * SocketTcp.h -- portable C++ socket wrappers + * + * Copyright (c) 2013, 2014 David Demelier <markand@malikania.fr> + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +#ifndef _SOCKET_TCP_NG_3_H_ +#define _SOCKET_TCP_NG_3_H_ + +#include "Socket.h" + +/** + * @class SocketAbstractTcp + * @brief Base class for TCP sockets + */ +class SocketAbstractTcp : public Socket { +public: + using Socket::Socket; + + inline SocketAbstractTcp(int domain, int protocol) + : Socket(domain, SOCK_STREAM, protocol) + { + } + + void listen(int max = 128); + + inline std::string recv(unsigned count) + { + std::string result; + + result.resize(count); + auto n = recv(const_cast<char *>(result.data()), count); + result.resize(n); + + return result; + } + + inline unsigned send(const std::string &data) + { + return send(data.c_str(), data.size()); + } + + virtual unsigned recv(void *data, unsigned length) = 0; + + virtual unsigned send(const void *data, unsigned length) = 0; +}; + +/** + * @class SocketTcp + * @brief End-user class for TCP sockets + */ +class SocketTcp final : public SocketAbstractTcp { +public: + using SocketAbstractTcp::SocketAbstractTcp; + using SocketAbstractTcp::recv; + using SocketAbstractTcp::send; + + SocketTcp accept(); + + SocketTcp accept(SocketAddress &info); + + void connect(const SocketAddress &address); + + unsigned recv(void *data, unsigned length) override; + + unsigned send(const void *data, unsigned length) override; +}; + +#endif // !_SOCKET_TCP_NG_3_H_
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/C++/SocketUdp.cpp Mon Mar 02 14:00:48 2015 +0100 @@ -0,0 +1,77 @@ +/* + * SocketUdp.cpp -- portable C++ socket wrappers + * + * Copyright (c) 2013, 2014 David Demelier <markand@malikania.fr> + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +#include "SocketAddress.h" +#include "SocketUdp.h" + +SocketUdp::SocketUdp(int domain, int protocol) + : Socket(domain, SOCK_DGRAM, protocol) +{ +} + +unsigned SocketUdp::recvfrom(void *data, unsigned length, SocketAddress &info) +{ + int nbread; + + // Store information + sockaddr_storage address; + socklen_t addrlen; + + addrlen = sizeof (struct sockaddr_storage); + nbread = ::recvfrom(m_handle, (Socket::Arg)data, length, 0, (sockaddr *)&address, &addrlen); + + info = SocketAddress(address, addrlen); + + if (nbread == Error) { +#if defined(_WIN32) + if (WSAGetLastError() == WSAEWOULDBLOCK) + throw SocketError("recvfrom", Socket::syserror(), WSAEWOULDBLOCK /*, Read */); + + throw SocketError("recvfrom", Socket::syserror(), WSAGetLastError()); +#else + if (errno == EAGAIN || errno == EWOULDBLOCK) + throw SocketError("recvfrom", Socket::syserror(), errno /* , Read */); + + throw SocketError("recvfrom", Socket::syserror(), errno); +#endif + } + + return (unsigned)nbread; +} + +unsigned SocketUdp::sendto(const void *data, unsigned length, const SocketAddress &info) +{ + int nbsent; + + nbsent = ::sendto(m_handle, (Socket::ConstArg)data, length, 0, (const sockaddr *)&info.address(), info.length()); + if (nbsent == Error) { +#if defined(_WIN32) + if (WSAGetLastError() == WSAEWOULDBLOCK) + throw SocketError("sendto", Socket::syserror(), WSAEWOULDBLOCK /*, Write */); + + throw SocketError("sendto", Socket::syserror(), errno); +#else + if (errno == EAGAIN || errno == EWOULDBLOCK) + throw SocketError("sendto", Socket::syserror(), errno /*, Write */); + + throw SocketError("sendto", Socket::syserror(), errno); +#endif + } + + return (unsigned)nbsent; +}
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/C++/SocketUdp.h Mon Mar 02 14:00:48 2015 +0100 @@ -0,0 +1,51 @@ +/* + * SocketUdp.h -- portable C++ socket wrappers + * + * Copyright (c) 2013, 2014 David Demelier <markand@malikania.fr> + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +#ifndef _SOCKET_UDP_NG_3_H_ +#define _SOCKET_UDP_NG_3_H_ + +#include "Socket.h" + +class SocketUdp : public Socket { +public: + using Socket::Socket; + + SocketUdp(int domain, int protocol); + + inline unsigned sendto(const std::string &data, const SocketAddress &info) + { + return sendto(data.c_str(), data.length(), info); + } + + inline std::string recvfrom(unsigned count, SocketAddress &info) + { + std::string result; + + result.resize(count); + auto n = recvfrom(const_cast<char *>(result.data()), count, info); + result.resize(n); + + return result; + } + + virtual unsigned recvfrom(void *data, unsigned length, SocketAddress &info); + + virtual unsigned sendto(const void *data, unsigned length, const SocketAddress &info); +}; + +#endif // !_SOCKET_UDP_NG_3_H_
--- a/C++/Tests/Sockets/CMakeLists.txt Wed Feb 25 13:53:41 2015 +0100 +++ b/C++/Tests/Sockets/CMakeLists.txt Mon Mar 02 14:00:48 2015 +0100 @@ -22,12 +22,16 @@ SOURCES ${code_SOURCE_DIR}/C++/Socket.cpp ${code_SOURCE_DIR}/C++/Socket.h + ${code_SOURCE_DIR}/C++/SocketTcp.cpp + ${code_SOURCE_DIR}/C++/SocketTcp.h + ${code_SOURCE_DIR}/C++/SocketUdp.cpp + ${code_SOURCE_DIR}/C++/SocketUdp.h ${code_SOURCE_DIR}/C++/SocketAddress.cpp ${code_SOURCE_DIR}/C++/SocketAddress.h ${code_SOURCE_DIR}/C++/SocketListener.cpp ${code_SOURCE_DIR}/C++/SocketListener.h - ${code_SOURCE_DIR}/C++/SocketSsl.cpp - ${code_SOURCE_DIR}/C++/SocketSsl.h + #${code_SOURCE_DIR}/C++/SocketSsl.cpp + #${code_SOURCE_DIR}/C++/SocketSsl.h main.cpp )
--- a/C++/Tests/Sockets/main.cpp Wed Feb 25 13:53:41 2015 +0100 +++ b/C++/Tests/Sockets/main.cpp Mon Mar 02 14:00:48 2015 +0100 @@ -24,6 +24,243 @@ #include <gtest/gtest.h> +#include "Socket.h" +#include "SocketAddress.h" +#include "SocketListener.h" +#include "SocketTcp.h" +#include "SocketUdp.h" + +using namespace address; +using namespace std::literals::chrono_literals; + +/* -------------------------------------------------------- + * TCP tests + * -------------------------------------------------------- */ + +class TcpServerTest : public testing::Test { +protected: + SocketTcp m_server{AF_INET, 0}; + SocketTcp m_client{AF_INET, 0}; + + std::thread m_tserver; + std::thread m_tclient; + +public: + TcpServerTest() + { + m_server.set(SOL_SOCKET, SO_REUSEADDR, 1); + } + + ~TcpServerTest() + { + if (m_tserver.joinable()) + m_tserver.join(); + if (m_tclient.joinable()) + m_tclient.join(); + } +}; + +TEST_F(TcpServerTest, connect) +{ + m_tserver = std::thread([this] () { + m_server.bind(Internet("*", 16000, AF_INET)); + m_server.listen(); + m_server.accept(); + m_server.close(); + }); + + std::this_thread::sleep_for(500ms); + + m_tclient = std::thread([this] () { + m_client.connect(Internet("127.0.0.1", 16000, AF_INET)); + m_client.close(); + }); +} + +TEST_F(TcpServerTest, io) +{ + m_tserver = std::thread([this] () { + m_server.bind(Internet("*", 16000, AF_INET)); + m_server.listen(); + + auto client = m_server.accept(); + auto msg = client.recv(512); + + ASSERT_EQ("hello world", msg); + + client.send(msg); + client.close(); + + m_server.close(); + }); + + std::this_thread::sleep_for(500ms); + + m_tclient = std::thread([this] () { + m_client.connect(Internet("127.0.0.1", 16000, AF_INET)); + m_client.send("hello world"); + + ASSERT_EQ("hello world", m_client.recv(512)); + + m_client.close(); + }); +} + +/* -------------------------------------------------------- + * UDP tests + * -------------------------------------------------------- */ + +class UdpServerTest : public testing::Test { +protected: + SocketUdp m_server{AF_INET, 0}; + SocketUdp m_client{AF_INET, 0}; + + std::thread m_tserver; + std::thread m_tclient; + +public: + UdpServerTest() + { + m_server.set(SOL_SOCKET, SO_REUSEADDR, 1); + } + + ~UdpServerTest() + { + if (m_tserver.joinable()) + m_tserver.join(); + if (m_tclient.joinable()) + m_tclient.join(); + } +}; + +TEST_F(UdpServerTest, io) +{ + m_tserver = std::thread([this] () { + SocketAddress info; + + m_server.bind(Internet("*", 16000, AF_INET)); + + auto msg = m_server.recvfrom(512, info); + + ASSERT_EQ("hello world", msg); + + m_server.sendto(msg, info); + m_server.close(); + }); + + std::this_thread::sleep_for(500ms); + + m_tclient = std::thread([this] () { + Internet info("127.0.0.1", 16000, AF_INET); + + m_client.sendto("hello world", info); + + ASSERT_EQ("hello world", m_client.recvfrom(512, info)); + + m_client.close(); + }); +} + +/* -------------------------------------------------------- + * Listener tests (standard) + * -------------------------------------------------------- */ + +class ListenerTest : public testing::Test { +protected: + SocketListener m_listener; + SocketTcp socket1{AF_INET, 0}; + SocketUdp socket2{AF_INET, 0}; + +public: + ~ListenerTest() + { + socket1.close(); + socket2.close(); + } +}; + +TEST_F(ListenerTest, set) +{ + m_listener.set(socket1, SocketListener::Read); + + ASSERT_EQ(1, static_cast<int>(m_listener.size())); + ASSERT_EQ(SocketListener::Read, m_listener.begin()->second); + + m_listener.set(socket1, SocketListener::Write); + + ASSERT_EQ(1, static_cast<int>(m_listener.size())); + ASSERT_EQ(0x3, m_listener.begin()->second); + + // Fake a re-insert of the same socket + m_listener.set(socket1, SocketListener::Write); + + ASSERT_EQ(1, static_cast<int>(m_listener.size())); + ASSERT_EQ(0x3, m_listener.begin()->second); + + // Add an other socket now + m_listener.set(socket2, SocketListener::Read | SocketListener::Write); + + ASSERT_EQ(2, static_cast<int>(m_listener.size())); + + for (auto &pair : m_listener) { + ASSERT_EQ(0x3, pair.second); + ASSERT_TRUE(pair.first == socket1 || pair.first == socket2); + } +} + +TEST_F(ListenerTest, unset) +{ + m_listener.set(socket1, SocketListener::Read | SocketListener::Write); + m_listener.set(socket2, SocketListener::Read | SocketListener::Write); + + m_listener.unset(socket1, SocketListener::Read); + + ASSERT_EQ(2, static_cast<int>(m_listener.size())); + + // Use a for loop since it can be ordered differently + for (auto &pair : m_listener) { + if (pair.first == socket1) { + ASSERT_EQ(0x2, pair.second); + } else if (pair.first == socket2) { + ASSERT_EQ(0x3, pair.second); + } + } + + m_listener.unset(socket1, SocketListener::Write); + + ASSERT_EQ(1, static_cast<int>(m_listener.size())); + ASSERT_EQ(0x3, m_listener.begin()->second); +} + +TEST_F(ListenerTest, remove) +{ + m_listener.set(socket1, SocketListener::Read | SocketListener::Write); + m_listener.set(socket2, SocketListener::Read | SocketListener::Write); + m_listener.remove(socket1); + + ASSERT_EQ(1, static_cast<int>(m_listener.size())); + ASSERT_EQ(0x3, m_listener.begin()->second); +} + +TEST_F(ListenerTest, clear) +{ + m_listener.set(socket1, SocketListener::Read | SocketListener::Write); + m_listener.set(socket2, SocketListener::Read | SocketListener::Write); + m_listener.clear(); + + ASSERT_EQ(0, static_cast<int>(m_listener.size())); +} + + + + + + + + + +#if 0 + #include <Socket.h> #include <SocketListener.h> #include <SocketAddress.h> @@ -1152,15 +1389,11 @@ s.close(); } +#endif + int main(int argc, char **argv) { - Socket::init(); - testing::InitGoogleTest(&argc, argv); - auto ret = RUN_ALL_TESTS(); - - Socket::finish(); - - return ret; + return RUN_ALL_TESTS(); }