changeset 461:41d1a36cc461

Socket: - Add more examples, - Implement new accept function.
author David Demelier <markand@malikania.fr>
date Tue, 03 Nov 2015 21:48:14 +0100
parents f6b4c7491d18
children 9d53c536372e
files C++/examples/Socket/blocking-accept.cpp C++/examples/Socket/blocking-connect.cpp C++/examples/Socket/non-blocking-accept.cpp C++/examples/Socket/non-blocking-connect.cpp C++/modules/Socket/Sockets.h C++/tests/Socket/main.cpp
diffstat 6 files changed, 381 insertions(+), 106 deletions(-) [+]
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/C++/examples/Socket/blocking-accept.cpp	Tue Nov 03 21:48:14 2015 +0100
@@ -0,0 +1,51 @@
+/*
+ * blocking-accept.cpp -- example of blocking accept
+ *
+ * Options:
+ *   - WITH_PORT (int), the port to use (default: 16000)
+ *   - WITH_TIMEOUT (int), number of seconds before giving up (default: 60)
+ *   - WITH_SSL (bool), true to test with SSL (default: false)
+ */
+
+#include <iostream>
+
+#include "Sockets.h"
+
+#if !defined(WITH_PORT)
+#  define WITH_PORT 16000
+#endif
+
+#if !defined(WITH_TIMEOUT)
+#  define WITH_TIMEOUT 60
+#endif
+
+int main()
+{
+#if defined(WITH_SSL)
+	net::SocketTls<net::Ipv4> master;
+	net::SocketTls<net::Ipv4> client{net::Invalid};
+#else
+	net::SocketTcp<net::Ipv4> master;
+	net::SocketTcp<net::Ipv4> client{net::Invalid};
+#endif
+
+	net::Listener<> listener;
+
+	try {
+		master.bind(net::Ipv4{"*", WITH_PORT});
+		master.listen();
+
+		listener.set(master.handle(), net::FlagRead);
+		listener.wait(std::chrono::seconds(WITH_TIMEOUT));
+
+		client = master.accept(nullptr);
+	} catch (const net::Error &error) {
+		std::cerr << "error: " << error.what() << std::endl;
+		std::exit(1);
+	}
+
+	std::cout << "Client successfully accepted!" << std::endl;
+
+	return 0;	
+}
+
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/C++/examples/Socket/blocking-connect.cpp	Tue Nov 03 21:48:14 2015 +0100
@@ -0,0 +1,50 @@
+/*
+ * blocking-connect.cpp -- example of blocking connect
+ *
+ * Options:
+ *   - WITH_HOST (string literal), the host to try (default: "malikania.fr")
+ *   - WITH_PORT (int), the port to use (default: 80)
+ *   - WITH_TIMEOUT (int), number of seconds before giving up (default: 30)
+ *   - WITH_SSL (bool), true to test with SSL (default: false)
+ */
+
+#include <iostream>
+
+#if !defined(WITH_HOST)
+#  define WITH_HOST "malikania.fr"
+#endif
+
+#if !defined(WITH_PORT)
+#  define WITH_PORT 80
+#endif
+
+#if !defined(WITH_TIMEOUT)
+#  define WITH_TIMEOUT 30
+#endif
+
+#include "ElapsedTimer.h"
+#include "Sockets.h"
+
+int main()
+{
+#if defined(WITH_SSL)
+	net::SocketTls<net::Ipv4> socket;
+#else
+	net::SocketTcp<net::Ipv4> socket;
+#endif
+
+	try {
+		std::cout << "Trying to connect to " << WITH_HOST << ":" << WITH_PORT << std::endl;
+		socket.connect(net::Ipv4{WITH_HOST, WITH_PORT});
+	} catch (const net::Error &error) {
+		std::cerr << "error: " << error.what() << std::endl;
+		std::exit(1);
+	}
+
+	if (socket.state() == net::State::Connected) {
+		std::cout << "Successfully connected!" << std::endl;
+	}
+
+	return 0;
+}
+
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/C++/examples/Socket/non-blocking-accept.cpp	Tue Nov 03 21:48:14 2015 +0100
@@ -0,0 +1,79 @@
+/*
+ * non-blocking-accept.cpp -- example of total non-blocking accept
+ *
+ * Options:
+ *   - WITH_PORT (int), the port to use (default: 16000)
+ *   - WITH_TIMEOUT (int), number of seconds before giving up (default: 60)
+ *   - WITH_SSL (bool), true to test with SSL (default: false)
+ */
+
+#include <iostream>
+
+#include "ElapsedTimer.h"
+#include "Sockets.h"
+
+#if !defined(WITH_PORT)
+#  define WITH_PORT 16000
+#endif
+
+#if !defined(WITH_TIMEOUT)
+#  define WITH_TIMEOUT 60
+#endif
+
+int main()
+{
+#if defined(WITH_SSL)
+	net::SocketTls<net::Ipv4> master;
+	net::SocketTls<net::Ipv4> client{net::Invalid};
+#else
+	net::SocketTcp<net::Ipv4> master;
+	net::SocketTcp<net::Ipv4> client{net::Invalid};
+#endif
+
+	net::Listener<> listener;
+	ElapsedTimer timer;
+
+	// 1. Create the master socket for listening.
+	try {
+		master.bind(net::Ipv4{"*", WITH_PORT});
+		master.listen();
+
+		// Usually never needed, but for the example put everything as non-blocking.
+		master.setBlockMode(false);
+	} catch (const net::Error &error) {
+		std::cerr << "error: " << error.what() << std::endl;
+		std::exit(1);
+	}
+
+	while (client.state() != net::State::Accepted && timer.elapsed() < (WITH_TIMEOUT * 1000)) {
+		try {
+			if (client.state() == net::State::Closed) {
+				// 1. Wait for a pre-accept process.
+				listener.set(master.handle(), net::FlagRead);
+				listener.wait(std::chrono::seconds(WITH_TIMEOUT));
+				client = master.accept(nullptr);
+				client.setBlockMode(false);
+				listener.remove(master.handle());
+			} else {
+				// 2. Wait for the accept process to complete.
+				listener.remove(client.handle());
+
+				if (client.state() == net::State::AcceptingRead) {
+					listener.set(client.handle(), net::FlagRead);
+				} else if (client.state() == net::State::AcceptingWrite) {
+					listener.set(client.handle(), net::FlagWrite);
+				}
+
+				listener.wait(std::chrono::seconds(WITH_TIMEOUT));
+				client.accept();
+			}
+		} catch (const net::Error &error) {
+			std::cerr << error.function() << ": " << error.what() << std::endl;
+			std::exit(1);
+		}
+	}
+
+	std::cout << "Client successfully accepted!" << std::endl;
+
+	return 0;	
+}
--- a/C++/examples/Socket/non-blocking-connect.cpp	Tue Nov 03 19:57:00 2015 +0100
+++ b/C++/examples/Socket/non-blocking-connect.cpp	Tue Nov 03 21:48:14 2015 +0100
@@ -22,8 +22,8 @@
 #  define WITH_TIMEOUT 30
 #endif
 
