Mercurial > code
view C++/SocketSsl.cpp @ 319:cba77da58496
* Finalize SocketSsl
* Add documentation
author | David Demelier <markand@malikania.fr> |
---|---|
date | Sun, 08 Mar 2015 11:04:01 +0100 |
parents | c9356cb38c86 |
children |
line wrap: on
line source
/* * SocketSsl.cpp -- OpenSSL extension for sockets * * Copyright (c) 2013, David Demelier <markand@malikania.fr> * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above * copyright notice and this permission notice appear in all copies. * * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */ #include "SocketAddress.h" #include "SocketListener.h" #include "SocketSsl.h" namespace { const SSL_METHOD *sslMethod(int mflags) { if (mflags & SocketSslOptions::All) return SSLv23_method(); if (mflags & SocketSslOptions::SSLv3) return SSLv3_method(); if (mflags & SocketSslOptions::TLSv1) return TLSv1_method(); return SSLv23_method(); } inline std::string sslError(int error) { return ERR_reason_error_string(error); } inline int toDirection(int error) { if (error == SocketError::WouldBlockRead) return SocketListener::Read; if (error == SocketError::WouldBlockWrite) return SocketListener::Write; return 0; } } // !namespace std::mutex SocketSsl::s_sslMutex; std::atomic<bool> SocketSsl::s_sslInitialized{false}; SocketSsl::SocketSsl(Socket::Handle handle, SSL_CTX *context, SSL *ssl) : SocketAbstractTcp(handle) , m_context(context, SSL_CTX_free) , m_ssl(ssl, SSL_free) { #if !defined(SOCKET_NO_SSL_INIT) if (!s_sslInitialized) sslInitialize(); #endif } SocketSsl::SocketSsl(int family, int protocol, SocketSslOptions options) : SocketAbstractTcp(family, protocol) , m_options(std::move(options)) { #if !defined(SOCKET_NO_SSL_INIT) if (!s_sslInitialized) sslInitialize(); #endif } void SocketSsl::connect(const SocketAddress &address) { standardConnect(address); // Context first auto context = SSL_CTX_new(sslMethod(m_options.method)); m_context = ContextHandle(context, SSL_CTX_free); // SSL object then auto ssl = SSL_new(context); m_ssl = SslHandle(ssl, SSL_free); SSL_set_fd(ssl, m_handle); auto ret = SSL_connect(ssl); if (ret <= 0) { auto error = SSL_get_error(ssl, ret); if (error == SSL_ERROR_WANT_READ) { throw SocketError(SocketError::WouldBlockRead, "connect", "Operation in progress"); } else if (error == SSL_ERROR_WANT_WRITE) { throw SocketError(SocketError::WouldBlockWrite, "connect", "Operation in progress"); } else { throw SocketError(SocketError::System, "connect", sslError(error)); } } m_state = SocketState::Connected; } void SocketSsl::waitConnect(const SocketAddress &address, int timeout) { try { // Initial try connect(address); } catch (const SocketError &ex) { if (ex.code() == SocketError::WouldBlockRead || ex.code() == SocketError::WouldBlockWrite) { SocketListener listener{{*this, toDirection(ex.code())}}; listener.select(timeout); // Second try connect(address); } else { throw; } } } SocketSsl SocketSsl::accept() { SocketAddress dummy; return accept(dummy); } SocketSsl SocketSsl::accept(SocketAddress &info) { auto client = standardAccept(info); auto context = SSL_CTX_new(sslMethod(m_options.method)); if (m_options.certificate.size() > 0) SSL_CTX_use_certificate_file(context, m_options.certificate.c_str(), SSL_FILETYPE_PEM); if (m_options.privateKey.size() > 0) SSL_CTX_use_PrivateKey_file(context, m_options.privateKey.c_str(), SSL_FILETYPE_PEM); if (m_options.verify && !SSL_CTX_check_private_key(context)) { client.close(); throw SocketError(SocketError::System, "accept", "certificate failure"); } // SSL object auto ssl = SSL_new(context); SSL_set_fd(ssl, client.handle()); auto ret = SSL_accept(ssl); if (ret <= 0) { auto error = SSL_get_error(ssl, ret); if (error == SSL_ERROR_WANT_READ) { throw SocketError(SocketError::WouldBlockRead, "accept", "Operation would block"); } else if (error == SSL_ERROR_WANT_WRITE) { throw SocketError(SocketError::WouldBlockWrite, "accept", "Operation would block"); } else { throw SocketError(SocketError::System, "accept", sslError(error)); } } return SocketSsl(client.handle(), context, ssl); } unsigned SocketSsl::recv(void *data, unsigned len) { auto nbread = SSL_read(m_ssl.get(), data, len); if (nbread <= 0) { auto error = SSL_get_error(m_ssl.get(), nbread); if (error == SSL_ERROR_WANT_READ) { throw SocketError(SocketError::WouldBlockRead, "recv", "Operation would block"); } else if (error == SSL_ERROR_WANT_WRITE) { throw SocketError(SocketError::WouldBlockWrite, "recv", "Operation would block"); } else { throw SocketError(SocketError::System, "recv", sslError(error)); } } return nbread; } unsigned SocketSsl::waitRecv(void *data, unsigned len, int timeout) { SocketListener listener{{*this, SocketListener::Read}}; listener.select(timeout); return recv(data, len); } unsigned SocketSsl::send(const void *data, unsigned len) { auto nbread = SSL_write(m_ssl.get(), data, len); if (nbread <= 0) { auto error = SSL_get_error(m_ssl.get(), nbread); if (error == SSL_ERROR_WANT_READ) { throw SocketError(SocketError::WouldBlockRead, "send", "Operation would block"); } else if (error == SSL_ERROR_WANT_WRITE) { throw SocketError(SocketError::WouldBlockWrite, "send", "Operation would block"); } else { throw SocketError(SocketError::System, "send", sslError(error)); } } return nbread; } unsigned SocketSsl::waitSend(const void *data, unsigned len, int timeout) { SocketListener listener{{*this, SocketListener::Write}}; listener.select(timeout); return send(data, len); }