view C++/tests/Socket/main.cpp @ 370:2c6a4f468499

Socket: add epoll method
author David Demelier <markand@malikania.fr>
date Wed, 29 Apr 2015 17:17:35 +0200
parents f3c762579073
children 92457ea8f7e2
line wrap: on
line source

/*
 * main.cpp -- test sockets
 *
 * Copyright (c) 2013, 2014 David Demelier <markand@malikania.fr>
 *
 * Permission to use, copy, modify, and/or distribute this software for any
 * purpose with or without fee is hereby granted, provided that the above
 * copyright notice and this permission notice appear in all copies.
 *
 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 */

#include <chrono>
#include <iostream>
#include <sstream>
#include <string>
#include <thread>

#include <gtest/gtest.h>

#include <Socket.h>
#include <SocketAddress.h>
#include <SocketListener.h>
#include <SocketSsl.h>
#include <SocketTcp.h>
#include <SocketUdp.h>

using namespace address;
using namespace std::literals::chrono_literals;

/* --------------------------------------------------------
 * TCP tests
 * -------------------------------------------------------- */

class TcpServerTest : public testing::Test {
protected:
	SocketTcp m_server{AF_INET, 0};
	SocketTcp m_client{AF_INET, 0};

	std::thread m_tserver;
	std::thread m_tclient;

public:
	TcpServerTest()
	{
		m_server.set(SOL_SOCKET, SO_REUSEADDR, 1);
	}

	~TcpServerTest()
	{
		if (m_tserver.joinable())
			m_tserver.join();
		if (m_tclient.joinable())
			m_tclient.join();
	}
};

TEST_F(TcpServerTest, connect)
{
	m_tserver = std::thread([this] () {
		m_server.bind(Internet("*", 16000, AF_INET));

		ASSERT_EQ(SocketState::Bound, m_server.state());

		m_server.listen();
		m_server.accept();
		m_server.close();
	});

	std::this_thread::sleep_for(100ms);

	m_tclient = std::thread([this] () {
		m_client.connect(Internet("127.0.0.1", 16000, AF_INET));

		ASSERT_EQ(SocketState::Connected, m_client.state());

		m_client.close();
	});
}

TEST_F(TcpServerTest, io)
{
	m_tserver = std::thread([this] () {
		m_server.bind(Internet("*", 16000, AF_INET));
		m_server.listen();

		auto client = m_server.accept();
		auto msg = client.recv(512);

		ASSERT_EQ("hello world", msg);

		client.send(msg);
		client.close();

		m_server.close();
	});

	std::this_thread::sleep_for(100ms);

	m_tclient = std::thread([this] () {
		m_client.connect(Internet("127.0.0.1", 16000, AF_INET));
		m_client.send("hello world");

		ASSERT_EQ("hello world", m_client.recv(512));

		m_client.close();
	});
}

/* --------------------------------------------------------
 * UDP tests
 * -------------------------------------------------------- */

class UdpServerTest : public testing::Test {
protected:
	SocketUdp m_server{AF_INET, 0};
	SocketUdp m_client{AF_INET, 0};

	std::thread m_tserver;
	std::thread m_tclient;

public:
	UdpServerTest()
	{
		m_server.set(SOL_SOCKET, SO_REUSEADDR, 1);
	}

	~UdpServerTest()
	{
		if (m_tserver.joinable())
			m_tserver.join();
		if (m_tclient.joinable())
			m_tclient.join();
	}
};

TEST_F(UdpServerTest, io)
{
	m_tserver = std::thread([this] () {
		SocketAddress info;

		m_server.bind(Internet("*", 16000, AF_INET));

		auto msg = m_server.recvfrom(512, info);

		ASSERT_EQ("hello world", msg);

		m_server.sendto(msg, info);
		m_server.close();
	});

	std::this_thread::sleep_for(100ms);

	m_tclient = std::thread([this] () {
		Internet info("127.0.0.1", 16000, AF_INET);

		m_client.sendto("hello world", info);

		ASSERT_EQ("hello world", m_client.recvfrom(512, info));

		m_client.close();
	});
}