+#include "ElapsedTimer.h"
 #include "Sockets.h"
-#include "ElapsedTimer.h"
 
 int main()
 {
@@ -64,7 +64,9 @@
 		std::exit(1);
 	}
 
-	std::cout << "Successfully connected!" << std::endl;
+	if (socket.state() == net::State::Connected) {
+		std::cout << "Successfully connected!" << std::endl;
+	}
 
 	return 0;
 }
--- a/C++/modules/Socket/Sockets.h	Tue Nov 03 19:57:00 2015 +0100
+++ b/C++/modules/Socket/Sockets.h	Tue Nov 03 21:48:14 2015 +0100
@@ -401,10 +401,11 @@
 enum class State {
 	Open,			//!< Socket is open
 	Bound,			//!< Socket is bound to an address
-	ConnectingRead,		//!< Connection is in progress but requires to bereadable
-	ConnectingWrite,	//!< Connection is in progress and requires to be writable
+	ConnectingRead,		//!< Connection is in progress but requires to be readable
+	ConnectingWrite,	//!< Connection is in progress but requires to be writable
 	Connected,		//!< Connection is complete
-	Accepting,		//!< The socket is being accepted (client)
+	AcceptingRead,		//!< The socket is being accepted but requires to be readable
+	AcceptingWrite,		//!< The socket is being accepted but requires to be writable
 	Accepted,		//!< Socket has been accepted (client)
 	Closed,			//!< The socket has been closed
 };
@@ -485,10 +486,10 @@
 	 * @param state specify the socket state
 	 * @param type the type of socket implementation
 	 */
-	explicit inline Socket(Handle handle, State state = State::Open, Type type = Type{}) noexcept
+	explicit inline Socket(Handle handle, State state = State::Closed, Type type = Type{}) noexcept
 		: m_type(std::move(type))
+		, m_state{state}
 		, m_handle{handle}
-		, m_state{state}
 	{
 	}
 
@@ -503,9 +504,9 @@
 	 * @param other the other socket
 	 */
 	inline Socket(Socket &&other) noexcept
