# HG changeset patch # User David Demelier # Date 1435670221 -7200 # Node ID e292f0fa53959b8b81106cb1f6c30f8629bcff7c # Parent 131ecc6ce7b5be0d290891dae6b6058cd6a975b3 Socket: resurrect SSL diff -r 131ecc6ce7b5 -r e292f0fa5395 C++/modules/Socket/Socket.h --- a/C++/modules/Socket/Socket.h Tue Jun 30 13:53:40 2015 +0200 +++ b/C++/modules/Socket/Socket.h Tue Jun 30 15:17:01 2015 +0200 @@ -29,8 +29,6 @@ * automatically calls WSAStartup() when creating sockets. Otherwise, you will need to call * SocketAbstract::init, SocketAbstract::finish yourself. * - * - **SOCKET_NO_SSL_INIT**: (bool) Set to false if you don't want OpenSSL to be - * initialized when the first SocketSsl object is created. */ #include @@ -536,6 +534,24 @@ */ template class SocketAbstractTcp : public Socket
{ +protected: + /** + * Do standard accept. + * + * @param info the address + * @return the connected handle + * @throw SocketError on error + */ + SocketAbstract::Handle standardAccept(Address &info); + + /** + * Do standard connect. + * + * @param info the address + * @throw SocketError on error + */ + void standardConnect(const Address &info); + public: /** * Inherited constructors. @@ -621,75 +637,44 @@ } }; -/** - * @class SocketTcp - * @brief Standard implementation of TCP sockets - * - * This class is the basic implementation of TCP sockets. - */ template -class SocketTcp : public SocketAbstractTcp
{ -public: - /** - * Inherited constructors. - */ - using SocketAbstractTcp
::SocketAbstractTcp; +SocketAbstract::Handle SocketAbstractTcp
::standardAccept(Address &info) +{ + SocketAbstract::Handle handle; + + // Store the information + sockaddr_storage address; + socklen_t addrlen; + + addrlen = sizeof (sockaddr_storage); + handle = ::accept(SocketAbstract::m_handle, reinterpret_cast(&address), &addrlen); - /** - * Default constructor. - */ - SocketTcp() = default; + if (handle == SocketAbstract::Invalid) { +#if defined(_WIN32) + int error = WSAGetLastError(); + + if (error == WSAEWOULDBLOCK) { + throw SocketError{SocketError::WouldBlockRead, "accept", error}; + } - /** - * Connect to an end point. - * - * @param address the address - * @throw SocketError on error - */ - void connect(const Address &address); + throw SocketError{SocketError::System, "accept", error}; +#else + if (errno == EAGAIN || errno == EWOULDBLOCK) { + throw SocketError{SocketError::WouldBlockRead, "accept"}; + } - /** - * Overloaded function. - */ - inline SocketTcp accept() - { - Address dummy; - - return accept(dummy); + throw SocketError{SocketError::System, "accept"}; +#endif } - /** - * Accept a clear TCP socket. - * - * @param info the client information - * @return the socket - * @throw SocketError on error - */ - SocketTcp accept(Address &info); - - /** - * @copydoc SocketAbstractTcp
::recv - */ - using SocketAbstractTcp
::recv; + info = Address{address, addrlen}; - /** - * @copydoc SocketAbstractTcp
::send - */ - using SocketAbstractTcp
::send; - - /** - * @copydoc SocketAbstractTcp
::recv - */ - unsigned recv(void *data, unsigned length) override; - - /** - * @copydoc SocketAbstractTcp
::send - */ - unsigned send(const void *data, unsigned length) override; -}; + //return SocketTcp{handle}; + return handle; +} template -void SocketTcp
::connect(const Address &address) +void SocketAbstractTcp
::standardConnect(const Address &address) { if (SocketAbstract::m_state == SocketState::Connected) { return; @@ -723,6 +708,79 @@ SocketAbstract::m_state = SocketState::Connected; } +/** + * @class SocketTcp + * @brief Standard implementation of TCP sockets + * + * This class is the basic implementation of TCP sockets. + */ +template +class SocketTcp : public SocketAbstractTcp
{ +public: + /** + * Inherited constructors. + */ + using SocketAbstractTcp
::SocketAbstractTcp; + + /** + * Default constructor. + */ + SocketTcp() = default; + + /** + * Connect to an end point. + * + * @param address the address + * @throw SocketError on error + */ + inline void connect(const Address &address) + { + SocketAbstractTcp
::standardConnect(address); + } + + /** + * Overloaded function. + */ + inline SocketTcp accept() + { + Address dummy; + + return accept(dummy); + } + + /** + * Accept a clear TCP socket. + * + * @param info the client information + * @return the socket + * @throw SocketError on error + */ + inline SocketTcp accept(Address &info) + { + return SocketTcp{SocketAbstractTcp
::standardAccept(info)}; + } + + /** + * @copydoc SocketAbstractTcp
::recv + */ + using SocketAbstractTcp
::recv; + + /** + * @copydoc SocketAbstractTcp
::send + */ + using SocketAbstractTcp
::send; + + /** + * @copydoc SocketAbstractTcp
::recv + */ + unsigned recv(void *data, unsigned length) override; + + /** + * @copydoc SocketAbstractTcp
::send + */ + unsigned send(const void *data, unsigned length) override; +}; + template unsigned SocketTcp
::recv(void *data, unsigned dataLen) { @@ -779,41 +837,6 @@ return static_cast(nbsent); } -template -SocketTcp
SocketTcp
::accept(Address &info) -{ - SocketAbstract::Handle handle; - - // Store the information - sockaddr_storage address; - socklen_t addrlen; - - addrlen = sizeof (sockaddr_storage); - handle = ::accept(SocketAbstract::m_handle, reinterpret_cast(&address), &addrlen); - - if (handle == SocketAbstract::Invalid) { -#if defined(_WIN32) - int error = WSAGetLastError(); - - if (error == WSAEWOULDBLOCK) { - throw SocketError{SocketError::WouldBlockRead, "accept", error}; - } - - throw SocketError{SocketError::System, "accept", error}; -#else - if (errno == EAGAIN || errno == EWOULDBLOCK) { - throw SocketError{SocketError::WouldBlockRead, "accept"}; - } - - throw SocketError{SocketError::System, "accept"}; -#endif - } - - info = Address{address, addrlen}; - - return SocketTcp{handle}; -} - /* -------------------------------------------------------- * UDP Sockets * -------------------------------------------------------- */ diff -r 131ecc6ce7b5 -r e292f0fa5395 C++/modules/Socket/SocketSsl.cpp --- a/C++/modules/Socket/SocketSsl.cpp Tue Jun 30 13:53:40 2015 +0200 +++ b/C++/modules/Socket/SocketSsl.cpp Tue Jun 30 15:17:01 2015 +0200 @@ -16,156 +16,31 @@ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */ -#if 0 - #include "SocketAddress.h" #include "SocketSsl.h" -namespace { - -inline auto sslMethod(int type) noexcept -{ - if (type == SocketSslOptions::SSLv3) - return SSLv3_method(); - if (type == SocketSslOptions::TLSv1) - return TLSv1_method(); - - throw std::invalid_argument("unknown method selected"); -} - -inline std::string sslError(int error) -{ - return ERR_reason_error_string(error); -} - -} // !namespace - -std::mutex SocketAbstractSsl::s_sslMutex; -std::atomic SocketAbstractSsl::s_sslInitialized{false}; - -SocketSsl::SocketSsl(SocketTcp
sc, SSL_CTX *context, SSL *ssl) - : SocketTcp{std::move(*sc)} - , m_context{context, SSL_CTX_free} - , m_ssl{ssl, SSL_free} -{ -#if !defined(SOCKET_NO_SSL_INIT) - if (!s_sslInitialized) { - sslInitialize(); - } -#endif -} +namespace detail { -SocketSsl::SocketSsl(int family, int protocol, SocketSslOptions options) - : SocketTcp{family, protocol} - , m_context{nullptr, nullptr} - , m_ssl{nullptr, nullptr} - , m_options{std::move(options)} -{ -#if !defined(SOCKET_NO_SSL_INIT) - if (!s_sslInitialized) { - sslInitialize(); - } -#endif - m_context = ContextHandle{SSL_CTX_new(sslMethod(m_options.method)), SSL_CTX_free}; - m_ssl = SslHandle{SSL_new(m_context.get()), SSL_free}; +std::mutex mutex; +std::atomic initialized{false}; - SSL_set_fd(m_ssl.get(), m_handle); -} - -void SocketSsl::connect(const std::unique_ptr &address) +void terminate() { - // 1. Standard connect - SocketTcp::connect(address); - - // 2. OpenSSL handshake - auto ret = SSL_connect(m_ssl.get()); - - if (ret <= 0) { - auto error = SSL_get_error(m_ssl.get(), ret); - - if (error == SSL_ERROR_WANT_READ) { - throw SocketError(SocketError::WouldBlockRead, "connect", "Operation in progress"); - } else if (error == SSL_ERROR_WANT_WRITE) { - throw SocketError(SocketError::WouldBlockWrite, "connect", "Operation in progress"); - } else { - throw SocketError(SocketError::System, "connect", sslError(error)); - } - } - - m_state = SocketState::Connected; + ERR_free_strings(); } -std::unique_ptr SocketSsl::accept(std::unique_ptr &info) +void initialize() { - auto client = SocketTcp::accept(info); - auto context = SSL_CTX_new(sslMethod(m_options.method)); - - if (m_options.certificate.size() > 0) - SSL_CTX_use_certificate_file(context, m_options.certificate.c_str(), SSL_FILETYPE_PEM); - if (m_options.privateKey.size() > 0) - SSL_CTX_use_PrivateKey_file(context, m_options.privateKey.c_str(), SSL_FILETYPE_PEM); - if (m_options.verify && !SSL_CTX_check_private_key(context)) { - throw SocketError(SocketError::System, "accept", "certificate failure"); - } - - // SSL object - auto ssl = SSL_new(context); + std::lock_guard lock(mutex); - SSL_set_fd(ssl, client->handle()); - - auto ret = SSL_accept(ssl); - - if (ret <= 0) { - auto error = SSL_get_error(ssl, ret); + if (!initialized) { + initialized = true; - if (error == SSL_ERROR_WANT_READ) { - throw SocketError(SocketError::WouldBlockRead, "accept", "Operation would block"); - } else if (error == SSL_ERROR_WANT_WRITE) { - throw SocketError(SocketError::WouldBlockWrite, "accept", "Operation would block"); - } else { - throw SocketError(SocketError::System, "accept", sslError(error)); - } + SSL_library_init(); + SSL_load_error_strings(); + + atexit(terminate); } - - return std::make_unique(std::move(client), context, ssl); } -unsigned SocketSsl::recv(void *data, unsigned len) -{ - auto nbread = SSL_read(m_ssl.get(), data, len); - - if (nbread <= 0) { - auto error = SSL_get_error(m_ssl.get(), nbread); - - if (error == SSL_ERROR_WANT_READ) { - throw SocketError(SocketError::WouldBlockRead, "recv", "Operation would block"); - } else if (error == SSL_ERROR_WANT_WRITE) { - throw SocketError(SocketError::WouldBlockWrite, "recv", "Operation would block"); - } else { - throw SocketError(SocketError::System, "recv", sslError(error)); - } - } - - return nbread; -} - -unsigned SocketSsl::send(const void *data, unsigned len) -{ - auto nbread = SSL_write(m_ssl.get(), data, len); - - if (nbread <= 0) { - auto error = SSL_get_error(m_ssl.get(), nbread); - - if (error == SSL_ERROR_WANT_READ) { - throw SocketError(SocketError::WouldBlockRead, "send", "Operation would block"); - } else if (error == SSL_ERROR_WANT_WRITE) { - throw SocketError(SocketError::WouldBlockWrite, "send", "Operation would block"); - } else { - throw SocketError(SocketError::System, "send", sslError(error)); - } - } - - return nbread; -} - -#endif +} // !detail diff -r 131ecc6ce7b5 -r e292f0fa5395 C++/modules/Socket/SocketSsl.h --- a/C++/modules/Socket/SocketSsl.h Tue Jun 30 13:53:40 2015 +0200 +++ b/C++/modules/Socket/SocketSsl.h Tue Jun 30 15:17:01 2015 +0200 @@ -16,11 +16,20 @@ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */ -#if 0 - #ifndef _SOCKET_SSL_NG_H_ #define _SOCKET_SSL_NG_H_ +/** + * @file SocketSsl.h + * @brief Bring SSL support to socket module + * @note This code is considered experimental + * + * User may set the following variables before compiling these files: + * + * - **SOCKET_NO_SSL_INIT**: (bool) Set to false if you don't want OpenSSL to be + * initialized when the first SocketSsl object is created. + */ + #include #include #include @@ -30,16 +39,7 @@ #include #include -#include "SocketTcp.h" - -/** - * This namespace is private. - */ -namespace ssl { - - - -} // !ssl +#include "Socket.h" /** * @class SocketSslOptions @@ -57,10 +57,10 @@ TLSv1 }; - int method{TLSv1}; //!< The method - std::string certificate; //!< The certificate path - std::string privateKey; //!< The private key file - bool verify{false}; //!< Verify or not + int method{TLSv1}; //!< The method + std::string certificate; //!< The certificate path + std::string privateKey; //!< The private key file + bool verify{false}; //!< Verify or not /** * Default constructor. @@ -85,6 +85,54 @@ }; /** + * This namespace is private. + */ +namespace detail { + +/** + * Mutex for thread-safe initialization. + */ +extern std::mutex mutex; + +/** + * Boolean that marks the SSL module initialized. + */ +extern std::atomic initialized; + +/** + * Get the appropriate method. + */ +inline auto method(int type) noexcept +{ + if (type == SocketSslOptions::SSLv3) + return SSLv3_method(); + if (type == SocketSslOptions::TLSv1) + return TLSv1_method(); + + throw std::invalid_argument("unknown method selected"); +} + +/** + * Get the error + */ +inline std::string error(int error) +{ + return ERR_reason_error_string(error); +} + +/** + * Close OpenSSL library. + */ +void terminate(); + +/** + * Open SSL library. + */ +void initialize(); + +} // !ssl + +/** * @class SocketSsl * @brief SSL interface for sockets * @@ -96,53 +144,26 @@ using ContextHandle = std::unique_ptr; using SslHandle = std::unique_ptr; - static std::mutex s_sslMutex; - static std::atomic s_sslInitialized; - ContextHandle m_context{nullptr, nullptr}; SslHandle m_ssl{nullptr, nullptr}; SocketSslOptions m_options; public: /** - * Close OpenSSL library. - */ - static inline void sslTerminate() - { - ERR_free_strings(); - } - - /** - * Open SSL library. - */ - static inline void sslInitialize() - { - std::lock_guard lock(s_sslMutex); - - if (!s_sslInitialized) { - s_sslInitialized = true; - - SSL_library_init(); - SSL_load_error_strings(); - - atexit(sslTerminate); - } - } - - /** * Create a SocketSsl from an already created one. * * @param sc the standard TCP socket * @param context the context * @param ssl the ssl object */ - SocketSsl(SocketTcp
sc, SSL_CTX *context, SSL *ssl); + explicit SocketSsl(SocketTcp
sc, SSL_CTX *context, SSL *ssl); /** * Open a SSL socket with the specified family. Automatically * use SOCK_STREAM as the type. * * @param family the family + * @param protocol the protocol * @param options the options */ SocketSsl(int family, int protocol, SocketSslOptions options = {}); @@ -154,7 +175,7 @@ * @return the socket * @throw SocketError on error */ - std::unique_ptr accept(std::unique_ptr &info) override; + SocketSsl accept(Address &info); /** * Connect to an end point. @@ -162,7 +183,7 @@ * @param address the address * @throw SocketError on error */ - void connect(const std::unique_ptr &address) override; + void connect(const Address &address); /** * @copydoc SocketTcp::recv @@ -177,14 +198,148 @@ /** * Bring back send overloads. */ - using SocketTcp::send; + using SocketAbstractTcp
::send; /** * Bring back recv overloads; */ - using SocketTcp::recv; + using SocketAbstractTcp
::recv; }; -#endif // !_SOCKET_SSL_NG_H_ +template +SocketSsl
::SocketSsl(SocketTcp
sc, SSL_CTX *context, SSL *ssl) + : SocketAbstractTcp
{sc.handle()} + , m_context{context, SSL_CTX_free} + , m_ssl{ssl, SSL_free} +{ +#if !defined(SOCKET_NO_SSL_INIT) + if (!detail::initialized) { + detail::initialize(); + } +#endif + + // Invalid other + sc.m_handle = -1; + sc.m_state = SocketState::Closed; +} + +template +SocketSsl
::SocketSsl(int family, int protocol, SocketSslOptions options) + : SocketAbstractTcp
{family, protocol} + , m_context{nullptr, nullptr} + , m_ssl{nullptr, nullptr} + , m_options{std::move(options)} +{ +#if !defined(SOCKET_NO_SSL_INIT) + if (!detail::initialized) { + detail::initialize(); + } +#endif + m_context = ContextHandle{SSL_CTX_new(detail::method(m_options.method)), SSL_CTX_free}; + m_ssl = SslHandle{SSL_new(m_context.get()), SSL_free}; + + SSL_set_fd(m_ssl.get(), SocketAbstract::m_handle); +} + + +template +void SocketSsl
::connect(const Address &address) +{ + // 1. Standard connect + SocketAbstractTcp
::standardConnect(address); + + // 2. OpenSSL handshake + auto ret = SSL_connect(m_ssl.get()); + + if (ret <= 0) { + auto error = SSL_get_error(m_ssl.get(), ret); + + if (error == SSL_ERROR_WANT_READ) { + throw SocketError{SocketError::WouldBlockRead, "connect", "Operation in progress"}; + } else if (error == SSL_ERROR_WANT_WRITE) { + throw SocketError{SocketError::WouldBlockWrite, "connect", "Operation in progress"}; + } else { + throw SocketError{SocketError::System, "connect", detail::error(error)}; + } + } + + SocketAbstract::m_state = SocketState::Connected; +} + +template +SocketSsl
SocketSsl
::accept(Address &info) +{ + auto client = SocketAbstractTcp
::standardAccept(info); + auto context = SSL_CTX_new(detail::method(m_options.method)); -#endif + if (m_options.certificate.size() > 0) + SSL_CTX_use_certificate_file(context, m_options.certificate.c_str(), SSL_FILETYPE_PEM); + if (m_options.privateKey.size() > 0) + SSL_CTX_use_PrivateKey_file(context, m_options.privateKey.c_str(), SSL_FILETYPE_PEM); + if (m_options.verify && !SSL_CTX_check_private_key(context)) { + throw SocketError(SocketError::System, "accept", "certificate failure"); + } + + // SSL object + auto ssl = SSL_new(context); + + SSL_set_fd(ssl, client->handle()); + + auto ret = SSL_accept(ssl); + + if (ret <= 0) { + auto error = SSL_get_error(ssl, ret); + + if (error == SSL_ERROR_WANT_READ) { + throw SocketError(SocketError::WouldBlockRead, "accept", "Operation would block"); + } else if (error == SSL_ERROR_WANT_WRITE) { + throw SocketError(SocketError::WouldBlockWrite, "accept", "Operation would block"); + } else { + throw SocketError(SocketError::System, "accept", detail::error(error)); + } + } + + return SocketSsl(std::move(client), context, ssl); +} + +template +unsigned SocketSsl
::recv(void *data, unsigned len) +{ + auto nbread = SSL_read(m_ssl.get(), data, len); + + if (nbread <= 0) { + auto error = SSL_get_error(m_ssl.get(), nbread); + + if (error == SSL_ERROR_WANT_READ) { + throw SocketError{SocketError::WouldBlockRead, "recv", "Operation would block"}; + } else if (error == SSL_ERROR_WANT_WRITE) { + throw SocketError{SocketError::WouldBlockWrite, "recv", "Operation would block"}; + } else { + throw SocketError{SocketError::System, "recv", detail::error(error)}; + } + } + + return nbread; +} + +template +unsigned SocketSsl
::send(const void *data, unsigned len) +{ + auto nbread = SSL_write(m_ssl.get(), data, len); + + if (nbread <= 0) { + auto error = SSL_get_error(m_ssl.get(), nbread); + + if (error == SSL_ERROR_WANT_READ) { + throw SocketError{SocketError::WouldBlockRead, "send", "Operation would block"}; + } else if (error == SSL_ERROR_WANT_WRITE) { + throw SocketError{SocketError::WouldBlockWrite, "send", "Operation would block"}; + } else { + throw SocketError{SocketError::System, "send", detail::error(error)}; + } + } + + return nbread; +} + +#endif // !_SOCKET_SSL_NG_H_ diff -r 131ecc6ce7b5 -r e292f0fa5395 C++/tests/Socket/main.cpp --- a/C++/tests/Socket/main.cpp Tue Jun 30 13:53:40 2015 +0200 +++ b/C++/tests/Socket/main.cpp Tue Jun 30 15:17:01 2015 +0200 @@ -580,17 +580,15 @@ * Socket SSL * -------------------------------------------------------- */ -#if 0 - class SslTest : public testing::Test { protected: - SocketSsl client{AF_INET, 0}; + SocketSsl client{AF_INET, 0}; }; TEST_F(SslTest, connect) { try { - client.connect(std::make_unique("google.fr", 443, AF_INET)); + client.connect(Ipv4{"google.fr", 443}); client.close(); } catch (const SocketError &error) { FAIL() << error.what(); @@ -600,7 +598,7 @@ TEST_F(SslTest, recv) { try { - client.connect(std::make_unique("google.fr", 443, AF_INET)); + client.connect(Ipv4{"google.fr", 443}); client.send("GET / HTTP/1.0\r\n\r\n"); std::string msg = client.recv(512); @@ -614,8 +612,6 @@ } } -#endif - /* -------------------------------------------------------- * Operators * -------------------------------------------------------- */