Mercurial > code
changeset 319:cba77da58496
* Finalize SocketSsl
* Add documentation
author | David Demelier <markand@malikania.fr> |
---|---|
date | Sun, 08 Mar 2015 11:04:01 +0100 |
parents | 68ae6d7dea1f |
children | 4f282297625b |
files | C++/Socket.cpp C++/Socket.h C++/SocketAddress.h C++/SocketListener.h C++/SocketSsl.cpp C++/SocketSsl.h C++/SocketTcp.cpp C++/SocketTcp.h C++/SocketUdp.h C++/Tests/Sockets/CMakeLists.txt C++/Tests/Sockets/main.cpp |
diffstat | 11 files changed, 642 insertions(+), 207 deletions(-) [+] |
line wrap: on
line diff
--- a/C++/Socket.cpp Fri Mar 06 17:30:16 2015 +0100 +++ b/C++/Socket.cpp Sun Mar 08 11:04:01 2015 +0100 @@ -97,6 +97,11 @@ * Socket class * -------------------------------------------------------- */ +#if defined(_WIN32) +std::mutex Socket::s_mutex; +std::atomic<bool> Socket::s_initialized{false}; +#endif + Socket::Socket(int domain, int type, int protocol) { #if defined(_WIN32) && !defined(SOCKET_NO_WSA_INIT) @@ -108,6 +113,8 @@ if (m_handle == Invalid) throw SocketError(SocketError::System, "socket"); + + m_state = SocketState::Opened; } void Socket::bind(const SocketAddress &address) @@ -164,48 +171,3 @@ { return s1.handle() < s2.handle(); } - - - - - - - - - - - - - -#if 0 -void SocketStandard::tryConnect(Socket &s, const SocketAddress &address, int timeout) -{ - -} - -unsigned SocketStandard::tryRecvfrom(Socket &s, void *data, unsigned len, SocketAddress &info, int timeout) -{ - SocketListener listener{{s, Read}}; - - listener.select(timeout); - - return recvfrom(s, data, len, info); -} - -unsigned SocketStandard::trySendto(Socket &s, const void *data, unsigned len, const SocketAddress &info, int timeout) -{ - SocketListener listener{{s, Write}}; - - listener.select(timeout); - - return sendto(s, data, len, info); -} - -Socket Socket::tryAccept(int timeout) -{ - SocketAddress dummy; - - return tryAccept(dummy, timeout); -} - -#endif
--- a/C++/Socket.h Fri Mar 06 17:30:16 2015 +0100 +++ b/C++/Socket.h Sun Mar 08 11:04:01 2015 +0100 @@ -16,8 +16,8 @@ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */ -#ifndef _SOCKET_NG_3_H_ -#define _SOCKET_NG_3_H_ +#ifndef _SOCKET_NG_H_ +#define _SOCKET_NG_H_ /** * @file Socket.h @@ -78,7 +78,6 @@ enum Code { WouldBlockRead, ///!< The operation would block for reading WouldBlockWrite, ///!< The operation would block for writing - InProgress, ///!< The operation is in progress Timeout, ///!< The action did timeout System ///!< There is a system error }; @@ -113,16 +112,31 @@ */ SocketError(Code code, std::string function, std::string error); + /** + * Get which function has triggered the error. + * + * @return the function name (e.g connect) + */ inline const std::string &function() const noexcept { return m_function; } + /** + * The error code. + * + * @return the code + */ inline Code code() const noexcept { return m_code; } + /** + * Get the error (only the error content). + * + * @return the error + */ const char *what() const noexcept { return m_error.c_str(); @@ -191,8 +205,8 @@ */ #if defined(_WIN32) private: - static std::mutex s_mutex{}; - static std::atomic<bool> s_initialized{false}; + static std::mutex s_mutex; + static std::atomic<bool> s_initialized; public: static inline void finish() noexcept @@ -238,11 +252,6 @@ Handle m_handle; SocketState m_state{SocketState::Opened}; - inline Socket(Handle handle) - : m_handle(handle) - { - } - public: /** * Get the last socket system error. The error is set from errno or from @@ -261,11 +270,23 @@ static std::string syserror(int errn); /** + * Construct a socket with an already created descriptor. + * + * @param handle the native descriptor + */ + inline Socket(Handle handle) + : m_handle(handle) + , m_state(SocketState::Opened) + { + } + + /** * Create a socket handle. * * @param domain the domain AF_* * @param type the type SOCK_* * @param protocol the protocol + * @throw SocketError on failures */ Socket(int domain, int type, int protocol); @@ -280,7 +301,7 @@ * @param level the setting level * @param name the name * @param arg the value - * @throw error::Failure on error + * @throw SocketError on error */ template <typename Argument> inline void set(int level, int name, const Argument &arg) @@ -298,7 +319,7 @@ * * @param level the setting level * @param name the name - * @throw error::Failure on error + * @throw SocketError on error */ template <typename Argument> inline Argument get(int level, int name) @@ -329,6 +350,11 @@ return m_handle; } + /** + * Get the socket state. + * + * @return + */ inline SocketState state() const noexcept { return m_state; @@ -357,8 +383,22 @@ virtual void close(); }; +/** + * Compare two sockets. + * + * @param s1 the first socket + * @param s2 the second socket + * @return true if they equals + */ bool operator==(const Socket &s1, const Socket &s2); -bool operator<(const Socket &s, const Socket &s2); +/** + * Compare two sockets, ideal for putting in a std::map. + * + * @param s1 the first socket + * @param s2 the second socket + * @return true if s1 < s2 + */ +bool operator<(const Socket &s1, const Socket &s2); -#endif // !_SOCKET_NG_3_H_ +#endif // !_SOCKET_NG_H_
--- a/C++/SocketAddress.h Fri Mar 06 17:30:16 2015 +0100 +++ b/C++/SocketAddress.h Sun Mar 08 11:04:01 2015 +0100 @@ -16,8 +16,8 @@ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */ -#ifndef _SOCKET_ADDRESS_H_ -#define _SOCKET_ADDRESS_H_ +#ifndef _SOCKET_ADDRESS_NG_H_ +#define _SOCKET_ADDRESS_NG_H_ #include <string> @@ -140,4 +140,4 @@ } // !address -#endif // !_SOCKET_ADDRESS_H_ +#endif // !_SOCKET_ADDRESS_NG_H_
--- a/C++/SocketListener.h Fri Mar 06 17:30:16 2015 +0100 +++ b/C++/SocketListener.h Sun Mar 08 11:04:01 2015 +0100 @@ -16,8 +16,8 @@ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */ -#ifndef _SOCKET_LISTENER_H_ -#define _SOCKET_LISTENER_H_ +#ifndef _SOCKET_LISTENER_NG_H_ +#define _SOCKET_LISTENER_NG_H_ #include <chrono> #include <functional> @@ -128,6 +128,11 @@ * * Convenient wrapper around the select() system call. * + * This class is implemented using a bridge pattern to allow different uses + * of listener implementation. + * + * Currently, poll and select() are available. + * * This wrappers takes abstract sockets as non-const reference but it does not * own them so you must take care that sockets are still alive until the * SocketListener is destroyed. @@ -173,54 +178,117 @@ */ SocketListener(SocketMethod method = PreferredMethod); + /** + * Create a listener from a list of sockets. + * + * @param list the list + */ SocketListener(std::initializer_list<std::pair<std::reference_wrapper<Socket>, int>> list); + /** + * Return an iterator to the beginning. + * + * @return the iterator + */ inline auto begin() noexcept { return m_map.begin(); } + /** + * Overloaded function. + * + * @return the iterator + */ inline auto begin() const noexcept { return m_map.begin(); } + /** + * Overloaded function. + * + * @return the iterator + */ inline auto cbegin() const noexcept { return m_map.cbegin(); } + /** + * Return an iterator to the end. + * + * @return the iterator + */ inline auto end() noexcept { return m_map.end(); } + /** + * Overloaded function. + * + * @return the iterator + */ inline auto end() const noexcept { return m_map.end(); } + /** + * Overloaded function. + * + * @return the iterator + */ inline auto cend() const noexcept { return m_map.cend(); } - void set(Socket &sc, int flags); + /** + * Add a socket to the listener. + * + * @param sc the socket + * @param direction (may be OR'ed) + */ + void set(Socket &sc, int direction); - void unset(Socket &sc, int flags) noexcept; + /** + * Unset a socket from the listener, only the direction is removed + * unless the two directions are requested. + * + * For example, if you added a socket for both reading and writing, + * unsetting the write direction will keep the socket for reading. + * + * @param sc the socket + * @param direction the direction (may be OR'ed) + * @see remove + */ + void unset(Socket &sc, int direction) noexcept; + /** + * Remove completely the socket from the listener. + * + * @param sc the socket + */ inline void remove(Socket &sc) noexcept { m_map.erase(sc); m_interface->remove(sc); } + /** + * Remove all sockets. + */ inline void clear() noexcept { m_map.clear(); m_interface->clear(); } + /** + * Get the number of sockets in the listener. + */ unsigned size() const noexcept { return m_map.size(); @@ -276,4 +344,4 @@ } }; -#endif // !_SOCKET_LISTENER_H_ +#endif // !_SOCKET_LISTENER_NG_H_
--- a/C++/SocketSsl.cpp Fri Mar 06 17:30:16 2015 +0100 +++ b/C++/SocketSsl.cpp Sun Mar 08 11:04:01 2015 +0100 @@ -16,14 +16,13 @@ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */ +#include "SocketAddress.h" #include "SocketListener.h" #include "SocketSsl.h" -using namespace direction; - namespace { -const SSL_METHOD *method(int mflags) +const SSL_METHOD *sslMethod(int mflags) { if (mflags & SocketSslOptions::All) return SSLv23_method(); @@ -40,69 +39,105 @@ return ERR_reason_error_string(error); } -} // !namespace +inline int toDirection(int error) +{ + if (error == SocketError::WouldBlockRead) + return SocketListener::Read; + if (error == SocketError::WouldBlockWrite) + return SocketListener::Write; -SocketSslInterface::SocketSslInterface(SSL_CTX *context, SSL *ssl, SocketSslOptions options) - : SocketStandard() - , m_context(context, SSL_CTX_free) - , m_ssl(ssl, SSL_free) - , m_options(std::move(options)) -{ + return 0; } -SocketSslInterface::SocketSslInterface(SocketSslOptions options) - : SocketStandard() - , m_options(std::move(options)) +} // !namespace + +std::mutex SocketSsl::s_sslMutex; +std::atomic<bool> SocketSsl::s_sslInitialized{false}; + +SocketSsl::SocketSsl(Socket::Handle handle, SSL_CTX *context, SSL *ssl) + : SocketAbstractTcp(handle) + , m_context(context, SSL_CTX_free) + , m_ssl(ssl, SSL_free) { +#if !defined(SOCKET_NO_SSL_INIT) + if (!s_sslInitialized) + sslInitialize(); +#endif } -void SocketSslInterface::connect(Socket &s, const SocketAddress &address) +SocketSsl::SocketSsl(int family, int protocol, SocketSslOptions options) + : SocketAbstractTcp(family, protocol) + , m_options(std::move(options)) { - SocketStandard::connect(s, address); +#if !defined(SOCKET_NO_SSL_INIT) + if (!s_sslInitialized) + sslInitialize(); +#endif +} + +void SocketSsl::connect(const SocketAddress &address) +{ + standardConnect(address); // Context first - auto context = SSL_CTX_new(method(m_options.method)); + auto context = SSL_CTX_new(sslMethod(m_options.method)); - m_context = SslContext(context, SSL_CTX_free); + m_context = ContextHandle(context, SSL_CTX_free); // SSL object then auto ssl = SSL_new(context); - m_ssl = Ssl(ssl, SSL_free); + m_ssl = SslHandle(ssl, SSL_free); - SSL_set_fd(ssl, s.handle()); + SSL_set_fd(ssl, m_handle); auto ret = SSL_connect(ssl); if (ret <= 0) { auto error = SSL_get_error(ssl, ret); - if (error == SSL_ERROR_WANT_READ || error == SSL_ERROR_WANT_WRITE) - throw error::InProgress("connect", sslError(error), error, error); + if (error == SSL_ERROR_WANT_READ) { + throw SocketError(SocketError::WouldBlockRead, "connect", "Operation in progress"); + } else if (error == SSL_ERROR_WANT_WRITE) { + throw SocketError(SocketError::WouldBlockWrite, "connect", "Operation in progress"); + } else { + throw SocketError(SocketError::System, "connect", sslError(error)); + } + } - throw error::Error("accept", sslError(error), error); - } + m_state = SocketState::Connected; } -void SocketSslInterface::tryConnect(Socket &s, const SocketAddress &address, int timeout) +void SocketSsl::waitConnect(const SocketAddress &address, int timeout) { try { // Initial try - connect(s, address); - } catch (const error::InProgress &ipe) { - SocketListener listener{{s, ipe.direction()}}; + connect(address); + } catch (const SocketError &ex) { + if (ex.code() == SocketError::WouldBlockRead || ex.code() == SocketError::WouldBlockWrite) { + SocketListener listener{{*this, toDirection(ex.code())}}; - listener.select(timeout); + listener.select(timeout); - // Second try - connect(s, address); + // Second try + connect(address); + } else { + throw; + } } } -Socket SocketSslInterface::accept(Socket &s, SocketAddress &info) +SocketSsl SocketSsl::accept() { - auto client = SocketStandard::accept(s, info); - auto context = SSL_CTX_new(method(m_options.method)); + SocketAddress dummy; + + return accept(dummy); +} + +SocketSsl SocketSsl::accept(SocketAddress &info) +{ + auto client = standardAccept(info); + auto context = SSL_CTX_new(sslMethod(m_options.method)); if (m_options.certificate.size() > 0) SSL_CTX_use_certificate_file(context, m_options.certificate.c_str(), SSL_FILETYPE_PEM); @@ -110,7 +145,7 @@ SSL_CTX_use_PrivateKey_file(context, m_options.privateKey.c_str(), SSL_FILETYPE_PEM); if (m_options.verify && !SSL_CTX_check_private_key(context)) { client.close(); - throw error::Error("accept", "certificate failure", 0); + throw SocketError(SocketError::System, "accept", "certificate failure"); } // SSL object @@ -123,87 +158,71 @@ if (ret <= 0) { auto error = SSL_get_error(ssl, ret); - if (error == SSL_ERROR_WANT_READ || error == SSL_ERROR_WANT_WRITE) - throw error::InProgress("accept", sslError(error), error, error); - - throw error::Error("accept", sslError(error), error); + if (error == SSL_ERROR_WANT_READ) { + throw SocketError(SocketError::WouldBlockRead, "accept", "Operation would block"); + } else if (error == SSL_ERROR_WANT_WRITE) { + throw SocketError(SocketError::WouldBlockWrite, "accept", "Operation would block"); + } else { + throw SocketError(SocketError::System, "accept", sslError(error)); + } } - return SocketSsl{client.handle(), std::make_shared<SocketSslInterface>(context, ssl)}; + return SocketSsl(client.handle(), context, ssl); } -unsigned SocketSslInterface::recv(Socket &, void *data, unsigned len) +unsigned SocketSsl::recv(void *data, unsigned len) { auto nbread = SSL_read(m_ssl.get(), data, len); if (nbread <= 0) { auto error = SSL_get_error(m_ssl.get(), nbread); - if (error == SSL_ERROR_WANT_READ || error == SSL_ERROR_WANT_WRITE) - throw error::InProgress("accept", sslError(error), error, error); - - throw error::Error("recv", sslError(error), error); + if (error == SSL_ERROR_WANT_READ) { + throw SocketError(SocketError::WouldBlockRead, "recv", "Operation would block"); + } else if (error == SSL_ERROR_WANT_WRITE) { + throw SocketError(SocketError::WouldBlockWrite, "recv", "Operation would block"); + } else { + throw SocketError(SocketError::System, "recv", sslError(error)); + } } return nbread; } -unsigned SocketSslInterface::recvfrom(Socket &, void *, unsigned, SocketAddress &) +unsigned SocketSsl::waitRecv(void *data, unsigned len, int timeout) { - throw error::Error("recvfrom", "SSL socket is not UDP compatible", 0); -} - -unsigned SocketSslInterface::tryRecv(Socket &s, void *data, unsigned len, int timeout) -{ - SocketListener listener{{s, Read}}; + SocketListener listener{{*this, SocketListener::Read}}; listener.select(timeout); - return recv(s, data, len); + return recv(data, len); } -unsigned SocketSslInterface::tryRecvfrom(Socket &, void *, unsigned, SocketAddress &, int) -{ - throw error::Error("recvfrom", "SSL socket is not UDP compatible", 0); -} - -unsigned SocketSslInterface::send(Socket &, const void *data, unsigned len) +unsigned SocketSsl::send(const void *data, unsigned len) { auto nbread = SSL_write(m_ssl.get(), data, len); if (nbread <= 0) { auto error = SSL_get_error(m_ssl.get(), nbread); - if (error == SSL_ERROR_WANT_READ || error == SSL_ERROR_WANT_WRITE) - throw error::InProgress("accept", sslError(error), error, error); - - throw error::Error("recv", sslError(error), error); + if (error == SSL_ERROR_WANT_READ) { + throw SocketError(SocketError::WouldBlockRead, "send", "Operation would block"); + } else if (error == SSL_ERROR_WANT_WRITE) { + throw SocketError(SocketError::WouldBlockWrite, "send", "Operation would block"); + } else { + throw SocketError(SocketError::System, "send", sslError(error)); + } } return nbread; } -unsigned SocketSslInterface::sendto(Socket &, const void *, unsigned, const SocketAddress &) +unsigned SocketSsl::waitSend(const void *data, unsigned len, int timeout) { - throw error::Error("sendto", "SSL socket is not UDP compatible", 0); -} - -unsigned SocketSslInterface::trySend(Socket &s, const void *data, unsigned len, int timeout) -{ - SocketListener listener{{s, Write}}; + SocketListener listener{{*this, SocketListener::Write}}; listener.select(timeout); - return send(s, data, len); + return send(data, len); } -unsigned SocketSslInterface::trySendto(Socket &, const void *, unsigned, const SocketAddress &, int) -{ - throw error::Error("sendto", "SSL socket is not UDP compatible", 0); -} - -SocketSsl::SocketSsl(int family, SocketSslOptions options) - : Socket(family, SOCK_STREAM, 0) -{ - m_interface = std::make_shared<SocketSslInterface>(std::move(options)); -}
--- a/C++/SocketSsl.h Fri Mar 06 17:30:16 2015 +0100 +++ b/C++/SocketSsl.h Sun Mar 08 11:04:01 2015 +0100 @@ -16,8 +16,11 @@ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */ -#ifndef _SOCKET_SSL_NG_3_H_ -#define _SOCKET_SSL_NG_3_H_ +#ifndef _SOCKET_SSL_NG_H_ +#define _SOCKET_SSL_NG_H_ + +#include <atomic> +#include <mutex> #include <openssl/err.h> #include <openssl/evp.h> @@ -25,22 +28,40 @@ #include "SocketTcp.h" +/** + * @class SocketSslOptions + * @brief Options for SocketSsl + */ class SocketSslOptions { public: + /** + * @brief Method + */ enum { SSLv3 = (1 << 0), TLSv1 = (1 << 1), All = (0xf) }; - int method{All}; - std::string certificate; - std::string privateKey; - bool verify{false}; + int method{All}; //!< The method + std::string certificate; //!< The certificate path + std::string privateKey; //!< The private key file + bool verify{false}; //!< Verify or not + /** + * Default constructor. + */ SocketSslOptions() = default; - SocketSslOptions(unsigned short method, std::string certificate, std::string key, bool verify = false) + /** + * More advanced constructor. + * + * @param method the method requested + * @param certificate the certificate file + * @param key the key file + * @param verify set to true to verify + */ + SocketSslOptions(int method, std::string certificate, std::string key, bool verify = false) : method(method) , certificate(std::move(certificate)) , privateKey(std::move(key)) @@ -56,41 +77,57 @@ * This class derives from SocketAbstractTcp and provide SSL support through OpenSSL. */ class SocketSsl : public SocketAbstractTcp { +public: + using ContextHandle = std::unique_ptr<SSL_CTX, void (*)(SSL_CTX *)>; + using SslHandle = std::unique_ptr<SSL, void (*)(SSL *)>; + private: -#if defined(_WIN32) && !defined(SOCKET_NO_WSA_INIT) - static std::mutex s_mutex{}; - static std::atomic<bool> s_initialized{false}; + static std::mutex s_sslMutex; + static std::atomic<bool> s_sslInitialized; + + ContextHandle m_context{nullptr, nullptr}; + SslHandle m_ssl{nullptr, nullptr}; + SocketSslOptions m_options; - static inline void terminateSsl() +public: + using SocketAbstractTcp::recv; + using SocketAbstractTcp::waitRecv; + using SocketAbstractTcp::send; + using SocketAbstractTcp::waitSend; + + /** + * Close OpenSSL library. + */ + static inline void sslTerminate() { ERR_free_strings(); } - static inline void initializeSsl() + /** + * Open SSL library. + */ + static inline void sslInitialize() { - std::lock_guard<std::mutex> lock(s_mutex); + std::lock_guard<std::mutex> lock(s_sslMutex); - if (!s_initialized) { - s_initialized = true; + if (!s_sslInitialized) { + s_sslInitialized = true; SSL_library_init(); SSL_load_error_strings(); - std::atexit(terminate); + std::atexit(sslTerminate); } } -#endif - -public: - /** - * Initialize SSL library. - */ - static void init(); /** - * Close SSL library. + * Create a SocketSsl from an already created one. + * + * @param handle the native handle + * @param context the context + * @param ssl the ssl object */ - static void finish(); + SocketSsl(Socket::Handle handle, SSL_CTX *context, SSL *ssl); /** * Open a SSL socket with the specified family. Automatically @@ -99,7 +136,80 @@ * @param family the family * @param options the options */ - SocketSsl(int family, SocketSslOptions options = {}); + SocketSsl(int family, int protocol, SocketSslOptions options = {}); + + /** + * Accept a SSL TCP socket. + * + * @return the socket + * @throw SocketError on error + */ + SocketSsl accept(); + + /** + * Accept a SSL TCP socket. + * + * @param info the client information + * @return the socket + * @throw SocketError on error + */ + SocketSsl accept(SocketAddress &info); + + /** + * Accept a SSL TCP socket. + * + * @param timeout the maximum timeout in milliseconds + * @return the socket + * @throw SocketError on error + */ + SocketSsl waitAccept(int timeout); + + /** + * Accept a SSL TCP socket. + * + * @param info the client information + * @param timeout the maximum timeout in milliseconds + * @return the socket + * @throw SocketError on error + */ + SocketSsl waitAccept(SocketAddress &info, int timeout); + + /** + * Connect to an end point. + * + * @param address the address + * @throw SocketError on error + */ + void connect(const SocketAddress &address); + + /** + * Connect to an end point. + * + * @param timeout the maximum timeout in milliseconds + * @param address the address + * @throw SocketError on error + */ + void waitConnect(const SocketAddress &address, int timeout); + + /** + * @copydoc SocketAbstractTcp::recv + */ + unsigned recv(void *data, unsigned length) override; + + /** + * @copydoc SocketAbstractTcp::recv + */ + unsigned waitRecv(void *data, unsigned length, int timeout) override; + + /** + * @copydoc SocketAbstractTcp::recv + */ + unsigned send(const void *data, unsigned length) override; + + /** + * @copydoc SocketAbstractTcp::recv + */ + unsigned waitSend(const void *data, unsigned length, int timeout) override; }; -#endif // !_SOCKET_SSL_NG_3_H_ +#endif // !_SOCKET_SSL_NG_H_
--- a/C++/SocketTcp.cpp Fri Mar 06 17:30:16 2015 +0100 +++ b/C++/SocketTcp.cpp Sun Mar 08 11:04:01 2015 +0100 @@ -30,18 +30,7 @@ throw SocketError(SocketError::System, "listen"); } -/* -------------------------------------------------------- - * SocketTcp - * -------------------------------------------------------- */ - -SocketTcp SocketTcp::accept() -{ - SocketAddress dummy; - - return accept(dummy); -} - -SocketTcp SocketTcp::accept(SocketAddress &info) +Socket SocketAbstractTcp::standardAccept(SocketAddress &info) { Socket::Handle handle; @@ -68,13 +57,12 @@ #endif } - // Usually accept works only with SOCK_STREAM info = SocketAddress(address, addrlen); - return SocketTcp(handle); + return Socket(handle); } -void SocketTcp::connect(const SocketAddress &address) +void SocketAbstractTcp::standardConnect(const SocketAddress &address) { if (m_state == SocketState::Connected) return; @@ -91,12 +79,12 @@ int error = WSAGetLastError(); if (error == WSAEWOULDBLOCK) - throw SocketError(SocketError::InProgress, "connect", error); + throw SocketError(SocketError::WouldBlockWrite, "connect", error); throw SocketError(SocketError::System, "connect", error); #else if (errno == EINPROGRESS) - throw SocketError(SocketError::InProgress, "connect"); + throw SocketError(SocketError::WouldBlockWrite, "connect"); throw SocketError(SocketError::System, "connect"); #endif @@ -105,6 +93,27 @@ m_state = SocketState::Connected; } +/* -------------------------------------------------------- + * SocketTcp + * -------------------------------------------------------- */ + +SocketTcp SocketTcp::accept() +{ + SocketAddress dummy; + + return accept(dummy); +} + +SocketTcp SocketTcp::accept(SocketAddress &info) +{ + return standardAccept(info); +} + +void SocketTcp::connect(const SocketAddress &address) +{ + return standardConnect(address); +} + void SocketTcp::waitConnect(const SocketAddress &address, int timeout) { if (m_state == SocketState::Connected) @@ -114,7 +123,7 @@ try { connect(address); } catch (const SocketError &ex) { - if (ex.code() == SocketError::InProgress) { + if (ex.code() == SocketError::WouldBlockWrite) { SocketListener listener{{*this, SocketListener::Write}}; listener.select(timeout);
--- a/C++/SocketTcp.h Fri Mar 06 17:30:16 2015 +0100 +++ b/C++/SocketTcp.h Sun Mar 08 11:04:01 2015 +0100 @@ -16,26 +16,65 @@ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */ -#ifndef _SOCKET_TCP_NG_3_H_ -#define _SOCKET_TCP_NG_3_H_ +#ifndef _SOCKET_TCP_NG_H_ +#define _SOCKET_TCP_NG_H_ #include "Socket.h" /** * @class SocketAbstractTcp * @brief Base class for TCP sockets + * + * This abstract class provides standard TCP functions for both clear + * and SSL implementation. + * + * It does not contain default accept() and connect() because they varies too + * much between standard and SSL. Also, the accept() function return different + * types. */ class SocketAbstractTcp : public Socket { +protected: + Socket standardAccept(SocketAddress &address); + void standardConnect(const SocketAddress &address); + public: - using Socket::Socket; + /** + * Construct an abstract socket from an already made socket. + * + * @param s the socket + */ + inline SocketAbstractTcp(Socket s) + : Socket(s) + { + } + /** + * Construct a standard TCP socket. The type is automatically + * set to SOCK_STREAM. + * + * @param domain the domain + * @param protocol the protocol + * @throw SocketError on error + */ inline SocketAbstractTcp(int domain, int protocol) : Socket(domain, SOCK_STREAM, protocol) { } + /** + * Listen for pending connection. + * + * @param max the maximum number + */ void listen(int max = 128); + /** + * Overloaded function. + * + * @param count the number of bytes to receive + * @return the string + * @throw SocketError on error + */ inline std::string recv(unsigned count) { std::string result; @@ -47,6 +86,14 @@ return result; } + /** + * Overloaded function. + * + * @param count the number of bytes to receive + * @param timeout the maximum timeout in milliseconds + * @return the string + * @throw SocketError on error + */ inline std::string waitRecv(unsigned count, int timeout) { std::string result; @@ -58,22 +105,70 @@ return result; } + /** + * Overloaded function. + * + * @param data the string to send + * @return the number of bytes sent + * @throw SocketError on error + */ inline unsigned send(const std::string &data) { return send(data.c_str(), data.size()); } + /** + * Overloaded function. + * + * @param data the string to send + * @param timeout the maximum timeout in milliseconds + * @return the number of bytes sent + * @throw SocketError on error + */ inline unsigned waitSend(const std::string &data, int timeout) { return waitSend(data.c_str(), data.size(), timeout); } + /** + * Receive data. + * + * @param data the destination buffer + * @param length the buffer length + * @return the number of bytes received + * @throw SocketError on error + */ virtual unsigned recv(void *data, unsigned length) = 0; + /** + * Receive data. + * + * @param data the destination buffer + * @param length the buffer length + * @param timeout the maximum timeout in milliseconds + * @return the number of bytes received + * @throw SocketError on error + */ virtual unsigned waitRecv(void *data, unsigned length, int timeout) = 0; + /** + * Send data. + * + * @param data the buffer + * @param length the buffer length + * @return the number of bytes sent + * @throw SocketError on error + */ virtual unsigned send(const void *data, unsigned length) = 0; + /** + * Send data. + * + * @param data the buffer + * @param length the buffer length + * @return the number of bytes sent + * @throw SocketError on error + */ virtual unsigned waitSend(const void *data, unsigned length, int timeout) = 0; }; @@ -81,7 +176,7 @@ * @class SocketTcp * @brief End-user class for TCP sockets */ -class SocketTcp final : public SocketAbstractTcp { +class SocketTcp : public SocketAbstractTcp { public: using SocketAbstractTcp::SocketAbstractTcp; using SocketAbstractTcp::recv; @@ -89,25 +184,78 @@ using SocketAbstractTcp::send; using SocketAbstractTcp::waitSend; + /** + * Accept a clear TCP socket. + * + * @return the socket + * @throw SocketError on error + */ SocketTcp accept(); + /** + * Accept a clear TCP socket. + * + * @param info the client information + * @return the socket + * @throw SocketError on error + */ SocketTcp accept(SocketAddress &info); + /** + * Accept a clear TCP socket. + * + * @param timeout the maximum timeout in milliseconds + * @return the socket + * @throw SocketError on error + */ SocketTcp waitAccept(int timeout); + /** + * Accept a clear TCP socket. + * + * @param info the client information + * @param timeout the maximum timeout in milliseconds + * @return the socket + * @throw SocketError on error + */ SocketTcp waitAccept(SocketAddress &info, int timeout); + /** + * Connect to an end point. + * + * @param address the address + * @throw SocketError on error + */ void connect(const SocketAddress &address); + /** + * Connect to an end point. + * + * @param timeout the maximum timeout in milliseconds + * @param address the address + * @throw SocketError on error + */ void waitConnect(const SocketAddress &address, int timeout); + /** + * @copydoc SocketAbstractTcp::recv + */ unsigned recv(void *data, unsigned length) override; + /** + * @copydoc SocketAbstractTcp::waitRecv + */ unsigned waitRecv(void *data, unsigned length, int timeout) override; + /** + * @copydoc SocketAbstractTcp::send + */ unsigned send(const void *data, unsigned length) override; + /** + * @copydoc SocketAbstractTcp::waitSend + */ unsigned waitSend(const void *data, unsigned length, int timeout) override; }; -#endif // !_SOCKET_TCP_NG_3_H_ +#endif // !_SOCKET_TCP_NG_H_
--- a/C++/SocketUdp.h Fri Mar 06 17:30:16 2015 +0100 +++ b/C++/SocketUdp.h Sun Mar 08 11:04:01 2015 +0100 @@ -16,22 +16,46 @@ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */ -#ifndef _SOCKET_UDP_NG_3_H_ -#define _SOCKET_UDP_NG_3_H_ +#ifndef _SOCKET_UDP_NG_H_ +#define _SOCKET_UDP_NG_H_ #include "Socket.h" +/** + * @class SocketUdp + * @brief UDP implementation for sockets + */ class SocketUdp : public Socket { public: - using Socket::Socket; - + /** + * Construct a UDP socket. The type is automatically set to SOCK_DGRAM. + * + * @param domain the domain (e.g AF_INET) + * @param protocol the protocol (usually 0) + */ SocketUdp(int domain, int protocol); - inline unsigned sendto(const std::string &data, const SocketAddress &info) + /** + * Overloaded function. + * + * @param data the data + * @param address the address + * @return the number of bytes sent + * @throw SocketError on error + */ + inline unsigned sendto(const std::string &data, const SocketAddress &address) { - return sendto(data.c_str(), data.length(), info); + return sendto(data.c_str(), data.length(), address); } + /** + * Overloaded function. + * + * @param data the data + * @param info the client information + * @return the string + * @throw SocketError on error + */ inline std::string recvfrom(unsigned count, SocketAddress &info) { std::string result; @@ -43,9 +67,27 @@ return result; } + /** + * Receive data from an end point. + * + * @param data the destination buffer + * @param length the buffer length + * @param info the client information + * @return the number of bytes received + * @throw SocketError on error + */ virtual unsigned recvfrom(void *data, unsigned length, SocketAddress &info); - virtual unsigned sendto(const void *data, unsigned length, const SocketAddress &info); + /** + * Send data to an end point. + * + * @param data the buffer + * @param length the buffer length + * @param address the client address + * @return the number of bytes sent + * @throw SocketError on error + */ + virtual unsigned sendto(const void *data, unsigned length, const SocketAddress &address); }; -#endif // !_SOCKET_UDP_NG_3_H_ +#endif // !_SOCKET_UDP_NG_H_
--- a/C++/Tests/Sockets/CMakeLists.txt Fri Mar 06 17:30:16 2015 +0100 +++ b/C++/Tests/Sockets/CMakeLists.txt Sun Mar 08 11:04:01 2015 +0100 @@ -30,8 +30,8 @@ ${code_SOURCE_DIR}/C++/SocketAddress.h ${code_SOURCE_DIR}/C++/SocketListener.cpp ${code_SOURCE_DIR}/C++/SocketListener.h - #${code_SOURCE_DIR}/C++/SocketSsl.cpp - #${code_SOURCE_DIR}/C++/SocketSsl.h + ${code_SOURCE_DIR}/C++/SocketSsl.cpp + ${code_SOURCE_DIR}/C++/SocketSsl.h main.cpp )
--- a/C++/Tests/Sockets/main.cpp Fri Mar 06 17:30:16 2015 +0100 +++ b/C++/Tests/Sockets/main.cpp Sun Mar 08 11:04:01 2015 +0100 @@ -27,6 +27,7 @@ #include "Socket.h" #include "SocketAddress.h" #include "SocketListener.h" +#include "SocketSsl.h" #include "SocketTcp.h" #include "SocketUdp.h" @@ -664,6 +665,42 @@ }); } +/* -------------------------------------------------------- + * Socket SSL + * -------------------------------------------------------- */ + +class SslTest : public testing::Test { +protected: + SocketSsl client{AF_INET, 0}; +}; + +TEST_F(SslTest, connect) +{ + try { + client.connect(Internet("google.fr", 443, AF_INET)); + client.close(); + } catch (const SocketError &error) { + FAIL() << error.what(); + } +} + +TEST_F(SslTest, recv) +{ + try { + client.connect(Internet("google.fr", 443, AF_INET)); + client.send("GET / HTTP/1.0\r\n\r\n"); + + std::string msg = client.recv(512); + std::string content = msg.substr(0, 18); + + ASSERT_EQ("HTTP/1.0 302 Found", content); + + client.close(); + } catch (const SocketError &error) { + FAIL() << error.what(); + } +} + int main(int argc, char **argv) { testing::InitGoogleTest(&argc, argv);