changeset 457:060acd5945a3

Socket: - New net::State to describe the socket state (experimental), - The connect() function can be safely called again for non-blocking operations.
author David Demelier <markand@malikania.fr>
date Tue, 03 Nov 2015 14:07:10 +0100
parents a1cec0345d76
children db738947f359
files C++/modules/Socket/Sockets.h
diffstat 1 files changed, 190 insertions(+), 109 deletions(-) [+]
line wrap: on
line diff
--- a/C++/modules/Socket/Sockets.h	Tue Nov 03 11:34:46 2015 +0100
+++ b/C++/modules/Socket/Sockets.h	Tue Nov 03 14:07:10 2015 +0100
@@ -129,6 +129,7 @@
 #  include <openssl/ssl.h>
 #endif
 
+#include <cassert>
 #include <chrono>
 #include <cstdlib>
 #include <cstring>
@@ -385,6 +386,31 @@
 /* }}} */
 
 /*
+ * State class
+ * ------------------------------------------------------------------
+ *
+ * To facilitate higher-level stuff, the socket has a state.
+ */
+
+/* {{{ State */
+
+/**
+ * @enum State
+ * @brief Current socket state
+ */
+enum class State {
+	Open,		//!< Socket is open
+	Bound,		//!< Socket is bound to an address
+	Connecting,	//!< Connection is in progress
+	Connected,	//!< Connection is complete
+	Accepting,	//!< The socket is being accepted (client)
+	Accepted,	//!< Socket has been accepted (client)
+	Closed,		//!< The socket has been closed
+};
+
+/* }}} */
+
+/*
  * Base Socket class
  * ------------------------------------------------------------------
  *
@@ -402,6 +428,7 @@
 class Socket {
 private:
 	Type m_type;
+	State m_state{State::Closed};
 
 protected:
 	/**
@@ -411,33 +438,10 @@
 
 public:
 	/**
-	 * This tries to create a socket.
+	 * Create a socket handle.
 	 *
-	 * @param address which type of address
-	 * @param type the type instance
-	 */
-	explicit inline Socket(const Address &address = {}, Type type = Type{})
-		: Socket{address.domain(), type.type(), 0}
-	{
-		/* Some implementation requires more things */
-		m_type = std::move(type);
-		m_type.create(*this);
-	}
-
-	/**
-	 * Construct a socket with an already created descriptor.
-	 *
-	 * @param handle the native descriptor
-	 * @param type the type of socket implementation
-	 */
-	explicit inline Socket(Handle handle, Type type = Type{}) noexcept
-		: m_type(std::move(type))
-		, m_handle{handle}
-	{
-	}
-
-	/**
-	 * Create a socket handle.
+	 * This is the primary function and the only one that creates the socket handle, all other constructors
+	 * are just overloaded functions.
 	 *
 	 * @param domain the domain AF_*
 	 * @param type the type SOCK_*
@@ -458,6 +462,33 @@
 		}
 
 		m_type.create(*this);
+		m_state = State::Open;
+	}
+	/**
+	 * This tries to create a socket.
+	 *
+	 * Domain and type are determined by the Address and Type object.
+	 *
+	 * @param address which type of address
+	 * @param type the type instance
+	 */
+	explicit inline Socket(const Address &address = {}, Type type = Type{})
+		: Socket{address.domain(), type.type(), 0, std::move(type)}
+	{
+	}
+
+	/**
+	 * Construct a socket with an already created descriptor.
+	 *
+	 * @param handle the native descriptor
+	 * @param state specify the socket state
+	 * @param type the type of socket implementation
+	 */
+	explicit inline Socket(Handle handle, State state = State::Open, Type type = Type{}) noexcept
+		: m_type(std::move(type))
+		, m_handle{handle}
+		, m_state{state}
+	{
 	}
 
 	/**
@@ -473,9 +504,11 @@
 	inline Socket(Socket &&other) noexcept
 		: m_handle{other.m_handle}
 		, m_type{std::move(other.m_type)}
+		, m_state{other.m_state}
 	{
 		/* Invalidate other */
 		other.m_handle = -1;
+		other.m_state = State::Closed;
 	}
 
 	/**
@@ -487,6 +520,27 @@
 	}
 
 	/**
+	 * Get the current socket state.
+	 *
+	 * @return the state
+	 */
+	inline State state() const noexcept
+	{
+		return m_state;
+	}
+
+	/**
+	 * Change the current socket state.
+	 *
+	 * @param state the new state
+	 * @warning only implementations should call this function
+	 */
+	inline void setState(State state) noexcept
+	{
+		m_state = state;
+	}
+
+	/**
 	 * Set an option for the socket.
 	 *
 	 * @param level the setting level
@@ -580,12 +634,17 @@
 	 *
 	 * @param address the address
 	 * @param length the size
+	 * @pre state() must not be State::Bound
 	 */