/* --------------------------------------------------------
 * Listener: set function
 * -------------------------------------------------------- */

class TestBackendSet {
public:
	int m_callcount{0};
	bool m_added{false};
	int m_flags{0};

	inline void set(const SocketTable &, Socket sc, int flags, bool add) noexcept
	{
		m_callcount ++;
		m_added = add;
		m_flags |= flags;
	}

	inline void unset(const SocketTable &, Socket, int, bool) noexcept {}
	std::vector<SocketStatus> wait(const SocketTable &table, int ms) {}
};

class TestBackendSetFail {
public:
	inline void set(const SocketTable &, Socket, int, bool)
	{
		throw "fail";
	}

	inline void unset(const SocketTable &, Socket, int, bool) noexcept {}
	std::vector<SocketStatus> wait(const SocketTable &table, int ms) {}
};

TEST(ListenerSet, initialAdd)
{
	SocketListenerBase<TestBackendSet> listener;

	listener.set(Socket(0), SocketListener::Read);

	ASSERT_EQ(1U, listener.size());
	ASSERT_EQ(1, listener.backend().m_callcount);
	ASSERT_TRUE(listener.backend().m_added);
	ASSERT_TRUE(listener.backend().m_flags == SocketListener::Read);
}

TEST(ListenerSet, readThenWrite)
{
	SocketListenerBase<TestBackendSet> listener;
	Socket sc(0);

	listener.set(sc, SocketListener::Read);
	listener.set(sc, SocketListener::Write);

	ASSERT_EQ(1U, listener.size());
	ASSERT_EQ(2, listener.backend().m_callcount);
	ASSERT_FALSE(listener.backend().m_added);
	ASSERT_TRUE(listener.backend().m_flags == 0x3);
}

TEST(ListenerSet, allOneShot)
{
	SocketListenerBase<TestBackendSet> listener;
	Socket sc(0);

	listener.set(sc, SocketListener::Read | SocketListener::Write);

	ASSERT_EQ(1U, listener.size());
	ASSERT_EQ(1, listener.backend().m_callcount);
	ASSERT_TRUE(listener.backend().m_added);
	ASSERT_TRUE(listener.backend().m_flags == 0x3);
}

TEST(ListenerSet, readTwice)
{
	SocketListenerBase<TestBackendSet> listener;
	Socket sc(0);

	listener.set(sc, SocketListener::Read);
	listener.set(sc, SocketListener::Read);

	ASSERT_EQ(1U, listener.size());
	ASSERT_EQ(1, listener.backend().m_callcount);
	ASSERT_TRUE(listener.backend().m_added);
	ASSERT_TRUE(listener.backend().m_flags == SocketListener::Read);
}

TEST(ListenerSet, failure)
{
	SocketListenerBase<TestBackendSetFail> listener;

	try {
		listener.set(Socket(0), SocketListener::Read);
		FAIL() << "exception expected";
	} catch (...) {
	}

	ASSERT_EQ(0U, listener.size());
}

/* --------------------------------------------------------
 * Listener: unset / remove functions
 * -------------------------------------------------------- */

class TestBackendUnset {
public:
	bool m_isset{false};
	bool m_isunset{false};
	int m_flags{0};
	bool m_removal{false};

	inline void set(const SocketTable &, Socket, int flags, bool) noexcept
	{
		m_isset = true;
		m_flags |= flags;
	}

	inline void unset(const SocketTable &, Socket, int flags, bool remove) noexcept
	{
		m_isunset = true;
		m_flags &= ~(flags);
		m_removal = remove;
	}

	std::vector<SocketStatus> wait(const SocketTable &table, int ms) {}
};

class TestBackendUnsetFail {
public:
	inline void set(const SocketTable &, Socket, int, bool) noexcept {}

	inline void unset(const SocketTable &, Socket, int, bool)
	{
		throw "fail";
	}