-		: m_handle{other.m_handle}
-		, m_type{std::move(other.m_type)}
+		: m_type(std::move(other.m_type))
 		, m_state{other.m_state}
+		, m_handle{other.m_handle}
 	{
 		/* Invalidate other */
 		other.m_handle = -1;
@@ -521,6 +522,27 @@
 	}
 
 	/**
+	 * Access the implementation.
+	 *
+	 * @return the implementation
+	 * @warning use this function with care
+	 */
+	inline const Type &type() const noexcept
+	{
+		return m_type;
+	}
+
+	/**
+	 * Overloaded function.
+	 *
+	 * @return the implementation
+	 */
+	inline Type &type() noexcept
+	{
+		return m_type;
+	}
+
+	/**
 	 * Get the current socket state.
 	 *
 	 * @return the state
@@ -719,17 +741,45 @@
 	}
 
 	/**
-	 * Accept a pending connection.
+	 * Accept a pending connection. If the socket is marked non-blocking and has no pending connection then an
+	 * error is thrown with the WouldBlockRead code.
+	 *
+	 * If the client is accepted but not yet ready, its state may be set to AcceptingRead or AcceptingWrite
+	 * and the user is responsible of waiting the socket to be readable or writable and then call the
+	 * overloaded accept() (with 0 arguments) until the connection is complete.
 	 *
 	 * @param info the address where to store client's information (optional)
 	 * @return the new socket
+	 * @pre state() must be State::Bound
 	 * @throw Error on errors
 	 */
-	inline Socket<Address, Type> accept(Address *info = nullptr)
+	inline Socket<Address, Type> accept(Address *info)
 	{
-		Address dummy;
-
-		return m_type.accept(*this, info == nullptr ? dummy : *info);
+		assert(m_state == State::Bound);
+
+		sockaddr_storage storage;
+		socklen_t length = sizeof (storage);
+
+		Socket<Address, Type> sc = m_type.accept(*this, reinterpret_cast<sockaddr *>(&storage), &length);
+
+		if (info) {
+			*info = Address{&storage, length};
+		}
+
+		return sc;
+	}
+
+	/**
+	 * Continue the accept process on this client.
+	 *
+	 * @pre state() must be State::AcceptingRead or State::AcceptingWrite
+	 * @throw Error on errors
+	 */
+	inline void accept()
+	{
+		assert(m_state == State::AcceptingRead || m_state == State::AcceptingWrite);
+
+		m_type.accept(*this);
 	}
 
 	/**
@@ -1266,41 +1316,6 @@
  * @brief Clear TCP implementation.
  */
 class Tcp {
-protected:
-	/**
-	 * Standard accept.
-	 *
-	 * @param sc the socket
-	 * @param address the address destination
-	 * @param length the address initial length
-	 * @return the client handle
-	 * @throw Error on errors
-	 */
-	Handle accept(Handle sc, sockaddr *address, socklen_t *length)
-	{
-		Handle client = ::accept(sc, address, length);
-
-		if (client == Invalid) {
-#if defined(_WIN32)
-			int error = WSAGetLastError();
-
-			if (error == WSAEWOULDBLOCK) {
-				throw Error{Error::WouldBlockRead, "accept", error};
-			}
-
-			throw Error{Error::System, "accept", error};
-#else
-			if (errno == EAGAIN || errno == EWOULDBLOCK) {
-				throw Error{Error::WouldBlockRead, "accept"};
-			}
-
-			throw Error{Error::System, "accept"};
-#endif
-		}
-
-		return client;
-	}
-
 public:
 	/**
 	 * Socket type.
@@ -1382,15 +1397,39 @@
 	 * @return the socket
 	 * @throw Error on errors
 	 */
-	template <typename Address>
-	Socket<Address, Tcp> accept(Socket<Address, Tcp> &sc, Address &address)
+	template <typename Address, typename Type>
+	Socket<Address, Type> accept(Socket<Address, Type> &sc, sockaddr *address, socklen_t *length)
 	{
-		sockaddr_storage ss;
-		socklen_t length = sizeof (sockaddr_storage);
-		Handle handle = accept(sc.handle(), reinterpret_cast<sockaddr *>(&ss), &length);
-		address = Address{&ss, length};
-
-		return Socket<Address, Tcp>{handle};
+		Handle handle = ::accept(sc.handle(), address, length);
+
+		if (handle == Invalid) {
+#if defined(_WIN32)
+			int error = WSAGetLastError();
+
+			if (error == WSAEWOULDBLOCK) {
+				throw Error{Error::WouldBlockRead, "accept", error};
+			}
+
+			throw Error{Error::System, "accept", error};
+#else
+			if (errno == EAGAIN || errno == EWOULDBLOCK) {
+				throw Error{Error::WouldBlockRead, "accept"};
+			}
+
+			throw Error{Error::System, "accept"};
+#endif
+		}
+
+		return Socket<Address, Type>{handle, State::Accepted};
+	}
+
+	/**
+	 * Continue accept.
+	 */
+	template <typename Address, typename Type>
+	inline void accept(Socket<Address, Type> &) noexcept
+	{
+		/* no-op */
 	}
 
 	/**
@@ -1599,11 +1638,11 @@
 	};
 
 private:
-	using Context = std::unique_ptr<SSL_CTX, void (*)(SSL_CTX *)>;
+	using Context = std::shared_ptr<SSL_CTX>;
 	using Ssl = std::unique_ptr<SSL, void (*)(SSL *)>;
 
 	/* OpenSSL objects */
-	Context m_context{nullptr, nullptr};
+	Context m_context;
 	Ssl m_ssl{nullptr, nullptr};
 
 	/* Parameters */
@@ -1652,7 +1691,29 @@
 		} else {
 			sc.setState(State::Connected);
 		}
