changeset 270:46ccfbee84d9

Socket: * Add macro SOCKET_LISTENER_HAVE_POLL to enable / disable poll(2) method if it is not available. Fallback safely to select(2) if needed. * Add new function selectMultiple() which returns a std::vector of all ready sockets. * Add tests for multiple selection.
author David Demelier <markand@malikania.fr>
date Tue, 21 Oct 2014 10:13:33 +0200
parents 44dcc198bf0c
children f7000cc599d0
files C++/SocketListener.cpp C++/SocketListener.h C++/Tests/Sockets/main.cpp
diffstat 3 files changed, 202 insertions(+), 14 deletions(-) [+]
line wrap: on
line diff
--- a/C++/SocketListener.cpp	Fri Oct 17 14:20:00 2014 +0200
+++ b/C++/SocketListener.cpp	Tue Oct 21 10:13:33 2014 +0200
@@ -53,6 +53,7 @@
 	void clear() override;
 	unsigned size() const override;
 	SocketStatus select(int ms) override;
+	std::vector<SocketStatus> selectMultiple(int ms) override;
 };
 
 SelectMethod::SelectMethod()
@@ -129,6 +130,16 @@
 
 SocketStatus SelectMethod::select(int ms)
 {
+	auto result = selectMultiple(ms);
+
+	if (result.size() == 0)
+		throw error::Failure("No socket found");
+
+	return result[0];
+}
+
+std::vector<SocketStatus> SelectMethod::selectMultiple(int ms)
+{
 	timeval maxwait, *towait;
 
 	maxwait.tv_sec = 0;
@@ -143,20 +154,23 @@
 	if (error == 0)
 		throw error::Timeout();
 
+	std::vector<SocketStatus> sockets;
 	for (auto &c : m_lookup)
 		if (FD_ISSET(c.first, &m_readset))
-			return { std::get<0>(c.second), SocketDirection::Read };
+			sockets.push_back({ std::get<0>(c.second), SocketDirection::Read });
 	for (auto &c : m_lookup)
 		if (FD_ISSET(c.first, &m_writeset))
-			return { std::get<0>(c.second), SocketDirection::Write };
+			sockets.push_back({ std::get<0>(c.second), SocketDirection::Write });
 
-	throw error::Failure("No socket found");
+	return sockets;
 }
 
 /* --------------------------------------------------------
  * Poll implementation
  * -------------------------------------------------------- */
 
+#if defined(SOCKET_LISTENER_HAVE_POLL)
+
 #if defined(_WIN32)
 #  include <Winsock2.h>
 #  define poll WSAPoll
@@ -202,6 +216,7 @@
 	void clear() override;
 	unsigned size() const override;
 	SocketStatus select(int ms) override;
+	std::vector<SocketStatus> selectMultiple(int ms) override;
 };
 
 void PollMethod::add(Socket &&s, SocketDirection direction)
@@ -259,16 +274,37 @@
 	throw error::Failure("no socket found");
 }
 
+std::vector<SocketStatus> PollMethod::selectMultiple(int ms)
+{
+	auto result = poll(m_fds.data(), m_fds.size(), ms);
+	if (result == 0)
+		throw error::Timeout();
+	if (result < 0)
+		throw error::Failure(Socket::syserror());
+
+	std::vector<SocketStatus> sockets;
+	for (auto &fd : m_fds)
+		if (fd.revents != 0)
+			sockets.push_back({ m_lookup[fd.fd], todirection(fd.revents) });
+
+	return sockets;
+}
+
 } // !namespace
 
+#endif // !_SOCKET_LISTENER_HAVE_POLL
+
 /* --------------------------------------------------------
  * Socket listener
  * -------------------------------------------------------- */
 
 SocketListener::SocketListener(SocketMethod method)
 {
+#if defined(SOCKET_LISTENER_HAVE_POLL)
 	if (method == SocketMethod::Poll)
 		m_interface = std::make_unique<PollMethod>();
 	else
+#endif
 		m_interface = std::make_unique<SelectMethod>();
+
 }
--- a/C++/SocketListener.h	Fri Oct 17 14:20:00 2014 +0200
+++ b/C++/SocketListener.h	Tue Oct 21 10:13:33 2014 +0200
@@ -21,9 +21,18 @@
 
 #include <chrono>
 #include <functional>
+#include <vector>
 
 #include "Socket.h"
 
+#if defined(_WIN32)
+#  if _WIN32_WINNT >= 0x0600
+#    define SOCKET_LISTENER_HAVE_POLL
+#  endif
+#else
+#  define SOCKET_LISTENER_HAVE_POLL
+#endif
+
 /**
  * @enum SocketDirection
  * @brief The SocketDirection enum
@@ -113,6 +122,12 @@
 	 */
 	using MapFunc	= std::function<void (Socket &, SocketDirection)>;
 
+#if defined(SOCKET_LISTENER_HAVE_POLL)
+	static constexpr const SocketMethod PreferredMethod = SocketMethod::Poll;
+#else
+	static constexpr const SocketMethod PreferredMethod = SocketMethod::Select;
+#endif
+
 	/**
 	 * @class Interface
 	 * @brief Implement the polling method
@@ -158,7 +173,7 @@
 		virtual unsigned size() const = 0;
 
 		/**
-		 * Select a socket.
+		 * Select one socket.
 		 *
 		 * @param ms the number of milliseconds to wait, -1 means forever
 		 * @return the socket status
@@ -166,20 +181,20 @@
 		 * @throw error::Timeout on timeout
 		 */
 		virtual SocketStatus select(int ms) = 0;
