Mercurial > code
changeset 270:46ccfbee84d9
Socket:
* Add macro SOCKET_LISTENER_HAVE_POLL to enable / disable poll(2) method
if it is not available. Fallback safely to select(2) if needed.
* Add new function selectMultiple() which returns a std::vector of all ready
sockets.
* Add tests for multiple selection.
author | David Demelier <markand@malikania.fr> |
---|---|
date | Tue, 21 Oct 2014 10:13:33 +0200 |
parents | 44dcc198bf0c |
children | f7000cc599d0 |
files | C++/SocketListener.cpp C++/SocketListener.h C++/Tests/Sockets/main.cpp |
diffstat | 3 files changed, 202 insertions(+), 14 deletions(-) [+] |
line wrap: on
line diff
--- a/C++/SocketListener.cpp Fri Oct 17 14:20:00 2014 +0200 +++ b/C++/SocketListener.cpp Tue Oct 21 10:13:33 2014 +0200 @@ -53,6 +53,7 @@ void clear() override; unsigned size() const override; SocketStatus select(int ms) override; + std::vector<SocketStatus> selectMultiple(int ms) override; }; SelectMethod::SelectMethod() @@ -129,6 +130,16 @@ SocketStatus SelectMethod::select(int ms) { + auto result = selectMultiple(ms); + + if (result.size() == 0) + throw error::Failure("No socket found"); + + return result[0]; +} + +std::vector<SocketStatus> SelectMethod::selectMultiple(int ms) +{ timeval maxwait, *towait; maxwait.tv_sec = 0; @@ -143,20 +154,23 @@ if (error == 0) throw error::Timeout(); + std::vector<SocketStatus> sockets; for (auto &c : m_lookup) if (FD_ISSET(c.first, &m_readset)) - return { std::get<0>(c.second), SocketDirection::Read }; + sockets.push_back({ std::get<0>(c.second), SocketDirection::Read }); for (auto &c : m_lookup) if (FD_ISSET(c.first, &m_writeset)) - return { std::get<0>(c.second), SocketDirection::Write }; + sockets.push_back({ std::get<0>(c.second), SocketDirection::Write }); - throw error::Failure("No socket found"); + return sockets; } /* -------------------------------------------------------- * Poll implementation * -------------------------------------------------------- */ +#if defined(SOCKET_LISTENER_HAVE_POLL) + #if defined(_WIN32) # include <Winsock2.h> # define poll WSAPoll @@ -202,6 +216,7 @@ void clear() override; unsigned size() const override; SocketStatus select(int ms) override; + std::vector<SocketStatus> selectMultiple(int ms) override; }; void PollMethod::add(Socket &&s, SocketDirection direction) @@ -259,16 +274,37 @@ throw error::Failure("no socket found"); } +std::vector<SocketStatus> PollMethod::selectMultiple(int ms) +{ + auto result = poll(m_fds.data(), m_fds.size(), ms); + if (result == 0) + throw error::Timeout(); + if (result < 0) + throw error::Failure(Socket::syserror()); + + std::vector<SocketStatus> sockets; + for (auto &fd : m_fds) + if (fd.revents != 0) + sockets.push_back({ m_lookup[fd.fd], todirection(fd.revents) }); + + return sockets; +} + } // !namespace +#endif // !_SOCKET_LISTENER_HAVE_POLL + /* -------------------------------------------------------- * Socket listener * -------------------------------------------------------- */ SocketListener::SocketListener(SocketMethod method) { +#if defined(SOCKET_LISTENER_HAVE_POLL) if (method == SocketMethod::Poll) m_interface = std::make_unique<PollMethod>(); else +#endif m_interface = std::make_unique<SelectMethod>(); + }
--- a/C++/SocketListener.h Fri Oct 17 14:20:00 2014 +0200 +++ b/C++/SocketListener.h Tue Oct 21 10:13:33 2014 +0200 @@ -21,9 +21,18 @@ #include <chrono> #include <functional> +#include <vector> #include "Socket.h" +#if defined(_WIN32) +# if _WIN32_WINNT >= 0x0600 +# define SOCKET_LISTENER_HAVE_POLL +# endif +#else +# define SOCKET_LISTENER_HAVE_POLL +#endif + /** * @enum SocketDirection * @brief The SocketDirection enum @@ -113,6 +122,12 @@ */ using MapFunc = std::function<void (Socket &, SocketDirection)>; +#if defined(SOCKET_LISTENER_HAVE_POLL) + static constexpr const SocketMethod PreferredMethod = SocketMethod::Poll; +#else + static constexpr const SocketMethod PreferredMethod = SocketMethod::Select; +#endif + /** * @class Interface * @brief Implement the polling method @@ -158,7 +173,7 @@ virtual unsigned size() const = 0; /** - * Select a socket. + * Select one socket. * * @param ms the number of milliseconds to wait, -1 means forever * @return the socket status @@ -166,20 +181,20 @@ * @throw error::Timeout on timeout */ virtual SocketStatus select(int ms) = 0; + + /** + * Select many sockets. + * + * @param ms the number of milliseconds to wait, -1 means forever + * @return a vector of ready sockets + * @throw error::Failure on failure + * @throw error::Timeout on timeout + */ + virtual std::vector<SocketStatus> selectMultiple(int ms) = 0; }; std::unique_ptr<Interface> m_interface; -#if defined(_WIN32) -# if _WIN32_WINNT >= 0x0600 - static constexpr const SocketMethod PreferredMethod = SocketMethod::Poll; -# else - static constexpr const SocketMethod PreferredMethod = SocketMethod::Select; -# endif -#else - static constexpr const SocketMethod PreferredMethod = SocketMethod::Poll; -#endif - public: /** * Create a socket listener. @@ -231,6 +246,7 @@ /** * Select a socket. Waits for a specific amount of time specified as the duration. * + * @param duration the duration * @return the socket ready * @throw SocketError on error * @throw SocketTimeout on timeout @@ -256,6 +272,34 @@ } /** + * Select multiple sockets. + * + * @param duration the duration + * @return the socket ready + * @throw SocketError on error + * @throw SocketTimeout on timeout + */ + template <typename Rep, typename Ratio> + inline std::vector<SocketStatus> selectMultiple(const std::chrono::duration<Rep, Ratio> &duration) + { + auto cvt = std::chrono::duration_cast<std::chrono::milliseconds>(duration); + + return m_interface->selectMultiple(cvt.count()); + } + + /** + * Overload that waits indefinitely. + * + * @return the socket ready + * @throw SocketError on error + * @throw SocketTimeout on timeout + */ + inline std::vector<SocketStatus> selectMultiple() + { + return m_interface->selectMultiple(-1); + } + + /** * List every socket in the listener. * * @param func the function to call
--- a/C++/Tests/Sockets/main.cpp Fri Oct 17 14:20:00 2014 +0200 +++ b/C++/Tests/Sockets/main.cpp Tue Oct 21 10:13:33 2014 +0200 @@ -362,6 +362,114 @@ } /* -------------------------------------------------------- + * Multiple selection tests + * -------------------------------------------------------- */ + +TEST(MultipleSelection, select) +{ + /* + * Normally, 3 sockets added for writing should be marked ready immediately + * as there are no data being currently queued to be sent. + */ + std::thread tester([] () { + try { + SocketListener masterListener, clientListener; + Socket master(AF_INET, SOCK_STREAM, 0); + + master.bind(Internet("*", 10000, AF_INET)); + master.listen(8); + + masterListener.add(master, SocketDirection::Read); + + while (clientListener.size() != 3) { + masterListener.select(3s); + clientListener.add(master.accept(), SocketDirection::Write); + } + + // Now do the test of writing + auto result = clientListener.selectMultiple(3s); + ASSERT_EQ(3, result.size()); + + clientListener.list([] (auto s, auto direction) { + s.close(); + }); + + master.close(); + } catch (const std::exception &ex) { + FAIL() << ex.what(); + } + }); + + Socket s1(AF_INET, SOCK_STREAM, 0); + Socket s2(AF_INET, SOCK_STREAM, 0); + Socket s3(AF_INET, SOCK_STREAM, 0); + + s1.connect(Internet("localhost", 10000, AF_INET)); + s2.connect(Internet("localhost", 10000, AF_INET)); + s3.connect(Internet("localhost", 10000, AF_INET)); + + s1.close(); + s2.close(); + s3.close(); + + tester.join(); +} + +#if defined(SOCKET_LISTENER_HAVE_POLL) + +TEST(MultipleSelection, poll) +{ + /* + * Normally, 3 sockets added for writing should be marked ready immediately + * as there are no data being currently queued to be sent. + */ + std::thread tester([] () { + try { + SocketListener masterListener(SocketMethod::Poll), clientListener(SocketMethod::Poll); + Socket master(AF_INET, SOCK_STREAM, 0); + + master.bind(Internet("*", 10000, AF_INET)); + master.listen(8); + + masterListener.add(master, SocketDirection::Read); + + while (clientListener.size() != 3) { + masterListener.select(3s); + clientListener.add(master.accept(), SocketDirection::Write); + } + + // Now do the test of writing + auto result = clientListener.selectMultiple(3s); + ASSERT_EQ(3, result.size()); + + clientListener.list([] (auto s, auto direction) { + s.close(); + }); + + master.close(); + } catch (const std::exception &ex) { + FAIL() << ex.what(); + } + }); + + Socket s1(AF_INET, SOCK_STREAM, 0); + Socket s2(AF_INET, SOCK_STREAM, 0); + Socket s3(AF_INET, SOCK_STREAM, 0); + + s1.connect(Internet("localhost", 10000, AF_INET)); + s2.connect(Internet("localhost", 10000, AF_INET)); + s3.connect(Internet("localhost", 10000, AF_INET)); + + s1.close(); + s2.close(); + s3.close(); + + tester.join(); +} + +#endif + +/* -------------------------------------------------------- * Basic TCP tests * -------------------------------------------------------- */