-
+	}
+
+	/*
+	 * Continue accept.
+	 */
+	template <typename Address, typename Type>
+	void processAccept(Socket<Address, Type> &sc)
+	{
+		int ret = SSL_accept(m_ssl.get());
+
+		if (ret <= 0) {
+			int no = SSL_get_error(m_ssl.get(), ret);
+
+			if (no == SSL_ERROR_WANT_READ) {
+				sc.setState(State::AcceptingRead);
+			} else if (no == SSL_ERROR_WANT_WRITE) {
+				sc.setState(State::AcceptingWrite);
+			} else {
+				throw Error(Error::System, "accept", error(no));
+			}
+		} else {
+			sc.setState(State::Accepted);
+		}
 	}
 
 public:
@@ -1765,39 +1826,33 @@
 	 * @return the client
 	 */
 	template <typename Address>
-	Socket<Address, Tls> accept(Socket<Address, Tls> &sc, Address &address)
+	Socket<Address, Tls> accept(Socket<Address, Tls> &sc, sockaddr *address, socklen_t *length)
 	{
-		/* 0. We need to use this context */
-		if (!m_context) {
-			throw Error{Error::Other, "accept", "socket not prepared for accept"};
-		}
-
-		/* 1. Do standard accept */
-		sockaddr_storage ss;
-		socklen_t length = sizeof (sockaddr_storage);
-		Handle handle = Tcp::accept(sc.handle(), reinterpret_cast<sockaddr *>(&ss), &length);
-		address = Address{&ss, length};
-
-		/* 2. Create OpenSSL related stuff */
-		auto ssl = Ssl{SSL_new(m_context.get()), SSL_free};
-		SSL_set_fd(ssl.get(), handle);
-
-		/* 3. Do the OpenSSL accept */
-		auto ret = SSL_accept(ssl.get());
-
-		if (ret <= 0) {
-			auto no = SSL_get_error(ssl.get(), ret);
-
-			if (no == SSL_ERROR_WANT_READ) {
-				throw Error(Error::WouldBlockRead, "accept", "Operation would block");
-			} else if (no == SSL_ERROR_WANT_WRITE) {
-				throw Error(Error::WouldBlockWrite, "accept", "Operation would block");
-			} else {
-				throw Error(Error::System, "accept", error(no));
-			}
-		}
-
-		return Socket<Address, Tls>{handle, Tls{{nullptr, nullptr}, std::move(ssl)}};
+		Socket<Address, Tls> client = Tcp::accept(sc, address, length);
+
+		/* 1. Share the context */
+		client.type().m_context = sc.type().m_context;
+	
+		/* 2. Create new SSL instance */
+		client.type().m_ssl = Ssl{SSL_new(m_context.get()), SSL_free};
+		SSL_set_fd(client.type().m_ssl.get(), client.handle());
+
+		/* 3. Try accept process */
+		processAccept(sc);
+
+		return client;
+	}
+
+	/**
+	 * Continue accept.
+	 *
+	 * @param sc the socket
+	 * @throw Error on errors
+	 */
+	template <typename Address, typename Type>
+	inline void accept(Socket<Address, Type> &sc)
+	{
+		processAccept(sc);
 	}
 
 	/**
@@ -2377,6 +2432,15 @@
 
 /* }}} */
 
