view C++/SocketListener.cpp @ 279:af630354610f

Socket: silent warning if poll is not available
author David Demelier <markand@malikania.fr>
date Fri, 24 Oct 2014 10:28:17 +0200
parents adcae2bde2f0
children 91eb0583df52
line wrap: on
line source

/*
 * SocketListener.cpp -- portable select() wrapper
 *
 * Copyright (c) 2013, David Demelier <markand@malikania.fr>
 *
 * Permission to use, copy, modify, and/or distribute this software for any
 * purpose with or without fee is hereby granted, provided that the above
 * copyright notice and this permission notice appear in all copies.
 *
 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 */

#include <algorithm>
#include <map>
#include <set>
#include <tuple>
#include <vector>

#include "SocketListener.h"

/* --------------------------------------------------------
 * Select implementation
 * -------------------------------------------------------- */

/**
 * @class SelectMethod
 * @brief Implements select(2)
 *
 * This class is the fallback of any other method, it is not preferred at all for many reasons.
 */
class SelectMethod final : public SocketListener::Interface {
private:
	std::set<Socket> m_rsockets;
	std::set<Socket> m_wsockets;
	std::map<Socket::Handle, std::tuple<Socket, int>> m_lookup;

	fd_set m_readset;
	fd_set m_writeset;

	Socket::Handle m_max { 0 };

public:
	SelectMethod();
	void add(Socket s, int direction) override;
	void remove(const Socket &s, int direction) override;
	void list(const SocketListener::MapFunc &func) override;
	void clear() override;
	unsigned size() const override;
	SocketStatus select(int ms) override;
	std::vector<SocketStatus> selectMultiple(int ms) override;
};

SelectMethod::SelectMethod()
{
	FD_ZERO(&m_readset);
	FD_ZERO(&m_writeset);
}

void SelectMethod::add(Socket s, int direction)
{
	if (m_lookup.count(s.handle()) > 0)
		std::get<1>(m_lookup[s.handle()]) |= direction;
	else
		m_lookup[s.handle()] = std::make_tuple(s, direction);

	if (direction & Read) {
		m_rsockets.insert(s);
		FD_SET(s.handle(), &m_readset);
	}
	if (direction & Write) {
		m_wsockets.insert(s);
		FD_SET(s.handle(), &m_writeset);
	}

	if (s.handle() > m_max)
		m_max = s.handle();
}

void SelectMethod::remove(const Socket &s, int direction)
{
	std::get<1>(m_lookup[s.handle()]) &= ~(direction);

	if (static_cast<int>(std::get<1>(m_lookup[s.handle()])) == 0)
		m_lookup.erase(s.handle());

	if (direction & Read) {
		m_rsockets.erase(s.handle());
		FD_CLR(s.handle(), &m_readset);
	}
	if (direction & Write) {
		m_wsockets.erase(s.handle());
		FD_CLR(s.handle(), &m_writeset);
	}

	// Refind the max file descriptor
	if (m_lookup.size() > 0) {
		m_max = std::get<0>(std::max_element(m_lookup.begin(), m_lookup.end())->second).handle();
	} else
		m_max = 0;
}

void SelectMethod::list(const SocketListener::MapFunc &func)
{
	for (auto &s : m_lookup)
		func(std::get<0>(s.second), std::get<1>(s.second));
}

void SelectMethod::clear()
{
	m_rsockets.clear();
	m_wsockets.clear();
	m_lookup.clear();

	FD_ZERO(&m_readset);
	FD_ZERO(&m_writeset);

	m_max = 0;
}

unsigned SelectMethod::size() const
{
	return m_lookup.size();
}

SocketStatus SelectMethod::select(int ms)
{
	auto result = selectMultiple(ms);

	if (result.size() == 0)
		throw error::Failure("select", "No socket found");

	return result[0];
}

std::vector<SocketStatus> SelectMethod::selectMultiple(int ms)
{
	timeval maxwait, *towait;

	maxwait.tv_sec = 0;
	maxwait.tv_usec = ms * 1000;

	// Set to nullptr for infinite timeout.
	towait = (ms <= 0) ? nullptr : &maxwait;

	auto error = ::select(m_max + 1, &m_readset, &m_writeset, NULL, towait);
	if (error == SOCKET_ERROR)
		throw error::Failure("select", Socket::syserror());
	if (error == 0)
		throw error::Timeout("select");

	std::vector<SocketStatus> sockets;
	for (auto &c : m_lookup)
		if (FD_ISSET(c.first, &m_readset))
			sockets.push_back({ std::get<0>(c.second), Read });
	for (auto &c : m_lookup)
		if (FD_ISSET(c.first, &m_writeset))
			sockets.push_back({ std::get<0>(c.second), Write });

	return sockets;
}