-	inline void bind(const sockaddr *address, socklen_t length)
+	void bind(const sockaddr *address, socklen_t length)
 	{
+		assert(m_state != State::Bound);
+
 		if (::bind(m_handle, address, length) == Failure) {
 			throw Error{Error::System, "bind"};
 		}
+
+		m_state = State::Bound;
 	}
 
 	/**
@@ -603,23 +662,42 @@
 	 * Listen for pending connection.
 	 *
 	 * @param max the maximum number
+	 * @pre state() must be Bound
 	 */
 	inline void listen(int max = 128)
 	{
+		assert(m_state == State::Bound);
+
 		if (::listen(this->m_handle, max) == Failure) {
 			throw Error{Error::System, "listen"};
 		}
 	}
 
 	/**
-	 * Connect to an address.
+	 * Connect or continue the connection to the specified address.
 	 *
 	 * @param address the address
+	 * @param length the address length
+	 * @pre state() must be State::Connecting or State::Open
 	 * @throw Error on errors
 	 */
+	void connect(const sockaddr *address, socklen_t length)
+	{
+		assert(m_state == State::Connecting || m_state == State::Open);	
+
+		m_type.connect(*this, address, length);
+	}
+
+	/**
+	 * Overloaded function.
+	 *
+	 * Calls connect(address.address(), address.length());
+	 *
+	 * @param address the address
+	 */
 	inline void connect(const Address &address)
 	{
-		m_type.connect(*this, address);
+		connect(address.address(), address.length());
 	}
 
 	/**
@@ -641,9 +719,12 @@
 	 *
 	 * @return the address
 	 * @throw Error on failures
+	 * @pre state() must not be State::Closed
 	 */
 	Address address() const
 	{
+		assert(m_state != State::Closed);
+
 		sockaddr_storage ss;
 		socklen_t length = sizeof (sockaddr_storage);
 
@@ -775,7 +856,7 @@
 	 *
 	 * Automatically called from the destructor.
 	 */
-	virtual void close()
+	void close()
 	{
 		if (m_handle != Invalid) {
 #if defined(_WIN32)
@@ -784,6 +865,7 @@
 			::close(m_handle);
 #endif
 			m_handle = Invalid;
+			m_state = State::Closed;
 		}
 	}
 
@@ -805,9 +887,11 @@
 	{
 		m_handle = other.m_handle;
 		m_type = std::move(other.m_type);
+		m_state = other.m_state;
 
 		/* Invalidate other */
 		other.m_handle = Invalid;
+		other.m_state = State::Closed;
 
 		return *this;
 	}
@@ -1164,13 +1248,6 @@
  * @brief Clear TCP implementation.
  */
 class Tcp {
-private:
-	enum {
-		Undefined,
-		Connecting,
-		Connected
-	} m_state{Undefined};
-
 protected:
 	/**
 	 * Standard accept.
@@ -1206,42 +1283,6 @@
 		return client;
 	}
 
-	/**
-	 * Standard connect.
-	 *
-	 * @param sc the socket
-	 * @param address the address
-	 * @param length the length
-	 */
-	void connect(Handle sc, const sockaddr *address, socklen_t length)
-	{
-		if (::connect(sc, address, length) == Failure) {
-			/*
-			 * Determine if the error comes from a non-blocking connect that cannot be
-			 * accomplished yet.
-			 */
-#if defined(_WIN32)
-			int error = WSAGetLastError();
-
-			if (error == WSAEWOULDBLOCK) {
-				m_state = Connecting;
-				throw Error{Error::WouldBlockWrite, "connect", error};
-			}
-
-			throw Error{Error::System, "connect", error};
-#else
-			if (errno == EINPROGRESS) {
-				m_state = Connecting;
-				throw Error{Error::WouldBlockWrite, "connect"};
-			}
-
-			throw Error{Error::System, "connect"};
-#endif
-		} else {
-			m_state = Connected;
-		}
-	}
-
 public:
 	/**
 	 * Socket type.
@@ -1263,6 +1304,53 @@
 	}
 
 	/**
+	 * Standard connect. Wrapper of  connect(2)
+	 *
+	 * @param sc the socket
+	 * @param address the address
+	 * @param length the length
+	 */
+	template <typename Address, typename Type>
+	void connect(Socket<Address, Type> &sc, const sockaddr *address, socklen_t length)
+	{
+		if (sc.state() == State::Open) {
+			if (::connect(sc.handle(), address, length) == Failure) {
+				/*
+				 * Determine if the error comes from a non-blocking connect that cannot be
+				 * accomplished yet.
+				 */
+#if defined(_WIN32)
+				int error = WSAGetLastError();
+
+				if (error == WSAEWOULDBLOCK) {
+					sc.setState(State::Connecting);
+					throw Error{Error::WouldBlockWrite, "connect", error};
+				}
+
+				throw Error{Error::System, "connect", error};
+#else
+				if (errno == EINPROGRESS) {
+					sc.setState(State::Connecting);
+					throw Error{Error::WouldBlockWrite, "connect"};
+				}
+
+				throw Error{Error::System, "connect"};
+#endif
+			} else {
+				sc.setState(State::Connected);
+			}
+		} else if (sc.state() == State::Connecting) {
+			int error = sc.template get<int>(SOL_SOCKET, SO_ERROR);
+
+			if (error == Failure) {
+				throw Error{Error::System, "connect", error};
+			}
+
+			sc.setState(State::Connected);
+		}
+	}
+
+	/**
 	 * Accept a clear client. Wrapper of accept(2).
 	 *
 	 * @param sc the socket
@@ -1281,26 +1369,6 @@
 		return Socket<Address, Tcp>{handle};
 	}
 
-	/**
-	 * Connect to the end point. Wrapper for connect(2).
-	 *
-	 * @param sc the socket
-	 * @param address the address
-	 * @throw Error on errors
-	 */
-	template <typename Address>
-	void connect(Socket<Address, Tcp> &sc, const Address &address)
-	{
-		if (m_state == Undefined) {
-			connect(sc.handle(), address.address(), address.length());
-		} else if (m_state == Connecting) {
-			int error = sc.template get<int>(SOL_SOCKET, SO_ERROR);
-
-			if (error == Failure) {
-				throw Error{Error::System, "connect", error};
-			}
-		}
-	}
 
 	/**
 	 * Receive data. Wrapper of recv(2).
@@ -1521,6 +1589,9 @@
 	std::string m_certificate;
 	bool m_verify{false};
 
+	/* Status */
+	bool m_tcpconnected{false};
+
 	Tls(Context context, Ssl ssl)
 		: m_context{std::move(context)}
 		, m_ssl{std::move(ssl)}
@@ -1604,25 +1675,35 @@
 	 * @param address the address
 	 * @throw Error on errors
 	 */
-	template <typename Address>
-	void connect(Socket<Address, Tls> &sc, const Address &address)
+	template <typename Address, typename Type>
+	void connect(Socket<Address, Type> &sc, const sockaddr *address, socklen_t length)
 	{
 		/* 1. Standard connect */
-		Tcp::connect(sc, address);
-
-		/* 2. OpenSSL handshake */
-		auto ret = SSL_connect(m_ssl.get());
-
-		if (ret <= 0) {
-			auto no = SSL_get_error(m_ssl.get(), ret);
-
-			if (no == SSL_ERROR_WANT_READ) {
-				throw Error{Error::WouldBlockRead, "connect", "Operation in progress"};
-			} else if (no == SSL_ERROR_WANT_WRITE) {
-				throw Error{Error::WouldBlockWrite, "connect", "Operation in progress"};
-			} else {
-				throw Error{Error::System, "connect", error(no)};
+		if (!m_tcpconnected) {
+			Tcp::connect(sc, address, length);
+	
+			/* Standard connection is done */
+			m_tcpconnected = true;
+
+			/* Reset because the TCP is connected but not OpenSSL */
+			sc.setState(State::Connecting);
+		} else {
+			/* 2. OpenSSL handshake */
+			auto ret = SSL_connect(m_ssl.get());
+
+			if (ret <= 0) {
+				auto no = SSL_get_error(m_ssl.get(), ret);
+
+				if (no == SSL_ERROR_WANT_READ) {
+					throw Error{Error::WouldBlockRead, "connect", "Operation in progress"};
+				} else if (no == SSL_ERROR_WANT_WRITE) {
+					throw Error{Error::WouldBlockWrite, "connect", "Operation in progress"};
+				} else {
+					throw Error{Error::System, "connect", error(no)};
+				}
 			}
+
+			sc.setState(State::Connected);
 		}
 	}