	std::vector<SocketStatus> wait(const SocketTable &table, int ms) {}
};

TEST(ListenerUnsetRemove, unset)
{
	SocketListenerBase<TestBackendUnset> listener;
	Socket sc(0);

	listener.set(sc, SocketListener::Read);
	listener.unset(sc, SocketListener::Read);

	ASSERT_EQ(0U, listener.size());
	ASSERT_TRUE(listener.backend().m_isset);
	ASSERT_TRUE(listener.backend().m_isunset);
	ASSERT_TRUE(listener.backend().m_flags == 0);
	ASSERT_TRUE(listener.backend().m_removal);
}

TEST(ListenerUnsetRemove, unsetOne)
{
	SocketListenerBase<TestBackendUnset> listener;
	Socket sc(0);

	listener.set(sc, SocketListener::Read | SocketListener::Write);
	listener.unset(sc, SocketListener::Read);

	ASSERT_EQ(1U, listener.size());
	ASSERT_TRUE(listener.backend().m_isset);
	ASSERT_TRUE(listener.backend().m_isunset);
	ASSERT_TRUE(listener.backend().m_flags == SocketListener::Write);
	ASSERT_FALSE(listener.backend().m_removal);
}

TEST(ListenerUnsetRemove, unsetAll)
{
	SocketListenerBase<TestBackendUnset> listener;
	Socket sc(0);

	listener.set(sc, SocketListener::Read | SocketListener::Write);
	listener.unset(sc, SocketListener::Read);
	listener.unset(sc, SocketListener::Write);

	ASSERT_EQ(0U, listener.size());
	ASSERT_TRUE(listener.backend().m_isset);
	ASSERT_TRUE(listener.backend().m_isunset);
	ASSERT_TRUE(listener.backend().m_flags == 0);
	ASSERT_TRUE(listener.backend().m_removal);
}

TEST(ListenerUnsetRemove, remove)
{
	SocketListenerBase<TestBackendUnset> listener;
	Socket sc(0);

	listener.set(sc, SocketListener::Read | SocketListener::Write);
	listener.remove(sc);

	ASSERT_EQ(0U, listener.size());
	ASSERT_TRUE(listener.backend().m_isset);
	ASSERT_TRUE(listener.backend().m_isunset);
	ASSERT_TRUE(listener.backend().m_flags == 0);
	ASSERT_TRUE(listener.backend().m_removal);
}

TEST(ListenerUnsetRemove, failure)
{
	SocketListenerBase<TestBackendUnsetFail> listener;
	Socket sc(0);

	listener.set(sc, SocketListener::Read | SocketListener::Write);

	try {
		listener.remove(sc);
		FAIL() << "exception expected";
	} catch (...) {
	}

	/* If fail, kept into the table */
	ASSERT_EQ(1U, listener.size());
}

/* --------------------------------------------------------
 * Listener: system
 * -------------------------------------------------------- */

class ListenerTest : public testing::Test {
protected:
	SocketListenerBase<backend::Select> m_listener;
	SocketTcp m_masterTcp{AF_INET, 0};
	SocketTcp m_clientTcp{AF_INET, 0};

	std::thread m_tserver;
	std::thread m_tclient;

public:
	ListenerTest()
	{
		m_masterTcp.set(SOL_SOCKET, SO_REUSEADDR, 1);
		m_masterTcp.bind(Internet("*", 16000, AF_INET));
		m_masterTcp.listen();
	}

	~ListenerTest()
	{
		if (m_tserver.joinable()) {
			m_tserver.join();
		}
		if (m_tclient.joinable()) {
			m_tclient.join();
		}
	}
};

TEST_F(ListenerTest, accept)
{
	m_tserver = std::thread([this] () {
		try {
			m_listener.set(m_masterTcp, SocketListener::Read);
			m_listener.wait();
			m_masterTcp.accept();
			m_masterTcp.close();
		} catch (const std::exception &ex) {
			FAIL() << ex.what();
		}
	});

	std::this_thread::sleep_for(100ms);

	m_tclient = std::thread([this] () {
		m_clientTcp.connect(Internet("127.0.0.1", 16000, AF_INET));
	});
}