+/*
+ * Callback
+ * ------------------------------------------------------------------
+ *
+ * Function owner with tests.
+ */
+
+/* {{{ Callback */
+
 /**
  * @class Callback
  * @brief Convenient signal owner that checks if the target is valid.
@@ -2405,6 +2469,17 @@
 	}
 };
 
+/* }}} */
+
+/*
+ * StreamConnection
+ * ------------------------------------------------------------------
+ *
+ * Client connected on the server side.
+ */
+
+/* {{{ */
+
 /**
  * @class StreamConnection
  * @brief Connected client on the server side.
@@ -2501,6 +2576,17 @@
 	}
 };
 
+/* }}} */
+
+/*
+ * StreamServer
+ * ------------------------------------------------------------------
+ *
+ * Convenient stream oriented server.
+ */
+
+/* {{{ */
+
 /**
  * @class StreamServer
  * @brief Convenient stream server for TCP and TLS.
@@ -2888,6 +2974,8 @@
 	}
 };
 
+/* }}} */
+
 } // !net
 
 #endif // !_SOCKETS_H_
--- a/C++/tests/Socket/main.cpp	Tue Nov 03 19:57:00 2015 +0100
+++ b/C++/tests/Socket/main.cpp	Tue Nov 03 21:48:14 2015 +0100
@@ -61,7 +61,7 @@
 	m_tserver = std::thread([this] () {
 		m_server.bind(Ipv4{"*", 16000});
 		m_server.listen();
-		m_server.accept();
+		m_server.accept(nullptr);
 		m_server.close();
 	});
 
@@ -79,7 +79,7 @@
 		m_server.bind(Ipv4{"*", 16000});
 		m_server.listen();
 
-		auto client = m_server.accept();
+		auto client = m_server.accept(nullptr);
 		auto msg = client.recv(512);
 
 		ASSERT_EQ("hello world", msg);
@@ -427,7 +427,7 @@
 		try {
 			m_listener.set(m_masterTcp.handle(), FlagRead);
 			m_listener.wait();
-			m_masterTcp.accept();
+			m_masterTcp.accept(nullptr);
 			m_masterTcp.close();
 		} catch (const std::exception &ex) {
 			FAIL() << ex.what();
@@ -448,7 +448,7 @@
 			m_listener.set(m_masterTcp.handle(), FlagRead);
 			m_listener.wait();
 
-			auto sc = m_masterTcp.accept();
+			auto sc = m_masterTcp.accept(nullptr);
 
 			ASSERT_EQ("hello", sc.recv(512));
 		} catch (const std::exception &ex) {
@@ -552,7 +552,7 @@
 TEST_F(TcpRecvTest, blockingSuccess)
 {
 	m_tserver = std::thread([this] () {
-		auto client = m_server.accept();
+		auto client = m_server.accept(nullptr);
 
 		ASSERT_EQ("hello", client.recv(32));
 	});
@@ -581,7 +581,7 @@
 public:
 	TlsRecvTest()
 		: m_server{Ipv4{}, Tls{Tls::Tlsv1, false, "Socket/test.key", "Socket/test.crt"}}
-		, m_client{Ipv4{}, Tls{Tls::Tlsv1, false, "", "Socket/test.crt"}}
+		, m_client{Ipv4{}, Tls{Tls::Tlsv1, false}}
 	{
 		m_server.set(SOL_SOCKET, SO_REUSEADDR, 1);
 		m_server.bind(Ipv4{"*", 16000});
@@ -601,25 +601,30 @@
 {
 	m_tserver = std::thread([this] () {
 		try {
-			auto client = m_server.accept();
+			auto client = m_server.accept(nullptr);
+
+			printf("state: %d\n", client.state());
 
 			ASSERT_EQ("hello", client.recv(32));
-		} catch (const std::exception &ex) {
-			FAIL() << ex.what();
+			std::this_thread::sleep_for(5s);
+		} catch (const net::Error &ex) {
+			FAIL() << ex.function() << ": " << ex.what();
 		}
 	});
 
-	std::this_thread::sleep_for(100ms);
+	std::this_thread::sleep_for(250ms);
+
+#if 0
 
 	m_tclient = std::thread([this] () {
 		try {
 			m_client.connect(Ipv4{"127.0.0.1", 16000});
-			m_client.send("hello");
-			m_client.close();
-		} catch (const std::exception &ex) {
-			FAIL() << ex.what();
+			//m_client.send("hello");
+		} catch (const net::Error &ex) {
+			FAIL() << ex.function() << ": " << ex.what();
 		}
 	});
+#endif
 }
 
 int main(int argc, char **argv)