Mercurial > code
changeset 378:92457ea8f7e2
Socket: switch to more OO
author | David Demelier <markand@malikania.fr> |
---|---|
date | Fri, 19 Jun 2015 10:52:54 +0200 |
parents | 4885cd4dfa33 |
children | 57ce1a6293b9 |
files | C++/modules/Socket/Socket.cpp C++/modules/Socket/Socket.h C++/modules/Socket/SocketListener.cpp C++/modules/Socket/SocketListener.h C++/modules/Socket/SocketSsl.cpp C++/modules/Socket/SocketSsl.h C++/modules/Socket/SocketTcp.cpp C++/modules/Socket/SocketTcp.h C++/tests/Socket/main.cpp |
diffstat | 9 files changed, 263 insertions(+), 640 deletions(-) [+] |
line wrap: on
line diff
--- a/C++/modules/Socket/Socket.cpp Thu Jun 18 14:24:01 2015 +0200 +++ b/C++/modules/Socket/Socket.cpp Fri Jun 19 10:52:54 2015 +0200 @@ -119,6 +119,21 @@ m_state = SocketState::Opened; } +Socket::Socket(Socket &&other) noexcept +{ + m_handle = other.m_handle; + m_state = other.m_state; + + // Invalidate other + other.m_handle = -1; + other.m_state = SocketState::Closed; +} + +Socket::~Socket() +{ + close(); +} + SocketAddress Socket::address() const { #if defined(_WIN32) @@ -149,13 +164,15 @@ void Socket::close() { + if (m_state != SocketState::Closed) { #if defined(_WIN32) - ::closesocket(m_handle); + ::closesocket(m_handle); #else - ::close(m_handle); + ::close(m_handle); #endif - - m_state = SocketState::Closed; + m_handle = -1; + m_state = SocketState::Closed; + } } void Socket::setBlockMode(bool block) @@ -185,6 +202,18 @@ #endif } +Socket &Socket::operator=(Socket &&other) noexcept +{ + m_handle = other.m_handle; + m_state = other.m_state; + + // Invalidate other + other.m_handle = -1; + other.m_state = SocketState::Closed; + + return *this; +} + bool operator==(const Socket &s1, const Socket &s2) { return s1.handle() == s2.handle();
--- a/C++/modules/Socket/Socket.h Thu Jun 18 14:24:01 2015 +0200 +++ b/C++/modules/Socket/Socket.h Fri Jun 19 10:52:54 2015 +0200 @@ -269,8 +269,8 @@ * This create an invalid socket. */ inline Socket() noexcept - : m_handle(Invalid) - , m_state(SocketState::Closed) + : m_handle{Invalid} + , m_state{SocketState::Closed} { } @@ -280,8 +280,8 @@ * @param handle the native descriptor */ inline Socket(Handle handle) - : m_handle(handle) - , m_state(SocketState::Opened) + : m_handle{handle} + , m_state{SocketState::Opened} { } @@ -296,9 +296,21 @@ Socket(int domain, int type, int protocol); /** + * Copy constructor deleted. + */ + Socket(const Socket &) = delete; + + /** + * Transfer ownership from other to this. + * + * @param other the other socket + */ + Socket(Socket &&other) noexcept; + + /** * Default destructor. */ - virtual ~Socket() = default; + virtual ~Socket(); /** * Get the local name. This is a wrapper of getsockname(). @@ -392,8 +404,26 @@ /** * Close the socket. + * + * Automatically called from the destructor. */ virtual void close(); + + /** + * Assignment operator forbidden. + * + * @return *this + */ + Socket &operator=(const Socket &) = delete; + + /** + * Transfer ownership from other to this. The other socket is left + * invalid and will not be closed. + * + * @param other the other socket + * @return this + */ + Socket &operator=(Socket &&other) noexcept; }; /**
--- a/C++/modules/Socket/SocketListener.cpp Thu Jun 18 14:24:01 2015 +0200 +++ b/C++/modules/Socket/SocketListener.cpp Fri Jun 19 10:52:54 2015 +0200 @@ -127,23 +127,23 @@ return flags; } -void Poll::set(const SocketTable &, Socket s, int flags, bool add) +void Poll::set(const SocketTable &, const std::shared_ptr<Socket> s, int flags, bool add) { if (add) { - m_fds.push_back(pollfd{s.handle(), topoll(flags), 0}); + m_fds.push_back(pollfd{s->handle(), topoll(flags), 0}); } else { auto it = std::find_if(m_fds.begin(), m_fds.end(), [&] (const struct pollfd &pfd) { - return pfd.fd == s.handle(); + return pfd.fd == s->handle(); }); it->events |= topoll(flags); } } -void Poll::unset(const SocketTable &, Socket s, int flags, bool remove) +void Poll::unset(const SocketTable &, const std::shared_ptr<Socket> s, int flags, bool remove) { auto it = std::find_if(m_fds.begin(), m_fds.end(), [&] (const struct pollfd &pfd) { - return pfd.fd == s.handle(); + return pfd.fd == s->handle(); }); if (remove) { @@ -209,17 +209,17 @@ return flags; } -void Epoll::update(Socket &sc, int op, int flags) +void Epoll::update(const std::shared_ptr<Socket> &sc, int op, int flags) { struct epoll_event ev; std::memset(&ev, 0, sizeof (struct epoll_event)); ev.events = flags; - ev.data.fd = sc.handle(); + ev.data.fd = sc->handle(); - if (epoll_ctl(m_handle, op, sc.handle(), &ev) < 0) { - throw SocketError(SocketError::System, "epoll_ctl"); + if (epoll_ctl(m_handle, op, sc->handle(), &ev) < 0) { + throw SocketError{SocketError::System, "epoll_ctl"}; } } @@ -239,7 +239,7 @@ /* * Add a new epoll_event or just update it. */ -void Epoll::set(const SocketTable &, Socket &sc, int flags, bool add) +void Epoll::set(const SocketTable &, const std::shared_ptr<Socket> &sc, int flags, bool add) { update(sc, add ? EPOLL_CTL_ADD : EPOLL_CTL_MOD, toepoll(flags)); @@ -256,13 +256,13 @@ * So we put the same flags that are currently effective and remove the * requested one. */ -void Epoll::unset(const SocketTable &table, Socket &sc, int flags, bool remove) +void Epoll::unset(const SocketTable &table, const std::shared_ptr<Socket> &sc, int flags, bool remove) { if (remove) { update(sc, EPOLL_CTL_DEL, 0); m_events.resize(m_events.size() - 1); } else { - update(sc, EPOLL_CTL_MOD, table.at(sc.handle()).second & ~(toepoll(flags))); + update(sc, EPOLL_CTL_MOD, table.at(sc->handle()).second & ~(toepoll(flags))); } } @@ -306,18 +306,18 @@ close(m_handle); } -void Kqueue::update(const Socket &sc, int filter, int flags) +void Kqueue::update(const std::shared_ptr<Socket> &sc, int filter, int flags) { struct kevent ev; - EV_SET(&ev, sc.handle(), filter, flags, 0, 0, nullptr); + EV_SET(&ev, sc->handle(), filter, flags, 0, 0, nullptr); if (kevent(m_handle, &ev, 1, nullptr, 0, nullptr) < 0) { throw SocketError(SocketError::System, "kevent"); } } -void Kqueue::set(const SocketTable &, const Socket &sc, int flags, bool add) +void Kqueue::set(const SocketTable &, const std::shared_ptr<Socket> &sc, int flags, bool add) { if (flags & SocketListener::Read) { update(sc, EVFILT_READ, EV_ADD | EV_ENABLE); @@ -331,7 +331,7 @@ } } -void Kqueue::unset(const SocketTable &, const Socket &sc, int flags, bool remove) +void Kqueue::unset(const SocketTable &, const std::shared_ptr<Socket> &sc, int flags, bool remove) { if (flags & SocketListener::Read) { update(sc, EVFILT_READ, EV_DELETE); @@ -364,7 +364,7 @@ } for (int i = 0; i < nevents; ++i) { - Socket sc = table.at(m_result[i].ident).first; + std::shared_ptr<Socket> sc = table.at(m_result[i].ident).first; int flags = m_result[i].filter == EVFILT_READ ? SocketListener::Read : SocketListener::Write; sockets.push_back(SocketStatus{sc, flags});
--- a/C++/modules/Socket/SocketListener.h Thu Jun 18 14:24:01 2015 +0200 +++ b/C++/modules/Socket/SocketListener.h Fri Jun 19 10:52:54 2015 +0200 @@ -105,15 +105,15 @@ */ class SocketStatus { public: - Socket socket; //!< which socket is ready - int flags; //!< the flags + std::shared_ptr<Socket> socket; //!< which socket is ready + int flags; //!< the flags }; /** * Table used in the socket listener to store which sockets have been * set in which directions. */ -using SocketTable = std::map<Socket::Handle, std::pair<Socket, int>>; +using SocketTable = std::map<Socket::Handle, std::pair<std::shared_ptr<Socket>, int>>; namespace backend { @@ -136,12 +136,12 @@ /** * No-op, uses the SocketTable directly. */ - inline void set(const SocketTable &, const Socket &, int, bool) noexcept {} + inline void set(const SocketTable &, const std::shared_ptr<Socket> &, int, bool) noexcept {} /** * No-op, uses the SocketTable directly. */ - inline void unset(const SocketTable &, const Socket &, int, bool) noexcept {} + inline void unset(const SocketTable &, const std::shared_ptr<Socket> &, int, bool) noexcept {} std::vector<SocketStatus> wait(const SocketTable &table, int ms); }; @@ -163,8 +163,8 @@ int toflags(short &event) const noexcept; public: - void set(const SocketTable &, Socket sc, int flags, bool add); - void unset(const SocketTable &, Socket sc, int flags, bool remove); + void set(const SocketTable &, const std::shared_ptr<Socket> sc, int flags, bool add); + void unset(const SocketTable &, const std::shared_ptr<Socket> sc, int flags, bool remove); std::vector<SocketStatus> wait(const SocketTable &, int ms); /** @@ -192,13 +192,13 @@ uint32_t toepoll(int flags) const noexcept; int toflags(uint32_t events) const noexcept; - void update(Socket &sc, int op, int flags); + void update(const std::shared_ptr<Socket> &sc, int op, int flags); public: Epoll(); ~Epoll(); - void set(const SocketTable &, Socket &sc, int flags, bool add); - void unset(const SocketTable &, Socket &sc, int flags, bool remove); + void set(const SocketTable &, const std::shared_ptr<Socket> &sc, int flags, bool add); + void unset(const SocketTable &, const std::shared_ptr<Socket> &sc, int flags, bool remove); std::vector<SocketStatus> wait(const SocketTable &table, int ms); /** @@ -237,8 +237,8 @@ Kqueue(); ~Kqueue(); - void set(const SocketTable &, const Socket &sc, int flags, bool add); - void unset(const SocketTable &, const Socket &sc, int flags, bool remove); + void set(const SocketTable &, const std::shared_ptr<Socket> &sc, int flags, bool add); + void unset(const SocketTable &, const std::shared_ptr<Socket> &sc, int flags, bool remove); std::vector<SocketStatus> wait(const SocketTable &, int ms); /** @@ -267,14 +267,14 @@ * main loop as it can be extremely costly. Instead use the same listener that * you can safely modify on the fly. * - * Currently, poll, select and kqueue are available. + * Currently, poll, epoll, select and kqueue are available. * * To implement the backend, the following functions must be available: * * # Set * * @code - * void set(Socket sc, int flags, bool add); + * void set(const SocketTable &, const std::shared_ptr<Socket> &sc, int flags, bool add); * @endcode * * This function, takes the socket to be added and the flags. The flags are @@ -288,7 +288,7 @@ * # Unset * * @code - * void unset(Socket sc, int flags, bool remove); + * void unset(const SocketTable &, const std::shared_ptr<Socket> &sc, int flags, bool remove); * @endcode * * Like set, this function is only called if the flags are actually set and will @@ -296,11 +296,35 @@ * * Also like set, an optional remove argument is set if the socket is being * completely removed (e.g no more flags are set for this socket). + * + * # Wait + * + * @code + * std::vector<SocketStatus> wait(const SocketTable &, int ms); + * @encode + * + * Wait for the sockets to be ready with the specified milliseconds. Must return a list of SocketStatus, + * may throw any exceptions. + * + * # Name + * + * @code + * inline const char *name() const noexcept + * @endcode + * + * Returns the backend name. Usually the class in lower case. */ template <typename Backend = SOCKET_DEFAULT_BACKEND> class SocketListenerBase final { public: + /** + * Mark the socket for read operation. + */ static const int Read; + + /** + * Mark the socket for write operation. + */ static const int Write; private: @@ -311,7 +335,7 @@ /** * Construct an empty listener. */ - inline SocketListenerBase() + inline SocketListenerBase() noexcept { } @@ -320,7 +344,7 @@ * * @param list the list */ - inline SocketListenerBase(std::initializer_list<std::pair<Socket, int>> list) + inline SocketListenerBase(std::initializer_list<std::pair<std::shared_ptr<Socket>, int>> list) { for (const auto &p : list) { set(p.first, p.second); @@ -352,7 +376,7 @@ * * @return the iterator */ - inline auto begin() const noexcept + inline SocketTable::const_iterator begin() const noexcept { return m_table.begin(); } @@ -362,7 +386,7 @@ * * @return the iterator */ - inline auto cbegin() const noexcept + inline SocketTable::const_iterator cbegin() const noexcept { return m_table.cbegin(); } @@ -372,7 +396,7 @@ * * @return the iterator */ - inline auto end() const noexcept + inline SocketTable::const_iterator end() const noexcept { return m_table.end(); } @@ -382,7 +406,7 @@ * * @return the iterator */ - inline auto cend() const noexcept + inline SocketTable::const_iterator cend() const noexcept { return m_table.cend(); } @@ -399,7 +423,7 @@ * @param flags (may be OR'ed) * @throw SocketError if the backend failed to set */ - void set(Socket sc, int flags); + void set(const std::shared_ptr<Socket> &sc, int flags); /** * Unset a socket from the listener, only the flags is removed @@ -412,7 +436,7 @@ * @param flags the flags (may be OR'ed) * @see remove */ - void unset(Socket sc, int flags); + void unset(const std::shared_ptr<Socket> &sc, int flags); /** * Remove completely the socket from the listener. @@ -421,7 +445,7 @@ * * @param sc the socket */ - inline void remove(Socket sc) + inline void remove(std::shared_ptr<Socket> sc) { unset(sc, Read | Write); } @@ -495,13 +519,13 @@ }; template <typename Backend> -void SocketListenerBase<Backend>::set(Socket sc, int flags) +void SocketListenerBase<Backend>::set(const std::shared_ptr<Socket> &sc, int flags) { /* Invalid or useless flags */ if (flags == 0 || flags > 0x3) return; - auto it = m_table.find(sc.handle()); + auto it = m_table.find(sc->handle()); /* * Do not update the table if the backend failed to add @@ -509,7 +533,7 @@ */ if (it == m_table.end()) { m_backend.set(m_table, sc, flags, true); - m_table.emplace(sc.handle(), std::make_pair(sc, flags)); + m_table.emplace(sc->handle(), std::make_pair(sc, flags)); } else { if ((flags & Read) && (it->second.second & Read)) { flags &= ~(Read); @@ -527,9 +551,9 @@ } template <typename Backend> -void SocketListenerBase<Backend>::unset(Socket sc, int flags) +void SocketListenerBase<Backend>::unset(const std::shared_ptr<Socket> &sc, int flags) { - auto it = m_table.find(sc.handle()); + auto it = m_table.find(sc->handle()); /* Invalid or useless flags */ if (flags == 0 || flags > 0x3 || it == m_table.end())
--- a/C++/modules/Socket/SocketSsl.cpp Thu Jun 18 14:24:01 2015 +0200 +++ b/C++/modules/Socket/SocketSsl.cpp Fri Jun 19 10:52:54 2015 +0200 @@ -17,7 +17,6 @@ */ #include "SocketAddress.h" -#include "SocketListener.h" #include "SocketSsl.h" namespace { @@ -37,23 +36,13 @@ return ERR_reason_error_string(error); } -inline int toDirection(int error) -{ - if (error == SocketError::WouldBlockRead) - return SocketListener::Read; - if (error == SocketError::WouldBlockWrite) - return SocketListener::Write; - - return 0; -} - } // !namespace std::mutex SocketSsl::s_sslMutex; std::atomic<bool> SocketSsl::s_sslInitialized{false}; SocketSsl::SocketSsl(Socket::Handle handle, SSL_CTX *context, SSL *ssl) - : SocketAbstractTcp(handle) + : SocketTcp(handle) , m_context(context, SSL_CTX_free) , m_ssl(ssl, SSL_free) { @@ -65,7 +54,7 @@ } SocketSsl::SocketSsl(int family, int protocol, SocketSslOptions options) - : SocketAbstractTcp(family, protocol) + : SocketTcp(family, protocol) , m_options(std::move(options)) { #if !defined(SOCKET_NO_SSL_INIT) @@ -75,9 +64,10 @@ #endif } -void SocketSsl::connect(const SocketAddress &address) +void SocketSsl::connect(const std::unique_ptr<SocketAddress> &address) { - standardConnect(address); +#if 0 + std::unique_ptr<SocketTcp> standard = SocketTcp::connect(address); // Context first auto context = SSL_CTX_new(sslMethod(m_options.method)); @@ -106,36 +96,12 @@ } m_state = SocketState::Connected; +#endif } -void SocketSsl::waitConnect(const SocketAddress &address, int timeout) +std::unique_ptr<SocketTcp> SocketSsl::accept(std::unique_ptr<SocketAddress> &info) { - try { - // Initial try - connect(address); - } catch (const SocketError &ex) { - if (ex.code() == SocketError::WouldBlockRead || ex.code() == SocketError::WouldBlockWrite) { - SocketListener listener{{*this, toDirection(ex.code())}}; - - listener.wait(timeout); - - // Second try - connect(address); - } else { - throw; - } - } -} - -SocketSsl SocketSsl::accept() -{ - SocketAddress dummy; - - return accept(dummy); -} - -SocketSsl SocketSsl::accept(SocketAddress &info) -{ +#if 0 auto client = standardAccept(info); auto context = SSL_CTX_new(sslMethod(m_options.method)); @@ -168,6 +134,7 @@ } return SocketSsl(client.handle(), context, ssl); +#endif } unsigned SocketSsl::recv(void *data, unsigned len) @@ -189,15 +156,6 @@ return nbread; } -unsigned SocketSsl::waitRecv(void *data, unsigned len, int timeout) -{ - SocketListener listener{{*this, SocketListener::Read}}; - - listener.wait(timeout); - - return recv(data, len); -} - unsigned SocketSsl::send(const void *data, unsigned len) { auto nbread = SSL_write(m_ssl.get(), data, len); @@ -216,13 +174,3 @@ return nbread; } - -unsigned SocketSsl::waitSend(const void *data, unsigned len, int timeout) -{ - SocketListener listener{{*this, SocketListener::Write}}; - - listener.wait(timeout); - - return send(data, len); -} -
--- a/C++/modules/Socket/SocketSsl.h Thu Jun 18 14:24:01 2015 +0200 +++ b/C++/modules/Socket/SocketSsl.h Fri Jun 19 10:52:54 2015 +0200 @@ -79,12 +79,11 @@ * * This class derives from SocketAbstractTcp and provide SSL support through OpenSSL. */ -class SocketSsl : public SocketAbstractTcp { -public: +class SocketSsl : public SocketTcp { +private: using ContextHandle = std::unique_ptr<SSL_CTX, void (*)(SSL_CTX *)>; using SslHandle = std::unique_ptr<SSL, void (*)(SSL *)>; -private: static std::mutex s_sslMutex; static std::atomic<bool> s_sslInitialized; @@ -93,11 +92,6 @@ SocketSslOptions m_options; public: - using SocketAbstractTcp::recv; - using SocketAbstractTcp::waitRecv; - using SocketAbstractTcp::send; - using SocketAbstractTcp::waitSend; - /** * Close OpenSSL library. */ @@ -144,38 +138,11 @@ /** * Accept a SSL TCP socket. * - * @return the socket - * @throw SocketError on error - */ - SocketSsl accept(); - - /** - * Accept a SSL TCP socket. - * * @param info the client information * @return the socket * @throw SocketError on error */ - SocketSsl accept(SocketAddress &info); - - /** - * Accept a SSL TCP socket. - * - * @param timeout the maximum timeout in milliseconds - * @return the socket - * @throw SocketError on error - */ - SocketSsl waitAccept(int timeout); - - /** - * Accept a SSL TCP socket. - * - * @param info the client information - * @param timeout the maximum timeout in milliseconds - * @return the socket - * @throw SocketError on error - */ - SocketSsl waitAccept(SocketAddress &info, int timeout); + std::unique_ptr<SocketTcp> accept(std::unique_ptr<SocketAddress> &info) override; /** * Connect to an end point. @@ -183,36 +150,27 @@ * @param address the address * @throw SocketError on error */ - void connect(const SocketAddress &address); + void connect(const std::unique_ptr<SocketAddress> &address) override; /** - * Connect to an end point. - * - * @param timeout the maximum timeout in milliseconds - * @param address the address - * @throw SocketError on error - */ - void waitConnect(const SocketAddress &address, int timeout); - - /** - * @copydoc SocketAbstractTcp::recv + * @copydoc SocketTcp::recv */ unsigned recv(void *data, unsigned length) override; /** - * @copydoc SocketAbstractTcp::recv - */ - unsigned waitRecv(void *data, unsigned length, int timeout) override; - - /** - * @copydoc SocketAbstractTcp::recv + * @copydoc SocketTcp::recv */ unsigned send(const void *data, unsigned length) override; /** - * @copydoc SocketAbstractTcp::recv + * Bring back send overloads. */ - unsigned waitSend(const void *data, unsigned length, int timeout) override; + using SocketTcp::send; + + /** + * Bring back recv overloads; + */ + using SocketTcp::recv; }; #endif // !_SOCKET_SSL_NG_H_
--- a/C++/modules/Socket/SocketTcp.cpp Thu Jun 18 14:24:01 2015 +0200 +++ b/C++/modules/Socket/SocketTcp.cpp Fri Jun 19 10:52:54 2015 +0200 @@ -17,21 +17,23 @@ */ #include "SocketAddress.h" -#include "SocketListener.h" #include "SocketTcp.h" -/* -------------------------------------------------------- - * SocketAbstractTcp - * -------------------------------------------------------- */ - -void SocketAbstractTcp::listen(int max) +void SocketTcp::listen(int max) { if (::listen(m_handle, max) == Error) { throw SocketError(SocketError::System, "listen"); } } -Socket SocketAbstractTcp::standardAccept(SocketAddress &info) +std::unique_ptr<SocketTcp> SocketTcp::accept() +{ + std::unique_ptr<SocketAddress> dummy; + + return accept(dummy); +} + +std::unique_ptr<SocketTcp> SocketTcp::accept(std::unique_ptr<SocketAddress> &info) { Socket::Handle handle; @@ -60,19 +62,20 @@ #endif } - info = SocketAddress(address, addrlen); + // TODO: add it + //info = SocketAddress(address, addrlen); - return Socket(handle); + return std::make_unique<SocketTcp>(handle); } -void SocketAbstractTcp::standardConnect(const SocketAddress &address) +void SocketTcp::connect(const std::unique_ptr<SocketAddress> &address) { if (m_state == SocketState::Connected) { return; } - auto &sa = address.address(); - auto addrlen = address.length(); + auto &sa = address->address(); + auto addrlen = address->length(); if (::connect(m_handle, reinterpret_cast<const sockaddr *>(&sa), addrlen) == Error) { /* @@ -99,72 +102,6 @@ m_state = SocketState::Connected; } -/* -------------------------------------------------------- - * SocketTcp - * -------------------------------------------------------- */ - -SocketTcp SocketTcp::accept() -{ - SocketAddress dummy; - - return accept(dummy); -} - -SocketTcp SocketTcp::accept(SocketAddress &info) -{ - return standardAccept(info); -} - -void SocketTcp::connect(const SocketAddress &address) -{ - return standardConnect(address); -} - -void SocketTcp::waitConnect(const SocketAddress &address, int timeout) -{ - if (m_state == SocketState::Connected) { - return; - } - - // Initial try - try { - connect(address); - } catch (const SocketError &ex) { - if (ex.code() == SocketError::WouldBlockWrite) { - SocketListener listener{{*this, SocketListener::Write}}; - - listener.wait(timeout); - - // Socket is writable? Check if there is an error - int error = get<int>(SOL_SOCKET, SO_ERROR); - - if (error) { - throw SocketError(SocketError::System, "connect", error); - } - } else { - throw; - } - } - - m_state = SocketState::Connected; -} - -SocketTcp SocketTcp::waitAccept(int timeout) -{ - SocketAddress dummy; - - return waitAccept(dummy, timeout); -} - -SocketTcp SocketTcp::waitAccept(SocketAddress &info, int timeout) -{ - SocketListener listener{{*this, SocketListener::Read}}; - - listener.wait(timeout); - - return accept(info); -} - unsigned SocketTcp::recv(void *data, unsigned dataLen) { int nbread; @@ -193,15 +130,6 @@ return (unsigned)nbread; } -unsigned SocketTcp::waitRecv(void *data, unsigned length, int timeout) -{ - SocketListener listener{{*this, SocketListener::Read}}; - - listener.wait(timeout); - - return recv(data, length); -} - unsigned SocketTcp::send(const void *data, unsigned length) { int nbsent; @@ -227,12 +155,3 @@ return (unsigned)nbsent; } - -unsigned SocketTcp::waitSend(const void *data, unsigned length, int timeout) -{ - SocketListener listener{{*this, SocketListener::Write}}; - - listener.wait(timeout); - - return send(data, length); -}
--- a/C++/modules/Socket/SocketTcp.h Thu Jun 18 14:24:01 2015 +0200 +++ b/C++/modules/Socket/SocketTcp.h Fri Jun 19 10:52:54 2015 +0200 @@ -19,34 +19,20 @@ #ifndef _SOCKET_TCP_NG_H_ #define _SOCKET_TCP_NG_H_ +#include <memory> + #include "Socket.h" /** - * @class SocketAbstractTcp - * @brief Base class for TCP sockets - * - * This abstract class provides standard TCP functions for both clear - * and SSL implementation. - * - * It does not contain default accept() and connect() because they varies too - * much between standard and SSL. Also, the accept() function return different - * types. + * @class SocketTcp + * @brief End-user class for TCP sockets */ -class SocketAbstractTcp : public Socket { -protected: - Socket standardAccept(SocketAddress &address); - void standardConnect(const SocketAddress &address); - +class SocketTcp : public Socket { public: /** - * Construct an abstract socket from an already made socket. - * - * @param s the socket + * Inherited constructors. */ - inline SocketAbstractTcp(Socket s) - : Socket(s) - { - } + using Socket::Socket; /** * Construct a standard TCP socket. The type is automatically @@ -56,7 +42,7 @@ * @param protocol the protocol * @throw SocketError on error */ - inline SocketAbstractTcp(int domain, int protocol) + inline SocketTcp(int domain, int protocol) : Socket(domain, SOCK_STREAM, protocol) { } @@ -66,7 +52,41 @@ * * @param max the maximum number */ - void listen(int max = 128); + virtual void listen(int max = 128); + + /** + * Accept a clear TCP socket without its address + * + * @return the socket + * @throw SocketError on error + */ + std::unique_ptr<SocketTcp> accept(); + + /** + * Accept a clear TCP socket. + * + * @param info the client information + * @return the socket + * @throw SocketError on error + */ + virtual std::unique_ptr<SocketTcp> accept(std::unique_ptr<SocketAddress> &info); + + /** + * Connect to an end point. + * + * @param address the address + * @throw SocketError on error + */ + virtual void connect(const std::unique_ptr<SocketAddress> &address); + + /** + * Receive some data. + * + * @param data the destination buffer + * @param length the buffer length + * @throw SocketError on error + */ + virtual unsigned recv(void *data, unsigned length); /** * Overloaded function. @@ -87,23 +107,12 @@ } /** - * Overloaded function. + * Send some data. * - * @param count the number of bytes to receive - * @param timeout the maximum timeout in milliseconds - * @return the string - * @throw SocketError on error + * @param data the data to send + * @param length the data length */ - inline std::string waitRecv(unsigned count, int timeout) - { - std::string result; - - result.resize(count); - auto n = waitRecv(const_cast<char *>(result.data()), count, timeout); - result.resize(n); - - return result; - } + virtual unsigned send(const void *data, unsigned length); /** * Overloaded function. @@ -116,146 +125,6 @@ { return send(data.c_str(), data.size()); } - - /** - * Overloaded function. - * - * @param data the string to send - * @param timeout the maximum timeout in milliseconds - * @return the number of bytes sent - * @throw SocketError on error - */ - inline unsigned waitSend(const std::string &data, int timeout) - { - return waitSend(data.c_str(), data.size(), timeout); - } - - /** - * Receive data. - * - * @param data the destination buffer - * @param length the buffer length - * @return the number of bytes received - * @throw SocketError on error - */ - virtual unsigned recv(void *data, unsigned length) = 0; - - /** - * Receive data. - * - * @param data the destination buffer - * @param length the buffer length - * @param timeout the maximum timeout in milliseconds - * @return the number of bytes received - * @throw SocketError on error - */ - virtual unsigned waitRecv(void *data, unsigned length, int timeout) = 0; - - /** - * Send data. - * - * @param data the buffer - * @param length the buffer length - * @return the number of bytes sent - * @throw SocketError on error - */ - virtual unsigned send(const void *data, unsigned length) = 0; - - /** - * Send data. - * - * @param data the buffer - * @param length the buffer length - * @return the number of bytes sent - * @throw SocketError on error - */ - virtual unsigned waitSend(const void *data, unsigned length, int timeout) = 0; -}; - -/** - * @class SocketTcp - * @brief End-user class for TCP sockets - */ -class SocketTcp : public SocketAbstractTcp { -public: - using SocketAbstractTcp::SocketAbstractTcp; - using SocketAbstractTcp::recv; - using SocketAbstractTcp::waitRecv; - using SocketAbstractTcp::send; - using SocketAbstractTcp::waitSend; - - /** - * Accept a clear TCP socket. - * - * @return the socket - * @throw SocketError on error - */ - SocketTcp accept(); - - /** - * Accept a clear TCP socket. - * - * @param info the client information - * @return the socket - * @throw SocketError on error - */ - SocketTcp accept(SocketAddress &info); - - /** - * Accept a clear TCP socket. - * - * @param timeout the maximum timeout in milliseconds - * @return the socket - * @throw SocketError on error - */ - SocketTcp waitAccept(int timeout); - - /** - * Accept a clear TCP socket. - * - * @param info the client information - * @param timeout the maximum timeout in milliseconds - * @return the socket - * @throw SocketError on error - */ - SocketTcp waitAccept(SocketAddress &info, int timeout); - - /** - * Connect to an end point. - * - * @param address the address - * @throw SocketError on error - */ - void connect(const SocketAddress &address); - - /** - * Connect to an end point. - * - * @param timeout the maximum timeout in milliseconds - * @param address the address - * @throw SocketError on error - */ - void waitConnect(const SocketAddress &address, int timeout); - - /** - * @copydoc SocketAbstractTcp::recv - */ - unsigned recv(void *data, unsigned length) override; - - /** - * @copydoc SocketAbstractTcp::waitRecv - */ - unsigned waitRecv(void *data, unsigned length, int timeout) override; - - /** - * @copydoc SocketAbstractTcp::send - */ - unsigned send(const void *data, unsigned length) override; - - /** - * @copydoc SocketAbstractTcp::waitSend - */ - unsigned waitSend(const void *data, unsigned length, int timeout) override; }; #endif // !_SOCKET_TCP_NG_H_
--- a/C++/tests/Socket/main.cpp Thu Jun 18 14:24:01 2015 +0200 +++ b/C++/tests/Socket/main.cpp Fri Jun 19 10:52:54 2015 +0200 @@ -76,7 +76,7 @@ std::this_thread::sleep_for(100ms); m_tclient = std::thread([this] () { - m_client.connect(Internet("127.0.0.1", 16000, AF_INET)); + m_client.connect(std::make_unique<Internet>("127.0.0.1", 16000, AF_INET)); ASSERT_EQ(SocketState::Connected, m_client.state()); @@ -91,25 +91,20 @@ m_server.listen(); auto client = m_server.accept(); - auto msg = client.recv(512); + auto msg = client->recv(512); ASSERT_EQ("hello world", msg); - client.send(msg); - client.close(); - - m_server.close(); + client->send(msg); }); std::this_thread::sleep_for(100ms); m_tclient = std::thread([this] () { - m_client.connect(Internet("127.0.0.1", 16000, AF_INET)); + m_client.connect(std::make_unique<Internet>("127.0.0.1", 16000, AF_INET)); m_client.send("hello world"); ASSERT_EQ("hello world", m_client.recv(512)); - - m_client.close(); }); } @@ -178,25 +173,25 @@ bool m_added{false}; int m_flags{0}; - inline void set(const SocketTable &, Socket sc, int flags, bool add) noexcept + inline void set(const SocketTable &, const std::shared_ptr<Socket> &sc, int flags, bool add) noexcept { m_callcount ++; m_added = add; m_flags |= flags; } - inline void unset(const SocketTable &, Socket, int, bool) noexcept {} + inline void unset(const SocketTable &, const std::shared_ptr<Socket> &, int, bool) noexcept {} std::vector<SocketStatus> wait(const SocketTable &table, int ms) {} }; class TestBackendSetFail { public: - inline void set(const SocketTable &, Socket, int, bool) + inline void set(const SocketTable &, const std::shared_ptr<Socket> &, int, bool) { throw "fail"; } - inline void unset(const SocketTable &, Socket, int, bool) noexcept {} + inline void unset(const SocketTable &, const std::shared_ptr<Socket> &, int, bool) noexcept {} std::vector<SocketStatus> wait(const SocketTable &table, int ms) {} }; @@ -204,7 +199,7 @@ { SocketListenerBase<TestBackendSet> listener; - listener.set(Socket(0), SocketListener::Read); + listener.set(std::make_shared<Socket>(0), SocketListener::Read); ASSERT_EQ(1U, listener.size()); ASSERT_EQ(1, listener.backend().m_callcount); @@ -215,7 +210,7 @@ TEST(ListenerSet, readThenWrite) { SocketListenerBase<TestBackendSet> listener; - Socket sc(0); + std::shared_ptr<Socket> sc = std::make_shared<Socket>(0); listener.set(sc, SocketListener::Read); listener.set(sc, SocketListener::Write); @@ -229,7 +224,7 @@ TEST(ListenerSet, allOneShot) { SocketListenerBase<TestBackendSet> listener; - Socket sc(0); + std::shared_ptr<Socket> sc = std::make_shared<Socket>(0); listener.set(sc, SocketListener::Read | SocketListener::Write); @@ -242,7 +237,7 @@ TEST(ListenerSet, readTwice) { SocketListenerBase<TestBackendSet> listener; - Socket sc(0); + std::shared_ptr<Socket> sc = std::make_shared<Socket>(0); listener.set(sc, SocketListener::Read); listener.set(sc, SocketListener::Read); @@ -258,7 +253,7 @@ SocketListenerBase<TestBackendSetFail> listener; try { - listener.set(Socket(0), SocketListener::Read); + listener.set(std::make_shared<Socket>(0), SocketListener::Read); FAIL() << "exception expected"; } catch (...) { } @@ -277,13 +272,13 @@ int m_flags{0}; bool m_removal{false}; - inline void set(const SocketTable &, Socket, int flags, bool) noexcept + inline void set(const SocketTable &, const std::shared_ptr<Socket> &, int flags, bool) noexcept { m_isset = true; m_flags |= flags; } - inline void unset(const SocketTable &, Socket, int flags, bool remove) noexcept + inline void unset(const SocketTable &, const std::shared_ptr<Socket> &, int flags, bool remove) noexcept { m_isunset = true; m_flags &= ~(flags); @@ -295,9 +290,9 @@ class TestBackendUnsetFail { public: - inline void set(const SocketTable &, Socket, int, bool) noexcept {} + inline void set(const SocketTable &, const std::shared_ptr<Socket> &, int, bool) noexcept {} - inline void unset(const SocketTable &, Socket, int, bool) + inline void unset(const SocketTable &, const std::shared_ptr<Socket> &, int, bool) { throw "fail"; } @@ -308,7 +303,7 @@ TEST(ListenerUnsetRemove, unset) { SocketListenerBase<TestBackendUnset> listener; - Socket sc(0); + std::shared_ptr<Socket> sc = std::make_shared<Socket>(0); listener.set(sc, SocketListener::Read); listener.unset(sc, SocketListener::Read); @@ -323,7 +318,7 @@ TEST(ListenerUnsetRemove, unsetOne) { SocketListenerBase<TestBackendUnset> listener; - Socket sc(0); + std::shared_ptr<Socket> sc = std::make_shared<Socket>(0); listener.set(sc, SocketListener::Read | SocketListener::Write); listener.unset(sc, SocketListener::Read); @@ -338,7 +333,7 @@ TEST(ListenerUnsetRemove, unsetAll) { SocketListenerBase<TestBackendUnset> listener; - Socket sc(0); + std::shared_ptr<Socket> sc = std::make_shared<Socket>(0); listener.set(sc, SocketListener::Read | SocketListener::Write); listener.unset(sc, SocketListener::Read); @@ -354,7 +349,7 @@ TEST(ListenerUnsetRemove, remove) { SocketListenerBase<TestBackendUnset> listener; - Socket sc(0); + std::shared_ptr<Socket> sc = std::make_shared<Socket>(0); listener.set(sc, SocketListener::Read | SocketListener::Write); listener.remove(sc); @@ -369,7 +364,7 @@ TEST(ListenerUnsetRemove, failure) { SocketListenerBase<TestBackendUnsetFail> listener; - Socket sc(0); + std::shared_ptr<Socket> sc = std::make_shared<Socket>(0); listener.set(sc, SocketListener::Read | SocketListener::Write); @@ -390,18 +385,20 @@ class ListenerTest : public testing::Test { protected: SocketListenerBase<backend::Select> m_listener; - SocketTcp m_masterTcp{AF_INET, 0}; - SocketTcp m_clientTcp{AF_INET, 0}; + std::shared_ptr<SocketTcp> m_masterTcp; + std::shared_ptr<SocketTcp> m_clientTcp; std::thread m_tserver; std::thread m_tclient; public: ListenerTest() + : m_masterTcp{std::make_shared<SocketTcp>(AF_INET, 0)} + , m_clientTcp{std::make_shared<SocketTcp>(AF_INET, 0)} { - m_masterTcp.set(SOL_SOCKET, SO_REUSEADDR, 1); - m_masterTcp.bind(Internet("*", 16000, AF_INET)); - m_masterTcp.listen(); + m_masterTcp->set(SOL_SOCKET, SO_REUSEADDR, 1); + m_masterTcp->bind(Internet("*", 16000, AF_INET)); + m_masterTcp->listen(); } ~ListenerTest() @@ -421,8 +418,8 @@ try { m_listener.set(m_masterTcp, SocketListener::Read); m_listener.wait(); - m_masterTcp.accept(); - m_masterTcp.close(); + m_masterTcp->accept(); + m_masterTcp->close(); } catch (const std::exception &ex) { FAIL() << ex.what(); } @@ -431,7 +428,7 @@ std::this_thread::sleep_for(100ms); m_tclient = std::thread([this] () { - m_clientTcp.connect(Internet("127.0.0.1", 16000, AF_INET)); + m_clientTcp->connect(std::make_unique<Internet>("127.0.0.1", 16000, AF_INET)); }); } @@ -442,11 +439,9 @@ m_listener.set(m_masterTcp, SocketListener::Read); m_listener.wait(); - auto sc = m_masterTcp.accept(); + auto sc = m_masterTcp->accept(); - ASSERT_EQ("hello", sc.recv(512)); - - m_masterTcp.close(); + ASSERT_EQ("hello", sc->recv(512)); } catch (const std::exception &ex) { FAIL() << ex.what(); } @@ -455,8 +450,8 @@ std::this_thread::sleep_for(100ms); m_tclient = std::thread([this] () { - m_clientTcp.connect(Internet("127.0.0.1", 16000, AF_INET)); - m_clientTcp.send("hello"); + m_clientTcp->connect(std::make_unique<Internet>("127.0.0.1", 16000, AF_INET)); + m_clientTcp->send("hello"); }); } @@ -487,54 +482,6 @@ } }; -TEST_F(NonBlockingConnectTest, success) -{ - m_server.set(SOL_SOCKET, SO_REUSEADDR, 1); - m_server.bind(Internet("*", 16000, AF_INET)); - m_server.listen(); - - m_tserver = std::thread([this] () { - SocketTcp client = m_server.accept(); - - std::this_thread::sleep_for(100ms); - - m_server.close(); - client.close(); - }); - - std::this_thread::sleep_for(100ms); - - m_tclient = std::thread([this] () { - try { - m_client.waitConnect(Internet("127.0.0.1", 16000, AF_INET), 3000); - } catch (const SocketError &error) { - FAIL() << error.function() << ": " << error.what(); - } - - ASSERT_EQ(SocketState::Connected, m_client.state()); - - m_client.close(); - }); -} - -TEST_F(NonBlockingConnectTest, fail) -{ - /* - * /!\ If you find a way to test this locally please tell me /!\ - */ - m_tclient = std::thread([this] () { - try { - m_client.waitConnect(Internet("google.fr", 9000, AF_INET), 100); - - FAIL() << "Expected exception, got success"; - } catch (const SocketError &error) { - ASSERT_EQ(SocketError::Timeout, error.code()); - } - - m_client.close(); - }); -} - /* -------------------------------------------------------- * TCP accept * -------------------------------------------------------- */ @@ -564,62 +511,6 @@ } }; -TEST_F(TcpAcceptTest, blockingWaitSuccess) -{ - m_tserver = std::thread([this] () { - try { - m_server.waitAccept(3000).close(); - } catch (const SocketError &error) { - FAIL() << error.what(); - } - - m_server.close(); - }); - - std::this_thread::sleep_for(100ms); - - m_tclient = std::thread([this] () { - m_client.connect(Internet("127.0.0.1", 16000, AF_INET)); - m_client.close(); - }); -} - -TEST_F(TcpAcceptTest, nonBlockingWaitSuccess) -{ - m_tserver = std::thread([this] () { - try { - m_server.setBlockMode(false); - m_server.waitAccept(3000).close(); - } catch (const SocketError &error) { - FAIL() << error.what(); - } - - m_server.close(); - }); - - std::this_thread::sleep_for(100ms); - - m_tclient = std::thread([this] () { - m_client.connect(Internet("127.0.0.1", 16000, AF_INET)); - m_client.close(); - }); -} - -TEST_F(TcpAcceptTest, nonBlockingWaitFail) -{ - // No client, no accept - try { - m_server.setBlockMode(false); - m_server.waitAccept(100).close(); - - FAIL() << "Expected exception, got success"; - } catch (const SocketError &error) { - ASSERT_EQ(SocketError::Timeout, error.code()); - } - - m_server.close(); -} - /* -------------------------------------------------------- * TCP recv * -------------------------------------------------------- */ @@ -652,60 +543,15 @@ TEST_F(TcpRecvTest, blockingSuccess) { m_tserver = std::thread([this] () { - SocketTcp client = m_server.accept(); + auto client = m_server.accept(); - ASSERT_EQ("hello", client.recv(32)); - - client.close(); - m_server.close(); + ASSERT_EQ("hello", client->recv(32)); }); std::this_thread::sleep_for(100ms); m_tclient = std::thread([this] () { - m_client.connect(Internet("127.0.0.1", 16000, AF_INET)); - m_client.send("hello"); - m_client.close(); - }); -} - -TEST_F(TcpRecvTest, blockingWaitSuccess) -{ - m_tserver = std::thread([this] () { - SocketTcp client = m_server.accept(); - - ASSERT_EQ("hello", client.waitRecv(32, 3000)); - - client.close(); - m_server.close(); - }); - - std::this_thread::sleep_for(100ms); - - m_tclient = std::thread([this] () { - m_client.connect(Internet("127.0.0.1", 16000, AF_INET)); - m_client.send("hello"); - m_client.close(); - }); -} - -TEST_F(TcpRecvTest, nonBlockingWaitSuccess) -{ - m_tserver = std::thread([this] () { - SocketTcp client = m_server.accept(); - - client.setBlockMode(false); - - ASSERT_EQ("hello", client.waitRecv(32, 3000)); - - client.close(); - m_server.close(); - }); - - std::this_thread::sleep_for(100ms); - - m_tclient = std::thread([this] () { - m_client.connect(Internet("127.0.0.1", 16000, AF_INET)); + m_client.connect(std::make_unique<Internet>("127.0.0.1", 16000, AF_INET)); m_client.send("hello"); m_client.close(); }); @@ -723,7 +569,7 @@ TEST_F(SslTest, connect) { try { - client.connect(Internet("google.fr", 443, AF_INET)); + client.connect(std::make_unique<Internet>("google.fr", 443, AF_INET)); client.close(); } catch (const SocketError &error) { FAIL() << error.what(); @@ -733,7 +579,7 @@ TEST_F(SslTest, recv) { try { - client.connect(Internet("google.fr", 443, AF_INET)); + client.connect(std::make_unique<Internet>("google.fr", 443, AF_INET)); client.send("GET / HTTP/1.0\r\n\r\n"); std::string msg = client.recv(512);