Mercurial > code
changeset 457:060acd5945a3
Socket:
- New net::State to describe the socket state (experimental),
- The connect() function can be safely called again for non-blocking
operations.
author | David Demelier <markand@malikania.fr> |
---|---|
date | Tue, 03 Nov 2015 14:07:10 +0100 |
parents | a1cec0345d76 |
children | db738947f359 |
files | C++/modules/Socket/Sockets.h |
diffstat | 1 files changed, 190 insertions(+), 109 deletions(-) [+] |
line wrap: on
line diff
--- a/C++/modules/Socket/Sockets.h Tue Nov 03 11:34:46 2015 +0100 +++ b/C++/modules/Socket/Sockets.h Tue Nov 03 14:07:10 2015 +0100 @@ -129,6 +129,7 @@ # include <openssl/ssl.h> #endif +#include <cassert> #include <chrono> #include <cstdlib> #include <cstring> @@ -385,6 +386,31 @@ /* }}} */ /* + * State class + * ------------------------------------------------------------------ + * + * To facilitate higher-level stuff, the socket has a state. + */ + +/* {{{ State */ + +/** + * @enum State + * @brief Current socket state + */ +enum class State { + Open, //!< Socket is open + Bound, //!< Socket is bound to an address + Connecting, //!< Connection is in progress + Connected, //!< Connection is complete + Accepting, //!< The socket is being accepted (client) + Accepted, //!< Socket has been accepted (client) + Closed, //!< The socket has been closed +}; + +/* }}} */ + +/* * Base Socket class * ------------------------------------------------------------------ * @@ -402,6 +428,7 @@ class Socket { private: Type m_type; + State m_state{State::Closed}; protected: /** @@ -411,33 +438,10 @@ public: /** - * This tries to create a socket. + * Create a socket handle. * - * @param address which type of address - * @param type the type instance - */ - explicit inline Socket(const Address &address = {}, Type type = Type{}) - : Socket{address.domain(), type.type(), 0} - { - /* Some implementation requires more things */ - m_type = std::move(type); - m_type.create(*this); - } - - /** - * Construct a socket with an already created descriptor. - * - * @param handle the native descriptor - * @param type the type of socket implementation - */ - explicit inline Socket(Handle handle, Type type = Type{}) noexcept - : m_type(std::move(type)) - , m_handle{handle} - { - } - - /** - * Create a socket handle. + * This is the primary function and the only one that creates the socket handle, all other constructors + * are just overloaded functions. * * @param domain the domain AF_* * @param type the type SOCK_* @@ -458,6 +462,33 @@ } m_type.create(*this); + m_state = State::Open; + } + /** + * This tries to create a socket. + * + * Domain and type are determined by the Address and Type object. + * + * @param address which type of address + * @param type the type instance + */ + explicit inline Socket(const Address &address = {}, Type type = Type{}) + : Socket{address.domain(), type.type(), 0, std::move(type)} + { + } + + /** + * Construct a socket with an already created descriptor. + * + * @param handle the native descriptor + * @param state specify the socket state + * @param type the type of socket implementation + */ + explicit inline Socket(Handle handle, State state = State::Open, Type type = Type{}) noexcept + : m_type(std::move(type)) + , m_handle{handle} + , m_state{state} + { } /** @@ -473,9 +504,11 @@ inline Socket(Socket &&other) noexcept : m_handle{other.m_handle} , m_type{std::move(other.m_type)} + , m_state{other.m_state} { /* Invalidate other */ other.m_handle = -1; + other.m_state = State::Closed; } /** @@ -487,6 +520,27 @@ } /** + * Get the current socket state. + * + * @return the state + */ + inline State state() const noexcept + { + return m_state; + } + + /** + * Change the current socket state. + * + * @param state the new state + * @warning only implementations should call this function + */ + inline void setState(State state) noexcept + { + m_state = state; + } + + /** * Set an option for the socket. * * @param level the setting level @@ -580,12 +634,17 @@ * * @param address the address * @param length the size + * @pre state() must not be State::Bound */ - inline void bind(const sockaddr *address, socklen_t length) + void bind(const sockaddr *address, socklen_t length) { + assert(m_state != State::Bound); + if (::bind(m_handle, address, length) == Failure) { throw Error{Error::System, "bind"}; } + + m_state = State::Bound; } /** @@ -603,23 +662,42 @@ * Listen for pending connection. * * @param max the maximum number + * @pre state() must be Bound */ inline void listen(int max = 128) { + assert(m_state == State::Bound); + if (::listen(this->m_handle, max) == Failure) { throw Error{Error::System, "listen"}; } } /** - * Connect to an address. + * Connect or continue the connection to the specified address. * * @param address the address + * @param length the address length + * @pre state() must be State::Connecting or State::Open * @throw Error on errors */ + void connect(const sockaddr *address, socklen_t length) + { + assert(m_state == State::Connecting || m_state == State::Open); + + m_type.connect(*this, address, length); + } + + /** + * Overloaded function. + * + * Calls connect(address.address(), address.length()); + * + * @param address the address + */ inline void connect(const Address &address) { - m_type.connect(*this, address); + connect(address.address(), address.length()); } /** @@ -641,9 +719,12 @@ * * @return the address * @throw Error on failures + * @pre state() must not be State::Closed */ Address address() const { + assert(m_state != State::Closed); + sockaddr_storage ss; socklen_t length = sizeof (sockaddr_storage); @@ -775,7 +856,7 @@ * * Automatically called from the destructor. */ - virtual void close() + void close() { if (m_handle != Invalid) { #if defined(_WIN32) @@ -784,6 +865,7 @@ ::close(m_handle); #endif m_handle = Invalid; + m_state = State::Closed; } } @@ -805,9 +887,11 @@ { m_handle = other.m_handle; m_type = std::move(other.m_type); + m_state = other.m_state; /* Invalidate other */ other.m_handle = Invalid; + other.m_state = State::Closed; return *this; } @@ -1164,13 +1248,6 @@ * @brief Clear TCP implementation. */ class Tcp { -private: - enum { - Undefined, - Connecting, - Connected - } m_state{Undefined}; - protected: /** * Standard accept. @@ -1206,42 +1283,6 @@ return client; } - /** - * Standard connect. - * - * @param sc the socket - * @param address the address - * @param length the length - */ - void connect(Handle sc, const sockaddr *address, socklen_t length) - { - if (::connect(sc, address, length) == Failure) { - /* - * Determine if the error comes from a non-blocking connect that cannot be - * accomplished yet. - */ -#if defined(_WIN32) - int error = WSAGetLastError(); - - if (error == WSAEWOULDBLOCK) { - m_state = Connecting; - throw Error{Error::WouldBlockWrite, "connect", error}; - } - - throw Error{Error::System, "connect", error}; -#else - if (errno == EINPROGRESS) { - m_state = Connecting; - throw Error{Error::WouldBlockWrite, "connect"}; - } - - throw Error{Error::System, "connect"}; -#endif - } else { - m_state = Connected; - } - } - public: /** * Socket type. @@ -1263,6 +1304,53 @@ } /** + * Standard connect. Wrapper of connect(2) + * + * @param sc the socket + * @param address the address + * @param length the length + */ + template <typename Address, typename Type> + void connect(Socket<Address, Type> &sc, const sockaddr *address, socklen_t length) + { + if (sc.state() == State::Open) { + if (::connect(sc.handle(), address, length) == Failure) { + /* + * Determine if the error comes from a non-blocking connect that cannot be + * accomplished yet. + */ +#if defined(_WIN32) + int error = WSAGetLastError(); + + if (error == WSAEWOULDBLOCK) { + sc.setState(State::Connecting); + throw Error{Error::WouldBlockWrite, "connect", error}; + } + + throw Error{Error::System, "connect", error}; +#else + if (errno == EINPROGRESS) { + sc.setState(State::Connecting); + throw Error{Error::WouldBlockWrite, "connect"}; + } + + throw Error{Error::System, "connect"}; +#endif + } else { + sc.setState(State::Connected); + } + } else if (sc.state() == State::Connecting) { + int error = sc.template get<int>(SOL_SOCKET, SO_ERROR); + + if (error == Failure) { + throw Error{Error::System, "connect", error}; + } + + sc.setState(State::Connected); + } + } + + /** * Accept a clear client. Wrapper of accept(2). * * @param sc the socket @@ -1281,26 +1369,6 @@ return Socket<Address, Tcp>{handle}; } - /** - * Connect to the end point. Wrapper for connect(2). - * - * @param sc the socket - * @param address the address - * @throw Error on errors - */ - template <typename Address> - void connect(Socket<Address, Tcp> &sc, const Address &address) - { - if (m_state == Undefined) { - connect(sc.handle(), address.address(), address.length()); - } else if (m_state == Connecting) { - int error = sc.template get<int>(SOL_SOCKET, SO_ERROR); - - if (error == Failure) { - throw Error{Error::System, "connect", error}; - } - } - } /** * Receive data. Wrapper of recv(2). @@ -1521,6 +1589,9 @@ std::string m_certificate; bool m_verify{false}; + /* Status */ + bool m_tcpconnected{false}; + Tls(Context context, Ssl ssl) : m_context{std::move(context)} , m_ssl{std::move(ssl)} @@ -1604,25 +1675,35 @@ * @param address the address * @throw Error on errors */ - template <typename Address> - void connect(Socket<Address, Tls> &sc, const Address &address) + template <typename Address, typename Type> + void connect(Socket<Address, Type> &sc, const sockaddr *address, socklen_t length) { /* 1. Standard connect */ - Tcp::connect(sc, address); - - /* 2. OpenSSL handshake */ - auto ret = SSL_connect(m_ssl.get()); - - if (ret <= 0) { - auto no = SSL_get_error(m_ssl.get(), ret); - - if (no == SSL_ERROR_WANT_READ) { - throw Error{Error::WouldBlockRead, "connect", "Operation in progress"}; - } else if (no == SSL_ERROR_WANT_WRITE) { - throw Error{Error::WouldBlockWrite, "connect", "Operation in progress"}; - } else { - throw Error{Error::System, "connect", error(no)}; + if (!m_tcpconnected) { + Tcp::connect(sc, address, length); + + /* Standard connection is done */ + m_tcpconnected = true; + + /* Reset because the TCP is connected but not OpenSSL */ + sc.setState(State::Connecting); + } else { + /* 2. OpenSSL handshake */ + auto ret = SSL_connect(m_ssl.get()); + + if (ret <= 0) { + auto no = SSL_get_error(m_ssl.get(), ret); + + if (no == SSL_ERROR_WANT_READ) { + throw Error{Error::WouldBlockRead, "connect", "Operation in progress"}; + } else if (no == SSL_ERROR_WANT_WRITE) { + throw Error{Error::WouldBlockWrite, "connect", "Operation in progress"}; + } else { + throw Error{Error::System, "connect", error(no)}; + } } + + sc.setState(State::Connected); } }