/* --------------------------------------------------------
 * Poll implementation
 * -------------------------------------------------------- */

#if defined(SOCKET_LISTENER_HAVE_POLL)

#if defined(_WIN32)
#  include <Winsock2.h>
#  define poll WSAPoll
#else
#  include <poll.h>
#endif

namespace {

class PollMethod final : public SocketListener::Interface {
private:
	std::vector<pollfd> m_fds;
	std::map<Socket::Handle, Socket> m_lookup;

	inline short topoll(int direction)
	{
		short result(0);

		if (direction & Read)
			result |= POLLIN;
		if (direction & Write)
			result |= POLLOUT;

		return result;
	}

	inline int todirection(short event)
	{
		int direction{};

		if (event & POLLIN)
			direction |= Read;
		if (event & POLLOUT)
			direction |= Write;

		return direction;
	}

public:
	void add(Socket s, int direction) override;
	void remove(const Socket &s, int direction) override;
	void list(const SocketListener::MapFunc &func) override;
	void clear() override;
	unsigned size() const override;
	SocketStatus select(int ms) override;
	std::vector<SocketStatus> selectMultiple(int ms) override;
};

void PollMethod::add(Socket s, int direction)
{
	auto it = std::find_if(m_fds.begin(), m_fds.end(), [&] (const auto &pfd) { return pfd.fd == s.handle(); });

	// If found, add the new direction, otherwise add a new socket
	if (it != m_fds.end())
		it->events |= topoll(direction);
	else {
		m_lookup[s.handle()] = s;
		m_fds.push_back({ s.handle(), topoll(direction), 0 });
	}
}

void PollMethod::remove(const Socket &s, int direction)
{
	for (auto i = m_fds.begin(); i != m_fds.end();) {
		if (i->fd == s) {
			i->events &= ~(topoll(direction));

			if (i->events == 0) {
				m_lookup.erase(i->fd);
				i = m_fds.erase(i);
			} else {
				++i;
			}
		} else
			++i;
	}
}

void PollMethod::list(const SocketListener::MapFunc &func)
{
	for (auto &fd : m_fds)
		func(m_lookup[fd.fd], todirection(fd.events));
}

void PollMethod::clear()
{
	m_fds.clear();
	m_lookup.clear();
}

unsigned PollMethod::size() const
{
	return static_cast<unsigned>(m_fds.size());
}

SocketStatus PollMethod::select(int ms)
{
	auto result = poll(m_fds.data(), m_fds.size(), ms);
	if (result == 0)
		throw error::Timeout("select");
	if (result < 0)
		throw error::Failure("select", Socket::syserror());

	for (auto &fd : m_fds)
		if (fd.revents != 0)
			return { m_lookup[fd.fd], todirection(fd.revents) };

	throw error::Failure("select", "No socket found");
}

std::vector<SocketStatus> PollMethod::selectMultiple(int ms)
{
	auto result = poll(m_fds.data(), m_fds.size(), ms);
	if (result == 0)
		throw error::Timeout("select");
	if (result < 0)
		throw error::Failure("select", Socket::syserror());

	std::vector<SocketStatus> sockets;
	for (auto &fd : m_fds)
		if (fd.revents != 0)
			sockets.push_back({ m_lookup[fd.fd], todirection(fd.revents) });

	return sockets;
}

} // !namespace

#endif // !_SOCKET_LISTENER_HAVE_POLL

/* --------------------------------------------------------
 * Socket listener
 * -------------------------------------------------------- */

SocketListener::SocketListener(int method)
{
#if defined(SOCKET_LISTENER_HAVE_POLL)
	if (method == Poll)
		m_interface = std::make_unique<PollMethod>();
	else
#endif
		m_interface = std::make_unique<SelectMethod>();

	(void)method;
}

SocketListener::SocketListener(std::initializer_list<std::pair<Socket, int>> list, int method)
	: SocketListener(method)
{
	for (const auto &p : list)
		add(p.first, p.second);
}