TEST_F(ListenerTest, recv)
{
	m_tserver = std::thread([this] () {
		try {
			m_listener.set(m_masterTcp, SocketListener::Read);
			m_listener.wait();

			auto sc = m_masterTcp.accept();

			ASSERT_EQ("hello", sc.recv(512));

			m_masterTcp.close();
		} catch (const std::exception &ex) {
			FAIL() << ex.what();
		}
	});

	std::this_thread::sleep_for(100ms);

	m_tclient = std::thread([this] () {
		m_clientTcp.connect(Internet("127.0.0.1", 16000, AF_INET));
		m_clientTcp.send("hello");
	});
}

/* --------------------------------------------------------
 * Non-blocking connect
 * -------------------------------------------------------- */

class NonBlockingConnectTest : public testing::Test {
protected:
	SocketTcp m_server{AF_INET, 0};
	SocketTcp m_client{AF_INET, 0};

	std::thread m_tserver;
	std::thread m_tclient;

public:
	NonBlockingConnectTest()
	{
		m_client.setBlockMode(false);
	}

	~NonBlockingConnectTest()
	{
		if (m_tserver.joinable())
			m_tserver.join();
		if (m_tclient.joinable())
			m_tclient.join();
	}
};

TEST_F(NonBlockingConnectTest, success)
{
	m_server.set(SOL_SOCKET, SO_REUSEADDR, 1);
	m_server.bind(Internet("*", 16000, AF_INET));
	m_server.listen();

	m_tserver = std::thread([this] () {
		SocketTcp client = m_server.accept();

		std::this_thread::sleep_for(100ms);

		m_server.close();
		client.close();
	});

	std::this_thread::sleep_for(100ms);

	m_tclient = std::thread([this] () {
		try {
			m_client.waitConnect(Internet("127.0.0.1", 16000, AF_INET), 3000);
		} catch (const SocketError &error) {
			FAIL() << error.function() << ": " << error.what();
		}

		ASSERT_EQ(SocketState::Connected, m_client.state());

		m_client.close();
	});
}

TEST_F(NonBlockingConnectTest, fail)
{
	/*
	 * /!\ If you find a way to test this locally please tell me /!\
	 */
	m_tclient = std::thread([this] () {
		try {
			m_client.waitConnect(Internet("google.fr", 9000, AF_INET), 100);

			FAIL() << "Expected exception, got success";
		} catch (const SocketError &error) {
			ASSERT_EQ(SocketError::Timeout, error.code());
		}

		m_client.close();
	});
}

/* --------------------------------------------------------
 * TCP accept
 * -------------------------------------------------------- */

class TcpAcceptTest : public testing::Test {
protected:
	SocketTcp m_server{AF_INET, 0};
	SocketTcp m_client{AF_INET, 0};

	std::thread m_tserver;
	std::thread m_tclient;

public:
	TcpAcceptTest()
	{
		m_server.set(SOL_SOCKET, SO_REUSEADDR, 1);
		m_server.bind(Internet("*", 16000, AF_INET));
		m_server.listen();
	}

	~TcpAcceptTest()
	{
		if (m_tserver.joinable())
			m_tserver.join();
		if (m_tclient.joinable())
			m_tclient.join();
	}
};

TEST_F(TcpAcceptTest, blockingWaitSuccess)
{
	m_tserver = std::thread([this] () {
		try {
			m_server.waitAccept(3000).close();
		} catch (const SocketError &error) {
			FAIL() << error.what();
		}

		m_server.close();
	});

	std::this_thread::sleep_for(100ms);

	m_tclient = std::thread([this] () {
		m_client.connect(Internet("127.0.0.1", 16000, AF_INET));
		m_client.close();
	});
}