+
+		/**
+		 * Select many sockets.
+		 *
+		 * @param ms the number of milliseconds to wait, -1 means forever
+		 * @return a vector of ready sockets
+		 * @throw error::Failure on failure
+		 * @throw error::Timeout on timeout
+		 */
+		virtual std::vector<SocketStatus> selectMultiple(int ms) = 0;
 	};
 
 	std::unique_ptr<Interface> m_interface;
 
-#if defined(_WIN32)
-#  if _WIN32_WINNT >= 0x0600
-	static constexpr const SocketMethod PreferredMethod = SocketMethod::Poll;
-#  else
-	static constexpr const SocketMethod PreferredMethod = SocketMethod::Select;
-#  endif
-#else
-	static constexpr const SocketMethod PreferredMethod = SocketMethod::Poll;
-#endif
-
 public:
 	/**
 	 * Create a socket listener.
@@ -231,6 +246,7 @@
 	/**
 	 * Select a socket. Waits for a specific amount of time specified as the duration.
 	 *
+	 * @param duration the duration
 	 * @return the socket ready
 	 * @throw SocketError on error
 	 * @throw SocketTimeout on timeout
@@ -256,6 +272,34 @@
 	}
 
 	/**
+	 * Select multiple sockets.
+	 *
+	 * @param duration the duration
+	 * @return the socket ready
+	 * @throw SocketError on error
+	 * @throw SocketTimeout on timeout
+	 */
+	template <typename Rep, typename Ratio>
+	inline std::vector<SocketStatus> selectMultiple(const std::chrono::duration<Rep, Ratio> &duration)
+	{
+		auto cvt = std::chrono::duration_cast<std::chrono::milliseconds>(duration);
+
+		return m_interface->selectMultiple(cvt.count());
+	}
+
+	/**
+	 * Overload that waits indefinitely.
+	 *
+	 * @return the socket ready
+	 * @throw SocketError on error
+	 * @throw SocketTimeout on timeout
+	 */
+	inline std::vector<SocketStatus> selectMultiple()
+	{
+		return m_interface->selectMultiple(-1);
+	}
+
+	/**
 	 * List every socket in the listener.
 	 *
 	 * @param func the function to call
--- a/C++/Tests/Sockets/main.cpp	Fri Oct 17 14:20:00 2014 +0200
+++ b/C++/Tests/Sockets/main.cpp	Tue Oct 21 10:13:33 2014 +0200
@@ -362,6 +362,114 @@
 }
 
 /* --------------------------------------------------------
+ * Multiple selection tests
+ * -------------------------------------------------------- */
+
+TEST(MultipleSelection, select)
+{
+	/*
+	 * Normally, 3 sockets added for writing should be marked ready immediately
+	 * as there are no data being currently queued to be sent.
+	 */
+	std::thread tester([] () {
+		try {
+			SocketListener masterListener, clientListener;
+			Socket master(AF_INET, SOCK_STREAM, 0);
+
+			master.bind(Internet("*", 10000, AF_INET));
+			master.listen(8);
+
+			masterListener.add(master, SocketDirection::Read);
+
+			while (clientListener.size() != 3) {
+				masterListener.select(3s);
+				clientListener.add(master.accept(), SocketDirection::Write);
+			}
+
+			// Now do the test of writing
+			auto result = clientListener.selectMultiple(3s);
+			ASSERT_EQ(3, result.size());
+
+			clientListener.list([] (auto s, auto direction) {
+				s.close();
+			});
+	
+			master.close();
+		} catch (const std::exception &ex) {
+			FAIL() << ex.what();
+		}
+	});
+	
+	Socket s1(AF_INET, SOCK_STREAM, 0);
+	Socket s2(AF_INET, SOCK_STREAM, 0);
+	Socket s3(AF_INET, SOCK_STREAM, 0);
+
+	s1.connect(Internet("localhost", 10000, AF_INET));
+	s2.connect(Internet("localhost", 10000, AF_INET));
+	s3.connect(Internet("localhost", 10000, AF_INET));
+	
+	s1.close();
+	s2.close();
+	s3.close();
+
+	tester.join();
+}
+
+#if defined(SOCKET_LISTENER_HAVE_POLL)
+
+TEST(MultipleSelection, poll)
+{
+	/*
+	 * Normally, 3 sockets added for writing should be marked ready immediately
+	 * as there are no data being currently queued to be sent.
+	 */
+	std::thread tester([] () {
+		try {
+			SocketListener masterListener(SocketMethod::Poll), clientListener(SocketMethod::Poll);
+			Socket master(AF_INET, SOCK_STREAM, 0);
+
+			master.bind(Internet("*", 10000, AF_INET));
+			master.listen(8);
+
+			masterListener.add(master, SocketDirection::Read);
+
+			while (clientListener.size() != 3) {
+				masterListener.select(3s);
+				clientListener.add(master.accept(), SocketDirection::Write);
+			}
+
+			// Now do the test of writing
+			auto result = clientListener.selectMultiple(3s);
+			ASSERT_EQ(3, result.size());
+
+			clientListener.list([] (auto s, auto direction) {
+				s.close();
+			});
+	
+			master.close();
+		} catch (const std::exception &ex) {
+			FAIL() << ex.what();
+		}
+	});
+	
+	Socket s1(AF_INET, SOCK_STREAM, 0);
+	Socket s2(AF_INET, SOCK_STREAM, 0);
+	Socket s3(AF_INET, SOCK_STREAM, 0);
+
+	s1.connect(Internet("localhost", 10000, AF_INET));
+	s2.connect(Internet("localhost", 10000, AF_INET));
+	s3.connect(Internet("localhost", 10000, AF_INET));
+	
+	s1.close();
+	s2.close();
+	s3.close();
+
+	tester.join();
+}
+
+#endif
+
+/* --------------------------------------------------------
  * Basic TCP tests
  * -------------------------------------------------------- */