Mercurial > code
changeset 508:8fc7fe1ec915
Sockets: pass Condition by reference, fix Tls
author | David Demelier <markand@malikania.fr> |
---|---|
date | Tue, 23 Feb 2016 12:04:40 +0100 |
parents | b2b2442e3291 |
children | 36e81ef34ed5 |
files | modules/sockets/CMakeLists.txt modules/sockets/sockets.h modules/sockets/test/main.cpp |
diffstat | 3 files changed, 306 insertions(+), 410 deletions(-) [+] |
line wrap: on
line diff
--- a/modules/sockets/CMakeLists.txt Mon Feb 22 20:19:54 2016 +0100 +++ b/modules/sockets/CMakeLists.txt Tue Feb 23 12:04:40 2016 +0100 @@ -26,7 +26,6 @@ code_define_module( NAME sockets INCLUDES ${OPENSSL_INCLUDE_DIR} - FLAGS -DSOCKET_NO_SSL LIBRARIES ${LIBRARIES} ${OPENSSL_LIBRARIES}
--- a/modules/sockets/sockets.h Mon Feb 22 20:19:54 2016 +0100 +++ b/modules/sockets/sockets.h Tue Feb 23 12:04:40 2016 +0100 @@ -570,6 +570,12 @@ private: Protocol m_proto; + inline void reset(Condition *condition) const noexcept + { + if (condition) + *condition = Condition::None; + } + protected: /** * The native handle. @@ -841,65 +847,6 @@ } /** - * - * @pre isOpen() - * @param address the address - * @param length the the address length - * @param condition the condition to wait (Optional) - * @throw net::Error on errors - */ - inline void connect(const sockaddr *address, socklen_t length, Condition *condition = nullptr) - { - assert(m_handle != Invalid); - - m_proto.connect(*this, address, length, condition); - - } - - /** - * Overloaded function. - * - * Effectively call connect(address.address(), address.length(), condition); - * - * @pre isOpen() - * @param address the address - * @param condition the condition to wait (Optional) - */ - inline void connect(const Address &address, Condition *condition = nullptr) - { - assert(m_handle != Invalid); - - connect(address.address(), address.length(), condition); - } - - inline void connect(Condition &condition) - { - m_proto.connect(*this, &condition); - } - - Socket<Address, Protocol> accept(Address *info, Condition *condition = nullptr) - { - assert(m_handle != Invalid); - - sockaddr_storage storage; - socklen_t length = sizeof (storage); - - Socket<Address, Protocol> sc = m_proto.accept(*this, reinterpret_cast<sockaddr *>(&storage), &length, condition); - - if (info) - *info = Address{&storage, length}; - - return sc; - } - - inline void accept(Condition &condition) - { - assert(m_handle != Invalid); - - m_proto.accept(*this, &condition); - } - - /** * Get the local name. This is a wrapper of getsockname(). * * @pre isOpen() @@ -920,108 +867,220 @@ return Address(&ss, length); } - /** - * - * - * @pre isOpen() - * @param data the destination buffer - * @param length the buffer length - * @return the number of bytes received or 0 - * @throw Error on error - */ - inline unsigned recv(void *data, unsigned length, Condition *condition = nullptr) + + + + + void connect(const sockaddr *address, socklen_t length, Condition *cond = nullptr) + { + assert(m_handle != Invalid); + + Condition dummy; + + reset(cond); + m_proto.connect(*this, address, length, cond ? *cond : dummy); + } + + inline void connect(const Address &address, Condition *cond = nullptr) + { + connect(address.address(), address.length(), cond); + } + + + + + + + + + + + + void connect(Condition *cond = nullptr) + { + assert(m_handle != Invalid); + + Condition dummy; + + reset(cond); + m_proto.connect(*this, cond ? *cond : dummy); + } + + + + + + + void accept(Socket<Address, Protocol> &client, Address *address = nullptr, Condition *cond = nullptr) { assert(m_handle != Invalid); - return m_proto.recv(*this, data, length, condition); + reset(cond); + + Condition dummy; + sockaddr_storage storage; + socklen_t length = sizeof (storage); + + reset(cond); + m_proto.accept(*this, client, reinterpret_cast<sockaddr *>(&storage), &length, cond ? *cond : dummy); + + if (address) + *address = Address(&storage, length); + } + + void accept(Condition *cond = nullptr) + { + assert(m_handle != Invalid); + + Condition dummy; + + reset(cond); + m_proto.accept(*this, cond ? *cond : dummy); } - /** - * Overloaded function. - * - * @pre isOpen() - * @param count the number of bytes to receive - * @return the string - * @throw Error on error - */ - std::string recv(unsigned count, Condition *condition = nullptr) + + + + + + + + + + + + + + + + + + + unsigned recv(void *data, unsigned length, Condition *cond = nullptr) + { + assert(m_handle != Invalid); + + Condition dummy; + + reset(cond); + + return m_proto.recv(*this, data, length, cond ? *cond : dummy); + } + + std::string recv(unsigned count, Condition *cond = nullptr) { assert(m_handle != Invalid); std::string result; result.resize(count); - auto n = recv(const_cast<char *>(result.data()), count, condition); + auto n = recv(const_cast<char *>(result.data()), count, cond); result.resize(n); return result; } - /** - * - * - * @pre isOpen() - * @param data the data buffer - * @param length the buffer length - * @return the number of bytes sent or 0 - * @throw Error on error - */ - unsigned send(const void *data, unsigned length, Condition *condition = nullptr) + + + + + + + + unsigned send(const void *data, unsigned length, Condition *cond = nullptr) + { + assert(m_handle != Invalid); + + Condition dummy; + + reset(cond); + + return m_proto.send(*this, data, length, cond ? *cond : dummy); + } + + inline unsigned send(const std::string &data, Condition *cond = nullptr) + { + return send(data.c_str(), data.length(), cond); + } + + + + + + + + + + + unsigned sendto(const void *data, unsigned length, const sockaddr *address, socklen_t addrlen, Condition *cond = nullptr) { assert(m_handle != Invalid); - return m_proto.send(*this, data, length, condition); - } - - inline unsigned send(const std::string &data, Condition *condition = nullptr) - { - return m_proto.send(*this, data.c_str(), data.size(), condition); + Condition dummy; + + reset(cond); + + return m_proto.sendto(*this, data, length, address, addrlen, cond ? *cond : dummy); } - inline unsigned sendto(const void *data, unsigned length, const sockaddr *address, socklen_t addrlen, Condition *condition = nullptr) + inline unsigned sendto(const void *data, unsigned length, const Address &address, Condition *cond = nullptr) { - return m_proto.sendto(*this, data, length, address, addrlen); + return sendto(data, length, address.address(), address.length(), cond); } - inline unsigned sendto(const void *data, unsigned length, const Address &address, Condition *condition = nullptr) + inline unsigned sendto(const std::string &data, const Address &address, Condition *cond = nullptr) { - return m_proto.sendto(data, length, address.address(), address.length(), condition); - } - - inline unsigned sendto(const std::string &data, const Address &address, Condition *condition = nullptr) - { - return m_proto.sendto(*this, data.c_str(), data.length(), address.address(), address.length(), condition); + return sendto(data.c_str(), data.length(), address.address(), address.length(), cond); } - inline unsigned recvfrom(void *data, unsigned length, sockaddr *address, socklen_t *addrlen, Condition *condition = nullptr) + + + + + + + + unsigned recvfrom(void *data, unsigned length, sockaddr *address, socklen_t *addrlen, Condition *cond = nullptr) { - return m_proto.recvfrom(*this, data, length, address, addrlen); + assert(m_handle != Invalid); + + Condition dummy; + + reset(cond); + + return m_proto.recvfrom(*this, data, length, address, addrlen, cond ? *cond : dummy); } - inline unsigned recvfrom(void *data, unsigned length, Address *info = nullptr, Condition *condition = nullptr) + unsigned recvfrom(void *data, unsigned length, Address *address = nullptr, Condition *cond = nullptr) { sockaddr_storage storage; socklen_t addrlen = sizeof (sockaddr_storage); - auto n = m_proto.recvfrom(*this, data, length, reinterpret_cast<sockaddr *>(&storage), &addrlen, condition); - - if (info && n != 0) - *info = Address{&storage, addrlen}; + auto n = recvfrom(data, length, reinterpret_cast<sockaddr *>(&storage), &addrlen, cond); + + if (address && n != 0) + *address = Address(&storage, addrlen); return n; } - std::string recvfrom(unsigned count, Address *info = nullptr, Condition *condition = nullptr) + std::string recvfrom(unsigned count, Address *info = nullptr, Condition *cond = nullptr) { std::string result; result.resize(count); - auto n = recvfrom(const_cast<char *>(result.data()), count, info, condition); + auto n = recvfrom(const_cast<char *>(result.data()), count, info, cond); result.resize(n); return result; } + + + + + + /** * Close the socket. * @@ -1505,7 +1564,7 @@ std::memset(result, 0, sizeof (result)); - if (!inet_ntop(AF_INET, &m_sin.sin_addr, result, sizeof (result))) + if (!inet_ntop(AF_INET, const_cast<in_addr *>(&m_sin.sin_addr), result, sizeof (result))) throw Error(Error::System, "inet_ntop"); return result; @@ -1563,7 +1622,7 @@ std::memset(result, 0, sizeof (result)); - if (!inet_ntop(AF_INET6, &m_sin6.sin6_addr, result, sizeof (result))) + if (!inet_ntop(AF_INET6, const_cast<in6_addr *>(&m_sin6.sin6_addr), result, sizeof (result))) throw Error(Error::System, "inet_ntop"); return result; @@ -1696,11 +1755,8 @@ } template <typename Address, typename Protocol> - void connect(Socket<Address, Protocol> &sc, const sockaddr *address, socklen_t length, Condition *condition) + void connect(Socket<Address, Protocol> &sc, const sockaddr *address, socklen_t length, Condition &cond) { - if (condition) - *condition = Condition::None; - if (::connect(sc.handle(), address, length) == Failure) { /* * Determine if the error comes from a non-blocking connect that cannot be @@ -1709,29 +1765,22 @@ #if defined(_WIN32) int error = WSAGetLastError(); - if (error == WSAEWOULDBLOCK) { - if (condition) - *condition = Condition::Writable; - } else { + if (error == WSAEWOULDBLOCK) + cond = Condition::Writable; + else throw Error(Error::System, "connect", error); - } #else - if (errno == EINPROGRESS) { - if (condition) - *condition = Condition::Writable; - } else { + if (errno == EINPROGRESS) + cond = Condition::Writable; + else throw Error(Error::System, "connect"); - } #endif } } template <typename Address, typename Protocol> - void connect(Socket<Address, Protocol> &sc, Condition *condition = nullptr) + void connect(Socket<Address, Protocol> &sc, Condition &) { - if (condition) - *condition = Condition::None; - int error = sc.template get<int>(SOL_SOCKET, SO_ERROR); if (error != 0) @@ -1739,32 +1788,25 @@ } template <typename Address, typename Protocol> - Socket<Address, Protocol> accept(Socket<Address, Protocol> &sc, sockaddr *address, socklen_t *length, Condition *condition = nullptr) + void accept(Socket<Address, Protocol> &sc, Socket<Address, Protocol> &client, sockaddr *address, socklen_t *length, Condition &) { - if (condition) - *condition = Condition::None; - Handle handle = ::accept(sc.handle(), address, length); if (handle == Invalid) - return Socket<Address, Protocol>(); - - return Socket<Address, Protocol>(handle); + client = Socket<Address, Protocol>(); + + client = Socket<Address, Protocol>(handle); } template <typename Address, typename Protocol> - inline void accept(Socket<Address, Protocol> &, Condition *condition) const noexcept + inline void accept(Socket<Address, Protocol> &, Condition &) const noexcept { - if (condition) - *condition = Condition::None; + /* no op */ } template <typename Address> - unsigned recv(Socket<Address, Tcp> &sc, void *data, unsigned length, Condition *condition) + unsigned recv(Socket<Address, Tcp> &sc, void *data, unsigned length, Condition &cond) { - if (condition) - *condition = Condition::None; - int nbread = ::recv(sc.handle(), (Arg)data, length, 0); if (nbread == Failure) { @@ -1773,18 +1815,14 @@ if (error == WSAEWOULDBLOCK) { nbread = 0; - - if (condition) - condition = Condition::Readable; + cond = Condition::Readable; } else { throw Error(Error::System, "recv", error); } #else if (errno == EAGAIN || errno == EWOULDBLOCK) { nbread = 0; - - if (condition) - *condition = Condition::Readable; + cond = Condition::Readable; } else { throw Error(Error::System, "recv"); } @@ -1795,11 +1833,8 @@ } template <typename Address> - unsigned send(Socket<Address, Tcp> &sc, const void *data, unsigned length, Condition *condition) + unsigned send(Socket<Address, Tcp> &sc, const void *data, unsigned length, Condition &cond) { - if (condition) - *condition = Condition::None; - int nbsent = ::send(sc.handle(), (ConstArg)data, length, 0); if (nbsent == Failure) { @@ -1808,18 +1843,14 @@ if (error == WSAEWOULDBLOCK) { nbsent = 0; - - if (condition) - *condition = Condition::Writable; + cond = Condition::Writable; } else { throw Error(Error::System, "send", error); } #else if (errno == EAGAIN || errno == EWOULDBLOCK) { nbsent = 0; - - if (condition) - *condition = Condition::Writable; + cond = Condition::Writable; } else { throw Error(Error::System, "send"); } @@ -1862,11 +1893,8 @@ } template <typename Address> - unsigned recvfrom(Socket<Address, Udp> &sc, void *data, unsigned length, sockaddr *address, socklen_t *addrlen, Condition *condition) + unsigned recvfrom(Socket<Address, Udp> &sc, void *data, unsigned length, sockaddr *address, socklen_t *addrlen, Condition &cond) { - if (condition) - *condition = Condition::Readable; - int nbread; nbread = ::recvfrom(sc.handle(), (Arg)data, length, 0, address, addrlen); @@ -1877,18 +1905,14 @@ if (error == WSAEWOULDBLOCK) { nbread = 0; - - if (condition) - *condition = Condition::Writable; + cond = Condition::Writable; } else { throw Error(Error::System, "recvfrom"); } #else if (errno == EAGAIN || errno == EWOULDBLOCK) { nbread = 0; - - if (condition) - *condition = Condition::Writable; + cond = Condition::Writable; } else { throw Error(Error::System, "recvfrom"); } @@ -1899,11 +1923,8 @@ } template <typename Address> - unsigned sendto(Socket<Address, Udp> &sc, const void *data, unsigned length, const sockaddr *address, socklen_t addrlen, Condition *condition) + unsigned sendto(Socket<Address, Udp> &sc, const void *data, unsigned length, const sockaddr *address, socklen_t addrlen, Condition &cond) { - if (condition) - *condition = Condition::None; - int nbsent; nbsent = ::sendto(sc.handle(), (ConstArg)data, length, 0, address, addrlen); @@ -1913,18 +1934,14 @@ if (error == WSAEWOULDBLOCK) { nbsent = 0; - - if (condition) - *condition = Condition::Writable; + cond = Condition::Writable; } else { throw Error(Error::System, "sendto", error); } #else if (errno == EAGAIN || errno == EWOULDBLOCK) { nbsent = 0; - - if (condition) - *condition = Condition::Writable; + cond = Condition::Writable; } else { throw Error(Error::System, "sendto"); } @@ -1941,19 +1958,6 @@ #if !defined(SOCKET_NO_SSL) -/** - * @class Tls - * @brief OpenSSL secure layer for TCP. - * - * **Note:** This protocol is much more difficult to use with non-blocking sockets, if some operations would block, the - * user is responsible of calling the function again by waiting for the appropriate condition. See the functions for - * more details. - * - * @see Tls::accept - * @see Tls::connect - * @see Tls::recv - * @see Tls::send - */ class Tls : private Tcp { private: using Context = std::shared_ptr<SSL_CTX>; @@ -1993,66 +1997,41 @@ return msg == nullptr ? "" : msg; } - /* - * Update the states after an uncompleted operation. - */ - template <typename Address, typename Protocol> - inline void updateStates(Socket<Address, Protocol> &sc, State state, Action action, int code) + template <typename Function> + void wrap(const std::string &func, Condition &cond, Function &&function) { - assert(code == SSL_ERROR_WANT_READ || code == SSL_ERROR_WANT_WRITE); - - sc.setState(state); - sc.setAction(action); - - if (code == SSL_ERROR_WANT_READ) { - sc.setCondition(Condition::Readable); - } else { - sc.setCondition(Condition::Writable); - } - } - - /* - * Continue the connect operation. - */ - template <typename Address, typename Protocol> - void processConnect(Socket<Address, Protocol> &sc) - { - int ret = SSL_connect(m_ssl.get()); + auto ret = function(); if (ret <= 0) { int no = SSL_get_error(m_ssl.get(), ret); - if (no == SSL_ERROR_WANT_READ || no == SSL_ERROR_WANT_WRITE) { - updateStates(sc, State::Connecting, Action::Connect, no); - } else { - sc.setState(State::Disconnected); - throw Error{Error::System, "connect", error(no)}; + switch (no) { + case SSL_ERROR_WANT_READ: + cond = Condition::Readable; + break; + case SSL_ERROR_WANT_WRITE: + cond = Condition::Writable; + break; + default: + throw Error(Error::System, func, error(no)); } - } else { - sc.setState(State::Connected); } } - /* - * Continue accept. - */ template <typename Address, typename Protocol> - void processAccept(Socket<Address, Protocol> &sc) + void doConnect(Socket<Address, Protocol> &sc, Condition &cond) { - int ret = SSL_accept(m_ssl.get()); - - if (ret <= 0) { - int no = SSL_get_error(m_ssl.get(), ret); - - if (no == SSL_ERROR_WANT_READ || no == SSL_ERROR_WANT_WRITE) { - updateStates(sc, State::Accepting, Action::Accept, no); - } else { - sc.setState(State::Disconnected); - throw Error(Error::System, "accept", error(no)); - } - } else { - sc.setState(State::Accepted); - } + wrap("connect", cond, [&] () -> int { + return SSL_connect(m_ssl.get()); + }); + } + + template <typename Address, typename Protocol> + void doAccept(Socket<Address, Protocol> &sc, Condition &cond) + { + wrap("accept", cond, [&] () -> int { + return SSL_accept(m_ssl.get()); + }); } public: @@ -2129,190 +2108,97 @@ { auto method = (m_method == ssl::Tlsv1) ? TLSv1_method() : SSLv3_method(); - m_context = {SSL_CTX_new(method), SSL_CTX_free}; - m_ssl = {SSL_new(m_context.get()), SSL_free}; + m_context = Context(SSL_CTX_new(method), SSL_CTX_free); + m_ssl = Ssl(SSL_new(m_context.get()), SSL_free); SSL_set_fd(m_ssl.get(), sc.handle()); /* Load certificates */ - if (m_certificate.size() > 0) { + if (m_certificate.size() > 0) SSL_CTX_use_certificate_file(m_context.get(), m_certificate.c_str(), SSL_FILETYPE_PEM); - } - if (m_key.size() > 0) { + if (m_key.size() > 0) SSL_CTX_use_PrivateKey_file(m_context.get(), m_key.c_str(), SSL_FILETYPE_PEM); - } - if (m_verify && !SSL_CTX_check_private_key(m_context.get())) { - throw Error{Error::System, "(openssl)", "unable to verify key"}; + if (m_verify && !SSL_CTX_check_private_key(m_context.get())) + throw Error(Error::System, "(openssl)", "unable to verify key"); + } + + template <typename Address, typename Protocol> + void connect(Socket<Address, Protocol> &sc, const sockaddr *address, socklen_t length, Condition &cond) + { + /* 1. Connect using raw TCP */ + Tcp::connect(sc, address, length, cond); + + /* 2. If the connection is complete (e.g. non-blocking), try handshake */ + if (cond == Condition::None) { + m_tcpconnected = true; + doConnect(sc, cond); } } - /** - * Connect to a secure host. - * - * If the socket is marked non-blocking and the connection cannot be established yet, then the state is set - * to State::Connecting, the condition is set to Condition::Readable or Condition::Writable, the user must - * wait for the appropriate condition before calling the overload connect which takes 0 argument. - * - * If the socket is blocking, this functions blocks until the connection is complete. - * - * If the connection was completed correctly the state is set to State::Connected. - * - * @param sc the socket - * @param address the address - * @param length the address length - * @throw net::Error on errors - */ template <typename Address, typename Protocol> - void connect(Socket<Address, Protocol> &sc, const sockaddr *address, socklen_t length) - { - /* 1. Connect using raw TCP */ - Tcp::connect(sc, address, length); - - /* 2. If the connection is complete (e.g. non-blocking), try handshake */ - if (sc.state() == State::Connected) { - m_tcpconnected = true; - processConnect(sc); - } - } - - /** - * Continue the connection. - * - * This function must be called when the socket is ready for reading or writing (check with Socket::condition), - * the state may change exactly like the initial connect call. - * - * @param sc the socket - * @throw net::Error on errors - */ - template <typename Address, typename Protocol> - void connect(Socket<Address, Protocol> &sc) + void connect(Socket<Address, Protocol> &sc, Condition &cond) { /* 1. Be sure to complete standard connect before */ if (!m_tcpconnected) { - Tcp::connect(sc); - m_tcpconnected = sc.state() == State::Connected; + Tcp::connect(sc, cond); + m_tcpconnected = cond = Condition::None; } - if (m_tcpconnected) { - processConnect(sc); + if (m_tcpconnected) + doConnect(sc, cond); + } + + template <typename Address> + void accept(Socket<Address, Tls> &sc, Socket<Address, Tls> &client, sockaddr *address, socklen_t *length, Condition &cond) + { + /* TCP sets empty client if no pending connection is available */ + Tcp::accept(sc, client, address, length, cond); + + if (client.isOpen()) { + Tls &proto = client.protocol(); + + /* 1. Share the context */ + proto.m_context = m_context; + + /* 2. Create new SSL instance */ + proto.m_ssl = Ssl(SSL_new(m_context.get()), SSL_free); + + SSL_set_fd(proto.m_ssl.get(), client.handle()); + + /* 3. Try accept process on the **new** client */ + proto.doAccept(client, cond); } } - /** - * Accept a secure client. - * - * Because SSL needs several round-trips, if the socket is marked non-blocking and the connection is not - * completed yet, a new socket is returned but with the State::Accepting state. Its condition is set to - * Condition::Readable or Condition::Writable, the user is responsible of calling accept overload which takes - * 0 arguments on the returned socket when the condition is met. - * - * If the socket is blocking, this function blocks until the client is accepted and returned. - * - * If the client is accepted correctly, its state is set to State::Accepted. This instance does not change. - * - * @param sc the socket - * @param address the address destination - * @param length the address length - * @return the client - * @throw net::Error on errors - */ - template <typename Address> - Socket<Address, Tls> accept(Socket<Address, Tls> &sc, sockaddr *address, socklen_t *length) + template <typename Address, typename Protocol> + inline void accept(Socket<Address, Protocol> &sc, Condition &cond) { - Socket<Address, Tls> client = Tcp::accept(sc, address, length); - Tls &proto = client.protocol(); - - /* 1. Share the context */ - proto.m_context = m_context; - - /* 2. Create new SSL instance */ - proto.m_ssl = Ssl{SSL_new(m_context.get()), SSL_free}; - SSL_set_fd(proto.m_ssl.get(), client.handle()); - - /* 3. Try accept process on the **new** client */ - proto.processAccept(client); - - return client; - } - - /** - * Continue accept. - * - * This function must be called on the client that is being accepted. - * - * Like accept or connect, user is responsible of calling this function until the connection is complete. - * - * @param sc the socket - * @throw net::Error on errors - */ - template <typename Address, typename Protocol> - inline void accept(Socket<Address, Protocol> &sc) - { - processAccept(sc); + doAccept(sc, cond); } - /** - * Receive some secure data. - * - * If the socket is marked non-blocking, 0 is returned if no data is available yet or if the connection - * needs renegociation. If renegociation is required case, the action is set to Action::Receive and condition - * is set to Condition::Readable or Condition::Writable. The user must wait that the condition is met and - * call this function again. - * - * @param sc the socket - * @param data the destination - * @param len the buffer length - * @return the number of bytes read - * @throw net::Error on errors - */ + template <typename Address> - unsigned recv(Socket<Address, Tls> &sc, void *data, unsigned len) + unsigned recv(Socket<Address, Tls> &sc, void *data, unsigned len, Condition &cond) { - auto nbread = SSL_read(m_ssl.get(), data, len); - - if (nbread <= 0) { - auto no = SSL_get_error(m_ssl.get(), nbread); - - if (no == SSL_ERROR_WANT_READ || no == SSL_ERROR_WANT_WRITE) { - nbread = 0; - updateStates(sc, sc.state(), Action::Receive, no); - } else { - throw Error{Error::System, "recv", error(no)}; - } - } - - return nbread; + int nbread = 0; + + wrap("recv", cond, [&] () -> int { + return (nbread = SSL_read(m_ssl.get(), data, len)); + }); + + return static_cast<unsigned>(nbread < 0 ? 0 : nbread); } - /** - * Send some data. - * - * Like recv, if the socket is marked non-blocking and no data can be sent or a negociation is required, - * condition and action are set. See receive for more details - * - * @param sc the socket - * @param data the data to send - * @param len the buffer length - * @return the number of bytes sent - * @throw net::Error on errors - */ template <typename Address> - unsigned send(Socket<Address, Tls> &sc, const void *data, unsigned len) + unsigned send(Socket<Address, Tls> &sc, const void *data, unsigned len, Condition &cond) { - auto nbsent = SSL_write(m_ssl.get(), data, len); - - if (nbsent <= 0) { - auto no = SSL_get_error(m_ssl.get(), nbsent); - - if (no == SSL_ERROR_WANT_READ || no == SSL_ERROR_WANT_WRITE) { - nbsent = 0; - updateStates(sc, sc.state(), Action::Send, no); - } else { - throw Error{Error::System, "send", error(no)}; - } - } - - return nbsent; + int nbsent = 0; + + wrap("send", cond, [&] () -> int { + return (nbsent = SSL_write(m_ssl.get(), data, len)); + }); + + return static_cast<unsigned>(nbsent < 0 ? 0 : nbsent); } }; @@ -2387,7 +2273,7 @@ /** * Helper to create OpenSSL TCP/Ip sockets. */ -using SocketTlsIp = Socket<address::Ip, protocol::Tls>; +using SocketTlsIp = Socket<address::Ipv4, protocol::Tls>; #endif // !SOCKET_NO_SSL
--- a/modules/sockets/test/main.cpp Mon Feb 22 20:19:54 2016 +0100 +++ b/modules/sockets/test/main.cpp Tue Feb 23 12:04:40 2016 +0100 @@ -113,9 +113,11 @@ TEST_F(TcpServerTest, connect) { m_tserver = std::thread([this] () { + SocketTcp<Ipv4> sc; + m_server.bind(Ipv4("*", 16000)); m_server.listen(); - m_server.accept(nullptr); + m_server.accept(sc); m_server.close(); }); @@ -133,7 +135,10 @@ m_server.bind(Ipv4("*", 16000)); m_server.listen(); - auto client = m_server.accept(nullptr); + SocketTcp<Ipv4> client; + + m_server.accept(client); + auto msg = client.recv(512); ASSERT_EQ("hello world", msg); @@ -500,7 +505,9 @@ m_listener.set(m_masterTcp.handle(), Condition::Readable); m_listener.wait(); - auto sc = m_masterTcp.accept(nullptr); + SocketTcp<Ipv4> sc; + + m_masterTcp.accept(sc); ASSERT_EQ("hello", sc.recv(512)); } catch (const std::exception &ex) { @@ -604,7 +611,9 @@ TEST_F(TcpRecvTest, blockingSuccess) { m_tserver = std::thread([this] () { - auto client = m_server.accept(nullptr); + SocketTcp<Ipv4> client; + + m_server.accept(client); ASSERT_EQ("hello", client.recv(32)); }); @@ -660,7 +669,9 @@ { m_tserver = std::thread([this] () { try { - auto client = m_server.accept(nullptr); + SocketTls<Ipv4> client; + + m_server.accept(client); ASSERT_EQ("hello", client.recv(32)); } catch (const net::Error &ex) {