TEST_F(TcpAcceptTest, nonBlockingWaitSuccess)
{
	m_tserver = std::thread([this] () {
		try {
			m_server.setBlockMode(false);
			m_server.waitAccept(3000).close();
		} catch (const SocketError &error) {
			FAIL() << error.what();
		}

		m_server.close();
	});

	std::this_thread::sleep_for(100ms);

	m_tclient = std::thread([this] () {
		m_client.connect(Internet("127.0.0.1", 16000, AF_INET));
		m_client.close();
	});
}

TEST_F(TcpAcceptTest, nonBlockingWaitFail)
{
	// No client, no accept
	try {
		m_server.setBlockMode(false);
		m_server.waitAccept(100).close();

		FAIL() << "Expected exception, got success";
	} catch (const SocketError &error) {
		ASSERT_EQ(SocketError::Timeout, error.code());
	}

	m_server.close();
}

/* --------------------------------------------------------
 * TCP recv
 * -------------------------------------------------------- */

class TcpRecvTest : public testing::Test {
protected:
	SocketTcp m_server{AF_INET, 0};
	SocketTcp m_client{AF_INET, 0};

	std::thread m_tserver;
	std::thread m_tclient;

public:
	TcpRecvTest()
	{
		m_server.set(SOL_SOCKET, SO_REUSEADDR, 1);
		m_server.bind(Internet("*", 16000, AF_INET));
		m_server.listen();
	}

	~TcpRecvTest()
	{
		if (m_tserver.joinable())
			m_tserver.join();
		if (m_tclient.joinable())
			m_tclient.join();
	}
};

TEST_F(TcpRecvTest, blockingSuccess)
{
	m_tserver = std::thread([this] () {
		SocketTcp client = m_server.accept();

		ASSERT_EQ("hello", client.recv(32));

		client.close();
		m_server.close();
	});

	std::this_thread::sleep_for(100ms);

	m_tclient = std::thread([this] () {
		m_client.connect(Internet("127.0.0.1", 16000, AF_INET));
		m_client.send("hello");
		m_client.close();
	});
}

TEST_F(TcpRecvTest, blockingWaitSuccess)
{
	m_tserver = std::thread([this] () {
		SocketTcp client = m_server.accept();

		ASSERT_EQ("hello", client.waitRecv(32, 3000));

		client.close();
		m_server.close();
	});

	std::this_thread::sleep_for(100ms);

	m_tclient = std::thread([this] () {
		m_client.connect(Internet("127.0.0.1", 16000, AF_INET));
		m_client.send("hello");
		m_client.close();
	});
}

TEST_F(TcpRecvTest, nonBlockingWaitSuccess)
{
	m_tserver = std::thread([this] () {
		SocketTcp client = m_server.accept();

		client.setBlockMode(false);

		ASSERT_EQ("hello", client.waitRecv(32, 3000));

		client.close();
		m_server.close();
	});

	std::this_thread::sleep_for(100ms);

	m_tclient = std::thread([this] () {
		m_client.connect(Internet("127.0.0.1", 16000, AF_INET));
		m_client.send("hello");
		m_client.close();
	});
}

/* --------------------------------------------------------
 * Socket SSL
 * -------------------------------------------------------- */

class SslTest : public testing::Test {
protected:
	SocketSsl client{AF_INET, 0};
};

TEST_F(SslTest, connect)
{
	try {
		client.connect(Internet("google.fr", 443, AF_INET));
		client.close();
	} catch (const SocketError &error) {
		FAIL() << error.what();
	}
}

TEST_F(SslTest, recv)
{
	try {
		client.connect(Internet("google.fr", 443, AF_INET));
		client.send("GET / HTTP/1.0\r\n\r\n");

		std::string msg = client.recv(512);
		std::string content = msg.substr(0, 18);

		ASSERT_EQ("HTTP/1.0 302 Found", content);

		client.close();
	} catch (const SocketError &error) {
		FAIL() << error.what();
	}
}

int main(int argc, char **argv)
{
	testing::InitGoogleTest(&argc, argv);

	return RUN_ALL_TESTS();
}