# HG changeset patch # User David Demelier # Date 1416053970 -3600 # Node ID bae4af872cdefe56062227f0356e281dd3a6673e # Parent c019f194475a59307c4223503e17492828a1286a# Parent 4e17193db14188b51da92a5f46111db8579760dc MFS diff -r 4e17193db141 -r bae4af872cde C++/Socket.cpp --- a/C++/Socket.cpp Sat Nov 15 13:19:21 2014 +0100 +++ b/C++/Socket.cpp Sat Nov 15 13:19:30 2014 +0100 @@ -20,6 +20,9 @@ #include "Socket.h" #include "SocketAddress.h" +#include "SocketListener.h" + +using namespace direction; /* -------------------------------------------------------- * Socket exceptions @@ -27,6 +30,72 @@ namespace error { +/* -------------------------------------------------------- + * Error + * -------------------------------------------------------- */ + +Error::Error(std::string function, std::string error, int code) + : m_function(std::move(function)) + , m_error(std::move(error)) + , m_code(std::move(code)) +{ + m_shortcut = m_function + ": " + m_error; +} + +const std::string &Error::function() const noexcept +{ + return m_function; +} + +const std::string &Error::error() const noexcept +{ + return m_error; +} + +int Error::code() const noexcept +{ + return m_code; +} + +const char *Error::what() const noexcept +{ + return m_shortcut.c_str(); +} + +/* -------------------------------------------------------- + * InProgress + * -------------------------------------------------------- */ + +InProgress::InProgress(std::string func, std::string reason, int code, int direction) + : Error(std::move(func), std::move(reason), code) + , m_direction(direction) +{ +} + +int InProgress::direction() const noexcept +{ + return m_direction; +} + +/* -------------------------------------------------------- + * WouldBlock + * -------------------------------------------------------- */ + +WouldBlock::WouldBlock(std::string func, std::string reason, int code, int direction) + : Error(std::move(func), std::move(reason), code) + , m_direction(direction) +{ +} + +int WouldBlock::direction() const noexcept +{ + return m_direction; +} + +/* -------------------------------------------------------- + * Timeout + * -------------------------------------------------------- */ + Timeout::Timeout(std::string func) : m_error(func + ": Timeout exception") { @@ -37,36 +106,6 @@ return m_error.c_str(); } -InProgress::InProgress(std::string func) - : m_error(func + ": Operation in progress") -{ -} - -const char *InProgress::what() const noexcept -{ - return m_error.c_str(); -} - -WouldBlock::WouldBlock(std::string func) - : m_error(func + ": Operation would block") -{ -} - -const char *WouldBlock::what() const noexcept -{ - return m_error.c_str(); -} - -Failure::Failure(std::string func, std::string message) - : m_error(func + ": " + message) -{ -} - -const char *Failure::what() const noexcept -{ - return m_error.c_str(); -} - } // !error /* -------------------------------------------------------- @@ -136,33 +175,23 @@ } /* -------------------------------------------------------- - * Standard clear implementation + * SocketStandard clear implementation * -------------------------------------------------------- */ -class Standard final : public SocketInterface { -public: - void bind(Socket &s, const SocketAddress &address) override; - void close(Socket &s) override; - void connect(Socket &s, const SocketAddress &address) override; - Socket accept(Socket &s, SocketAddress &info) override; - void listen(Socket &s, int max) override; - unsigned recv(Socket &s, void *data, unsigned len) override; - unsigned recvfrom(Socket &s, void *data, unsigned len, SocketAddress &info) override ; - unsigned send(Socket &s, const void *data, unsigned len) override; - unsigned sendto(Socket &s, const void *data, unsigned len, const SocketAddress &info) override; -}; - -void Standard::bind(Socket &s, const SocketAddress &addr) +void SocketStandard::bind(Socket &s, const SocketAddress &addr) { auto &sa = addr.address(); auto addrlen = addr.length(); if (::bind(s.handle(), (sockaddr *)&sa, addrlen) == SOCKET_ERROR) - throw error::Failure("bind", Socket::syserror()); + throw error::Error("bind", Socket::syserror(), errno); } -void Standard::connect(Socket &s, const SocketAddress &addr) +void SocketStandard::connect(Socket &s, const SocketAddress &addr) { + if (m_connected) + return; + auto &sa = addr.address(); auto addrlen = addr.length(); @@ -173,19 +202,44 @@ */ #if defined(_WIN32) if (WSAGetLastError() == WSAEWOULDBLOCK) - throw error::InProgress("connect"); + throw error::InProgress("connect", Socket::syserror(WSAEWOULDBLOCK), WSAEWOULDBLOCK, Write); - throw error::Failure("connect", Socket::syserror()); + throw error::Error("connect", Socket::syserror(WSAEWOULDBLOCK), WSAGetLastError()); #else if (errno == EINPROGRESS) - throw error::InProgress("connect"); + throw error::InProgress("connect", Socket::syserror(EINPROGRESS), EINPROGRESS, Write); - throw error::Failure("connect", Socket::syserror()); + throw error::Error("connect", Socket::syserror(), errno); #endif } + + m_connected = true; } -Socket Standard::accept(Socket &s, SocketAddress &info) +void SocketStandard::tryConnect(Socket &s, const SocketAddress &address, int timeout) +{ + if (m_connected) + return; + + // Initial try + try { + connect(s, address); + } catch (const error::InProgress &) { + SocketListener listener{{s, Write}}; + + listener.select(timeout); + + // Socket is writable? Check if there is an error + auto error = s.get(SOL_SOCKET, SO_ERROR); + + if (error) + throw error::Error("connect", Socket::syserror(error), error); + } + + m_connected = true; +} + +Socket SocketStandard::accept(Socket &s, SocketAddress &info) { Socket::Handle handle; @@ -199,30 +253,39 @@ if (handle == INVALID_SOCKET) { #if defined(_WIN32) if (WSAGetLastError() == WSAEWOULDBLOCK) - throw error::WouldBlock("accept"); + throw error::WouldBlock("accept", Socket::syserror(WSAEWOULDBLOCK), WSAEWOULDBLOCK, Read); - throw error::Failure("accept", Socket::syserror()); + throw error::Error("accept", Socket::syserror(), WSAGetLastError()); #else if (errno == EAGAIN || errno == EWOULDBLOCK) - throw error::WouldBlock("accept"); + throw error::WouldBlock("accept", Socket::syserror(EWOULDBLOCK), EWOULDBLOCK, Read); - throw error::Failure("accept", Socket::syserror()); + throw error::Error("accept", Socket::syserror(), errno); #endif } // Usually accept works only with SOCK_STREAM info = SocketAddress(address, addrlen); - return Socket(handle, std::make_shared()); + return Socket(handle, std::make_shared()); } -void Standard::listen(Socket &s, int max) +Socket SocketStandard::tryAccept(Socket &s, SocketAddress &info, int timeout) +{ + SocketListener listener{{s, Read}}; + + listener.select(timeout); + + return accept(s, info); +} + +void SocketStandard::listen(Socket &s, int max) { if (::listen(s.handle(), max) == SOCKET_ERROR) - throw error::Failure("listen", Socket::syserror()); + throw error::Error("listen", Socket::syserror(), errno); } -unsigned Standard::recv(Socket &s, void *data, unsigned dataLen) +unsigned SocketStandard::recv(Socket &s, void *data, unsigned dataLen) { int nbread; @@ -230,21 +293,30 @@ if (nbread == SOCKET_ERROR) { #if defined(_WIN32) if (WSAGetLastError() == WSAEWOULDBLOCK) - throw error::WouldBlock("recv"); + throw error::WouldBlock("recv", Socket::syserror(), WSAEWOULDBLOCK, Read); - throw error::Failure("recv", Socket::syserror()); + throw error::Error("recv", Socket::syserror(), WSAGetLastError()); #else if (errno == EAGAIN || errno == EWOULDBLOCK) - throw error::WouldBlock("recv"); + throw error::WouldBlock("recv", Socket::syserror(), errno, Read); - throw error::Failure("recv", Socket::syserror()); + throw error::Error("recv", Socket::syserror(), errno); #endif } return (unsigned)nbread; } -unsigned Standard::send(Socket &s, const void *data, unsigned dataLen) +unsigned SocketStandard::tryRecv(Socket &s, void *data, unsigned len, int timeout) +{ + SocketListener listener{{s, Read}}; + + listener.select(timeout); + + return recv(s, data, len); +} + +unsigned SocketStandard::send(Socket &s, const void *data, unsigned dataLen) { int nbsent; @@ -252,21 +324,30 @@ if (nbsent == SOCKET_ERROR) { #if defined(_WIN32) if (WSAGetLastError() == WSAEWOULDBLOCK) - throw error::WouldBlock("send"); + throw error::WouldBlock("send", Socket::syserror(), WSAEWOULDBLOCK, Write); - throw error::Failure("send", Socket::syserror()); + throw error::Error("send", Socket::syserror(), WSAGetLastError()); #else if (errno == EAGAIN || errno == EWOULDBLOCK) - throw error::WouldBlock("send"); + throw error::WouldBlock("send", Socket::syserror(), errno, Write); - throw error::Failure("send", Socket::syserror()); + throw error::Error("send", Socket::syserror(), errno); #endif } return (unsigned)nbsent; } -unsigned Standard::recvfrom(Socket &s, void *data, unsigned dataLen, SocketAddress &info) +unsigned SocketStandard::trySend(Socket &s, const void *data, unsigned len, int timeout) +{ + SocketListener listener{{s, Write}}; + + listener.select(timeout); + + return send(s, data, len); +} + +unsigned SocketStandard::recvfrom(Socket &s, void *data, unsigned dataLen, SocketAddress &info) { int nbread; @@ -282,21 +363,30 @@ if (nbread == SOCKET_ERROR) { #if defined(_WIN32) if (WSAGetLastError() == WSAEWOULDBLOCK) - throw error::WouldBlock("recvfrom"); + throw error::WouldBlock("recvfrom", Socket::syserror(), WSAEWOULDBLOCK, Read); - throw error::Failure("recvfrom", Socket::syserror()); + throw error::Error("recvfrom", Socket::syserror(), WSAGetLastError()); #else if (errno == EAGAIN || errno == EWOULDBLOCK) - throw error::WouldBlock("recvfrom"); + throw error::WouldBlock("recvfrom", Socket::syserror(), errno, Read); - throw error::Failure("recvfrom", Socket::syserror()); + throw error::Error("recvfrom", Socket::syserror(), errno); #endif } return (unsigned)nbread; } -unsigned Standard::sendto(Socket &s, const void *data, unsigned dataLen, const SocketAddress &info) +unsigned SocketStandard::tryRecvfrom(Socket &s, void *data, unsigned len, SocketAddress &info, int timeout) +{ + SocketListener listener{{s, Read}}; + + listener.select(timeout); + + return recvfrom(s, data, len, info); +} + +unsigned SocketStandard::sendto(Socket &s, const void *data, unsigned dataLen, const SocketAddress &info) { int nbsent; @@ -304,23 +394,34 @@ if (nbsent == SOCKET_ERROR) { #if defined(_WIN32) if (WSAGetLastError() == WSAEWOULDBLOCK) - throw error::WouldBlock("sendto"); + throw error::WouldBlock("sendto", Socket::syserror(), WSAEWOULDBLOCK, Write); - throw error::Failure("sendto", Socket::syserror()); + throw error::Error("sendto", Socket::syserror(), errno); #else if (errno == EAGAIN || errno == EWOULDBLOCK) - throw error::WouldBlock("sendto"); + throw error::WouldBlock("sendto", Socket::syserror(), errno, Write); - throw error::Failure("sendto", Socket::syserror()); + throw error::Error("sendto", Socket::syserror(), errno); #endif } return (unsigned)nbsent; } -void Standard::close(Socket &s) +unsigned SocketStandard::trySendto(Socket &s, const void *data, unsigned len, const SocketAddress &info, int timeout) +{ + SocketListener listener{{s, Write}}; + + listener.select(timeout); + + return sendto(s, data, len, info); +} + +void SocketStandard::close(Socket &s) { (void)closesocket(s.handle()); + + m_connected = false; } /* -------------------------------------------------------- @@ -328,7 +429,7 @@ * -------------------------------------------------------- */ Socket::Socket() - : m_interface(std::make_shared()) + : m_interface(std::make_shared()) { } @@ -338,7 +439,7 @@ m_handle = socket(domain, type, protocol); if (m_handle == INVALID_SOCKET) - throw error::Failure("socket", syserror()); + throw error::Error("socket", syserror(), errno); } Socket::Socket(Handle handle, std::shared_ptr iface) @@ -366,39 +467,15 @@ flags |= O_NONBLOCK; if (fcntl(m_handle, F_SETFL, flags) == -1) - throw error::Failure("blockMode", Socket::syserror()); + throw error::Error("blockMode", Socket::syserror(), errno); #else unsigned long flags = (block) ? 0 : 1; if (ioctlsocket(m_handle, FIONBIO, &flags) == SOCKET_ERROR) - throw error::Failure("blockMode", Socket::syserror()); + throw error::Error("blockMode", Socket::syserror(), WSAGetLastError()); #endif } -Socket Socket::accept() -{ - SocketAddress dummy; - - return m_interface->accept(*this, dummy); -} - -unsigned Socket::send(const std::string &message) -{ - return Socket::send(message.c_str(), message.length()); -} - -unsigned Socket::recvfrom(void *data, unsigned dataLen) -{ - SocketAddress dummy; - - return m_interface->recvfrom(*this, data, dataLen, dummy); -} - -unsigned Socket::sendto(const std::string &message, const SocketAddress &info) -{ - return sendto(message.c_str(), message.length(), info); -} - bool operator==(const Socket &s1, const Socket &s2) { return s1.handle() == s2.handle(); @@ -408,3 +485,35 @@ { return s1.handle() < s2.handle(); } + +/* + * + */ + +Socket Socket::tryAccept(int timeout) +{ + SocketAddress dummy; + + return tryAccept(dummy, timeout); +} + +Socket Socket::accept() +{ + SocketAddress dummy; + + return accept(dummy); +} + +unsigned Socket::recvfrom(void *data, unsigned dataLen) +{ + SocketAddress dummy; + + return m_interface->recvfrom(*this, data, dataLen, dummy); +} + +std::string Socket::recvfrom(unsigned count) +{ + SocketAddress dummy; + + return recvfrom(count, dummy); +} diff -r 4e17193db141 -r bae4af872cde C++/Socket.h --- a/C++/Socket.h Sat Nov 15 13:19:21 2014 +0100 +++ b/C++/Socket.h Sat Nov 15 13:19:30 2014 +0100 @@ -16,8 +16,8 @@ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */ -#ifndef _SOCKET_H_ -#define _SOCKET_H_ +#ifndef _SOCKET_NG_H_ +#define _SOCKET_NG_H_ #include #include @@ -28,6 +28,8 @@ # include # include #else +# include + # include # include # include @@ -52,9 +54,138 @@ class Socket; class SocketAddress; +/** + * Namespace for listener flags. + */ +namespace direction { + +constexpr const int Read = (1 << 0); //!< Wants to read +constexpr const int Write = (1 << 1); //!< Wants to write + +} // !direction + +/** + * Various errors. + */ namespace error { /** + * Base error class, contains + */ +class Error : public std::exception { +private: + std::string m_function; + std::string m_error; + int m_code; + std::string m_shortcut; + +public: + /** + * Construct a full error. + * + * @param function which function + * @param error the error + * @param code the native code + */ + Error(std::string function, std::string error, int code); + + /** + * Get the function which thrown an exception. + * + * @return the function name + */ + const std::string &function() const noexcept; + + /** + * Get the error string. + * + * @return the error + */ + const std::string &error() const noexcept; + + /** + * Get the native code. Use with care because it varies from the system, + * the type of socket and such. + * + * @return the code + */ + int code() const noexcept; + + /** + * Get a brief error message + * + * @return a message + */ + const char *what() const noexcept override; +}; + +/** + * @class InProgress + * @brief Operation cannot be accomplished now + * + * Usually thrown in a non-blocking connect call. + */ +class InProgress final : public Error { +private: + int m_direction; + +public: + /** + * Operation is in progress, contains the direction needed to complete + * the operation. + * + * The direction may be different from the requested operation because + * some sockets (e.g SSL) requires reading even when writing! + * + * @param func the function + * @param reason the reason + * @param code the native code + * @param direction the direction + */ + InProgress(std::string func, std::string reason, int code, int direction); + + /** + * Get the required direction for listening operation requires. + * + * @return the direction required + */ + int direction() const noexcept; +}; + +/** + * @class WouldBlock + * @brief The operation would block + * + * Usually thrown in a non-blocking connect send or receive. + */ +class WouldBlock final : public Error { +private: + int m_direction; + +public: + /** + * Operation would block, contains the direction needed to complete the + * operation. + * + * The direction may be different from the requested operation because + * some sockets (e.g SSL) requires reading even when writing! + * + * @param func the function + * @param reason the reason + * @param code the native code + * @param direction the direction + */ + WouldBlock(std::string func, std::string reason, int code, int direction); + + /** + * Get the required direction for listening operation requires. + * + * @return the direction required + */ + int direction() const noexcept; +}; + +/** * @class Timeout * @brief Describe a timeout expiration * @@ -65,52 +196,16 @@ std::string m_error; public: + /** + * Timeout exception. + * + * @param func the function name + */ Timeout(std::string func); - const char *what() const noexcept override; -}; - -/** - * @class InProgress - * @brief Operation cannot be accomplished now - * - * Usually thrown in a non-blocking connect call. - */ -class InProgress final : public std::exception { -private: - std::string m_error; - -public: - InProgress(std::string func); - const char *what() const noexcept override; -}; -/** - * @class WouldBlock - * @brief The operation would block - * - * Usually thrown in a non-blocking connect send or receive. - */ -class WouldBlock final : public std::exception { -private: - std::string m_error; - -public: - WouldBlock(std::string func); - const char *what() const noexcept override; -}; - -/** - * @class Failure - * @brief General socket failure - * - * An operation failed. - */ -class Failure final : public std::exception { -private: - std::string m_error; - -public: - Failure(std::string func, std::string message); + /** + * The error message. + */ const char *what() const noexcept override; }; @@ -125,11 +220,16 @@ class SocketInterface { public: /** + * Default destructor. + */ + virtual ~SocketInterface() = default; + + /** * Bind the socket. * * @param s the socket * @param address the address - * @throw SocketError error + * @throw error::Failure on error */ virtual void bind(Socket &s, const SocketAddress &address) = 0; @@ -144,23 +244,42 @@ * Try to connect to the specific address * * @param s the socket - * @param addr the address + * @param address the address * @throw error::Failure on error * @throw error::InProgress if the socket is marked non-blocking and connection cannot be established yet */ virtual void connect(Socket &s, const SocketAddress &address) = 0; /** + * Try to connect without blocking. + * + * @param s the socket + * @param address the address + * @param timeout the timeout in milliseconds + * @warning This function won't block only if the socket is marked non blocking + */ + virtual void tryConnect(Socket &s, const SocketAddress &address, int timeout) = 0; + + /** * Accept a client. * * @param s the socket - * @param info the optional client info + * @param info the client information * @return a client ready to use * @throw error::Failure on error */ virtual Socket accept(Socket &s, SocketAddress &info) = 0; /** + * Try to accept a client with a timeout. + * + * @param s the socket + * @param info the client information + * @param timeout the timeout in milliseconds + */ + virtual Socket tryAccept(Socket &s, SocketAddress &info, int timeout) = 0; + + /** * Listen to a specific number of pending connections. * * @param s the socket @@ -174,7 +293,7 @@ * * @param s the socket * @param data the destination pointer - * @param dataLen max length to receive + * @param len max length to receive * @return the number of bytes received * @throw error::Failure on error * @throw error::WouldBlock if the socket is marked non-blocking and the operation would block @@ -182,13 +301,24 @@ virtual unsigned recv(Socket &s, void *data, unsigned len) = 0; /** - * Receive from a connection-less socket and get the client - * information. + * Try to receive data without blocking. + * + * @param s the socket + * @param data the destination pointer + * @param len max length to receive + * @param timeout the timeout in milliseconds + * @return the number of bytes received + * @throw error::Failure on error + */ + virtual unsigned tryRecv(Socket &s, void *data, unsigned len, int timeout) = 0; + + /** + * Receive from a connection-less socket. * * @param s the socket * @param data the destination pointer * @param dataLen max length to receive - * @param info the client info + * @param info the client information * @return the number of bytes received * @throw error::Failure on error * @throw error::WouldBlock if the socket is marked non-blocking and the operation would block @@ -196,6 +326,19 @@ virtual unsigned recvfrom(Socket &s, void *data, unsigned len, SocketAddress &info) = 0; /** + * Try to receive data from a connection-less socket without blocking. + * + * @param s the socket + * @param data the data + * @param len max length to receive + * @param info the client information + * @param timeout the optional timeout in milliseconds + * @return the number of bytes received + * @throw error::Failure on error + */ + virtual unsigned tryRecvfrom(Socket &s, void *data, unsigned len, SocketAddress &info, int timeout) = 0; + + /** * Send some data. * * @param s the socket @@ -208,6 +351,20 @@ virtual unsigned send(Socket &s, const void *data, unsigned len) = 0; /** + * Try to send some data without blocking. + * + * @param s the socket + * @param data the data to send + * @param dataLen the data length + * @param timeout the optional timeout in milliseconds + * @return the number of bytes sent + * @throw error::Failure on error + * @throw error::WouldBlock if the socket is marked non-blocking and the operation would block + * @warning This function won't block only if the socket is marked non blocking + */ + virtual unsigned trySend(Socket &s, const void *data, unsigned len, int timeout) = 0; + + /** * Send some data to a connection-less socket. * * @param s the socket @@ -219,6 +376,57 @@ * @throw error::WouldBlock if the socket is marked non-blocking and the operation would block */ virtual unsigned sendto(Socket &s, const void *data, unsigned len, const SocketAddress &info) = 0; + + /** + * Try to send some data to a connection-less without blocking. + * + * @param s the socket + * @param data the data to send + * @param dataLen the data length + * @param address the address + * @return the number of bytes sent + * @throw error::Failure on error + * @throw error::WouldBlock if the socket is marked non-blocking and the operation would block + * @warning This function won't block only if the socket is marked non blocking + */ + virtual unsigned trySendto(Socket &s, const void *data, unsigned len, const SocketAddress &info, int timeout) = 0; +}; + +/** + * Standard interface for sockets. + * + * This will use standard clear functions: + * + * bind(2) + * connect(2) + * close(2), closesocket on Windows, + * accept(2), + * listen(2), + * recv(2), + * recvfrom(2), + * send(2), + * sendto(2) + */ +class SocketStandard : public SocketInterface { +protected: + bool m_connected{false}; + +public: + void bind(Socket &s, const SocketAddress &address) override; + void close(Socket &s) override; + void connect(Socket &s, const SocketAddress &address) override; + void tryConnect(Socket &s, const SocketAddress &address, int timeout) override; + Socket accept(Socket &s, SocketAddress &info) override; + Socket tryAccept(Socket &s, SocketAddress &info, int timeout) override; + void listen(Socket &s, int max) override; + unsigned recv(Socket &s, void *data, unsigned len) override; + unsigned tryRecv(Socket &s, void *data, unsigned len, int timeout) override; + unsigned recvfrom(Socket &s, void *data, unsigned len, SocketAddress &info) override; + unsigned tryRecvfrom(Socket &s, void *data, unsigned len, SocketAddress &info, int timeout) override; + unsigned send(Socket &s, const void *data, unsigned len) override; + unsigned trySend(Socket &s, const void *data, unsigned len, int timeout) override; + unsigned sendto(Socket &s, const void *data, unsigned len, const SocketAddress &info) override; + unsigned trySendto(Socket &s, const void *data, unsigned len, const SocketAddress &info, int timeout) override; }; /** @@ -247,7 +455,7 @@ protected: Iface m_interface; //!< the interface - Handle m_handle { INVALID_SOCKET }; //!< the socket shared pointer + Handle m_handle{INVALID_SOCKET}; //!< the socket shared pointer public: /** @@ -323,7 +531,7 @@ void set(int level, int name, const Argument &arg) { if (setsockopt(m_handle, level, name, (Socket::ConstArg)&arg, sizeof (arg)) == SOCKET_ERROR) - throw error::Failure("set", syserror()); + throw error::Error("set", syserror(), errno); } /** @@ -340,7 +548,7 @@ socklen_t size = sizeof (result); if (getsockopt(m_handle, level, name, (Socket::Arg)&desired, &size) == SOCKET_ERROR) - throw error::Failure("get", syserror()); + throw error::Error("get", syserror(), errno); std::memcpy(&result, &desired, size); @@ -380,6 +588,32 @@ } /** + * @copydoc SocketInterface::tryConnect + */ + inline void tryConnect(const SocketAddress &address, int timeout) + { + m_interface->tryConnect(*this, address, timeout); + } + + /** + * @copydoc SocketInterface::tryAccept + */ + inline Socket tryAccept(SocketAddress &address, int timeout) + { + return m_interface->tryAccept(*this, address, timeout); + } + + /** + * Overload without client information. + * + * @param timeout the timeout + * @return the socket + * @throw error::Error on errors + * @throw error::Timeout if operation has timeout + */ + Socket tryAccept(int timeout); + + /** * Accept a client without getting its info. * * @return a client ready to use @@ -412,16 +646,50 @@ } /** - * Overload for char array. + * Overload for strings. * - * @param data the destination buffer - * @throw error::Failure on error - * @throw error::WouldBlock if the socket is marked non-blocking and the operation would block + * @param count number of bytes to recv + * @return the string + * @throw error::Error on failures + * @throw error::WouldBlock if operation would block + */ + inline std::string recv(unsigned count) + { + std::string result; + + result.resize(count); + auto n = recv(const_cast(result.data()), count); + result.resize(n); + + return result; + } + + /** + * @copydoc SocketInterface::tryRecv */ - template - inline unsigned recv(char (&data)[Size]) + inline unsigned tryRecv(void *data, unsigned dataLen, int timeout = 0) { - return recv(data, sizeof (data)); + return m_interface->tryRecv(*this, data, dataLen, timeout); + } + + /** + * Overload for string. + * + * @param count + * @param timeout + * @return the string + * @throw error::Error on failures + * @throw error::Timeout if operation has timeout + */ + inline std::string tryRecv(unsigned count, int timeout = 0) + { + std::string result; + + result.resize(count); + auto n = tryRecv(const_cast(result.data()), count, timeout); + result.resize(n); + + return result; } /** @@ -444,29 +712,32 @@ } /** - * Overload for char array. + * Overload for string. * - * @param data the destination buffer - * @throw error::Failure on error - * @throw error::WouldBlock if the socket is marked non-blocking and the operation would block + * @param count the number of bytes to receive + * @return the string + * @throw error::Error on failures + * @throw error::WouldBlock if operation would block */ - template - inline unsigned recvfrom(char (&data)[Size]) - { - return recvfrom(data, sizeof (data)); - } + std::string recvfrom(unsigned count); /** - * Overload for char array. + * Overload with client information. * - * @param data the destination buffer - * @throw error::Failure on error - * @throw error::WouldBlock if the socket is marked non-blocking and the operation would block + * @param count the number of bytes to receive + * @return the string + * @throw error::Error on failures + * @throw error::WouldBlock if operation would block */ - template - inline unsigned recvfrom(char (&data)[Size], SocketAddress &info) + inline std::string recvfrom(unsigned count, SocketAddress &info) { - return recvfrom(data, sizeof (data), info); + std::string result; + + result.resize(count); + auto n = recvfrom(const_cast(result.data()), count, info); + result.resize(n); + + return result; } /** @@ -484,7 +755,10 @@ * @return the number of bytes sent * @throw SocketError on error */ - unsigned send(const std::string &message); + inline unsigned send(const std::string &message) + { + return send(message.data(), message.size()); + } /** * @copydoc SocketInterface::sendto @@ -502,11 +776,14 @@ * @return the number of bytes sent * @throw SocketError on error */ - unsigned sendto(const std::string &message, const SocketAddress &info); + inline unsigned sendto(const std::string &message, const SocketAddress &info) + { + return sendto(message.data(), message.size(), info); + } }; bool operator==(const Socket &s1, const Socket &s2); bool operator<(const Socket &s, const Socket &s2); -#endif // !_SOCKET_H_ +#endif // !_SOCKET_NG_H_ diff -r 4e17193db141 -r bae4af872cde C++/SocketAddress.cpp --- a/C++/SocketAddress.cpp Sat Nov 15 13:19:21 2014 +0100 +++ b/C++/SocketAddress.cpp Sat Nov 15 13:19:30 2014 +0100 @@ -56,7 +56,7 @@ auto error = getaddrinfo(host.c_str(), std::to_string(port).c_str(), &hints, &res); if (error != 0) - throw error::Failure("getaddrinfo", gai_strerror(error)); + throw error::Error("getaddrinfo", gai_strerror(error), error); std::memcpy(&m_addr, res->ai_addr, res->ai_addrlen); m_addrlen = res->ai_addrlen; diff -r 4e17193db141 -r bae4af872cde C++/SocketAddress.h --- a/C++/SocketAddress.h Sat Nov 15 13:19:21 2014 +0100 +++ b/C++/SocketAddress.h Sat Nov 15 13:19:30 2014 +0100 @@ -104,7 +104,7 @@ * Create a connect address for internet protocol, * using getaddrinfo(3). */ -class Internet final : public SocketAddress { +class Internet : public SocketAddress { public: /** * Create an IPv4 or IPV6 end point. @@ -125,7 +125,7 @@ * * Create an address to a specific path. Only available on Unix. */ -class Unix final : public SocketAddress { +class Unix : public SocketAddress { public: /** * Construct an address to a path. diff -r 4e17193db141 -r bae4af872cde C++/SocketListener.cpp --- a/C++/SocketListener.cpp Sat Nov 15 13:19:21 2014 +0100 +++ b/C++/SocketListener.cpp Sat Nov 15 13:19:30 2014 +0100 @@ -24,6 +24,8 @@ #include "SocketListener.h" +using namespace direction; + /* -------------------------------------------------------- * Select implementation * -------------------------------------------------------- */ @@ -88,7 +90,7 @@ auto result = selectMultiple(ms); if (result.size() == 0) - throw error::Failure("select", "No socket found"); + throw error::Error("select", "No socket found", 0); return result[0]; } @@ -118,11 +120,15 @@ maxwait.tv_usec = ms * 1000; // Set to nullptr for infinite timeout. - towait = (ms <= 0) ? nullptr : &maxwait; + towait = (ms < 0) ? nullptr : &maxwait; auto error = ::select(max + 1, &readset, &writeset, nullptr, towait); if (error == SOCKET_ERROR) - throw error::Failure("select", Socket::syserror()); +#if defined(_WIN32) + throw error::Error("select", Socket::syserror(), WSAGetLastError()); +#else + throw error::Error("select", Socket::syserror(), errno); +#endif if (error == 0) throw error::Timeout("select"); @@ -252,13 +258,17 @@ if (result == 0) throw error::Timeout("select"); if (result < 0) - throw error::Failure("select", Socket::syserror()); +#if defined(_WIN32) + throw error::Error("poll", Socket::syserror(), WSAGetLastError()); +#else + throw error::Error("poll", Socket::syserror(), errno); +#endif for (auto &fd : m_fds) if (fd.revents != 0) return { m_lookup[fd.fd], todirection(fd.revents) }; - throw error::Failure("select", "No socket found"); + throw error::Error("select", "No socket found", 0); } std::vector PollMethod::selectMultiple(int ms) @@ -267,7 +277,11 @@ if (result == 0) throw error::Timeout("select"); if (result < 0) - throw error::Failure("select", Socket::syserror()); +#if defined(_WIN32) + throw error::Error("poll", Socket::syserror(), WSAGetLastError()); +#else + throw error::Error("poll", Socket::syserror(), errno); +#endif std::vector sockets; for (auto &fd : m_fds) diff -r 4e17193db141 -r bae4af872cde C++/SocketListener.h --- a/C++/SocketListener.h Sat Nov 15 13:19:21 2014 +0100 +++ b/C++/SocketListener.h Sat Nov 15 13:19:30 2014 +0100 @@ -36,17 +36,6 @@ #endif /** - * @enum SocketDirection - * @brief The SocketDirection enum - * - * Bitmask that can be set to both reading and writing. - */ -enum SocketDirection { - Read = (1 << 0), //!< only for receive - Write = (1 << 1) //!< only for sending -}; - -/** * @enum SocketMethod * @brief The SocketMethod enum * @@ -244,15 +233,16 @@ } /** - * Overload that waits indefinitely. + * Overload with milliseconds. * + * @param timeout the optional timeout in milliseconds * @return the socket ready * @throw SocketError on error * @throw SocketTimeout on timeout */ - inline SocketStatus select() + inline SocketStatus select(int timeout = -1) { - return m_interface->select(-1); + return m_interface->select(timeout); } /** @@ -272,15 +262,15 @@ } /** - * Overload that waits indefinitely. + * Overload with milliseconds. * * @return the socket ready * @throw SocketError on error * @throw SocketTimeout on timeout */ - inline std::vector selectMultiple() + inline std::vector selectMultiple(int timeout = -1) { - return m_interface->selectMultiple(-1); + return m_interface->selectMultiple(timeout); } /** diff -r 4e17193db141 -r bae4af872cde C++/SocketSsl.cpp --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/C++/SocketSsl.cpp Sat Nov 15 13:19:30 2014 +0100 @@ -0,0 +1,220 @@ +/* + * SocketSsl.cpp -- OpenSSL extension for sockets + * + * Copyright (c) 2013, David Demelier + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +#include "SocketListener.h" +#include "SocketSsl.h" + +using namespace direction; + +namespace { + +const SSL_METHOD *method(int mflags) +{ + if (mflags & SocketSslOptions::All) + return SSLv23_method(); + if (mflags & SocketSslOptions::SSLv3) + return SSLv3_method(); + if (mflags & SocketSslOptions::TLSv1) + return TLSv1_method(); + + return SSLv23_method(); +} + +inline std::string sslError(int error) +{ + return ERR_reason_error_string(error); +} + +} // !namespace + +SocketSslInterface::SocketSslInterface(SSL_CTX *context, SSL *ssl, SocketSslOptions options) + : SocketStandard() + , m_context(context, SSL_CTX_free) + , m_ssl(ssl, SSL_free) + , m_options(std::move(options)) +{ +} + +SocketSslInterface::SocketSslInterface(SocketSslOptions options) + : SocketStandard() + , m_options(std::move(options)) +{ +} + +void SocketSslInterface::connect(Socket &s, const SocketAddress &address) +{ + SocketStandard::connect(s, address); + + // Context first + auto context = SSL_CTX_new(method(m_options.method)); + + m_context = SslContext(context, SSL_CTX_free); + + // SSL object then + auto ssl = SSL_new(context); + + m_ssl = Ssl(ssl, SSL_free); + + SSL_set_fd(ssl, s.handle()); + + auto ret = SSL_connect(ssl); + + if (ret <= 0) { + auto error = SSL_get_error(ssl, ret); + + if (error == SSL_ERROR_WANT_READ || error == SSL_ERROR_WANT_WRITE) + throw error::InProgress("connect", sslError(error), error, error); + + throw error::Error("accept", sslError(error), error); + } +} + +void SocketSslInterface::tryConnect(Socket &s, const SocketAddress &address, int timeout) +{ + try { + // Initial try + connect(s, address); + } catch (const error::InProgress &ipe) { + SocketListener listener{{s, ipe.direction()}}; + + listener.select(timeout); + + // Second try + connect(s, address); + } +} + +Socket SocketSslInterface::accept(Socket &s, SocketAddress &info) +{ + auto client = SocketStandard::accept(s, info); + auto context = SSL_CTX_new(method(m_options.method)); + + if (m_options.certificate.size() > 0) + SSL_CTX_use_certificate_file(context, m_options.certificate.c_str(), SSL_FILETYPE_PEM); + if (m_options.privateKey.size() > 0) + SSL_CTX_use_PrivateKey_file(context, m_options.privateKey.c_str(), SSL_FILETYPE_PEM); + if (m_options.verify && !SSL_CTX_check_private_key(context)) { + client.close(); + throw error::Error("accept", "certificate failure", 0); + } + + // SSL object + auto ssl = SSL_new(context); + + SSL_set_fd(ssl, client.handle()); + + auto ret = SSL_accept(ssl); + + if (ret <= 0) { + auto error = SSL_get_error(ssl, ret); + + if (error == SSL_ERROR_WANT_READ || error == SSL_ERROR_WANT_WRITE) + throw error::InProgress("accept", sslError(error), error, error); + + throw error::Error("accept", sslError(error), error); + } + + return SocketSsl{client.handle(), std::make_shared(context, ssl)}; +} + +unsigned SocketSslInterface::recv(Socket &, void *data, unsigned len) +{ + auto nbread = SSL_read(m_ssl.get(), data, len); + + if (nbread <= 0) { + auto error = SSL_get_error(m_ssl.get(), nbread); + + if (error == SSL_ERROR_WANT_READ || error == SSL_ERROR_WANT_WRITE) + throw error::InProgress("accept", sslError(error), error, error); + + throw error::Error("recv", sslError(error), error); + } + + return nbread; +} + +unsigned SocketSslInterface::recvfrom(Socket &, void *, unsigned, SocketAddress &) +{ + throw error::Error("recvfrom", "SSL socket is not UDP compatible", 0); +} + +unsigned SocketSslInterface::tryRecv(Socket &s, void *data, unsigned len, int timeout) +{ + SocketListener listener{{s, Read}}; + + listener.select(timeout); + + return recv(s, data, len); +} + +unsigned SocketSslInterface::tryRecvfrom(Socket &, void *, unsigned, SocketAddress &, int) +{ + throw error::Error("recvfrom", "SSL socket is not UDP compatible", 0); +} + +unsigned SocketSslInterface::send(Socket &, const void *data, unsigned len) +{ + auto nbread = SSL_write(m_ssl.get(), data, len); + + if (nbread <= 0) { + auto error = SSL_get_error(m_ssl.get(), nbread); + + if (error == SSL_ERROR_WANT_READ || error == SSL_ERROR_WANT_WRITE) + throw error::InProgress("accept", sslError(error), error, error); + + throw error::Error("recv", sslError(error), error); + } + + return nbread; +} + +unsigned SocketSslInterface::sendto(Socket &, const void *, unsigned, const SocketAddress &) +{ + throw error::Error("sendto", "SSL socket is not UDP compatible", 0); +} + +unsigned SocketSslInterface::trySend(Socket &s, const void *data, unsigned len, int timeout) +{ + SocketListener listener{{s, Write}}; + + listener.select(timeout); + + return send(s, data, len); +} + +unsigned SocketSslInterface::trySendto(Socket &, const void *, unsigned, const SocketAddress &, int) +{ + throw error::Error("sendto", "SSL socket is not UDP compatible", 0); +} + +void SocketSsl::init() +{ + SSL_library_init(); + SSL_load_error_strings(); +} + +void SocketSsl::finish() +{ + ERR_free_strings(); +} + +SocketSsl::SocketSsl(int family, SocketSslOptions options) + : Socket(family, SOCK_STREAM, 0) +{ + m_interface = std::make_shared(std::move(options)); +} diff -r 4e17193db141 -r bae4af872cde C++/SocketSsl.h --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/C++/SocketSsl.h Sat Nov 15 13:19:30 2014 +0100 @@ -0,0 +1,109 @@ +/* + * SocketSsl.h -- OpenSSL extension for sockets + * + * Copyright (c) 2013, David Demelier + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +#ifndef _SOCKET_SSL_H_ +#define _SOCKET_SSL_H_ + +#include +#include +#include + +#include "Socket.h" + +struct SocketSslOptions { + enum { + SSLv3 = (1 << 0), + TLSv1 = (1 << 1), + All = (0xf) + }; + + unsigned short method{All}; + std::string certificate; + std::string privateKey; + bool verify{false}; + + SocketSslOptions() = default; + + SocketSslOptions(unsigned short method, std::string certificate, std::string key, bool verify = false) + : method(method) + , certificate(std::move(certificate)) + , privateKey(std::move(key)) + , verify(verify) + { + } +}; + +class SocketSslInterface : public SocketStandard { +private: + using Ssl = std::shared_ptr; + using SslContext = std::shared_ptr; + + SslContext m_context; + Ssl m_ssl; + SocketSslOptions m_options; + +public: + SocketSslInterface(SSL_CTX *context, SSL *ssl, SocketSslOptions options = {}); + SocketSslInterface(SocketSslOptions options = {}); + void connect(Socket &s, const SocketAddress &address) override; + void tryConnect(Socket &s, const SocketAddress &address, int timeout) override; + Socket accept(Socket &s, SocketAddress &info) override; + unsigned recv(Socket &s, void *data, unsigned len) override; + unsigned recvfrom(Socket &s, void *data, unsigned len, SocketAddress &info) override; + unsigned tryRecv(Socket &s, void *data, unsigned len, int timeout) override; + unsigned tryRecvfrom(Socket &s, void *data, unsigned len, SocketAddress &info, int timeout) override; + unsigned send(Socket &s, const void *data, unsigned len) override; + unsigned sendto(Socket &s, const void *data, unsigned len, const SocketAddress &info) override; + unsigned trySend(Socket &s, const void *data, unsigned len, int timeout) override; + unsigned trySendto(Socket &s, const void *data, unsigned len, const SocketAddress &info, int timeout) override; +}; + +/** + * @class SocketSsl + * @brief SSL interface for sockets + * + * This class derives from Socket and provide SSL support through OpenSSL. + * + * It is perfectly safe to cast SocketSsl to Socket and vice-versa. + */ +class SocketSsl : public Socket { +private: + using Socket::Socket; + +public: + /** + * Initialize SSL library. + */ + static void init(); + + /** + * Close SSL library. + */ + static void finish(); + + /** + * Open a SSL socket with the specified family. Automatically + * use SOCK_STREAM as the type. + * + * @param family the family + * @param options the options + */ + SocketSsl(int family, SocketSslOptions options = {}); +}; + +#endif // !_SOCKET_SSL_H_ diff -r 4e17193db141 -r bae4af872cde C++/Tests/Sockets/CMakeLists.txt --- a/C++/Tests/Sockets/CMakeLists.txt Sat Nov 15 13:19:21 2014 +0100 +++ b/C++/Tests/Sockets/CMakeLists.txt Sat Nov 15 13:19:30 2014 +0100 @@ -16,6 +16,8 @@ # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # +find_package(OpenSSL REQUIRED) + set( SOURCES ${code_SOURCE_DIR}/C++/Socket.cpp @@ -24,6 +26,8 @@ ${code_SOURCE_DIR}/C++/SocketAddress.h ${code_SOURCE_DIR}/C++/SocketListener.cpp ${code_SOURCE_DIR}/C++/SocketListener.h + ${code_SOURCE_DIR}/C++/SocketSsl.cpp + ${code_SOURCE_DIR}/C++/SocketSsl.h main.cpp ) @@ -32,3 +36,6 @@ if (WIN32) target_link_libraries(socket ws2_32) endif () + +target_include_directories(socket PRIVATE ${OPENSSL_INCLUDE_DIR}) +target_link_libraries(socket ${OPENSSL_LIBRARIES}) diff -r 4e17193db141 -r bae4af872cde C++/Tests/Sockets/main.cpp --- a/C++/Tests/Sockets/main.cpp Sat Nov 15 13:19:21 2014 +0100 +++ b/C++/Tests/Sockets/main.cpp Sat Nov 15 13:19:30 2014 +0100 @@ -65,15 +65,15 @@ s2 = { AF_INET6, SOCK_STREAM, 0 }; SocketListener listener { - { s1, Read }, - { s2, Read }, + { s1, direction::Read }, + { s2, direction::Read }, }; ASSERT_EQ(2UL, listener.size()); listener.list([&] (const auto &so, auto direction) { ASSERT_TRUE(so == s1 || so == s2); - ASSERT_EQ(Read, direction); + ASSERT_EQ(direction::Read, direction); }); } catch (const std::exception &ex) { std::cerr << "warning: " << ex.what() << std::endl; @@ -101,7 +101,7 @@ s.bind(Internet{"*", 10000, AF_INET}); s.listen(10); - listener.add(s, Read); + listener.add(s, direction::Read); while (running) { try { @@ -151,8 +151,8 @@ s2 = { AF_INET, SOCK_STREAM, 0 }; SocketListener listener(Select); - listener.add(s, Read); - listener.add(s2, Read); + listener.add(s, direction::Read); + listener.add(s2, direction::Read); ASSERT_EQ(2UL, listener.size()); } catch (const std::exception &ex) { @@ -172,10 +172,10 @@ s2 = { AF_INET, SOCK_STREAM, 0 }; SocketListener listener(Select); - listener.add(s, Read); - listener.add(s2, Read); - listener.remove(s, Read); - listener.remove(s2, Read); + listener.add(s, direction::Read); + listener.add(s2, direction::Read); + listener.remove(s, direction::Read); + listener.remove(s2, direction::Read); ASSERT_EQ(0UL, listener.size()); } catch (const std::exception &ex) { @@ -187,7 +187,7 @@ } /* - * Add two sockets for both reading and writing, them remove only reading and then + * Add two sockets for both direction::Reading and writing, them remove only direction::Reading and then * move only writing. */ TEST(ListenerMethodSelect, inOut) @@ -199,26 +199,26 @@ s2 = { AF_INET, SOCK_STREAM, 0 }; SocketListener listener(Select); - listener.add(s, Read | Write); - listener.add(s2, Read | Write); + listener.add(s, direction::Read | direction::Write); + listener.add(s2, direction::Read | direction::Write); listener.list([&] (Socket &si, int dir) { ASSERT_TRUE(si == s || si == s2); ASSERT_EQ(0x03, static_cast(dir)); }); - listener.remove(s, Write); - listener.remove(s2, Write); + listener.remove(s, direction::Write); + listener.remove(s2, direction::Write); ASSERT_EQ(2UL, listener.size()); listener.list([&] (Socket &si, int dir) { ASSERT_TRUE(si == s || si == s2); - ASSERT_EQ(Read, dir); + ASSERT_EQ(direction::Read, dir); }); - listener.remove(s, Read); - listener.remove(s2, Read); + listener.remove(s, direction::Read); + listener.remove(s2, direction::Read); ASSERT_EQ(0UL, listener.size()); } catch (const std::exception &ex) { @@ -237,14 +237,14 @@ s = { AF_INET, SOCK_STREAM, 0 }; SocketListener listener(Select); - listener.add(s, Read); + listener.add(s, direction::Read); ASSERT_EQ(1UL, listener.size()); listener.list([&] (const Socket &si, int dir) { ASSERT_TRUE(si == s); - ASSERT_EQ(Read, dir); + ASSERT_EQ(direction::Read, dir); }); - listener.add(s, Write); + listener.add(s, direction::Write); ASSERT_EQ(1UL, listener.size()); listener.list([&] (const Socket &si, int dir) { ASSERT_TRUE(si == s); @@ -252,9 +252,9 @@ }); // Oops, added the same - listener.add(s, Read); - listener.add(s, Write); - listener.add(s, Read | Write); + listener.add(s, direction::Read); + listener.add(s, direction::Write); + listener.add(s, direction::Read | direction::Write); ASSERT_EQ(1UL, listener.size()); listener.list([&] (const Socket &si, int dir) { @@ -286,7 +286,7 @@ s.bind(Internet{"*", 10000, AF_INET}); s.listen(10); - listener.add(s, Read); + listener.add(s, direction::Read); while (running) { try { @@ -336,8 +336,8 @@ s2 = { AF_INET, SOCK_STREAM, 0 }; SocketListener listener(Poll); - listener.add(s, Read); - listener.add(s2, Read); + listener.add(s, direction::Read); + listener.add(s2, direction::Read); ASSERT_EQ(2UL, listener.size()); } catch (const std::exception &ex) { @@ -357,10 +357,10 @@ s2 = { AF_INET, SOCK_STREAM, 0 }; SocketListener listener(Poll); - listener.add(s, Read); - listener.add(s2, Read); - listener.remove(s, Read); - listener.remove(s2, Read); + listener.add(s, direction::Read); + listener.add(s2, direction::Read); + listener.remove(s, direction::Read); + listener.remove(s2, direction::Read); ASSERT_EQ(0UL, listener.size()); } catch (const std::exception &ex) { @@ -380,26 +380,26 @@ s2 = { AF_INET, SOCK_STREAM, 0 }; SocketListener listener(Poll); - listener.add(s, Read | Write); - listener.add(s2, Read | Write); + listener.add(s, direction::Read | direction::Write); + listener.add(s2, direction::Read | direction::Write); listener.list([&] (Socket &si, int dir) { ASSERT_TRUE(si == s || si == s2); ASSERT_EQ(0x03, static_cast(dir)); }); - listener.remove(s, Write); - listener.remove(s2, Write); + listener.remove(s, direction::Write); + listener.remove(s2, direction::Write); ASSERT_EQ(2UL, listener.size()); listener.list([&] (Socket &si, int dir) { ASSERT_TRUE(si == s || si == s2); - ASSERT_EQ(Read, dir); + ASSERT_EQ(direction::Read, dir); }); - listener.remove(s, Read); - listener.remove(s2, Read); + listener.remove(s, direction::Read); + listener.remove(s2, direction::Read); ASSERT_EQ(0UL, listener.size()); } catch (const std::exception &ex) { @@ -418,14 +418,14 @@ s = { AF_INET, SOCK_STREAM, 0 }; SocketListener listener(Poll); - listener.add(s, Read); + listener.add(s, direction::Read); ASSERT_EQ(1UL, listener.size()); listener.list([&] (const Socket &si, int dir) { ASSERT_TRUE(si == s); - ASSERT_EQ(Read, dir); + ASSERT_EQ(direction::Read, dir); }); - listener.add(s, Write); + listener.add(s, direction::Write); ASSERT_EQ(1UL, listener.size()); listener.list([&] (const Socket &si, int dir) { ASSERT_TRUE(si == s); @@ -433,9 +433,9 @@ }); // Oops, added the same - listener.add(s, Read); - listener.add(s, Write); - listener.add(s, Read | Write); + listener.add(s, direction::Read); + listener.add(s, direction::Write); + listener.add(s, direction::Read | direction::Write); ASSERT_EQ(1UL, listener.size()); listener.list([&] (const Socket &si, int dir) { @@ -466,11 +466,11 @@ s.bind(Internet{"*", 10000, AF_INET}); s.listen(8); - listener.add(s, Read); + listener.add(s, direction::Read); auto client = listener.select(10s); - ASSERT_TRUE(client.direction == Read); + ASSERT_TRUE(client.direction == direction::Read); ASSERT_TRUE(client.socket == s); } catch (const std::exception &ex) { std::cerr << "warning: " << ex.what() << std::endl; @@ -511,21 +511,21 @@ s.bind(Internet("*", 10000, AF_INET)); s.listen(8); - // Read for master - listener.add(s, Read); + // direction::Read for master + listener.add(s, direction::Read); auto result = listener.select(10s); - ASSERT_TRUE(result.direction == Read); + ASSERT_TRUE(result.direction == direction::Read); ASSERT_TRUE(result.socket == s); // Wait for client auto client = s.accept(); - listener.add(client, Read); + listener.add(client, direction::Read); result = listener.select(10s); - ASSERT_TRUE(result.direction == Read); + ASSERT_TRUE(result.direction == direction::Read); ASSERT_TRUE(result.socket == client); char data[512]; @@ -577,13 +577,13 @@ server.set(SOL_SOCKET, SO_REUSEADDR, 1); server.bind(Internet{"*", 10000, AF_INET}); server.listen(10); - listener.add(server, Read); + listener.add(server, direction::Read); while (!finished) { auto s = listener.select(60s).socket; if (s == server) { - listener.add(s.accept(), Read); + listener.add(s.accept(), direction::Read); } else { char data[512]; auto nb = s.recv(data, sizeof (data) - 1); @@ -625,7 +625,7 @@ client.connect(Internet{"localhost", 10000, AF_INET}); client.blockMode(false); - listener.add(client, Write); + listener.add(client, direction::Write); while (data.size() > 0) { auto s = listener.select(30s).socket; @@ -650,7 +650,7 @@ TEST(MultipleSelection, select) { /* - * Normally, 3 sockets added for writing should be marked ready immediately + * Normally, 3 sockets added for writing should be marked direction::Ready immediately * as there are no data being currently queued to be sent. */ std::thread server([] () { @@ -664,11 +664,11 @@ master.bind(Internet{"*", 10000, AF_INET}); master.listen(8); - masterListener.add(master, Read); + masterListener.add(master, direction::Read); while (clientListener.size() != 3) { masterListener.select(3s); - clientListener.add(master.accept(), Write); + clientListener.add(master.accept(), direction::Write); } // Now do the test of writing @@ -716,7 +716,7 @@ TEST(MultipleSelection, poll) { /* - * Normally, 3 sockets added for writing should be marked ready immediately + * Normally, 3 sockets added for writing should be marked direction::Ready immediately * as there are no data being currently queued to be sent. */ std::thread server([] () { @@ -730,11 +730,11 @@ master.bind(Internet{"*", 10000, AF_INET}); master.listen(8); - masterListener.add(master, Read); + masterListener.add(master, direction::Read); while (clientListener.size() != 3) { masterListener.select(3s); - clientListener.add(master.accept(), Write); + clientListener.add(master.accept(), direction::Write); } // Now do the test of writing @@ -1063,6 +1063,95 @@ #endif +/* -------------------------------------------------------- + * Non-blocking functions failures + * -------------------------------------------------------- */ + +TEST(NonBlockingFailures, connect) +{ + Socket s; + + try { + s = {AF_INET, SOCK_STREAM, 0}; + s.blockMode(false); + } catch (const std::exception &ex) { + std::cerr << ex.what() << std::endl; + return; + } + + auto time1 = std::chrono::system_clock::now(); + + try { + s.tryConnect(Internet{"google.fr", 9000, AF_INET}, 1000); + + FAIL() << "unexpected code path"; + } catch (const std::exception &ex) {} + + auto time2 = std::chrono::system_clock::now(); + auto duration = std::chrono::duration_cast(time2 - time1).count(); + + // Assert between 0,9 and 1,1 seconds + ASSERT_GT(duration, 900); + ASSERT_LT(duration, 1100); + + s.close(); +} + +TEST(NonBlockingFailures, recv) +{ + Socket s; + Socket server; + bool ready{false}; + + try { + s = {AF_INET, SOCK_STREAM, 0}; + + server = {AF_INET, SOCK_STREAM, 0}; + server.set(SOL_SOCKET, SO_REUSEADDR, 1); + server.bind(Internet{"localhost", 10000, AF_INET}); + server.listen(10); + + ready = true; + } catch (const std::exception &ex) { + std::cerr << "warning: " << ex.what() << std::endl; + return; + } + + if (!ready) { + s.close(); + server.close(); + } + + try { + s.connect(Internet{"localhost", 10000, AF_INET}); + s.blockMode(false); + } catch (const std::exception &ex) { + std::cerr << "warning: " << ex.what() << std::endl; + + s.close(); + server.close(); + + return; + } + + auto time1 = std::chrono::system_clock::now(); + + try { + s.tryRecv(100, 1000); + + FAIL() << "unexpected code path"; + } catch (const std::exception &) {} + + auto time2 = std::chrono::system_clock::now(); + auto duration = std::chrono::duration_cast(time2 - time1).count(); + + // Assert between 0,9 and 1,1 seconds + ASSERT_GT(duration, 900); + ASSERT_LT(duration, 1100); + + s.close(); +} + int main(int argc, char **argv) { Socket::init();