# HG changeset patch # User David Demelier # Date 1425809055 -3600 # Node ID 4f282297625bd9a7cb0e66fbb516def28826c2d2 # Parent 9e223d1de96fdc7ef153a2a14c387fd728fcab55# Parent cba77da58496c5aba1155d087779230f55357049 Merge sockets diff -r 9e223d1de96f -r 4f282297625b C++/Socket.cpp --- a/C++/Socket.cpp Fri Mar 06 22:04:38 2015 +0100 +++ b/C++/Socket.cpp Sun Mar 08 11:04:15 2015 +0100 @@ -20,112 +20,6 @@ #include "Socket.h" #include "SocketAddress.h" -#include "SocketListener.h" - -using namespace direction; - -/* -------------------------------------------------------- - * Socket exceptions - * -------------------------------------------------------- */ - -namespace error { - -/* -------------------------------------------------------- - * Error - * -------------------------------------------------------- */ - -Error::Error(std::string function, std::string error, int code) - : m_function(std::move(function)) - , m_error(std::move(error)) - , m_code(std::move(code)) -{ - m_shortcut = m_function + ": " + m_error; -} - -const std::string &Error::function() const noexcept -{ - return m_function; -} - -const std::string &Error::error() const noexcept -{ - return m_error; -} - -int Error::code() const noexcept -{ - return m_code; -} - -const char *Error::what() const noexcept -{ - return m_shortcut.c_str(); -} - -/* -------------------------------------------------------- - * InProgress - * -------------------------------------------------------- */ - -InProgress::InProgress(std::string func, std::string reason, int code, int direction) - : Error(std::move(func), std::move(reason), code) - , m_direction(direction) -{ -} - -int InProgress::direction() const noexcept -{ - return m_direction; -} - -/* -------------------------------------------------------- - * WouldBlock - * -------------------------------------------------------- */ - -WouldBlock::WouldBlock(std::string func, std::string reason, int code, int direction) - : Error(std::move(func), std::move(reason), code) - , m_direction(direction) -{ -} - -int WouldBlock::direction() const noexcept -{ - return m_direction; -} - -/* -------------------------------------------------------- - * Timeout - * -------------------------------------------------------- */ - -Timeout::Timeout(std::string func) - : m_error(func + ": Timeout exception") -{ -} - -const char *Timeout::what() const noexcept -{ - return m_error.c_str(); -} - -} // !error - -/* -------------------------------------------------------- - * Socket implementation - * -------------------------------------------------------- */ - -void Socket::init() -{ -#if defined(_WIN32) - WSADATA wsa; - WSAStartup(MAKEWORD(2, 2), &wsa); -#endif -} - -void Socket::finish() -{ -#if defined(_WIN32) - WSACleanup(); -#endif -} /* -------------------------------------------------------- * System dependent code @@ -175,285 +69,77 @@ } /* -------------------------------------------------------- - * SocketStandard clear implementation + * SocketError class * -------------------------------------------------------- */ -void SocketStandard::bind(Socket &s, const SocketAddress &addr) -{ - auto &sa = addr.address(); - auto addrlen = addr.length(); - - if (::bind(s.handle(), (sockaddr *)&sa, addrlen) == SOCKET_ERROR) - throw error::Error("bind", Socket::syserror(), errno); -} - -void SocketStandard::connect(Socket &s, const SocketAddress &addr) -{ - if (m_connected) - return; - - auto &sa = addr.address(); - auto addrlen = addr.length(); - - if (::connect(s.handle(), (sockaddr *)&sa, addrlen) == SOCKET_ERROR) { - /* - * Determine if the error comes from a non-blocking connect that cannot be - * accomplished yet. - */ -#if defined(_WIN32) - if (WSAGetLastError() == WSAEWOULDBLOCK) - throw error::InProgress("connect", Socket::syserror(WSAEWOULDBLOCK), WSAEWOULDBLOCK, Write); - - throw error::Error("connect", Socket::syserror(WSAEWOULDBLOCK), WSAGetLastError()); -#else - if (errno == EINPROGRESS) - throw error::InProgress("connect", Socket::syserror(EINPROGRESS), EINPROGRESS, Write); - - throw error::Error("connect", Socket::syserror(), errno); -#endif - } - - m_connected = true; -} - -void SocketStandard::tryConnect(Socket &s, const SocketAddress &address, int timeout) +SocketError::SocketError(Code code, std::string function) + : m_code(code) + , m_function(std::move(function)) + , m_error(Socket::syserror()) { - if (m_connected) - return; - - // Initial try - try { - connect(s, address); - } catch (const error::InProgress &) { - SocketListener listener{{s, Write}}; - - listener.select(timeout); - - // Socket is writable? Check if there is an error - auto error = s.get(SOL_SOCKET, SO_ERROR); - - if (error) - throw error::Error("connect", Socket::syserror(error), error); - } - - m_connected = true; -} - -Socket SocketStandard::accept(Socket &s, SocketAddress &info) -{ - Socket::Handle handle; - - // Store the information - sockaddr_storage address; - socklen_t addrlen; - - addrlen = sizeof (sockaddr_storage); - handle = ::accept(s.handle(), (sockaddr *)&address, &addrlen); - - if (handle == INVALID_SOCKET) { -#if defined(_WIN32) - if (WSAGetLastError() == WSAEWOULDBLOCK) - throw error::WouldBlock("accept", Socket::syserror(WSAEWOULDBLOCK), WSAEWOULDBLOCK, Read); - - throw error::Error("accept", Socket::syserror(), WSAGetLastError()); -#else - if (errno == EAGAIN || errno == EWOULDBLOCK) - throw error::WouldBlock("accept", Socket::syserror(EWOULDBLOCK), EWOULDBLOCK, Read); - - throw error::Error("accept", Socket::syserror(), errno); -#endif - } - - // Usually accept works only with SOCK_STREAM - info = SocketAddress(address, addrlen); - - return Socket(handle, std::make_shared()); -} - -Socket SocketStandard::tryAccept(Socket &s, SocketAddress &info, int timeout) -{ - SocketListener listener{{s, Read}}; - - listener.select(timeout); - - return accept(s, info); -} - -void SocketStandard::listen(Socket &s, int max) -{ - if (::listen(s.handle(), max) == SOCKET_ERROR) - throw error::Error("listen", Socket::syserror(), errno); } -unsigned SocketStandard::recv(Socket &s, void *data, unsigned dataLen) -{ - int nbread; - - nbread = ::recv(s.handle(), (Socket::Arg)data, dataLen, 0); - if (nbread == SOCKET_ERROR) { -#if defined(_WIN32) - if (WSAGetLastError() == WSAEWOULDBLOCK) - throw error::WouldBlock("recv", Socket::syserror(), WSAEWOULDBLOCK, Read); - - throw error::Error("recv", Socket::syserror(), WSAGetLastError()); -#else - if (errno == EAGAIN || errno == EWOULDBLOCK) - throw error::WouldBlock("recv", Socket::syserror(), errno, Read); - - throw error::Error("recv", Socket::syserror(), errno); -#endif - } - - return (unsigned)nbread; -} - -unsigned SocketStandard::tryRecv(Socket &s, void *data, unsigned len, int timeout) +SocketError::SocketError(Code code, std::string function, int error) + : m_code(code) + , m_function(std::move(function)) + , m_error(Socket::syserror(error)) { - SocketListener listener{{s, Read}}; - - listener.select(timeout); - - return recv(s, data, len); -} - -unsigned SocketStandard::send(Socket &s, const void *data, unsigned dataLen) -{ - int nbsent; - - nbsent = ::send(s.handle(), (Socket::ConstArg)data, dataLen, 0); - if (nbsent == SOCKET_ERROR) { -#if defined(_WIN32) - if (WSAGetLastError() == WSAEWOULDBLOCK) - throw error::WouldBlock("send", Socket::syserror(), WSAEWOULDBLOCK, Write); - - throw error::Error("send", Socket::syserror(), WSAGetLastError()); -#else - if (errno == EAGAIN || errno == EWOULDBLOCK) - throw error::WouldBlock("send", Socket::syserror(), errno, Write); - - throw error::Error("send", Socket::syserror(), errno); -#endif - } - - return (unsigned)nbsent; -} - -unsigned SocketStandard::trySend(Socket &s, const void *data, unsigned len, int timeout) -{ - SocketListener listener{{s, Write}}; - - listener.select(timeout); - - return send(s, data, len); } -unsigned SocketStandard::recvfrom(Socket &s, void *data, unsigned dataLen, SocketAddress &info) -{ - int nbread; - - // Store information - sockaddr_storage address; - socklen_t addrlen; - - addrlen = sizeof (struct sockaddr_storage); - nbread = ::recvfrom(s.handle(), (Socket::Arg)data, dataLen, 0, (sockaddr *)&address, &addrlen); - - info = SocketAddress(address, addrlen); - - if (nbread == SOCKET_ERROR) { -#if defined(_WIN32) - if (WSAGetLastError() == WSAEWOULDBLOCK) - throw error::WouldBlock("recvfrom", Socket::syserror(), WSAEWOULDBLOCK, Read); - - throw error::Error("recvfrom", Socket::syserror(), WSAGetLastError()); -#else - if (errno == EAGAIN || errno == EWOULDBLOCK) - throw error::WouldBlock("recvfrom", Socket::syserror(), errno, Read); - - throw error::Error("recvfrom", Socket::syserror(), errno); -#endif - } - - return (unsigned)nbread; -} - -unsigned SocketStandard::tryRecvfrom(Socket &s, void *data, unsigned len, SocketAddress &info, int timeout) +SocketError::SocketError(Code code, std::string function, std::string error) + : m_code(code) + , m_function(std::move(function)) + , m_error(std::move(error)) { - SocketListener listener{{s, Read}}; - - listener.select(timeout); - - return recvfrom(s, data, len, info); -} - -unsigned SocketStandard::sendto(Socket &s, const void *data, unsigned dataLen, const SocketAddress &info) -{ - int nbsent; - - nbsent = ::sendto(s.handle(), (Socket::ConstArg)data, dataLen, 0, (const sockaddr *)&info.address(), info.length()); - if (nbsent == SOCKET_ERROR) { -#if defined(_WIN32) - if (WSAGetLastError() == WSAEWOULDBLOCK) - throw error::WouldBlock("sendto", Socket::syserror(), WSAEWOULDBLOCK, Write); - - throw error::Error("sendto", Socket::syserror(), errno); -#else - if (errno == EAGAIN || errno == EWOULDBLOCK) - throw error::WouldBlock("sendto", Socket::syserror(), errno, Write); - - throw error::Error("sendto", Socket::syserror(), errno); -#endif - } - - return (unsigned)nbsent; -} - -unsigned SocketStandard::trySendto(Socket &s, const void *data, unsigned len, const SocketAddress &info, int timeout) -{ - SocketListener listener{{s, Write}}; - - listener.select(timeout); - - return sendto(s, data, len, info); -} - -void SocketStandard::close(Socket &s) -{ - (void)closesocket(s.handle()); - - m_connected = false; } /* -------------------------------------------------------- - * Socket code + * Socket class * -------------------------------------------------------- */ -Socket::Socket() - : m_interface(std::make_shared()) -{ -} +#if defined(_WIN32) +std::mutex Socket::s_mutex; +std::atomic Socket::s_initialized{false}; +#endif Socket::Socket(int domain, int type, int protocol) - : Socket() { - m_handle = socket(domain, type, protocol); +#if defined(_WIN32) && !defined(SOCKET_NO_WSA_INIT) + if (!s_initialized) + initialize(); +#endif - if (m_handle == INVALID_SOCKET) - throw error::Error("socket", syserror(), errno); + m_handle = ::socket(domain, type, protocol); + + if (m_handle == Invalid) + throw SocketError(SocketError::System, "socket"); + + m_state = SocketState::Opened; } -Socket::Socket(Handle handle, std::shared_ptr iface) - : m_interface(std::move(iface)) - , m_handle(handle) +void Socket::bind(const SocketAddress &address) { + const auto &sa = address.address(); + const auto addrlen = address.length(); + + if (::bind(m_handle, reinterpret_cast(&sa), addrlen) == Error) + throw SocketError(SocketError::System, "bind"); + + m_state = SocketState::Bound; } -Socket::Handle Socket::handle() const +void Socket::close() { - return m_handle; +#if defined(_WIN32) + ::closesocket(m_handle); +#else + ::close(m_handle); +#endif + + SocketState::Closed; } -void Socket::blockMode(bool block) +void Socket::setBlockMode(bool block) { #if defined(O_NONBLOCK) && !defined(_WIN32) int flags; @@ -466,13 +152,13 @@ else flags |= O_NONBLOCK; - if (fcntl(m_handle, F_SETFL, flags) == -1) - throw error::Error("blockMode", Socket::syserror(), errno); + if (fcntl(m_handle, F_SETFL, flags) == Error) + throw SocketError(SocketError::System, "setBlockMode"); #else unsigned long flags = (block) ? 0 : 1; - if (ioctlsocket(m_handle, FIONBIO, &flags) == SOCKET_ERROR) - throw error::Error("blockMode", Socket::syserror(), WSAGetLastError()); + if (ioctlsocket(m_handle, FIONBIO, &flags) == Error) + throw SocketError(SocketError::System, "setBlockMode"); #endif } @@ -485,35 +171,3 @@ { return s1.handle() < s2.handle(); } - -/* - * - */ - -Socket Socket::tryAccept(int timeout) -{ - SocketAddress dummy; - - return tryAccept(dummy, timeout); -} - -Socket Socket::accept() -{ - SocketAddress dummy; - - return accept(dummy); -} - -unsigned Socket::recvfrom(void *data, unsigned dataLen) -{ - SocketAddress dummy; - - return m_interface->recvfrom(*this, data, dataLen, dummy); -} - -std::string Socket::recvfrom(unsigned count) -{ - SocketAddress dummy; - - return recvfrom(count, dummy); -} diff -r 9e223d1de96f -r 4f282297625b C++/Socket.h --- a/C++/Socket.h Fri Mar 06 22:04:38 2015 +0100 +++ b/C++/Socket.h Sun Mar 08 11:04:15 2015 +0100 @@ -19,12 +19,36 @@ #ifndef _SOCKET_NG_H_ #define _SOCKET_NG_H_ +/** + * @file Socket.h + * @brief Portable socket abstraction + * + * User may set the following variables before compiling these files: + * + * SOCKET_NO_WSA_INIT - (bool) Set to false if you don't want Socket class to + * automatically calls WSAStartup() when creating sockets. + * + * Otherwise, you will need to call Socket::init, + * Socket::finish yourself. + * + * SOCKET_NO_SSL_INIT - (bool) Set to false if you don't want OpenSSL to be + * initialized when the first SocketSsl object is created. + * + * SOCKET_HAVE_POLL - (bool) Set to true if poll(2) function is available. + * + * Note: on Windows, this is automatically set if the + * _WIN32_WINNT variable is greater or equal to 0x0600. + */ + #include #include -#include #include #if defined(_WIN32) +# include +# include +# include + # include # include #else @@ -41,406 +65,109 @@ # include # include # include - -# define ioctlsocket(s, p, a) ::ioctl(s, p, a) -# define closesocket(s) ::close(s) - -# define gai_strerrorA gai_strerror - -# define INVALID_SOCKET -1 -# define SOCKET_ERROR -1 #endif -class Socket; class SocketAddress; /** - * Namespace for listener flags. + * @class SocketError + * @brief Base class for sockets error */ -namespace direction { - -constexpr const int Read = (1 << 0); //!< Wants to read -constexpr const int Write = (1 << 1); //!< Wants to write - -} // !direction +class SocketError : public std::exception { +public: + enum Code { + WouldBlockRead, ///!< The operation would block for reading + WouldBlockWrite, ///!< The operation would block for writing + Timeout, ///!< The action did timeout + System ///!< There is a system error + }; -/** - * Various errors. - */ -namespace error { - -/** - * Base error class, contains - */ -class Error : public std::exception { -private: + Code m_code; std::string m_function; std::string m_error; - int m_code; - std::string m_shortcut; -public: /** - * Construct a full error. + * Constructor that use the last system error. * - * @param function which function + * @param code which kind of error + * @param function the function name + */ + SocketError(Code code, std::string function); + + /** + * Constructor that use the system error set by the user. + * + * @param code which kind of error + * @param function the function name * @param error the error - * @param code the native code */ - Error(std::string function, std::string error, int code); + SocketError(Code code, std::string function, int error); /** - * Get the function which thrown an exception. + * Constructor that set the error specified by the user. * - * @return the function name + * @param code which kind of error + * @param function the function name + * @param error the error */ - const std::string &function() const noexcept; + SocketError(Code code, std::string function, std::string error); /** - * Get the error string. + * Get which function has triggered the error. * - * @return the error + * @return the function name (e.g connect) */ - const std::string &error() const noexcept; + inline const std::string &function() const noexcept + { + return m_function; + } /** - * Get the native code. Use with care because it varies from the system, - * the type of socket and such. + * The error code. * * @return the code */ - int code() const noexcept; - - /** - * Get a brief error message - * - * @return a message - */ - const char *what() const noexcept override; -}; - -/** - * @class InProgress - * @brief Operation cannot be accomplished now - * - * Usually thrown in a non-blocking connect call. - */ -class InProgress final : public Error { -private: - int m_direction; - -public: - /** - * Operation is in progress, contains the direction needed to complete - * the operation. - * - * The direction may be different from the requested operation because - * some sockets (e.g SSL) requires reading even when writing! - * - * @param func the function - * @param reason the reason - * @param code the native code - * @param direction the direction - */ - InProgress(std::string func, std::string reason, int code, int direction); + inline Code code() const noexcept + { + return m_code; + } /** - * Get the required direction for listening operation requires. + * Get the error (only the error content). * - * @return the direction required + * @return the error */ - int direction() const noexcept; -}; - -/** - * @class WouldBlock - * @brief The operation would block - * - * Usually thrown in a non-blocking connect send or receive. - */ -class WouldBlock final : public Error { -private: - int m_direction; - -public: - /** - * Operation would block, contains the direction needed to complete the - * operation. - * - * The direction may be different from the requested operation because - * some sockets (e.g SSL) requires reading even when writing! - * - * @param func the function - * @param reason the reason - * @param code the native code - * @param direction the direction - */ - WouldBlock(std::string func, std::string reason, int code, int direction); - - /** - * Get the required direction for listening operation requires. - * - * @return the direction required - */ - int direction() const noexcept; + const char *what() const noexcept + { + return m_error.c_str(); + } }; /** - * @class Timeout - * @brief Describe a timeout expiration - * - * Usually thrown on timeout in SocketListener::select. - */ -class Timeout final : public std::exception { -private: - std::string m_error; - -public: - /** - * Timeout exception. - * - * @param func the function name - */ - Timeout(std::string func); - - /** - * The error message. - */ - const char *what() const noexcept override; -}; - -} // !error - -/** - * @class SocketInterface - * @brief Interface to implement - * - * This class implements the socket functions. + * @enum SocketState + * @brief Category of error */ -class SocketInterface { -public: - /** - * Default destructor. - */ - virtual ~SocketInterface() = default; - - /** - * Bind the socket. - * - * @param s the socket - * @param address the address - * @throw error::Failure on error - */ - virtual void bind(Socket &s, const SocketAddress &address) = 0; - - /** - * Close the socket. - * - * @param s the socket - */ - virtual void close(Socket &s) = 0; - - /** - * Try to connect to the specific address - * - * @param s the socket - * @param address the address - * @throw error::Failure on error - * @throw error::InProgress if the socket is marked non-blocking and connection cannot be established yet - */ - virtual void connect(Socket &s, const SocketAddress &address) = 0; - - /** - * Try to connect without blocking. - * - * @param s the socket - * @param address the address - * @param timeout the timeout in milliseconds - * @warning This function won't block only if the socket is marked non blocking - */ - virtual void tryConnect(Socket &s, const SocketAddress &address, int timeout) = 0; - - /** - * Accept a client. - * - * @param s the socket - * @param info the client information - * @return a client ready to use - * @throw error::Failure on error - */ - virtual Socket accept(Socket &s, SocketAddress &info) = 0; - - /** - * Try to accept a client with a timeout. - * - * @param s the socket - * @param info the client information - * @param timeout the timeout in milliseconds - */ - virtual Socket tryAccept(Socket &s, SocketAddress &info, int timeout) = 0; - - /** - * Listen to a specific number of pending connections. - * - * @param s the socket - * @param max the max number of clients - * @throw error::Failure on error - */ - virtual void listen(Socket &s, int max) = 0; - - /** - * Receive some data. - * - * @param s the socket - * @param data the destination pointer - * @param len max length to receive - * @return the number of bytes received - * @throw error::Failure on error - * @throw error::WouldBlock if the socket is marked non-blocking and the operation would block - */ - virtual unsigned recv(Socket &s, void *data, unsigned len) = 0; - - /** - * Try to receive data without blocking. - * - * @param s the socket - * @param data the destination pointer - * @param len max length to receive - * @param timeout the timeout in milliseconds - * @return the number of bytes received - * @throw error::Failure on error - */ - virtual unsigned tryRecv(Socket &s, void *data, unsigned len, int timeout) = 0; - - /** - * Receive from a connection-less socket. - * - * @param s the socket - * @param data the destination pointer - * @param dataLen max length to receive - * @param info the client information - * @return the number of bytes received - * @throw error::Failure on error - * @throw error::WouldBlock if the socket is marked non-blocking and the operation would block - */ - virtual unsigned recvfrom(Socket &s, void *data, unsigned len, SocketAddress &info) = 0; - - /** - * Try to receive data from a connection-less socket without blocking. - * - * @param s the socket - * @param data the data - * @param len max length to receive - * @param info the client information - * @param timeout the optional timeout in milliseconds - * @return the number of bytes received - * @throw error::Failure on error - */ - virtual unsigned tryRecvfrom(Socket &s, void *data, unsigned len, SocketAddress &info, int timeout) = 0; - - /** - * Send some data. - * - * @param s the socket - * @param data the data to send - * @param dataLen the data length - * @return the number of bytes sent - * @throw error::Failure on error - * @throw error::WouldBlock if the socket is marked non-blocking and the operation would block - */ - virtual unsigned send(Socket &s, const void *data, unsigned len) = 0; - - /** - * Try to send some data without blocking. - * - * @param s the socket - * @param data the data to send - * @param dataLen the data length - * @param timeout the optional timeout in milliseconds - * @return the number of bytes sent - * @throw error::Failure on error - * @throw error::WouldBlock if the socket is marked non-blocking and the operation would block - * @warning This function won't block only if the socket is marked non blocking - */ - virtual unsigned trySend(Socket &s, const void *data, unsigned len, int timeout) = 0; - - /** - * Send some data to a connection-less socket. - * - * @param s the socket - * @param data the data to send - * @param dataLen the data length - * @param address the address - * @return the number of bytes sent - * @throw error::Failure on error - * @throw error::WouldBlock if the socket is marked non-blocking and the operation would block - */ - virtual unsigned sendto(Socket &s, const void *data, unsigned len, const SocketAddress &info) = 0; - - /** - * Try to send some data to a connection-less without blocking. - * - * @param s the socket - * @param data the data to send - * @param dataLen the data length - * @param address the address - * @return the number of bytes sent - * @throw error::Failure on error - * @throw error::WouldBlock if the socket is marked non-blocking and the operation would block - * @warning This function won't block only if the socket is marked non blocking - */ - virtual unsigned trySendto(Socket &s, const void *data, unsigned len, const SocketAddress &info, int timeout) = 0; -}; - -/** - * Standard interface for sockets. - * - * This will use standard clear functions: - * - * bind(2) - * connect(2) - * close(2), closesocket on Windows, - * accept(2), - * listen(2), - * recv(2), - * recvfrom(2), - * send(2), - * sendto(2) - */ -class SocketStandard : public SocketInterface { -protected: - bool m_connected{false}; - -public: - void bind(Socket &s, const SocketAddress &address) override; - void close(Socket &s) override; - void connect(Socket &s, const SocketAddress &address) override; - void tryConnect(Socket &s, const SocketAddress &address, int timeout) override; - Socket accept(Socket &s, SocketAddress &info) override; - Socket tryAccept(Socket &s, SocketAddress &info, int timeout) override; - void listen(Socket &s, int max) override; - unsigned recv(Socket &s, void *data, unsigned len) override; - unsigned tryRecv(Socket &s, void *data, unsigned len, int timeout) override; - unsigned recvfrom(Socket &s, void *data, unsigned len, SocketAddress &info) override; - unsigned tryRecvfrom(Socket &s, void *data, unsigned len, SocketAddress &info, int timeout) override; - unsigned send(Socket &s, const void *data, unsigned len) override; - unsigned trySend(Socket &s, const void *data, unsigned len, int timeout) override; - unsigned sendto(Socket &s, const void *data, unsigned len, const SocketAddress &info) override; - unsigned trySendto(Socket &s, const void *data, unsigned len, const SocketAddress &info, int timeout) override; +enum class SocketState { + Opened, ///!< Socket is opened + Closed, ///!< Socket has been closed + Bound, ///!< Socket is bound to address + Connected, ///!< Socket is connected to an end point + Disconnected, ///!< Socket is disconnected + Timeout ///!< Timeout has occured in a waiting operation }; /** * @class Socket - * @brief socket abstraction - * - * This class is a big wrapper around sockets functions but portable, - * there is some functions that helps for getting error reporting. - * - * This class is implemented as a PIMPL idiom, it is perfectly - * safe to cast the object to any other derivate children. + * @brief Base socket class for socket operations */ class Socket { public: + /* {{{ Portable types */ + + /* + * The following types are defined differently between Unix + * and Windows. + */ #if defined(_WIN32) using Handle = SOCKET; using ConstArg = const char *; @@ -451,19 +178,82 @@ using Arg = void *; #endif - using Iface = std::shared_ptr; + /* }}} */ + + /* {{{ Portable constants */ + + /* + * The following constants are defined differently from Unix + * to Windows. + */ +#if defined(_WIN32) + static constexpr const int Invalid = INVALID_SOCKET; + static constexpr const int Error = SOCKET_ERROR; +#else + static constexpr const int Invalid = -1; + static constexpr const int Error = -1; +#endif + + /* }}} */ + + /* {{{ Portable initialization */ + + /* + * Initialization stuff. + * + * The function init and finish are threadsafe. + */ +#if defined(_WIN32) +private: + static std::mutex s_mutex; + static std::atomic s_initialized; + +public: + static inline void finish() noexcept + { + WSACleanup(); + } + + static inline void init() noexcept + { + std::lock_guard lock(s_mutex); + + if (!s_initialized) { + s_initialized = true; + + WSADATA wsa; + WSAStartup(MAKEWORD(2, 2), &wsa); + + /* + * If SOCKET_WSA_NO_INIT is not set then the user + * must also call finish himself. + */ +#if !defined(SOCKET_WSA_NO_INIT) + std::atexit(finish); +#endif + } + } +#else +public: + /** + * no-op. + */ + static inline void init() noexcept {} + + /** + * no-op. + */ + static inline void finish() noexcept {} +#endif + + /* }}} */ protected: - Iface m_interface; //!< the interface - Handle m_handle{INVALID_SOCKET}; //!< the socket shared pointer + Handle m_handle; + SocketState m_state{SocketState::Opened}; public: /** - * To be called before any socket operation. - */ - static void init(); - - /** * Get the last socket system error. The error is set from errno or from * WSAGetLastError on Windows. * @@ -480,58 +270,48 @@ static std::string syserror(int errn); /** - * To be called before exiting. + * Construct a socket with an already created descriptor. + * + * @param handle the native descriptor */ - static void finish(); + inline Socket(Handle handle) + : m_handle(handle) + , m_state(SocketState::Opened) + { + } /** - * Default constructor. - */ - Socket(); - - /** - * Constructor to create a new socket. + * Create a socket handle. * - * @param domain the domain - * @param type the type + * @param domain the domain AF_* + * @param type the type SOCK_* * @param protocol the protocol - * @throw error::Failure on error + * @throw SocketError on failures */ Socket(int domain, int type, int protocol); /** - * Create a socket object with a already initialized socket. - * - * @param handle the handle - * @param interface the interface to use - */ - Socket(Handle handle, std::shared_ptr iface); - - /** - * Close the socket. + * Default destructor. */ virtual ~Socket() = default; /** - * Get the socket. - * - * @return the socket - */ - Handle handle() const; - - /** * Set an option for the socket. * * @param level the setting level * @param name the name * @param arg the value - * @throw error::Failure on error + * @throw SocketError on error */ template - void set(int level, int name, const Argument &arg) + inline void set(int level, int name, const Argument &arg) { +#if defined(_WIN32) if (setsockopt(m_handle, level, name, (Socket::ConstArg)&arg, sizeof (arg)) == SOCKET_ERROR) - throw error::Error("set", syserror(), errno); +#else + if (setsockopt(m_handle, level, name, (Socket::ConstArg)&arg, sizeof (arg)) < 0) +#endif + throw SocketError(SocketError::System, "set"); } /** @@ -539,16 +319,20 @@ * * @param level the setting level * @param name the name - * @throw error::Failure on error + * @throw SocketError on error */ template - Argument get(int level, int name) + inline Argument get(int level, int name) { Argument desired, result{}; socklen_t size = sizeof (result); +#if defined(_WIN32) if (getsockopt(m_handle, level, name, (Socket::Arg)&desired, &size) == SOCKET_ERROR) - throw error::Error("get", syserror(), errno); +#else + if (getsockopt(m_handle, level, name, (Socket::Arg)&desired, &size) < 0) +#endif + throw SocketError(SocketError::System, "get"); std::memcpy(&result, &desired, size); @@ -556,234 +340,65 @@ } /** - * Enable or disable blocking mode. + * Get the native handle. * - * @param block the mode - * @throw error::Failure on error - */ - void blockMode(bool block = true); - - /** - * @copydoc SocketInterface::bind + * @return the handle + * @warning Not portable */ - inline void bind(const SocketAddress &address) - { - m_interface->bind(*this, address); - } - - /** - * @copydoc SocketInterface::close - */ - inline void close() + inline Handle handle() const noexcept { - m_interface->close(*this); - } - - /** - * @copydoc SocketInterface::connect - */ - inline void connect(const SocketAddress &address) - { - m_interface->connect(*this, address); - } - - /** - * @copydoc SocketInterface::tryConnect - */ - inline void tryConnect(const SocketAddress &address, int timeout) - { - m_interface->tryConnect(*this, address, timeout); - } - - /** - * @copydoc SocketInterface::tryAccept - */ - inline Socket tryAccept(SocketAddress &address, int timeout) - { - return m_interface->tryAccept(*this, address, timeout); + return m_handle; } /** - * Overload without client information. - * - * @param timeout the timeout - * @return the socket - * @throw error::Error on errors - * @throw error::Timeout if operation has timeout - */ - Socket tryAccept(int timeout); - - /** - * Accept a client without getting its info. + * Get the socket state. * - * @return a client ready to use - * @throw error::Failure on error - */ - Socket accept(); - - /** - * @copydoc SocketInterface::accept - */ - inline Socket accept(SocketAddress &info) - { - return m_interface->accept(*this, info); - } - - /** - * @copydoc SocketInterface::listen + * @return */ - inline void listen(int max) - { - m_interface->listen(*this, max); - } - - /** - * @copydoc SocketInterface::recv - */ - inline unsigned recv(void *data, unsigned dataLen) + inline SocketState state() const noexcept { - return m_interface->recv(*this, data, dataLen); - } - - /** - * Overload for strings. - * - * @param count number of bytes to recv - * @return the string - * @throw error::Error on failures - * @throw error::WouldBlock if operation would block - */ - inline std::string recv(unsigned count) - { - std::string result; - - result.resize(count); - auto n = recv(const_cast(result.data()), count); - result.resize(n); - - return result; + return m_state; } /** - * @copydoc SocketInterface::tryRecv - */ - inline unsigned tryRecv(void *data, unsigned dataLen, int timeout = 0) - { - return m_interface->tryRecv(*this, data, dataLen, timeout); - } - - /** - * Overload for string. + * Bind to an address. * - * @param count - * @param timeout - * @return the string - * @throw error::Error on failures - * @throw error::Timeout if operation has timeout + * @param address the address + * @throw SocketError on any error */ - inline std::string tryRecv(unsigned count, int timeout = 0) - { - std::string result; - - result.resize(count); - auto n = tryRecv(const_cast(result.data()), count, timeout); - result.resize(n); - - return result; - } + void bind(const SocketAddress &address); /** - * Receive from a connection-less socket without getting - * client information. + * Set the blocking mode, if set to false, the socket will be marked + * **non-blocking**. * - * @param data the destination pointer - * @param dataLen max length to receive - * @return the number of bytes received - * @throw error::Failure on error - */ - unsigned recvfrom(void *data, unsigned dataLen); - - /** - * @copydoc SocketInterface::recvfrom + * @param block set to false to mark **non-blocking** + * @throw SocketError on any error */ - inline unsigned recvfrom(void *data, unsigned dataLen, SocketAddress &info) - { - return m_interface->recvfrom(*this, data, dataLen, info); - } - - /** - * Overload for string. - * - * @param count the number of bytes to receive - * @return the string - * @throw error::Error on failures - * @throw error::WouldBlock if operation would block - */ - std::string recvfrom(unsigned count); + void setBlockMode(bool block); /** - * Overload with client information. - * - * @param count the number of bytes to receive - * @return the string - * @throw error::Error on failures - * @throw error::WouldBlock if operation would block - */ - inline std::string recvfrom(unsigned count, SocketAddress &info) - { - std::string result; - - result.resize(count); - auto n = recvfrom(const_cast(result.data()), count, info); - result.resize(n); - - return result; - } - - /** - * @copydoc SocketInterface::send + * Close the socket. */ - inline unsigned send(const void *data, unsigned dataLen) - { - return m_interface->send(*this, data, dataLen); - } - - /** - * Send a message as a string. - * - * @param message the message - * @return the number of bytes sent - * @throw SocketError on error - */ - inline unsigned send(const std::string &message) - { - return send(message.data(), message.size()); - } - - /** - * @copydoc SocketInterface::sendto - */ - inline unsigned sendto(const void *data, unsigned dataLen, const SocketAddress &info) - { - return m_interface->sendto(*this, data, dataLen, info); - } - - /** - * Send a message to a connection-less socket. - * - * @param message the message - * @param address the address - * @return the number of bytes sent - * @throw SocketError on error - */ - inline unsigned sendto(const std::string &message, const SocketAddress &info) - { - return sendto(message.data(), message.size(), info); - } + virtual void close(); }; +/** + * Compare two sockets. + * + * @param s1 the first socket + * @param s2 the second socket + * @return true if they equals + */ bool operator==(const Socket &s1, const Socket &s2); -bool operator<(const Socket &s, const Socket &s2); +/** + * Compare two sockets, ideal for putting in a std::map. + * + * @param s1 the first socket + * @param s2 the second socket + * @return true if s1 < s2 + */ +bool operator<(const Socket &s1, const Socket &s2); #endif // !_SOCKET_NG_H_ diff -r 9e223d1de96f -r 4f282297625b C++/SocketAddress.cpp --- a/C++/SocketAddress.cpp Fri Mar 06 22:04:38 2015 +0100 +++ b/C++/SocketAddress.cpp Sun Mar 08 11:04:15 2015 +0100 @@ -56,7 +56,7 @@ auto error = getaddrinfo(host.c_str(), std::to_string(port).c_str(), &hints, &res); if (error != 0) - throw error::Error("getaddrinfo", gai_strerror(error), error); + throw SocketError(SocketError::System, "getaddrinfo", gai_strerror(error)); std::memcpy(&m_addr, res->ai_addr, res->ai_addrlen); m_addrlen = res->ai_addrlen; diff -r 9e223d1de96f -r 4f282297625b C++/SocketAddress.h --- a/C++/SocketAddress.h Fri Mar 06 22:04:38 2015 +0100 +++ b/C++/SocketAddress.h Sun Mar 08 11:04:15 2015 +0100 @@ -16,8 +16,8 @@ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */ -#ifndef _SOCKET_ADDRESS_H_ -#define _SOCKET_ADDRESS_H_ +#ifndef _SOCKET_ADDRESS_NG_H_ +#define _SOCKET_ADDRESS_NG_H_ #include @@ -140,4 +140,4 @@ } // !address -#endif // !_SOCKET_ADDRESS_H_ +#endif // !_SOCKET_ADDRESS_NG_H_ diff -r 9e223d1de96f -r 4f282297625b C++/SocketListener.cpp --- a/C++/SocketListener.cpp Fri Mar 06 22:04:38 2015 +0100 +++ b/C++/SocketListener.cpp Sun Mar 08 11:04:15 2015 +0100 @@ -24,131 +24,126 @@ #include "SocketListener.h" -using namespace direction; - /* -------------------------------------------------------- * Select implementation * -------------------------------------------------------- */ +namespace { + /** * @class SelectMethod * @brief Implements select(2) * * This class is the fallback of any other method, it is not preferred at all for many reasons. */ -class SelectMethod final : public SocketListener::Interface { +class SelectMethod final : public SocketListenerInterface { private: - std::map> m_table; + std::map, int>> m_table; public: - void add(Socket s, int direction) override; - void remove(const Socket &s, int direction) override; - void list(const SocketListener::MapFunc &func) override; - void clear() override; - unsigned size() const override; - SocketStatus select(int ms) override; - std::vector selectMultiple(int ms) override; + void set(Socket &s, int direction) override + { + if (m_table.count(s.handle()) > 0) { + m_table.at(s.handle()).second |= direction; + } else { + m_table.insert({s.handle(), {s, direction}}); + } + + } + + void unset(Socket &s, int direction) override + { + if (m_table.count(s.handle()) != 0) { + m_table.at(s.handle()).second &= ~(direction); + + // If no read, no write is requested, remove it + if (m_table.at(s.handle()).second == 0) { + m_table.erase(s.handle()); + } + } + } + + void remove(Socket &sc) override + { + m_table.erase(sc.handle()); + } + + void clear() override + { + m_table.clear(); + } + + SocketStatus select(int ms) override + { + auto result = selectMultiple(ms); + + if (result.size() == 0) { + throw SocketError(SocketError::System, "select", "No socket found"); + } + + return result[0]; + } + + std::vector selectMultiple(int ms) override + { + timeval maxwait, *towait; + fd_set readset; + fd_set writeset; + + FD_ZERO(&readset); + FD_ZERO(&writeset); + + Socket::Handle max = 0; + + for (auto &s : m_table) { + if (s.second.second & SocketListener::Read) { + FD_SET(s.first, &readset); + } + if (s.second.second & SocketListener::Write) { + FD_SET(s.first, &writeset); + } + + if (s.first > max) { + max = s.first; + } + } + + maxwait.tv_sec = 0; + maxwait.tv_usec = ms * 1000; + + // Set to nullptr for infinite timeout. + towait = (ms < 0) ? nullptr : &maxwait; + + auto error = ::select(max + 1, &readset, &writeset, nullptr, towait); + if (error == Socket::Error) { + throw SocketError(SocketError::System, "select"); + } + if (error == 0) { + throw SocketError(SocketError::Timeout, "select", "Timeout while listening"); + } + + std::vector sockets; + + for (auto &c : m_table) { + if (FD_ISSET(c.first, &readset)) { + sockets.push_back({ c.second.first, SocketListener::Read }); + } + if (FD_ISSET(c.first, &writeset)) { + sockets.push_back({ c.second.first, SocketListener::Write }); + } + } + + return sockets; + } }; -void SelectMethod::add(Socket s, int direction) -{ - if (m_table.count(s.handle()) > 0) - m_table[s.handle()].second |= direction; - else - m_table[s.handle()] = { s, direction }; -} - -void SelectMethod::remove(const Socket &s, int direction) -{ - if (m_table.count(s.handle()) != 0) { - m_table[s.handle()].second &= ~(direction); - - // If no read, no write is requested, remove it - if (m_table[s.handle()].second == 0) - m_table.erase(s.handle()); - } -} - -void SelectMethod::list(const SocketListener::MapFunc &func) -{ - for (auto &s : m_table) - func(s.second.first, s.second.second); -} - -void SelectMethod::clear() -{ - m_table.clear(); -} - -unsigned SelectMethod::size() const -{ - return m_table.size(); -} - -SocketStatus SelectMethod::select(int ms) -{ - auto result = selectMultiple(ms); - - if (result.size() == 0) - throw error::Error("select", "No socket found", 0); - - return result[0]; -} - -std::vector SelectMethod::selectMultiple(int ms) -{ - timeval maxwait, *towait; - fd_set readset; - fd_set writeset; - - FD_ZERO(&readset); - FD_ZERO(&writeset); - - Socket::Handle max = 0; - - for (auto &s : m_table) { - if (s.second.second & Read) - FD_SET(s.first, &readset); - if (s.second.second & Write) - FD_SET(s.first, &writeset); - - if (s.first > max) - max = s.first; - } - - maxwait.tv_sec = 0; - maxwait.tv_usec = ms * 1000; - - // Set to nullptr for infinite timeout. - towait = (ms < 0) ? nullptr : &maxwait; - - auto error = ::select(max + 1, &readset, &writeset, nullptr, towait); - if (error == SOCKET_ERROR) -#if defined(_WIN32) - throw error::Error("select", Socket::syserror(), WSAGetLastError()); -#else - throw error::Error("select", Socket::syserror(), errno); -#endif - if (error == 0) - throw error::Timeout("select"); - - std::vector sockets; - - for (auto &c : m_table) { - if (FD_ISSET(c.first, &readset)) - sockets.push_back({ c.second.first, Read }); - if (FD_ISSET(c.first, &writeset)) - sockets.push_back({ c.second.first, Write }); - } - - return sockets; -} +} // !namespace /* -------------------------------------------------------- * Poll implementation * -------------------------------------------------------- */ -#if defined(SOCKET_LISTENER_HAVE_POLL) +#if defined(SOCKET_HAVE_POLL) #if defined(_WIN32) # include @@ -159,18 +154,18 @@ namespace { -class PollMethod final : public SocketListener::Interface { +class PollMethod final : public SocketListenerInterface { private: std::vector m_fds; - std::map m_lookup; + std::map> m_lookup; inline short topoll(int direction) { short result(0); - if (direction & Read) + if (direction & SocketListener::Read) result |= POLLIN; - if (direction & Write) + if (direction & SocketListener::Write) result |= POLLOUT; return result; @@ -188,121 +183,120 @@ * return 0 so we mark the socket as readable. */ if ((event & POLLIN) || (event & POLLHUP)) - direction |= Read; + direction |= SocketListener::Read; if (event & POLLOUT) - direction |= Write; + direction |= SocketListener::Write; return direction; } public: - void add(Socket s, int direction) override; - void remove(const Socket &s, int direction) override; - void list(const SocketListener::MapFunc &func) override; - void clear() override; - unsigned size() const override; - SocketStatus select(int ms) override; - std::vector selectMultiple(int ms) override; -}; + void set(Socket &s, int direction) override + { + auto it = std::find_if(m_fds.begin(), m_fds.end(), [&] (const auto &pfd) { return pfd.fd == s.handle(); }); -void PollMethod::add(Socket s, int direction) -{ - auto it = std::find_if(m_fds.begin(), m_fds.end(), [&] (const auto &pfd) { return pfd.fd == s.handle(); }); + // If found, add the new direction, otherwise add a new socket + if (it != m_fds.end()) + it->events |= topoll(direction); + else { + m_lookup.insert({s.handle(), s}); + m_fds.push_back({ s.handle(), topoll(direction), 0 }); + } + } - // If found, add the new direction, otherwise add a new socket - if (it != m_fds.end()) - it->events |= topoll(direction); - else { - m_lookup[s.handle()] = s; - m_fds.push_back({ s.handle(), topoll(direction), 0 }); - } -} + void unset(Socket &s, int direction) override + { + for (auto i = m_fds.begin(); i != m_fds.end();) { + if (i->fd == s.handle()) { + i->events &= ~(topoll(direction)); -void PollMethod::remove(const Socket &s, int direction) -{ - for (auto i = m_fds.begin(); i != m_fds.end();) { - if (i->fd == s.handle()) { - i->events &= ~(topoll(direction)); - - if (i->events == 0) { - m_lookup.erase(i->fd); - i = m_fds.erase(i); - } else { + if (i->events == 0) { + m_lookup.erase(i->fd); + i = m_fds.erase(i); + } else { + ++i; + } + } else ++i; - } - } else - ++i; + } } -} -void PollMethod::list(const SocketListener::MapFunc &func) -{ - for (auto &fd : m_fds) - func(m_lookup[fd.fd], todirection(fd.events)); -} + void remove(Socket &s) override + { + auto it = std::find_if(m_fds.begin(), m_fds.end(), [&] (const auto &pfd) { return pfd.fd == s.handle(); }); + + if (it != m_fds.end()) { + m_fds.erase(it); + m_lookup.erase(s.handle()); + } + } -void PollMethod::clear() -{ - m_fds.clear(); - m_lookup.clear(); -} - -unsigned PollMethod::size() const -{ - return static_cast(m_fds.size()); -} + void clear() override + { + m_fds.clear(); + m_lookup.clear(); + } -SocketStatus PollMethod::select(int ms) -{ - auto result = poll(m_fds.data(), m_fds.size(), ms); - if (result == 0) - throw error::Timeout("select"); - if (result < 0) -#if defined(_WIN32) - throw error::Error("poll", Socket::syserror(), WSAGetLastError()); -#else - throw error::Error("poll", Socket::syserror(), errno); -#endif + SocketStatus select(int ms) override + { + auto result = poll(m_fds.data(), m_fds.size(), ms); + if (result == 0) + throw SocketError(SocketError::Timeout, "select", "Timeout while listening"); + if (result < 0) + throw SocketError(SocketError::System, "poll"); + + for (auto &fd : m_fds) { + if (fd.revents != 0) { + return { m_lookup.at(fd.fd), todirection(fd.revents) }; + } + } - for (auto &fd : m_fds) - if (fd.revents != 0) - return { m_lookup[fd.fd], todirection(fd.revents) }; - - throw error::Error("select", "No socket found", 0); -} + throw SocketError(SocketError::System, "select", "No socket found"); + } -std::vector PollMethod::selectMultiple(int ms) -{ - auto result = poll(m_fds.data(), m_fds.size(), ms); - if (result == 0) - throw error::Timeout("select"); - if (result < 0) -#if defined(_WIN32) - throw error::Error("poll", Socket::syserror(), WSAGetLastError()); -#else - throw error::Error("poll", Socket::syserror(), errno); -#endif + std::vector selectMultiple(int ms) override + { + auto result = poll(m_fds.data(), m_fds.size(), ms); + if (result == 0) { + throw SocketError(SocketError::Timeout, "select", "Timeout while listening"); + } + if (result < 0) { + throw SocketError(SocketError::System, "poll"); + } - std::vector sockets; - for (auto &fd : m_fds) - if (fd.revents != 0) - sockets.push_back({ m_lookup[fd.fd], todirection(fd.revents) }); + std::vector sockets; + for (auto &fd : m_fds) { + if (fd.revents != 0) { + sockets.push_back({ m_lookup.at(fd.fd), todirection(fd.revents) }); + } + } - return sockets; -} + return sockets; + } +}; } // !namespace -#endif // !_SOCKET_LISTENER_HAVE_POLL +#endif // !_SOCKET_HAVE_POLL /* -------------------------------------------------------- - * Socket listener + * SocketListener * -------------------------------------------------------- */ -SocketListener::SocketListener(int method) +const int SocketListener::Read{1 << 0}; +const int SocketListener::Write{1 << 1}; + +SocketListener::SocketListener(std::initializer_list, int>> list) + : SocketListener() { -#if defined(SOCKET_LISTENER_HAVE_POLL) - if (method == Poll) + for (const auto &p : list) + set(p.first, p.second); +} + +SocketListener::SocketListener(SocketMethod method) +{ +#if defined(SOCKET_HAVE_POLL) + if (method == SocketMethod::Poll) m_interface = std::make_unique(); else #endif @@ -311,9 +305,26 @@ (void)method; } -SocketListener::SocketListener(std::initializer_list> list, int method) - : SocketListener(method) +void SocketListener::set(Socket &sc, int flags) { - for (const auto &p : list) - add(p.first, p.second); + if (m_map.count(sc) > 0) { + m_map[sc] |= flags; + m_interface->set(sc, flags); + } else { + m_map.insert({sc, flags}); + m_interface->set(sc, flags); + } } + +void SocketListener::unset(Socket &sc, int flags) noexcept +{ + if (m_map.count(sc) > 0) { + m_map[sc] &= ~(flags); + m_interface->unset(sc, flags); + + // No more flags, remove it + if (m_map[sc] == 0) { + m_map.erase(sc); + } + } +} diff -r 9e223d1de96f -r 4f282297625b C++/SocketListener.h --- a/C++/SocketListener.h Fri Mar 06 22:04:38 2015 +0100 +++ b/C++/SocketListener.h Sun Mar 08 11:04:15 2015 +0100 @@ -16,12 +16,14 @@ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */ -#ifndef _SOCKET_LISTENER_H_ -#define _SOCKET_LISTENER_H_ +#ifndef _SOCKET_LISTENER_NG_H_ +#define _SOCKET_LISTENER_NG_H_ #include #include #include +#include +#include #include #include @@ -29,10 +31,10 @@ #if defined(_WIN32) # if _WIN32_WINNT >= 0x0600 -# define SOCKET_LISTENER_HAVE_POLL +# define SOCKET_HAVE_POLL # endif #else -# define SOCKET_LISTENER_HAVE_POLL +# define SOCKET_HAVE_POLL #endif /** @@ -42,7 +44,7 @@ * Select the method of polling. It is only a preferred method, for example if you * request for poll but it is not available, select will be used. */ -enum SocketMethod { +enum class SocketMethod { Select, //!< select(2) method, fallback Poll //!< poll(2), everywhere possible }; @@ -54,9 +56,70 @@ * Result of a select call, returns the first ready socket found with its * direction. */ -struct SocketStatus { - Socket socket; //!< which socket is ready - int direction; //!< the direction +class SocketStatus { +public: + Socket &socket; //!< which socket is ready + int direction; //!< the direction +}; + +/** + * @class SocketListenerInterface + * @brief Implement the polling method + */ +class SocketListenerInterface { +public: + /** + * Default destructor. + */ + virtual ~SocketListenerInterface() = default; + + /** + * Add a socket with a specified direction. + * + * @param s the socket + * @param direction the direction + */ + virtual void set(Socket &sc, int direction) = 0; + + /** + * Remove a socket with a specified direction. + * + * @param s the socket + * @param direction the direction + */ + virtual void unset(Socket &sc, int direction) = 0; + + /** + * Remove completely a socket. + * + * @param sc the socket to remove + */ + virtual void remove(Socket &sc) = 0; + + /** + * Remove all sockets. + */ + virtual void clear() = 0; + + /** + * Select one socket. + * + * @param ms the number of milliseconds to wait, -1 means forever + * @return the socket status + * @throw error::Failure on failure + * @throw error::Timeout on timeout + */ + virtual SocketStatus select(int ms) = 0; + + /** + * Select many sockets. + * + * @param ms the number of milliseconds to wait, -1 means forever + * @return a vector of ready sockets + * @throw error::Failure on failure + * @throw error::Timeout on timeout + */ + virtual std::vector selectMultiple(int ms) = 0; }; /** @@ -64,86 +127,33 @@ * @brief Synchronous multiplexing * * Convenient wrapper around the select() system call. + * + * This class is implemented using a bridge pattern to allow different uses + * of listener implementation. + * + * Currently, poll and select() are available. + * + * This wrappers takes abstract sockets as non-const reference but it does not + * own them so you must take care that sockets are still alive until the + * SocketListener is destroyed. */ class SocketListener final { public: - /** - * @brief Function for listing all sockets - */ - using MapFunc = std::function; - -#if defined(SOCKET_LISTENER_HAVE_POLL) +#if defined(SOCKET_HAVE_POLL) static constexpr const SocketMethod PreferredMethod = SocketMethod::Poll; #else static constexpr const SocketMethod PreferredMethod = SocketMethod::Select; #endif - /** - * @class Interface - * @brief Implement the polling method - */ - class Interface { - public: - /** - * Default destructor. - */ - virtual ~Interface() = default; - - /** - * List all sockets in the interface. - * - * @param func the function - */ - virtual void list(const MapFunc &func) = 0; - - /** - * Add a socket with a specified direction. - * - * @param s the socket - * @param direction the direction - */ - virtual void add(Socket s, int direction) = 0; + static const int Read; + static const int Write; - /** - * Remove a socket with a specified direction. - * - * @param s the socket - * @param direction the direction - */ - virtual void remove(const Socket &s, int direction) = 0; - - /** - * Remove all sockets. - */ - virtual void clear() = 0; - - /** - * Get the total number of sockets in the listener. - */ - virtual unsigned size() const = 0; + using Map = std::map, int>; + using Iface = std::unique_ptr; - /** - * Select one socket. - * - * @param ms the number of milliseconds to wait, -1 means forever - * @return the socket status - * @throw error::Failure on failure - * @throw error::Timeout on timeout - */ - virtual SocketStatus select(int ms) = 0; - - /** - * Select many sockets. - * - * @param ms the number of milliseconds to wait, -1 means forever - * @return a vector of ready sockets - * @throw error::Failure on failure - * @throw error::Timeout on timeout - */ - virtual std::vector selectMultiple(int ms) = 0; - }; - - std::unique_ptr m_interface; +private: + Map m_map; + Iface m_interface; public: /** @@ -166,54 +176,122 @@ * * @param method the preferred method */ - SocketListener(int method = PreferredMethod); + SocketListener(SocketMethod method = PreferredMethod); + + /** + * Create a listener from a list of sockets. + * + * @param list the list + */ + SocketListener(std::initializer_list, int>> list); + + /** + * Return an iterator to the beginning. + * + * @return the iterator + */ + inline auto begin() noexcept + { + return m_map.begin(); + } /** - * Createa listener with some sockets. + * Overloaded function. * - * @param list the initializer list - * @param method the preferred method + * @return the iterator */ - SocketListener(std::initializer_list> list, int method = PreferredMethod); + inline auto begin() const noexcept + { + return m_map.begin(); + } /** - * Add a socket to listen to. + * Overloaded function. * - * @param s the socket - * @param direction the direction + * @return the iterator */ - inline void add(Socket s, int direction) + inline auto cbegin() const noexcept { - m_interface->add(std::move(s), direction); + return m_map.cbegin(); + } + + /** + * Return an iterator to the end. + * + * @return the iterator + */ + inline auto end() noexcept + { + return m_map.end(); } /** - * Remove a socket from the list. + * Overloaded function. * - * @param s the socket - * @param direction the direction + * @return the iterator */ - inline void remove(const Socket &s, int direction) + inline auto end() const noexcept { - m_interface->remove(s, direction); + return m_map.end(); + } + + /** + * Overloaded function. + * + * @return the iterator + */ + inline auto cend() const noexcept + { + return m_map.cend(); } /** - * Remove every sockets in the listener. + * Add a socket to the listener. + * + * @param sc the socket + * @param direction (may be OR'ed) + */ + void set(Socket &sc, int direction); + + /** + * Unset a socket from the listener, only the direction is removed + * unless the two directions are requested. + * + * For example, if you added a socket for both reading and writing, + * unsetting the write direction will keep the socket for reading. + * + * @param sc the socket + * @param direction the direction (may be OR'ed) + * @see remove */ - inline void clear() + void unset(Socket &sc, int direction) noexcept; + + /** + * Remove completely the socket from the listener. + * + * @param sc the socket + */ + inline void remove(Socket &sc) noexcept { + m_map.erase(sc); + m_interface->remove(sc); + } + + /** + * Remove all sockets. + */ + inline void clear() noexcept + { + m_map.clear(); m_interface->clear(); } /** - * Get the number of clients in listener. - * - * @return the total number of sockets in the listener + * Get the number of sockets in the listener. */ - inline unsigned size() const + unsigned size() const noexcept { - return m_interface->size(); + return m_map.size(); } /** @@ -221,8 +299,6 @@ * * @param duration the duration * @return the socket ready - * @throw SocketError on error - * @throw SocketTimeout on timeout */ template inline SocketStatus select(const std::chrono::duration &duration) @@ -237,8 +313,6 @@ * * @param timeout the optional timeout in milliseconds * @return the socket ready - * @throw SocketError on error - * @throw SocketTimeout on timeout */ inline SocketStatus select(int timeout = -1) { @@ -250,8 +324,6 @@ * * @param duration the duration * @return the socket ready - * @throw SocketError on error - * @throw SocketTimeout on timeout */ template inline std::vector selectMultiple(const std::chrono::duration &duration) @@ -265,24 +337,11 @@ * Overload with milliseconds. * * @return the socket ready - * @throw SocketError on error - * @throw SocketTimeout on timeout */ inline std::vector selectMultiple(int timeout = -1) { return m_interface->selectMultiple(timeout); } - - /** - * List every socket in the listener. - * - * @param func the function to call - */ - template - inline void list(Func func) - { - m_interface->list(func); - } }; -#endif // !_SOCKET_LISTENER_H_ +#endif // !_SOCKET_LISTENER_NG_H_ diff -r 9e223d1de96f -r 4f282297625b C++/SocketSsl.cpp --- a/C++/SocketSsl.cpp Fri Mar 06 22:04:38 2015 +0100 +++ b/C++/SocketSsl.cpp Sun Mar 08 11:04:15 2015 +0100 @@ -16,14 +16,13 @@ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */ +#include "SocketAddress.h" #include "SocketListener.h" #include "SocketSsl.h" -using namespace direction; - namespace { -const SSL_METHOD *method(int mflags) +const SSL_METHOD *sslMethod(int mflags) { if (mflags & SocketSslOptions::All) return SSLv23_method(); @@ -40,69 +39,105 @@ return ERR_reason_error_string(error); } -} // !namespace +inline int toDirection(int error) +{ + if (error == SocketError::WouldBlockRead) + return SocketListener::Read; + if (error == SocketError::WouldBlockWrite) + return SocketListener::Write; -SocketSslInterface::SocketSslInterface(SSL_CTX *context, SSL *ssl, SocketSslOptions options) - : SocketStandard() - , m_context(context, SSL_CTX_free) - , m_ssl(ssl, SSL_free) - , m_options(std::move(options)) -{ + return 0; } -SocketSslInterface::SocketSslInterface(SocketSslOptions options) - : SocketStandard() - , m_options(std::move(options)) +} // !namespace + +std::mutex SocketSsl::s_sslMutex; +std::atomic SocketSsl::s_sslInitialized{false}; + +SocketSsl::SocketSsl(Socket::Handle handle, SSL_CTX *context, SSL *ssl) + : SocketAbstractTcp(handle) + , m_context(context, SSL_CTX_free) + , m_ssl(ssl, SSL_free) { +#if !defined(SOCKET_NO_SSL_INIT) + if (!s_sslInitialized) + sslInitialize(); +#endif } -void SocketSslInterface::connect(Socket &s, const SocketAddress &address) +SocketSsl::SocketSsl(int family, int protocol, SocketSslOptions options) + : SocketAbstractTcp(family, protocol) + , m_options(std::move(options)) { - SocketStandard::connect(s, address); +#if !defined(SOCKET_NO_SSL_INIT) + if (!s_sslInitialized) + sslInitialize(); +#endif +} + +void SocketSsl::connect(const SocketAddress &address) +{ + standardConnect(address); // Context first - auto context = SSL_CTX_new(method(m_options.method)); + auto context = SSL_CTX_new(sslMethod(m_options.method)); - m_context = SslContext(context, SSL_CTX_free); + m_context = ContextHandle(context, SSL_CTX_free); // SSL object then auto ssl = SSL_new(context); - m_ssl = Ssl(ssl, SSL_free); + m_ssl = SslHandle(ssl, SSL_free); - SSL_set_fd(ssl, s.handle()); + SSL_set_fd(ssl, m_handle); auto ret = SSL_connect(ssl); if (ret <= 0) { auto error = SSL_get_error(ssl, ret); - if (error == SSL_ERROR_WANT_READ || error == SSL_ERROR_WANT_WRITE) - throw error::InProgress("connect", sslError(error), error, error); + if (error == SSL_ERROR_WANT_READ) { + throw SocketError(SocketError::WouldBlockRead, "connect", "Operation in progress"); + } else if (error == SSL_ERROR_WANT_WRITE) { + throw SocketError(SocketError::WouldBlockWrite, "connect", "Operation in progress"); + } else { + throw SocketError(SocketError::System, "connect", sslError(error)); + } + } - throw error::Error("accept", sslError(error), error); - } + m_state = SocketState::Connected; } -void SocketSslInterface::tryConnect(Socket &s, const SocketAddress &address, int timeout) +void SocketSsl::waitConnect(const SocketAddress &address, int timeout) { try { // Initial try - connect(s, address); - } catch (const error::InProgress &ipe) { - SocketListener listener{{s, ipe.direction()}}; + connect(address); + } catch (const SocketError &ex) { + if (ex.code() == SocketError::WouldBlockRead || ex.code() == SocketError::WouldBlockWrite) { + SocketListener listener{{*this, toDirection(ex.code())}}; - listener.select(timeout); + listener.select(timeout); - // Second try - connect(s, address); + // Second try + connect(address); + } else { + throw; + } } } -Socket SocketSslInterface::accept(Socket &s, SocketAddress &info) +SocketSsl SocketSsl::accept() { - auto client = SocketStandard::accept(s, info); - auto context = SSL_CTX_new(method(m_options.method)); + SocketAddress dummy; + + return accept(dummy); +} + +SocketSsl SocketSsl::accept(SocketAddress &info) +{ + auto client = standardAccept(info); + auto context = SSL_CTX_new(sslMethod(m_options.method)); if (m_options.certificate.size() > 0) SSL_CTX_use_certificate_file(context, m_options.certificate.c_str(), SSL_FILETYPE_PEM); @@ -110,7 +145,7 @@ SSL_CTX_use_PrivateKey_file(context, m_options.privateKey.c_str(), SSL_FILETYPE_PEM); if (m_options.verify && !SSL_CTX_check_private_key(context)) { client.close(); - throw error::Error("accept", "certificate failure", 0); + throw SocketError(SocketError::System, "accept", "certificate failure"); } // SSL object @@ -123,98 +158,71 @@ if (ret <= 0) { auto error = SSL_get_error(ssl, ret); - if (error == SSL_ERROR_WANT_READ || error == SSL_ERROR_WANT_WRITE) - throw error::InProgress("accept", sslError(error), error, error); - - throw error::Error("accept", sslError(error), error); + if (error == SSL_ERROR_WANT_READ) { + throw SocketError(SocketError::WouldBlockRead, "accept", "Operation would block"); + } else if (error == SSL_ERROR_WANT_WRITE) { + throw SocketError(SocketError::WouldBlockWrite, "accept", "Operation would block"); + } else { + throw SocketError(SocketError::System, "accept", sslError(error)); + } } - return SocketSsl{client.handle(), std::make_shared(context, ssl)}; + return SocketSsl(client.handle(), context, ssl); } -unsigned SocketSslInterface::recv(Socket &, void *data, unsigned len) +unsigned SocketSsl::recv(void *data, unsigned len) { auto nbread = SSL_read(m_ssl.get(), data, len); if (nbread <= 0) { auto error = SSL_get_error(m_ssl.get(), nbread); - if (error == SSL_ERROR_WANT_READ || error == SSL_ERROR_WANT_WRITE) - throw error::InProgress("accept", sslError(error), error, error); - - throw error::Error("recv", sslError(error), error); + if (error == SSL_ERROR_WANT_READ) { + throw SocketError(SocketError::WouldBlockRead, "recv", "Operation would block"); + } else if (error == SSL_ERROR_WANT_WRITE) { + throw SocketError(SocketError::WouldBlockWrite, "recv", "Operation would block"); + } else { + throw SocketError(SocketError::System, "recv", sslError(error)); + } } return nbread; } -unsigned SocketSslInterface::recvfrom(Socket &, void *, unsigned, SocketAddress &) +unsigned SocketSsl::waitRecv(void *data, unsigned len, int timeout) { - throw error::Error("recvfrom", "SSL socket is not UDP compatible", 0); -} - -unsigned SocketSslInterface::tryRecv(Socket &s, void *data, unsigned len, int timeout) -{ - SocketListener listener{{s, Read}}; + SocketListener listener{{*this, SocketListener::Read}}; listener.select(timeout); - return recv(s, data, len); + return recv(data, len); } -unsigned SocketSslInterface::tryRecvfrom(Socket &, void *, unsigned, SocketAddress &, int) -{ - throw error::Error("recvfrom", "SSL socket is not UDP compatible", 0); -} - -unsigned SocketSslInterface::send(Socket &, const void *data, unsigned len) +unsigned SocketSsl::send(const void *data, unsigned len) { auto nbread = SSL_write(m_ssl.get(), data, len); if (nbread <= 0) { auto error = SSL_get_error(m_ssl.get(), nbread); - if (error == SSL_ERROR_WANT_READ || error == SSL_ERROR_WANT_WRITE) - throw error::InProgress("accept", sslError(error), error, error); - - throw error::Error("recv", sslError(error), error); + if (error == SSL_ERROR_WANT_READ) { + throw SocketError(SocketError::WouldBlockRead, "send", "Operation would block"); + } else if (error == SSL_ERROR_WANT_WRITE) { + throw SocketError(SocketError::WouldBlockWrite, "send", "Operation would block"); + } else { + throw SocketError(SocketError::System, "send", sslError(error)); + } } return nbread; } -unsigned SocketSslInterface::sendto(Socket &, const void *, unsigned, const SocketAddress &) +unsigned SocketSsl::waitSend(const void *data, unsigned len, int timeout) { - throw error::Error("sendto", "SSL socket is not UDP compatible", 0); -} - -unsigned SocketSslInterface::trySend(Socket &s, const void *data, unsigned len, int timeout) -{ - SocketListener listener{{s, Write}}; + SocketListener listener{{*this, SocketListener::Write}}; listener.select(timeout); - return send(s, data, len); -} - -unsigned SocketSslInterface::trySendto(Socket &, const void *, unsigned, const SocketAddress &, int) -{ - throw error::Error("sendto", "SSL socket is not UDP compatible", 0); + return send(data, len); } -void SocketSsl::init() -{ - SSL_library_init(); - SSL_load_error_strings(); -} - -void SocketSsl::finish() -{ - ERR_free_strings(); -} - -SocketSsl::SocketSsl(int family, SocketSslOptions options) - : Socket(family, SOCK_STREAM, 0) -{ - m_interface = std::make_shared(std::move(options)); -} diff -r 9e223d1de96f -r 4f282297625b C++/SocketSsl.h --- a/C++/SocketSsl.h Fri Mar 06 22:04:38 2015 +0100 +++ b/C++/SocketSsl.h Sun Mar 08 11:04:15 2015 +0100 @@ -16,30 +16,52 @@ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */ -#ifndef _SOCKET_SSL_H_ -#define _SOCKET_SSL_H_ +#ifndef _SOCKET_SSL_NG_H_ +#define _SOCKET_SSL_NG_H_ + +#include +#include #include #include #include -#include "Socket.h" +#include "SocketTcp.h" -struct SocketSslOptions { +/** + * @class SocketSslOptions + * @brief Options for SocketSsl + */ +class SocketSslOptions { +public: + /** + * @brief Method + */ enum { SSLv3 = (1 << 0), TLSv1 = (1 << 1), All = (0xf) }; - unsigned short method{All}; - std::string certificate; - std::string privateKey; - bool verify{false}; + int method{All}; //!< The method + std::string certificate; //!< The certificate path + std::string privateKey; //!< The private key file + bool verify{false}; //!< Verify or not + /** + * Default constructor. + */ SocketSslOptions() = default; - SocketSslOptions(unsigned short method, std::string certificate, std::string key, bool verify = false) + /** + * More advanced constructor. + * + * @param method the method requested + * @param certificate the certificate file + * @param key the key file + * @param verify set to true to verify + */ + SocketSslOptions(int method, std::string certificate, std::string key, bool verify = false) : method(method) , certificate(std::move(certificate)) , privateKey(std::move(key)) @@ -48,53 +70,64 @@ } }; -class SocketSslInterface : public SocketStandard { -private: - using Ssl = std::shared_ptr; - using SslContext = std::shared_ptr; - - SslContext m_context; - Ssl m_ssl; - SocketSslOptions m_options; - -public: - SocketSslInterface(SSL_CTX *context, SSL *ssl, SocketSslOptions options = {}); - SocketSslInterface(SocketSslOptions options = {}); - void connect(Socket &s, const SocketAddress &address) override; - void tryConnect(Socket &s, const SocketAddress &address, int timeout) override; - Socket accept(Socket &s, SocketAddress &info) override; - unsigned recv(Socket &s, void *data, unsigned len) override; - unsigned recvfrom(Socket &s, void *data, unsigned len, SocketAddress &info) override; - unsigned tryRecv(Socket &s, void *data, unsigned len, int timeout) override; - unsigned tryRecvfrom(Socket &s, void *data, unsigned len, SocketAddress &info, int timeout) override; - unsigned send(Socket &s, const void *data, unsigned len) override; - unsigned sendto(Socket &s, const void *data, unsigned len, const SocketAddress &info) override; - unsigned trySend(Socket &s, const void *data, unsigned len, int timeout) override; - unsigned trySendto(Socket &s, const void *data, unsigned len, const SocketAddress &info, int timeout) override; -}; - /** * @class SocketSsl * @brief SSL interface for sockets * - * This class derives from Socket and provide SSL support through OpenSSL. - * - * It is perfectly safe to cast SocketSsl to Socket and vice-versa. + * This class derives from SocketAbstractTcp and provide SSL support through OpenSSL. */ -class SocketSsl : public Socket { +class SocketSsl : public SocketAbstractTcp { +public: + using ContextHandle = std::unique_ptr; + using SslHandle = std::unique_ptr; + private: - using Socket::Socket; + static std::mutex s_sslMutex; + static std::atomic s_sslInitialized; + + ContextHandle m_context{nullptr, nullptr}; + SslHandle m_ssl{nullptr, nullptr}; + SocketSslOptions m_options; public: + using SocketAbstractTcp::recv; + using SocketAbstractTcp::waitRecv; + using SocketAbstractTcp::send; + using SocketAbstractTcp::waitSend; + /** - * Initialize SSL library. + * Close OpenSSL library. */ - static void init(); + static inline void sslTerminate() + { + ERR_free_strings(); + } /** - * Close SSL library. + * Open SSL library. */ - static void finish(); + static inline void sslInitialize() + { + std::lock_guard lock(s_sslMutex); + + if (!s_sslInitialized) { + s_sslInitialized = true; + + SSL_library_init(); + SSL_load_error_strings(); + + std::atexit(sslTerminate); + } + } + + /** + * Create a SocketSsl from an already created one. + * + * @param handle the native handle + * @param context the context + * @param ssl the ssl object + */ + SocketSsl(Socket::Handle handle, SSL_CTX *context, SSL *ssl); /** * Open a SSL socket with the specified family. Automatically @@ -103,7 +136,80 @@ * @param family the family * @param options the options */ - SocketSsl(int family, SocketSslOptions options = {}); + SocketSsl(int family, int protocol, SocketSslOptions options = {}); + + /** + * Accept a SSL TCP socket. + * + * @return the socket + * @throw SocketError on error + */ + SocketSsl accept(); + + /** + * Accept a SSL TCP socket. + * + * @param info the client information + * @return the socket + * @throw SocketError on error + */ + SocketSsl accept(SocketAddress &info); + + /** + * Accept a SSL TCP socket. + * + * @param timeout the maximum timeout in milliseconds + * @return the socket + * @throw SocketError on error + */ + SocketSsl waitAccept(int timeout); + + /** + * Accept a SSL TCP socket. + * + * @param info the client information + * @param timeout the maximum timeout in milliseconds + * @return the socket + * @throw SocketError on error + */ + SocketSsl waitAccept(SocketAddress &info, int timeout); + + /** + * Connect to an end point. + * + * @param address the address + * @throw SocketError on error + */ + void connect(const SocketAddress &address); + + /** + * Connect to an end point. + * + * @param timeout the maximum timeout in milliseconds + * @param address the address + * @throw SocketError on error + */ + void waitConnect(const SocketAddress &address, int timeout); + + /** + * @copydoc SocketAbstractTcp::recv + */ + unsigned recv(void *data, unsigned length) override; + + /** + * @copydoc SocketAbstractTcp::recv + */ + unsigned waitRecv(void *data, unsigned length, int timeout) override; + + /** + * @copydoc SocketAbstractTcp::recv + */ + unsigned send(const void *data, unsigned length) override; + + /** + * @copydoc SocketAbstractTcp::recv + */ + unsigned waitSend(const void *data, unsigned length, int timeout) override; }; -#endif // !_SOCKET_SSL_H_ +#endif // !_SOCKET_SSL_NG_H_ diff -r 9e223d1de96f -r 4f282297625b C++/SocketTcp.cpp --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/C++/SocketTcp.cpp Sun Mar 08 11:04:15 2015 +0100 @@ -0,0 +1,227 @@ +/* + * SocketTcp.cpp -- portable C++ socket wrappers + * + * Copyright (c) 2013, 2014 David Demelier + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +#include "SocketAddress.h" +#include "SocketListener.h" +#include "SocketTcp.h" + +/* -------------------------------------------------------- + * SocketAbstractTcp + * -------------------------------------------------------- */ + +void SocketAbstractTcp::listen(int max) +{ + if (::listen(m_handle, max) == Error) + throw SocketError(SocketError::System, "listen"); +} + +Socket SocketAbstractTcp::standardAccept(SocketAddress &info) +{ + Socket::Handle handle; + + // Store the information + sockaddr_storage address; + socklen_t addrlen; + + addrlen = sizeof (sockaddr_storage); + handle = ::accept(m_handle, reinterpret_cast(&address), &addrlen); + + if (handle == Invalid) { +#if defined(_WIN32) + int error = WSAGetLastError(); + + if (error == WSAEWOULDBLOCK) + throw SocketError(SocketError::WouldBlockRead, "accept", error); + + throw SocketError(SocketError::System, "accept", error); +#else + if (errno == EAGAIN || errno == EWOULDBLOCK) + throw SocketError(SocketError::WouldBlockRead, "accept"); + + throw SocketError(SocketError::System, "accept"); +#endif + } + + info = SocketAddress(address, addrlen); + + return Socket(handle); +} + +void SocketAbstractTcp::standardConnect(const SocketAddress &address) +{ + if (m_state == SocketState::Connected) + return; + + auto &sa = address.address(); + auto addrlen = address.length(); + + if (::connect(m_handle, reinterpret_cast(&sa), addrlen) == Error) { + /* + * Determine if the error comes from a non-blocking connect that cannot be + * accomplished yet. + */ +#if defined(_WIN32) + int error = WSAGetLastError(); + + if (error == WSAEWOULDBLOCK) + throw SocketError(SocketError::WouldBlockWrite, "connect", error); + + throw SocketError(SocketError::System, "connect", error); +#else + if (errno == EINPROGRESS) + throw SocketError(SocketError::WouldBlockWrite, "connect"); + + throw SocketError(SocketError::System, "connect"); +#endif + } + + m_state = SocketState::Connected; +} + +/* -------------------------------------------------------- + * SocketTcp + * -------------------------------------------------------- */ + +SocketTcp SocketTcp::accept() +{ + SocketAddress dummy; + + return accept(dummy); +} + +SocketTcp SocketTcp::accept(SocketAddress &info) +{ + return standardAccept(info); +} + +void SocketTcp::connect(const SocketAddress &address) +{ + return standardConnect(address); +} + +void SocketTcp::waitConnect(const SocketAddress &address, int timeout) +{ + if (m_state == SocketState::Connected) + return; + + // Initial try + try { + connect(address); + } catch (const SocketError &ex) { + if (ex.code() == SocketError::WouldBlockWrite) { + SocketListener listener{{*this, SocketListener::Write}}; + + listener.select(timeout); + + // Socket is writable? Check if there is an error + + int error = get(SOL_SOCKET, SO_ERROR); + + if (error) { + throw SocketError(SocketError::System, "connect", error); + } + } else { + throw; + } + } + + m_state = SocketState::Connected; +} + +SocketTcp SocketTcp::waitAccept(int timeout) +{ + SocketAddress dummy; + + return waitAccept(dummy, timeout); +} + +SocketTcp SocketTcp::waitAccept(SocketAddress &info, int timeout) +{ + SocketListener listener{{*this, SocketListener::Read}}; + + listener.select(timeout); + + return accept(info); +} + +unsigned SocketTcp::recv(void *data, unsigned dataLen) +{ + int nbread; + + nbread = ::recv(m_handle, (Socket::Arg)data, dataLen, 0); + if (nbread == Error) { +#if defined(_WIN32) + int error = WSAGetLastError(); + + if (error == WSAEWOULDBLOCK) + throw SocketError(SocketError::WouldBlockRead, "recv", error) + + throw SocketError(SocketError::System, "recv", error); +#else + if (errno == EAGAIN || errno == EWOULDBLOCK) + throw SocketError(SocketError::WouldBlockRead, "recv"); + + throw SocketError(SocketError::System, "recv"); +#endif + } else if (nbread == 0) + m_state = SocketState::Closed; + + return (unsigned)nbread; +} + +unsigned SocketTcp::waitRecv(void *data, unsigned length, int timeout) +{ + SocketListener listener{{*this, SocketListener::Read}}; + + listener.select(timeout); + + return recv(data, length); +} + +unsigned SocketTcp::send(const void *data, unsigned length) +{ + int nbsent; + + nbsent = ::send(m_handle, (Socket::ConstArg)data, length, 0); + if (nbsent == Error) { +#if defined(_WIN32) + int error = WSAGetLastError(); + + if (error == WSAEWOULDBLOCK) + throw SocketError(SocketError::WouldBlockWrite, "send", error); + + throw SocketError(SocketError::System, "send", error); +#else + if (errno == EAGAIN || errno == EWOULDBLOCK) + throw SocketError(SocketError::WouldBlockWrite, "send"); + + throw SocketError(SocketError::System, "send"); +#endif + } + + return (unsigned)nbsent; +} + +unsigned SocketTcp::waitSend(const void *data, unsigned length, int timeout) +{ + SocketListener listener{{*this, SocketListener::Write}}; + + listener.select(timeout); + + return send(data, length); +} diff -r 9e223d1de96f -r 4f282297625b C++/SocketTcp.h --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/C++/SocketTcp.h Sun Mar 08 11:04:15 2015 +0100 @@ -0,0 +1,261 @@ +/* + * SocketTcp.h -- portable C++ socket wrappers + * + * Copyright (c) 2013, 2014 David Demelier + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +#ifndef _SOCKET_TCP_NG_H_ +#define _SOCKET_TCP_NG_H_ + +#include "Socket.h" + +/** + * @class SocketAbstractTcp + * @brief Base class for TCP sockets + * + * This abstract class provides standard TCP functions for both clear + * and SSL implementation. + * + * It does not contain default accept() and connect() because they varies too + * much between standard and SSL. Also, the accept() function return different + * types. + */ +class SocketAbstractTcp : public Socket { +protected: + Socket standardAccept(SocketAddress &address); + void standardConnect(const SocketAddress &address); + +public: + /** + * Construct an abstract socket from an already made socket. + * + * @param s the socket + */ + inline SocketAbstractTcp(Socket s) + : Socket(s) + { + } + + /** + * Construct a standard TCP socket. The type is automatically + * set to SOCK_STREAM. + * + * @param domain the domain + * @param protocol the protocol + * @throw SocketError on error + */ + inline SocketAbstractTcp(int domain, int protocol) + : Socket(domain, SOCK_STREAM, protocol) + { + } + + /** + * Listen for pending connection. + * + * @param max the maximum number + */ + void listen(int max = 128); + + /** + * Overloaded function. + * + * @param count the number of bytes to receive + * @return the string + * @throw SocketError on error + */ + inline std::string recv(unsigned count) + { + std::string result; + + result.resize(count); + auto n = recv(const_cast(result.data()), count); + result.resize(n); + + return result; + } + + /** + * Overloaded function. + * + * @param count the number of bytes to receive + * @param timeout the maximum timeout in milliseconds + * @return the string + * @throw SocketError on error + */ + inline std::string waitRecv(unsigned count, int timeout) + { + std::string result; + + result.resize(count); + auto n = waitRecv(const_cast(result.data()), count, timeout); + result.resize(n); + + return result; + } + + /** + * Overloaded function. + * + * @param data the string to send + * @return the number of bytes sent + * @throw SocketError on error + */ + inline unsigned send(const std::string &data) + { + return send(data.c_str(), data.size()); + } + + /** + * Overloaded function. + * + * @param data the string to send + * @param timeout the maximum timeout in milliseconds + * @return the number of bytes sent + * @throw SocketError on error + */ + inline unsigned waitSend(const std::string &data, int timeout) + { + return waitSend(data.c_str(), data.size(), timeout); + } + + /** + * Receive data. + * + * @param data the destination buffer + * @param length the buffer length + * @return the number of bytes received + * @throw SocketError on error + */ + virtual unsigned recv(void *data, unsigned length) = 0; + + /** + * Receive data. + * + * @param data the destination buffer + * @param length the buffer length + * @param timeout the maximum timeout in milliseconds + * @return the number of bytes received + * @throw SocketError on error + */ + virtual unsigned waitRecv(void *data, unsigned length, int timeout) = 0; + + /** + * Send data. + * + * @param data the buffer + * @param length the buffer length + * @return the number of bytes sent + * @throw SocketError on error + */ + virtual unsigned send(const void *data, unsigned length) = 0; + + /** + * Send data. + * + * @param data the buffer + * @param length the buffer length + * @return the number of bytes sent + * @throw SocketError on error + */ + virtual unsigned waitSend(const void *data, unsigned length, int timeout) = 0; +}; + +/** + * @class SocketTcp + * @brief End-user class for TCP sockets + */ +class SocketTcp : public SocketAbstractTcp { +public: + using SocketAbstractTcp::SocketAbstractTcp; + using SocketAbstractTcp::recv; + using SocketAbstractTcp::waitRecv; + using SocketAbstractTcp::send; + using SocketAbstractTcp::waitSend; + + /** + * Accept a clear TCP socket. + * + * @return the socket + * @throw SocketError on error + */ + SocketTcp accept(); + + /** + * Accept a clear TCP socket. + * + * @param info the client information + * @return the socket + * @throw SocketError on error + */ + SocketTcp accept(SocketAddress &info); + + /** + * Accept a clear TCP socket. + * + * @param timeout the maximum timeout in milliseconds + * @return the socket + * @throw SocketError on error + */ + SocketTcp waitAccept(int timeout); + + /** + * Accept a clear TCP socket. + * + * @param info the client information + * @param timeout the maximum timeout in milliseconds + * @return the socket + * @throw SocketError on error + */ + SocketTcp waitAccept(SocketAddress &info, int timeout); + + /** + * Connect to an end point. + * + * @param address the address + * @throw SocketError on error + */ + void connect(const SocketAddress &address); + + /** + * Connect to an end point. + * + * @param timeout the maximum timeout in milliseconds + * @param address the address + * @throw SocketError on error + */ + void waitConnect(const SocketAddress &address, int timeout); + + /** + * @copydoc SocketAbstractTcp::recv + */ + unsigned recv(void *data, unsigned length) override; + + /** + * @copydoc SocketAbstractTcp::waitRecv + */ + unsigned waitRecv(void *data, unsigned length, int timeout) override; + + /** + * @copydoc SocketAbstractTcp::send + */ + unsigned send(const void *data, unsigned length) override; + + /** + * @copydoc SocketAbstractTcp::waitSend + */ + unsigned waitSend(const void *data, unsigned length, int timeout) override; +}; + +#endif // !_SOCKET_TCP_NG_H_ diff -r 9e223d1de96f -r 4f282297625b C++/SocketUdp.cpp --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/C++/SocketUdp.cpp Sun Mar 08 11:04:15 2015 +0100 @@ -0,0 +1,81 @@ +/* + * SocketUdp.cpp -- portable C++ socket wrappers + * + * Copyright (c) 2013, 2014 David Demelier + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +#include "SocketAddress.h" +#include "SocketUdp.h" + +SocketUdp::SocketUdp(int domain, int protocol) + : Socket(domain, SOCK_DGRAM, protocol) +{ +} + +unsigned SocketUdp::recvfrom(void *data, unsigned length, SocketAddress &info) +{ + int nbread; + + // Store information + sockaddr_storage address; + socklen_t addrlen; + + addrlen = sizeof (struct sockaddr_storage); + nbread = ::recvfrom(m_handle, (Socket::Arg)data, length, 0, (sockaddr *)&address, &addrlen); + + info = SocketAddress(address, addrlen); + + if (nbread == Error) { +#if defined(_WIN32) + int error = WSAGetLastError(); + + if (error == WSAEWOULDBLOCK) + throw SocketError(SocketError::WouldBlockRead, "recvfrom", error); + + throw SocketError(SocketError::System, "recvfrom", error); +#else + if (errno == EAGAIN || errno == EWOULDBLOCK) + throw SocketError(SocketError::WouldBlockRead, "recvfrom"); + + throw SocketError(SocketError::System, "recvfrom"); +#endif + } + + return (unsigned)nbread; +} + +unsigned SocketUdp::sendto(const void *data, unsigned length, const SocketAddress &info) +{ + int nbsent; + + nbsent = ::sendto(m_handle, (Socket::ConstArg)data, length, 0, (const sockaddr *)&info.address(), info.length()); + if (nbsent == Error) { +#if defined(_WIN32) + int error = WSAGetLastError(); + + if (error == WSAEWOULDBLOCK) + throw SocketError(SocketError::WouldBlockWrite, "sendto", error); + + throw SocketError(SocketError::System, "sendto", error); +#else + if (errno == EAGAIN || errno == EWOULDBLOCK) + throw SocketError(SocketError::WouldBlockWrite, "sendto"); + + throw SocketError(SocketError::System, "sendto"); +#endif + } + + return (unsigned)nbsent; +} diff -r 9e223d1de96f -r 4f282297625b C++/SocketUdp.h --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/C++/SocketUdp.h Sun Mar 08 11:04:15 2015 +0100 @@ -0,0 +1,93 @@ +/* + * SocketUdp.h -- portable C++ socket wrappers + * + * Copyright (c) 2013, 2014 David Demelier + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +#ifndef _SOCKET_UDP_NG_H_ +#define _SOCKET_UDP_NG_H_ + +#include "Socket.h" + +/** + * @class SocketUdp + * @brief UDP implementation for sockets + */ +class SocketUdp : public Socket { +public: + /** + * Construct a UDP socket. The type is automatically set to SOCK_DGRAM. + * + * @param domain the domain (e.g AF_INET) + * @param protocol the protocol (usually 0) + */ + SocketUdp(int domain, int protocol); + + /** + * Overloaded function. + * + * @param data the data + * @param address the address + * @return the number of bytes sent + * @throw SocketError on error + */ + inline unsigned sendto(const std::string &data, const SocketAddress &address) + { + return sendto(data.c_str(), data.length(), address); + } + + /** + * Overloaded function. + * + * @param data the data + * @param info the client information + * @return the string + * @throw SocketError on error + */ + inline std::string recvfrom(unsigned count, SocketAddress &info) + { + std::string result; + + result.resize(count); + auto n = recvfrom(const_cast(result.data()), count, info); + result.resize(n); + + return result; + } + + /** + * Receive data from an end point. + * + * @param data the destination buffer + * @param length the buffer length + * @param info the client information + * @return the number of bytes received + * @throw SocketError on error + */ + virtual unsigned recvfrom(void *data, unsigned length, SocketAddress &info); + + /** + * Send data to an end point. + * + * @param data the buffer + * @param length the buffer length + * @param address the client address + * @return the number of bytes sent + * @throw SocketError on error + */ + virtual unsigned sendto(const void *data, unsigned length, const SocketAddress &address); +}; + +#endif // !_SOCKET_UDP_NG_H_ diff -r 9e223d1de96f -r 4f282297625b C++/Tests/Sockets/CMakeLists.txt --- a/C++/Tests/Sockets/CMakeLists.txt Fri Mar 06 22:04:38 2015 +0100 +++ b/C++/Tests/Sockets/CMakeLists.txt Sun Mar 08 11:04:15 2015 +0100 @@ -22,6 +22,10 @@ SOURCES ${code_SOURCE_DIR}/C++/Socket.cpp ${code_SOURCE_DIR}/C++/Socket.h + ${code_SOURCE_DIR}/C++/SocketTcp.cpp + ${code_SOURCE_DIR}/C++/SocketTcp.h + ${code_SOURCE_DIR}/C++/SocketUdp.cpp + ${code_SOURCE_DIR}/C++/SocketUdp.h ${code_SOURCE_DIR}/C++/SocketAddress.cpp ${code_SOURCE_DIR}/C++/SocketAddress.h ${code_SOURCE_DIR}/C++/SocketListener.cpp diff -r 9e223d1de96f -r 4f282297625b C++/Tests/Sockets/main.cpp --- a/C++/Tests/Sockets/main.cpp Fri Mar 06 22:04:38 2015 +0100 +++ b/C++/Tests/Sockets/main.cpp Sun Mar 08 11:04:15 2015 +0100 @@ -24,1143 +24,686 @@ #include -#include -#include -#include +#include "Socket.h" +#include "SocketAddress.h" +#include "SocketListener.h" +#include "SocketSsl.h" +#include "SocketTcp.h" +#include "SocketUdp.h" +using namespace address; using namespace std::literals::chrono_literals; -using namespace error; -using namespace address; - /* -------------------------------------------------------- - * Miscellaneous + * TCP tests * -------------------------------------------------------- */ -TEST(Misc, set) -{ - Socket s; - - try { - s = { AF_INET6, SOCK_STREAM, 0 }; +class TcpServerTest : public testing::Test { +protected: + SocketTcp m_server{AF_INET, 0}; + SocketTcp m_client{AF_INET, 0}; - s.set(IPPROTO_IPV6, IPV6_V6ONLY, 0); - ASSERT_EQ(0, s.get(IPPROTO_IPV6, IPV6_V6ONLY)); + std::thread m_tserver; + std::thread m_tclient; - s.set(IPPROTO_IPV6, IPV6_V6ONLY, 1); - ASSERT_EQ(1, s.get(IPPROTO_IPV6, IPV6_V6ONLY)); - } catch (const std::exception &ex) { - std::cerr << "warning: " << ex.what() << std::endl; +public: + TcpServerTest() + { + m_server.set(SOL_SOCKET, SO_REUSEADDR, 1); } - s.close(); + ~TcpServerTest() + { + if (m_tserver.joinable()) + m_tserver.join(); + if (m_tclient.joinable()) + m_tclient.join(); + } +}; + +TEST_F(TcpServerTest, connect) +{ + m_tserver = std::thread([this] () { + m_server.bind(Internet("*", 16000, AF_INET)); + + ASSERT_EQ(SocketState::Bound, m_server.state()); + + m_server.listen(); + m_server.accept(); + m_server.close(); + }); + + std::this_thread::sleep_for(100ms); + + m_tclient = std::thread([this] () { + m_client.connect(Internet("127.0.0.1", 16000, AF_INET)); + + ASSERT_EQ(SocketState::Connected, m_client.state()); + + m_client.close(); + }); } -TEST(Misc, initializer) +TEST_F(TcpServerTest, io) { - Socket s1, s2; + m_tserver = std::thread([this] () { + m_server.bind(Internet("*", 16000, AF_INET)); + m_server.listen(); - try { - s1 = { AF_INET6, SOCK_STREAM, 0 }; - s2 = { AF_INET6, SOCK_STREAM, 0 }; + auto client = m_server.accept(); + auto msg = client.recv(512); - SocketListener listener { - { s1, direction::Read }, - { s2, direction::Read }, - }; + ASSERT_EQ("hello world", msg); + + client.send(msg); + client.close(); - ASSERT_EQ(2UL, listener.size()); + m_server.close(); + }); + + std::this_thread::sleep_for(100ms); - listener.list([&] (const auto &so, auto direction) { - ASSERT_TRUE(so == s1 || so == s2); - ASSERT_EQ(direction::Read, direction); - }); - } catch (const std::exception &ex) { - std::cerr << "warning: " << ex.what() << std::endl; - } + m_tclient = std::thread([this] () { + m_client.connect(Internet("127.0.0.1", 16000, AF_INET)); + m_client.send("hello world"); - s1.close(); - s2.close(); + ASSERT_EQ("hello world", m_client.recv(512)); + + m_client.close(); + }); } /* -------------------------------------------------------- - * Select tests + * UDP tests * -------------------------------------------------------- */ -TEST(ListenerMethodSelect, timeout) -{ - std::thread server([] () { - Socket s, client; - SocketListener listener(Select); - bool running = true; - int tries = 0; - - try { - s = { AF_INET, SOCK_STREAM, 0 }; - s.set(SOL_SOCKET, SO_REUSEADDR, 1); - s.bind(Internet{"*", 10000, AF_INET}); - s.listen(10); - - listener.add(s, direction::Read); - - while (running) { - try { - listener.select(500ms); - client = s.accept(); - running = false; - - // Abort if no client connected - if (tries >= 10) - running = false; - } catch (const Timeout &) { - } - } - } catch (const std::exception &ex) { - std::cerr << "warning: " << ex.what() << std::endl; - } +class UdpServerTest : public testing::Test { +protected: + SocketUdp m_server{AF_INET, 0}; + SocketUdp m_client{AF_INET, 0}; - s.close(); - client.close(); - }); - - std::thread client([] () { - std::this_thread::sleep_for(2s); - - Socket s; - - try { - s = { AF_INET, SOCK_STREAM, 0 }; - s.connect(Internet{"localhost", 10000, AF_INET}); - } catch (const std::exception &ex) { - std::cerr << "warning: " << ex.what() << std::endl; - } - - s.close(); - }); + std::thread m_tserver; + std::thread m_tclient; - server.join(); - client.join(); -} - -TEST(ListenerMethodSelect, add) -{ - Socket s, s2; - - try { - s = { AF_INET, SOCK_STREAM, 0 }; - s2 = { AF_INET, SOCK_STREAM, 0 }; - SocketListener listener(Select); - - listener.add(s, direction::Read); - listener.add(s2, direction::Read); - - ASSERT_EQ(2UL, listener.size()); - } catch (const std::exception &ex) { - std::cerr << "warning: " << ex.what() << std::endl; +public: + UdpServerTest() + { + m_server.set(SOL_SOCKET, SO_REUSEADDR, 1); } - s.close(); - s2.close(); -} - -TEST(ListenerMethodSelect, remove) -{ - Socket s, s2; - - try { - s = { AF_INET, SOCK_STREAM, 0 }; - s2 = { AF_INET, SOCK_STREAM, 0 }; - SocketListener listener(Select); - - listener.add(s, direction::Read); - listener.add(s2, direction::Read); - listener.remove(s, direction::Read); - listener.remove(s2, direction::Read); - - ASSERT_EQ(0UL, listener.size()); - } catch (const std::exception &ex) { - std::cerr << "warning: " << ex.what() << std::endl; + ~UdpServerTest() + { + if (m_tserver.joinable()) + m_tserver.join(); + if (m_tclient.joinable()) + m_tclient.join(); } +}; - s.close(); - s2.close(); -} - -/* - * Add two sockets for both direction::Reading and writing, them remove only direction::Reading and then - * move only writing. - */ -TEST(ListenerMethodSelect, inOut) +TEST_F(UdpServerTest, io) { - Socket s, s2; + m_tserver = std::thread([this] () { + SocketAddress info; - try { - s = { AF_INET, SOCK_STREAM, 0 }; - s2 = { AF_INET, SOCK_STREAM, 0 }; - SocketListener listener(Select); - - listener.add(s, direction::Read | direction::Write); - listener.add(s2, direction::Read | direction::Write); + m_server.bind(Internet("*", 16000, AF_INET)); - listener.list([&] (Socket &si, int dir) { - ASSERT_TRUE(si == s || si == s2); - ASSERT_EQ(0x03, static_cast(dir)); - }); - - listener.remove(s, direction::Write); - listener.remove(s2, direction::Write); - - ASSERT_EQ(2UL, listener.size()); + auto msg = m_server.recvfrom(512, info); - listener.list([&] (Socket &si, int dir) { - ASSERT_TRUE(si == s || si == s2); - ASSERT_EQ(direction::Read, dir); - }); - - listener.remove(s, direction::Read); - listener.remove(s2, direction::Read); + ASSERT_EQ("hello world", msg); - ASSERT_EQ(0UL, listener.size()); - } catch (const std::exception &ex) { - std::cerr << "warning: " << ex.what() << std::endl; - } + m_server.sendto(msg, info); + m_server.close(); + }); - s.close(); - s2.close(); -} - -TEST(ListenerMethodSelect, addSame) -{ - Socket s; - - try { - s = { AF_INET, SOCK_STREAM, 0 }; - SocketListener listener(Select); + std::this_thread::sleep_for(100ms); - listener.add(s, direction::Read); - ASSERT_EQ(1UL, listener.size()); - listener.list([&] (const Socket &si, int dir) { - ASSERT_TRUE(si == s); - ASSERT_EQ(direction::Read, dir); - }); + m_tclient = std::thread([this] () { + Internet info("127.0.0.1", 16000, AF_INET); - listener.add(s, direction::Write); - ASSERT_EQ(1UL, listener.size()); - listener.list([&] (const Socket &si, int dir) { - ASSERT_TRUE(si == s); - ASSERT_EQ(0x03, static_cast(dir)); - }); + m_client.sendto("hello world", info); - // Oops, added the same - listener.add(s, direction::Read); - listener.add(s, direction::Write); - listener.add(s, direction::Read | direction::Write); + ASSERT_EQ("hello world", m_client.recvfrom(512, info)); - ASSERT_EQ(1UL, listener.size()); - listener.list([&] (const Socket &si, int dir) { - ASSERT_TRUE(si == s); - ASSERT_EQ(0x03, static_cast(dir)); - }); - } catch (const std::exception &ex) { - std::cerr << "warning: " << ex.what() << std::endl; - } - - s.close(); + m_client.close(); + }); } /* -------------------------------------------------------- - * Poll tests + * Listener tests (standard) * -------------------------------------------------------- */ -TEST(ListenerMethodPoll, timeout) -{ - std::thread server([] () { - Socket s, client; - SocketListener listener(Poll); - bool running = true; - int tries = 0; - - try { - s = { AF_INET, SOCK_STREAM, 0 }; - s.set(SOL_SOCKET, SO_REUSEADDR, 1); - s.bind(Internet{"*", 10000, AF_INET}); - s.listen(10); - - listener.add(s, direction::Read); +class ListenerTest : public testing::Test { +protected: + SocketListener m_listener; + SocketTcp socket1{AF_INET, 0}; + SocketUdp socket2{AF_INET, 0}; - while (running) { - try { - listener.select(500ms); - client = s.accept(); - running = false; +public: + ~ListenerTest() + { + socket1.close(); + socket2.close(); + } +}; - // Abort if no client connected - if (tries >= 10) - running = false; - } catch (const Timeout &) { - } - } - } catch (const std::exception &ex) { - std::cerr << "warning: " << ex.what() << std::endl; - } +TEST_F(ListenerTest, set) +{ + m_listener.set(socket1, SocketListener::Read); - s.close(); - client.close(); - }); + ASSERT_EQ(1, static_cast(m_listener.size())); + ASSERT_EQ(SocketListener::Read, m_listener.begin()->second); - std::thread client([] () { - std::this_thread::sleep_for(2s); - - Socket s; + m_listener.set(socket1, SocketListener::Write); - try { - s = { AF_INET, SOCK_STREAM, 0 }; - s.connect(Internet{"localhost", 10000, AF_INET}); - } catch (const std::exception &ex) { - std::cerr << "warning: " << ex.what() << std::endl; - } + ASSERT_EQ(1, static_cast(m_listener.size())); + ASSERT_EQ(0x3, m_listener.begin()->second); - s.close(); - }); - - server.join(); - client.join(); -} + // Fake a re-insert of the same socket + m_listener.set(socket1, SocketListener::Write); -TEST(ListenerMethodPoll, add) -{ - Socket s, s2; + ASSERT_EQ(1, static_cast(m_listener.size())); + ASSERT_EQ(0x3, m_listener.begin()->second); - try { - s = { AF_INET, SOCK_STREAM, 0 }; - s2 = { AF_INET, SOCK_STREAM, 0 }; - SocketListener listener(Poll); + // Add an other socket now + m_listener.set(socket2, SocketListener::Read | SocketListener::Write); - listener.add(s, direction::Read); - listener.add(s2, direction::Read); + ASSERT_EQ(2, static_cast(m_listener.size())); - ASSERT_EQ(2UL, listener.size()); - } catch (const std::exception &ex) { - std::cerr << "warning: " << ex.what() << std::endl; + for (auto &pair : m_listener) { + ASSERT_EQ(0x3, pair.second); + ASSERT_TRUE(pair.first == socket1 || pair.first == socket2); } - - s.close(); - s2.close(); } -TEST(ListenerMethodPoll, remove) +TEST_F(ListenerTest, unset) { - Socket s, s2; + m_listener.set(socket1, SocketListener::Read | SocketListener::Write); + m_listener.set(socket2, SocketListener::Read | SocketListener::Write); - try { - s = { AF_INET, SOCK_STREAM, 0 }; - s2 = { AF_INET, SOCK_STREAM, 0 }; - SocketListener listener(Poll); + m_listener.unset(socket1, SocketListener::Read); - listener.add(s, direction::Read); - listener.add(s2, direction::Read); - listener.remove(s, direction::Read); - listener.remove(s2, direction::Read); + ASSERT_EQ(2, static_cast(m_listener.size())); - ASSERT_EQ(0UL, listener.size()); - } catch (const std::exception &ex) { - std::cerr << "warning: " << ex.what() << std::endl; + // Use a for loop since it can be ordered differently + for (auto &pair : m_listener) { + if (pair.first == socket1) { + ASSERT_EQ(0x2, pair.second); + } else if (pair.first == socket2) { + ASSERT_EQ(0x3, pair.second); + } } - s.close(); - s2.close(); + m_listener.unset(socket1, SocketListener::Write); + + ASSERT_EQ(1, static_cast(m_listener.size())); + ASSERT_EQ(0x3, m_listener.begin()->second); } -TEST(ListenerMethodPoll, inOut) +TEST_F(ListenerTest, remove) { - Socket s, s2; - - try { - s = { AF_INET, SOCK_STREAM, 0 }; - s2 = { AF_INET, SOCK_STREAM, 0 }; - SocketListener listener(Poll); - - listener.add(s, direction::Read | direction::Write); - listener.add(s2, direction::Read | direction::Write); - - listener.list([&] (Socket &si, int dir) { - ASSERT_TRUE(si == s || si == s2); - ASSERT_EQ(0x03, static_cast(dir)); - }); + m_listener.set(socket1, SocketListener::Read | SocketListener::Write); + m_listener.set(socket2, SocketListener::Read | SocketListener::Write); + m_listener.remove(socket1); - listener.remove(s, direction::Write); - listener.remove(s2, direction::Write); - - ASSERT_EQ(2UL, listener.size()); - - listener.list([&] (Socket &si, int dir) { - ASSERT_TRUE(si == s || si == s2); - ASSERT_EQ(direction::Read, dir); - }); - - listener.remove(s, direction::Read); - listener.remove(s2, direction::Read); - - ASSERT_EQ(0UL, listener.size()); - } catch (const std::exception &ex) { - std::cerr << "warning: " << ex.what() << std::endl; - } - - s.close(); - s2.close(); + ASSERT_EQ(1, static_cast(m_listener.size())); + ASSERT_EQ(0x3, m_listener.begin()->second); } -TEST(ListenerMethodPoll, addSame) +TEST_F(ListenerTest, clear) { - Socket s; - - try { - s = { AF_INET, SOCK_STREAM, 0 }; - SocketListener listener(Poll); - - listener.add(s, direction::Read); - ASSERT_EQ(1UL, listener.size()); - listener.list([&] (const Socket &si, int dir) { - ASSERT_TRUE(si == s); - ASSERT_EQ(direction::Read, dir); - }); + m_listener.set(socket1, SocketListener::Read | SocketListener::Write); + m_listener.set(socket2, SocketListener::Read | SocketListener::Write); + m_listener.clear(); - listener.add(s, direction::Write); - ASSERT_EQ(1UL, listener.size()); - listener.list([&] (const Socket &si, int dir) { - ASSERT_TRUE(si == s); - ASSERT_EQ(0x03, static_cast(dir)); - }); - - // Oops, added the same - listener.add(s, direction::Read); - listener.add(s, direction::Write); - listener.add(s, direction::Read | direction::Write); - - ASSERT_EQ(1UL, listener.size()); - listener.list([&] (const Socket &si, int dir) { - ASSERT_TRUE(si == s); - ASSERT_EQ(0x03, static_cast(dir)); - }); - } catch (const std::exception &ex) { - std::cerr << "warning: " << ex.what() << std::endl; - } - - s.close(); + ASSERT_EQ(0, static_cast(m_listener.size())); } /* -------------------------------------------------------- - * Socket listener class + * Listener: poll * -------------------------------------------------------- */ -TEST(Listener, connection) -{ - std::thread server([] () { - Socket s; - SocketListener listener; - - try { - s = { AF_INET, SOCK_STREAM, 0 }; - - s.set(SOL_SOCKET, SO_REUSEADDR, 1); - s.bind(Internet{"*", 10000, AF_INET}); - s.listen(8); - - listener.add(s, direction::Read); +#if defined(SOCKET_HAVE_POLL) - auto client = listener.select(10s); - - ASSERT_TRUE(client.direction == direction::Read); - ASSERT_TRUE(client.socket == s); - } catch (const std::exception &ex) { - std::cerr << "warning: " << ex.what() << std::endl; - } - - s.close(); - }); - - std::thread client([] () { - Socket client; - - std::this_thread::sleep_for(500ms); +class ListenerPollTest : public testing::Test { +protected: + SocketListener m_listener{SocketMethod::Poll}; + SocketTcp m_masterTcp{AF_INET, 0}; + SocketTcp m_clientTcp{AF_INET, 0}; - try { - client = { AF_INET, SOCK_STREAM, 0 }; - client.connect(Internet{"localhost", 10000, AF_INET}); - } catch (const std::exception &ex) { - std::cerr << "warning: " << ex.what() << std::endl; - } - - client.close(); - }); - - server.join(); - client.join(); -} - -TEST(Listener, connectionAndRead) -{ - std::thread server([] () { - Socket s; - SocketListener listener; + std::thread m_tserver; + std::thread m_tclient; - try { - s = { AF_INET, SOCK_STREAM, 0 }; - - s.set(SOL_SOCKET, SO_REUSEADDR, 1); - s.bind(Internet("*", 10000, AF_INET)); - s.listen(8); - - // direction::Read for master - listener.add(s, direction::Read); - - auto result = listener.select(10s); - - ASSERT_TRUE(result.direction == direction::Read); - ASSERT_TRUE(result.socket == s); - - // Wait for client - auto client = s.accept(); - listener.add(client, direction::Read); - - result = listener.select(10s); - - ASSERT_TRUE(result.direction == direction::Read); - ASSERT_TRUE(result.socket == client); +public: + ListenerPollTest() + { + m_masterTcp.set(SOL_SOCKET, SO_REUSEADDR, 1); + m_masterTcp.bind(Internet("*", 16000, AF_INET)); + m_masterTcp.listen(); + } - char data[512]; - auto nb = client.recv(data, sizeof (data) - 1); - - data[nb] = '\0'; - - client.close(); - ASSERT_STREQ("hello world", data); - } catch (const std::exception &ex) { - std::cerr << "warning: " << ex.what() << std::endl; + ~ListenerPollTest() + { + if (m_tserver.joinable()) { + m_tserver.join(); } - - s.close(); - }); - - std::thread client([] () { - Socket client; + if (m_tclient.joinable()) { + m_tclient.join(); + } + } +}; - std::this_thread::sleep_for(500ms); - - try { - client = Socket(AF_INET, SOCK_STREAM, 0); - client.connect(Internet("localhost", 10000, AF_INET)); - client.send("hello world"); - } catch (const std::exception &ex) { - std::cerr << "warning: " << ex.what() << std::endl; - } - - client.close(); - }); - - server.join(); - client.join(); -} - -TEST(Listener, bigData) +TEST_F(ListenerPollTest, accept) { - std::thread server([] () { - std::ostringstream out; - - Socket server; - SocketListener listener; - bool finished(false); - + m_tserver = std::thread([this] () { try { - server = { AF_INET, SOCK_STREAM, 0 }; - - server.set(SOL_SOCKET, SO_REUSEADDR, 1); - server.bind(Internet{"*", 10000, AF_INET}); - server.listen(10); - listener.add(server, direction::Read); - - while (!finished) { - auto s = listener.select(60s).socket; - - if (s == server) { - listener.add(s.accept(), direction::Read); - } else { - char data[512]; - auto nb = s.recv(data, sizeof (data) - 1); - - if (nb == 0) - finished = true; - else { - data[nb] = '\0'; - out << data; - } - } - } - - ASSERT_EQ(9000002UL, out.str().size()); + m_listener.set(m_masterTcp, SocketListener::Read); + m_listener.select(); + m_masterTcp.accept(); + m_masterTcp.close(); } catch (const std::exception &ex) { - std::cerr << "warning: " << ex.what() << std::endl; + FAIL() << ex.what(); } - - server.close(); }); - std::thread client([] () { - std::string data; - - data.reserve(9000000); - for (int i = 0; i < 9000000; ++i) - data.push_back('a'); - - data.push_back('\r'); - data.push_back('\n'); - - std::this_thread::sleep_for(500ms); - - Socket client; - SocketListener listener; - - try { - client = { AF_INET, SOCK_STREAM, 0 }; + std::this_thread::sleep_for(100ms); - client.connect(Internet{"localhost", 10000, AF_INET}); - client.blockMode(false); - listener.add(client, direction::Write); - - while (data.size() > 0) { - auto s = listener.select(30s).socket; - auto nb = s.send(data.data(), data.size()); - data.erase(0, nb); - } - } catch (const std::exception &ex) { - std::cerr << "warning: " << ex.what() << std::endl; - } - - client.close(); + m_tclient = std::thread([this] () { + m_clientTcp.connect(Internet("127.0.0.1", 16000, AF_INET)); }); - - server.join(); - client.join(); } -/* -------------------------------------------------------- - * Multiple selection tests - * -------------------------------------------------------- */ - -TEST(MultipleSelection, select) +TEST_F(ListenerPollTest, recv) { - /* - * Normally, 3 sockets added for writing should be marked direction::Ready immediately - * as there are no data being currently queued to be sent. - */ - std::thread server([] () { - Socket master; - + m_tserver = std::thread([this] () { try { - SocketListener masterListener, clientListener; - master = { AF_INET, SOCK_STREAM, 0 }; + m_listener.set(m_masterTcp, SocketListener::Read); + m_listener.select(); - master.set(SOL_SOCKET, SO_REUSEADDR, 1); - master.bind(Internet{"*", 10000, AF_INET}); - master.listen(8); + auto sc = m_masterTcp.accept(); - masterListener.add(master, direction::Read); - - while (clientListener.size() != 3) { - masterListener.select(3s); - clientListener.add(master.accept(), direction::Write); - } + ASSERT_EQ("hello", sc.recv(512)); - // Now do the test of writing - auto result = clientListener.selectMultiple(3s); - ASSERT_EQ(3UL, result.size()); - - clientListener.list([] (auto s, auto) { - s.close(); - }); + m_masterTcp.close(); } catch (const std::exception &ex) { - std::cerr << "warning: " << ex.what() << std::endl; + FAIL() << ex.what(); } - - master.close(); }); - std::thread client([] () { - Socket s1, s2, s3; - - try { - s1 = { AF_INET, SOCK_STREAM, 0 }; - s2 = { AF_INET, SOCK_STREAM, 0 }; - s3 = { AF_INET, SOCK_STREAM, 0 }; - - std::this_thread::sleep_for(1s); - - s1.connect(Internet{"localhost", 10000, AF_INET}); - s2.connect(Internet{"localhost", 10000, AF_INET}); - s3.connect(Internet{"localhost", 10000, AF_INET}); - } catch (const std::exception &ex) { - std::cerr << "warning: " << ex.what() << std::endl; - } - - s1.close(); - s2.close(); - s3.close(); - }); - - server.join(); - client.join(); -} - -#if defined(SOCKET_LISTENER_HAVE_POLL) - -TEST(MultipleSelection, poll) -{ - /* - * Normally, 3 sockets added for writing should be marked direction::Ready immediately - * as there are no data being currently queued to be sent. - */ - std::thread server([] () { - Socket master; - - try { - SocketListener masterListener(Poll), clientListener(Poll); - master = { AF_INET, SOCK_STREAM, 0 }; - - master.set(SOL_SOCKET, SO_REUSEADDR, 1); - master.bind(Internet{"*", 10000, AF_INET}); - master.listen(8); + std::this_thread::sleep_for(100ms); - masterListener.add(master, direction::Read); - - while (clientListener.size() != 3) { - masterListener.select(3s); - clientListener.add(master.accept(), direction::Write); - } - - // Now do the test of writing - auto result = clientListener.selectMultiple(3s); - ASSERT_EQ(3UL, result.size()); - - clientListener.list([] (auto s, auto) { - s.close(); - }); - } catch (const std::exception &ex) { - std::cerr << "warning: " << ex.what() << std::endl; - } - - master.close(); + m_tclient = std::thread([this] () { + m_clientTcp.connect(Internet("127.0.0.1", 16000, AF_INET)); + m_clientTcp.send("hello"); }); - - std::thread client([] () { - Socket s1, s2, s3; - - try { - s1 = { AF_INET, SOCK_STREAM, 0 }; - s2 = { AF_INET, SOCK_STREAM, 0 }; - s3 = { AF_INET, SOCK_STREAM, 0 }; - - std::this_thread::sleep_for(1s); - - s1.connect(Internet("localhost", 10000, AF_INET)); - s2.connect(Internet("localhost", 10000, AF_INET)); - s3.connect(Internet("localhost", 10000, AF_INET)); - } catch (const std::exception &ex) { - std::cerr << "warning: " << ex.what() << std::endl; - } - - s1.close(); - s2.close(); - s3.close(); - }); - - server.join(); - client.join(); } #endif /* -------------------------------------------------------- - * Basic TCP tests + * Listener: select * -------------------------------------------------------- */ -TEST(BasicTcp, sendipv4) { - std::thread server([] () { - Socket server, client; +class ListenerSelectTest : public testing::Test { +protected: + SocketListener m_listener{SocketMethod::Select}; + SocketTcp m_masterTcp{AF_INET, 0}; + SocketTcp m_clientTcp{AF_INET, 0}; - try { - server = { AF_INET, SOCK_STREAM, 0 }; + std::thread m_tserver; + std::thread m_tclient; - server.set(SOL_SOCKET, SO_REUSEADDR, 1); - server.bind(Internet{"*", 10000, AF_INET}); - server.listen(8); - - client = server.accept(); +public: + ListenerSelectTest() + { + m_masterTcp.set(SOL_SOCKET, SO_REUSEADDR, 1); + m_masterTcp.bind(Internet("*", 16000, AF_INET)); + m_masterTcp.listen(); + } - char data[512]; - auto nb = client.recv(data, sizeof (data) - 1); - - data[nb] = '\0'; + ~ListenerSelectTest() + { + if (m_tserver.joinable()) { + m_tserver.join(); + } + if (m_tclient.joinable()) { + m_tclient.join(); + } + } +}; - ASSERT_STREQ("hello", data); +TEST_F(ListenerSelectTest, accept) +{ + m_tserver = std::thread([this] () { + try { + m_listener.set(m_masterTcp, SocketListener::Read); + m_listener.select(); + m_masterTcp.accept(); + m_masterTcp.close(); } catch (const std::exception &ex) { - std::cerr << "warning: " << ex.what() << std::endl; + FAIL() << ex.what(); } - - server.close(); - client.close(); }); - std::thread client([] () { - Socket s; - - std::this_thread::sleep_for(500ms); - try { - s = { AF_INET, SOCK_STREAM, 0 }; + std::this_thread::sleep_for(100ms); - s.connect(Internet{"localhost", 10000, AF_INET}); - s.send("hello"); - } catch (const std::exception &ex) { - std::cerr << "warning: " << ex.what() << std::endl; - } - - s.close(); + m_tclient = std::thread([this] () { + m_clientTcp.connect(Internet("127.0.0.1", 16000, AF_INET)); }); - - server.join(); - client.join(); } -TEST(BasicTcp, sendipv6) { - std::thread server([] () { - Socket server, client; - +TEST_F(ListenerSelectTest, recv) +{ + m_tserver = std::thread([this] () { try { - server = { AF_INET6, SOCK_STREAM, 0 }; - - server.set(SOL_SOCKET, SO_REUSEADDR, 1); - server.bind(Internet("*", 10000, AF_INET6)); - server.listen(8); - - client = server.accept(); + m_listener.set(m_masterTcp, SocketListener::Read); + m_listener.select(); - char data[512]; - auto nb = client.recv(data, sizeof (data) - 1); + auto sc = m_masterTcp.accept(); - data[nb] = '\0'; + ASSERT_EQ("hello", sc.recv(512)); - ASSERT_STREQ("hello", data); + m_masterTcp.close(); } catch (const std::exception &ex) { - std::cerr << "warning: " << ex.what() << std::endl; + FAIL() << ex.what(); } - - server.close(); - client.close(); }); - std::thread client([] () { - Socket s; - - std::this_thread::sleep_for(500ms); - try { - s = { AF_INET6, SOCK_STREAM, 0 }; + std::this_thread::sleep_for(100ms); - s.connect(Internet{"ip6-localhost", 10000, AF_INET6}); - s.send("hello"); - } catch (const std::exception &ex) { - std::cerr << "warning: " << ex.what() << std::endl; - } - - s.close(); + m_tclient = std::thread([this] () { + m_clientTcp.connect(Internet("127.0.0.1", 16000, AF_INET)); + m_clientTcp.send("hello"); }); - - server.join(); - client.join(); } -#if !defined(_WIN32) +/* -------------------------------------------------------- + * Non-blocking connect + * -------------------------------------------------------- */ -TEST(BasicTcp, sendunix) { - std::thread server([] () { - Socket server, client; +class NonBlockingConnectTest : public testing::Test { +protected: + SocketTcp m_server{AF_INET, 0}; + SocketTcp m_client{AF_INET, 0}; - try { - server = { AF_UNIX, SOCK_STREAM, 0 }; + std::thread m_tserver; + std::thread m_tclient; - server.set(SOL_SOCKET, SO_REUSEADDR, 1); - server.bind(Unix("/tmp/gtest-send-tcp-unix.sock", true)); - server.listen(8); +public: + NonBlockingConnectTest() + { + m_client.setBlockMode(false); + } - client = server.accept(); - - char data[512]; - auto nb = client.recv(data, sizeof (data) - 1); - - data[nb] = '\0'; + ~NonBlockingConnectTest() + { + if (m_tserver.joinable()) + m_tserver.join(); + if (m_tclient.joinable()) + m_tclient.join(); + } +}; - ASSERT_STREQ("hello", data); - } catch (const std::exception &ex) { - std::cerr << "warning: " << ex.what() << std::endl; - } +TEST_F(NonBlockingConnectTest, success) +{ + m_server.set(SOL_SOCKET, SO_REUSEADDR, 1); + m_server.bind(Internet("*", 16000, AF_INET)); + m_server.listen(); - server.close(); + m_tserver = std::thread([this] () { + SocketTcp client = m_server.accept(); + + m_server.close(); client.close(); }); - std::thread client([] () { - Socket s; + std::this_thread::sleep_for(100ms); - std::this_thread::sleep_for(500ms); + m_tclient = std::thread([this] () { try { - s = { AF_UNIX, SOCK_STREAM, 0 }; - - s.connect(Unix{"/tmp/gtest-send-tcp-unix.sock"}); - s.send("hello"); - } catch (const std::exception &ex) { - std::cerr << "warning: " << ex.what() << std::endl; + m_client.waitConnect(Internet("127.0.0.1", 16000, AF_INET), 3000); + } catch (const SocketError &error) { + FAIL() << error.what(); } - s.close(); - }); + ASSERT_EQ(SocketState::Connected, m_client.state()); - server.join(); - client.join(); + m_client.close(); + }); } -#endif +TEST_F(NonBlockingConnectTest, fail) +{ + /* + * /!\ If you find a way to test this locally please tell me /!\ + */ + m_tclient = std::thread([this] () { + try { + m_client.waitConnect(Internet("google.fr", 9000, AF_INET), 100); + + FAIL() << "Expected exception, got success"; + } catch (const SocketError &error) { + ASSERT_EQ(SocketError::Timeout, error.code()); + } + + m_client.close(); + }); +} /* -------------------------------------------------------- - * Basic UDP tests + * TCP accept * -------------------------------------------------------- */ -TEST(BasicUdp, sendipv4) { - std::thread server([] () { - Socket server; +class TcpAcceptTest : public testing::Test { +protected: + SocketTcp m_server{AF_INET, 0}; + SocketTcp m_client{AF_INET, 0}; + + std::thread m_tserver; + std::thread m_tclient; - try { - server = { AF_INET, SOCK_DGRAM, 0 }; +public: + TcpAcceptTest() + { + m_server.set(SOL_SOCKET, SO_REUSEADDR, 1); + m_server.bind(Internet("*", 16000, AF_INET)); + m_server.listen(); + } - server.set(SOL_SOCKET, SO_REUSEADDR, 1); - server.bind(Internet("*", 10000, AF_INET)); - - char data[512]; - auto nb = server.recvfrom(data, sizeof (data) - 1); + ~TcpAcceptTest() + { + if (m_tserver.joinable()) + m_tserver.join(); + if (m_tclient.joinable()) + m_tclient.join(); + } +}; - data[nb] = '\0'; - - ASSERT_STREQ("hello", data); - } catch (const std::exception &ex) { - std::cerr << "warning: " << ex.what() << std::endl; +TEST_F(TcpAcceptTest, blockingWaitSuccess) +{ + m_tserver = std::thread([this] () { + try { + m_server.waitAccept(3000).close(); + } catch (const SocketError &error) { + FAIL() << error.what(); } - server.close(); + m_server.close(); }); - std::thread client([] () { - Socket s; - - std::this_thread::sleep_for(500ms); - try { - s = { AF_INET, SOCK_DGRAM, 0 }; + std::this_thread::sleep_for(100ms); - s.sendto("hello", Internet{"localhost", 10000, AF_INET}); - } catch (const std::exception &ex) { - std::cerr << "warning: " << ex.what() << std::endl; - } - - s.close(); + m_tclient = std::thread([this] () { + m_client.connect(Internet("127.0.0.1", 16000, AF_INET)); + m_client.close(); }); - - server.join(); - client.join(); } -TEST(BasicUdp, sendipv6) { - std::thread server([] () { - Socket server; - +TEST_F(TcpAcceptTest, nonBlockingWaitSuccess) +{ + m_tserver = std::thread([this] () { try { - server = { AF_INET6, SOCK_DGRAM, 0 }; - - server.set(SOL_SOCKET, SO_REUSEADDR, 1); - server.set(IPPROTO_IPV6, IPV6_V6ONLY, 1); - server.bind(Internet{"*", 10000, AF_INET6}); - - char data[512]; - auto nb = server.recvfrom(data, sizeof (data) - 1); - - data[nb] = '\0'; - - ASSERT_STREQ("hello", data); - } catch (const std::exception &ex) { - std::cerr << "warning: " << ex.what() << std::endl; + m_server.setBlockMode(false); + m_server.waitAccept(3000).close(); + } catch (const SocketError &error) { + FAIL() << error.what(); } - server.close(); + m_server.close(); }); - std::thread client([] () { - Socket s; + std::this_thread::sleep_for(100ms); - std::this_thread::sleep_for(500ms); - try { - s = { AF_INET6, SOCK_DGRAM, 0 }; + m_tclient = std::thread([this] () { + m_client.connect(Internet("127.0.0.1", 16000, AF_INET)); + m_client.close(); + }); +} - s.sendto("hello", Internet{"ip6-localhost", 10000, AF_INET6}); - } catch (const std::exception &ex) { - std::cerr << "warning: " << ex.what() << std::endl; - } +TEST_F(TcpAcceptTest, nonBlockingWaitFail) +{ + // No client, no accept + try { + m_server.setBlockMode(false); + m_server.waitAccept(100).close(); - s.close(); - }); + FAIL() << "Expected exception, got success"; + } catch (const SocketError &error) { + ASSERT_EQ(SocketError::Timeout, error.code()); + } - server.join(); - client.join(); + m_server.close(); } -#if !defined(_WIN32) - -TEST(BasicUdp, sendunix) { - std::thread server([] () { - Socket server; - - try { - server = { AF_UNIX, SOCK_DGRAM, 0 }; - - server.set(SOL_SOCKET, SO_REUSEADDR, 1); - server.bind(Unix{"/tmp/gtest-send-udp-unix.sock", true}); - - char data[512]; - auto nb = server.recvfrom(data, sizeof (data) - 1); - - data[nb] = '\0'; - - ASSERT_STREQ("hello", data); - } catch (const std::exception &ex) { - std::cerr << "warning: " << ex.what() << std::endl; - } - - server.close(); - }); - - std::thread client([] () { - Socket s; - - std::this_thread::sleep_for(500ms); - try { - s = { AF_UNIX, SOCK_DGRAM, 0 }; - - s.sendto("hello", Unix{"/tmp/gtest-send-udp-unix.sock"}); - } catch (const std::exception &ex) { - std::cerr << "warning: " << ex.what() << std::endl; - } - - s.close(); - }); - - server.join(); - client.join(); -} - -#endif - /* -------------------------------------------------------- - * Non-blocking functions failures + * TCP recv * -------------------------------------------------------- */ -TEST(NonBlockingFailures, connect) -{ - Socket s; +class TcpRecvTest : public testing::Test { +protected: + SocketTcp m_server{AF_INET, 0}; + SocketTcp m_client{AF_INET, 0}; - try { - s = {AF_INET, SOCK_STREAM, 0}; - s.blockMode(false); - } catch (const std::exception &ex) { - std::cerr << ex.what() << std::endl; - return; + std::thread m_tserver; + std::thread m_tclient; + +public: + TcpRecvTest() + { + m_server.set(SOL_SOCKET, SO_REUSEADDR, 1); + m_server.bind(Internet("*", 16000, AF_INET)); + m_server.listen(); } - auto time1 = std::chrono::system_clock::now(); + ~TcpRecvTest() + { + if (m_tserver.joinable()) + m_tserver.join(); + if (m_tclient.joinable()) + m_tclient.join(); + } +}; - try { - s.tryConnect(Internet{"google.fr", 9000, AF_INET}, 1000); +TEST_F(TcpRecvTest, blockingSuccess) +{ + m_tserver = std::thread([this] () { + SocketTcp client = m_server.accept(); - FAIL() << "unexpected code path"; - } catch (const std::exception &ex) {} + ASSERT_EQ("hello", client.recv(32)); + + client.close(); + m_server.close(); + }); + + std::this_thread::sleep_for(100ms); - auto time2 = std::chrono::system_clock::now(); - auto duration = std::chrono::duration_cast(time2 - time1).count(); + m_tclient = std::thread([this] () { + m_client.connect(Internet("127.0.0.1", 16000, AF_INET)); + m_client.send("hello"); + m_client.close(); + }); +} + +TEST_F(TcpRecvTest, blockingWaitSuccess) +{ + m_tserver = std::thread([this] () { + SocketTcp client = m_server.accept(); - // Assert between 0,9 and 1,1 seconds - ASSERT_GT(duration, 900); - ASSERT_LT(duration, 1100); + ASSERT_EQ("hello", client.waitRecv(32, 3000)); + + client.close(); + m_server.close(); + }); - s.close(); + std::this_thread::sleep_for(100ms); + + m_tclient = std::thread([this] () { + m_client.connect(Internet("127.0.0.1", 16000, AF_INET)); + m_client.send("hello"); + m_client.close(); + }); } -TEST(NonBlockingFailures, recv) +TEST_F(TcpRecvTest, nonBlockingWaitSuccess) { - Socket s; - Socket server; - bool ready{false}; + m_tserver = std::thread([this] () { + SocketTcp client = m_server.accept(); + + client.setBlockMode(false); - try { - s = {AF_INET, SOCK_STREAM, 0}; + ASSERT_EQ("hello", client.waitRecv(32, 3000)); + + client.close(); + m_server.close(); + }); + + std::this_thread::sleep_for(100ms); - server = {AF_INET, SOCK_STREAM, 0}; - server.set(SOL_SOCKET, SO_REUSEADDR, 1); - server.bind(Internet{"localhost", 10000, AF_INET}); - server.listen(10); + m_tclient = std::thread([this] () { + m_client.connect(Internet("127.0.0.1", 16000, AF_INET)); + m_client.send("hello"); + m_client.close(); + }); +} - ready = true; - } catch (const std::exception &ex) { - std::cerr << "warning: " << ex.what() << std::endl; - return; - } +/* -------------------------------------------------------- + * Socket SSL + * -------------------------------------------------------- */ - if (!ready) { - s.close(); - server.close(); - } +class SslTest : public testing::Test { +protected: + SocketSsl client{AF_INET, 0}; +}; +TEST_F(SslTest, connect) +{ try { - s.connect(Internet{"localhost", 10000, AF_INET}); - s.blockMode(false); - } catch (const std::exception &ex) { - std::cerr << "warning: " << ex.what() << std::endl; + client.connect(Internet("google.fr", 443, AF_INET)); + client.close(); + } catch (const SocketError &error) { + FAIL() << error.what(); + } +} - s.close(); - server.close(); - - return; - } - - auto time1 = std::chrono::system_clock::now(); - +TEST_F(SslTest, recv) +{ try { - s.tryRecv(100, 1000); + client.connect(Internet("google.fr", 443, AF_INET)); + client.send("GET / HTTP/1.0\r\n\r\n"); - FAIL() << "unexpected code path"; - } catch (const std::exception &) {} + std::string msg = client.recv(512); + std::string content = msg.substr(0, 18); - auto time2 = std::chrono::system_clock::now(); - auto duration = std::chrono::duration_cast(time2 - time1).count(); + ASSERT_EQ("HTTP/1.0 302 Found", content); - // Assert between 0,9 and 1,1 seconds - ASSERT_GT(duration, 900); - ASSERT_LT(duration, 1100); - - s.close(); + client.close(); + } catch (const SocketError &error) { + FAIL() << error.what(); + } } int main(int argc, char **argv) { - Socket::init(); - testing::InitGoogleTest(&argc, argv); - auto ret = RUN_ALL_TESTS(); - - Socket::finish(); - - return ret; + return RUN_ALL_TESTS(); }