view C++/tests/Socket/main.cpp @ 447:828d3dc89f2d

Socket: use own tests for SSL
author David Demelier <markand@malikania.fr>
date Wed, 28 Oct 2015 21:16:27 +0100
parents 8396fd66e57a
children 41d1a36cc461
line wrap: on
line source

/*
 * main.cpp -- test sockets
 *
 * Copyright (c) 2013-2015 David Demelier <markand@malikania.fr>
 *
 * Permission to use, copy, modify, and/or distribute this software for any
 * purpose with or without fee is hereby granted, provided that the above
 * copyright notice and this permission notice appear in all copies.
 *
 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 */

#include <chrono>
#include <iostream>
#include <sstream>
#include <string>
#include <thread>

#include <gtest/gtest.h>

#include <Sockets.h>

using namespace net;
using namespace std::literals::chrono_literals;

/* --------------------------------------------------------
 * TCP tests
 * -------------------------------------------------------- */

class TcpServerTest : public testing::Test {
protected:
	SocketTcp<Ipv4> m_server;
	SocketTcp<Ipv4> m_client;

	std::thread m_tserver;
	std::thread m_tclient;

public:
	TcpServerTest()
	{
		m_server.set(SOL_SOCKET, SO_REUSEADDR, 1);
	}

	~TcpServerTest()
	{
		if (m_tserver.joinable())
			m_tserver.join();
		if (m_tclient.joinable())
			m_tclient.join();
	}
};

TEST_F(TcpServerTest, connect)
{
	m_tserver = std::thread([this] () {
		m_server.bind(Ipv4{"*", 16000});
		m_server.listen();
		m_server.accept();
		m_server.close();
	});

	std::this_thread::sleep_for(100ms);

	m_tclient = std::thread([this] () {
		m_client.connect(Ipv4{"127.0.0.1", 16000});
		m_client.close();
	});
}

TEST_F(TcpServerTest, io)
{
	m_tserver = std::thread([this] () {
		m_server.bind(Ipv4{"*", 16000});
		m_server.listen();

		auto client = m_server.accept();
		auto msg = client.recv(512);

		ASSERT_EQ("hello world", msg);

		client.send(msg);
	});

	std::this_thread::sleep_for(100ms);

	m_tclient = std::thread([this] () {
		m_client.connect(Ipv4{"127.0.0.1", 16000});
		m_client.send("hello world");

		ASSERT_EQ("hello world", m_client.recv(512));
	});
}

/* --------------------------------------------------------
 * UDP tests
 * -------------------------------------------------------- */

class UdpServerTest : public testing::Test {
protected:
	SocketUdp<Ipv4> m_server;
	SocketUdp<Ipv4> m_client;

	std::thread m_tserver;
	std::thread m_tclient;

public:
	UdpServerTest()
	{
		m_server.set(SOL_SOCKET, SO_REUSEADDR, 1);
	}

	~UdpServerTest()
	{
		if (m_tserver.joinable())
			m_tserver.join();
		if (m_tclient.joinable())
			m_tclient.join();
	}
};

TEST_F(UdpServerTest, io)
{
	m_tserver = std::thread([this] () {
		Ipv4 client;
		Ipv4 info{"*", 16000};

		m_server.bind(info);
		auto msg = m_server.recvfrom(512, &client);

		ASSERT_EQ("hello world", msg);

		m_server.sendto(msg, client);
		m_server.close();
	});

	std::this_thread::sleep_for(100ms);

	m_tclient = std::thread([this] () {
		Ipv4 info{"127.0.0.1", 16000};

		m_client.sendto("hello world", info);

		ASSERT_EQ("hello world", m_client.recvfrom(512, &info));

		m_client.close();
	});
}

/* --------------------------------------------------------
 * Listener: set function
 * -------------------------------------------------------- */

class TestBackendSet {
public:
	int m_callcount{0};
	bool m_added{false};
	int m_flags{0};

	inline void set(const ListenerTable &, Handle, int flags, bool add) noexcept
	{
		m_callcount ++;
		m_added = add;
		m_flags |= flags;
	}

	inline void unset(const ListenerTable &, Handle, int, bool) noexcept
	{
	}

	std::vector<ListenerStatus> wait(const ListenerTable &, int)
	{
		return {};
	}
};

class TestBackendSetFail {
public:
	inline void set(const ListenerTable &, Handle, int, bool)
	{
		throw "fail";
	}

	inline void unset(const ListenerTable &, Handle, int, bool) noexcept
	{
	}

	std::vector<ListenerStatus> wait(const ListenerTable &, int)
	{
		return {};
	}
};

TEST(ListenerSet, initialAdd)
{
	Listener<TestBackendSet> listener;
	Handle s{0};

	listener.set(s, FlagRead);

	ASSERT_EQ(1U, listener.size());
	ASSERT_EQ(1, listener.backend().m_callcount);
	ASSERT_TRUE(listener.backend().m_added);
	ASSERT_TRUE(listener.backend().m_flags == FlagRead);
}

TEST(ListenerSet, readThenWrite)
{
	Listener<TestBackendSet> listener;
	Handle s{0};

	listener.set(s, FlagRead);
	listener.set(s, FlagWrite);

	ASSERT_EQ(1U, listener.size());
	ASSERT_EQ(2, listener.backend().m_callcount);
	ASSERT_FALSE(listener.backend().m_added);
	ASSERT_TRUE(listener.backend().m_flags == 0x3);
}

TEST(ListenerSet, allOneShot)
{
	Listener<TestBackendSet> listener;
	Handle s{0};

	listener.set(s, FlagRead | FlagWrite);

	ASSERT_EQ(1U, listener.size());
	ASSERT_EQ(1, listener.backend().m_callcount);
	ASSERT_TRUE(listener.backend().m_added);
	ASSERT_TRUE(listener.backend().m_flags == 0x3);
}

TEST(ListenerSet, readTwice)
{
	Listener<TestBackendSet> listener;
	Handle s{0};

	listener.set(s, FlagRead);
	listener.set(s, FlagRead);

	ASSERT_EQ(1U, listener.size());
	ASSERT_EQ(1, listener.backend().m_callcount);
	ASSERT_TRUE(listener.backend().m_added);
	ASSERT_TRUE(listener.backend().m_flags == FlagRead);
}

TEST(ListenerSet, failure)
{
	Listener<TestBackendSetFail> listener;
	Handle s{0};

	try {
		listener.set(s, FlagRead);
		FAIL() << "exception expected";
	} catch (...) {
	}

	ASSERT_EQ(0U, listener.size());
}

/* --------------------------------------------------------
 * Listener: unset / remove functions
 * -------------------------------------------------------- */

class TestBackendUnset {
public:
	bool m_isset{false};
	bool m_isunset{false};
	int m_flags{0};
	bool m_removal{false};

	inline void set(const ListenerTable &, Handle &, int flags, bool) noexcept
	{
		m_isset = true;
		m_flags |= flags;
	}

	inline void unset(const ListenerTable &, Handle &, int flags, bool remove) noexcept
	{
		m_isunset = true;
		m_flags &= ~(flags);
		m_removal = remove;
	}

	std::vector<ListenerStatus> wait(const ListenerTable &, int) noexcept
	{
		return {};
	}
};

class TestBackendUnsetFail {
public:
	inline void set(const ListenerTable &, Handle &, int, bool) noexcept
	{
	}

	inline void unset(const ListenerTable &, Handle &, int, bool)
	{
		throw "fail";
	}

	std::vector<ListenerStatus> wait(const ListenerTable &, int)
	{
		return {};
	}
};

TEST(ListenerUnsetRemove, unset)
{
	Listener<TestBackendUnset> listener;
	Handle s{0};

	listener.set(s, FlagRead);
	listener.unset(s, FlagRead);

	ASSERT_EQ(0U, listener.size());
	ASSERT_TRUE(listener.backend().m_isset);
	ASSERT_TRUE(listener.backend().m_isunset);
	ASSERT_TRUE(listener.backend().m_flags == 0);
	ASSERT_TRUE(listener.backend().m_removal);
}

TEST(ListenerUnsetRemove, unsetOne)
{
	Listener<TestBackendUnset> listener;
	Handle s{0};

	listener.set(s, FlagRead | FlagWrite);
	listener.unset(s, FlagRead);

	ASSERT_EQ(1U, listener.size());
	ASSERT_TRUE(listener.backend().m_isset);
	ASSERT_TRUE(listener.backend().m_isunset);
	ASSERT_TRUE(listener.backend().m_flags == FlagWrite);
	ASSERT_FALSE(listener.backend().m_removal);
}

TEST(ListenerUnsetRemove, unsetAll)
{
	Listener<TestBackendUnset> listener;
	Handle s{0};

	listener.set(s, FlagRead | FlagWrite);
	listener.unset(s, FlagRead);
	listener.unset(s, FlagWrite);

	ASSERT_EQ(0U, listener.size());
	ASSERT_TRUE(listener.backend().m_isset);
	ASSERT_TRUE(listener.backend().m_isunset);
	ASSERT_TRUE(listener.backend().m_flags == 0);
	ASSERT_TRUE(listener.backend().m_removal);
}

TEST(ListenerUnsetRemove, remove)
{
	Listener<TestBackendUnset> listener;
	Handle s{0};

	listener.set(s, FlagRead | FlagWrite);
	listener.remove(s);

	ASSERT_EQ(0U, listener.size());
	ASSERT_TRUE(listener.backend().m_isset);
	ASSERT_TRUE(listener.backend().m_isunset);
	ASSERT_TRUE(listener.backend().m_flags == 0);
	ASSERT_TRUE(listener.backend().m_removal);
}

TEST(ListenerUnsetRemove, failure)
{
	Listener<TestBackendUnsetFail> listener;
	Handle s{0};

	listener.set(s, FlagRead | FlagWrite);

	try {
		listener.remove(s);
		FAIL() << "exception expected";
	} catch (...) {
	}

	/* If fail, kept into the table */
	ASSERT_EQ(1U, listener.size());
}

/* --------------------------------------------------------
 * Listener: system
 * -------------------------------------------------------- */

class ListenerTest : public testing::Test {
protected:
	Listener<Select> m_listener;
	SocketTcp<Ipv4> m_masterTcp;
	SocketTcp<Ipv4> m_clientTcp;

	std::thread m_tserver;
	std::thread m_tclient;

public:
	ListenerTest()
	{
		m_masterTcp.set(SOL_SOCKET, SO_REUSEADDR, 1);
		m_masterTcp.bind(Ipv4{"*", 16000});
		m_masterTcp.listen();
	}

	~ListenerTest()
	{
		if (m_tserver.joinable()) {
			m_tserver.join();
		}
		if (m_tclient.joinable()) {
			m_tclient.join();
		}
	}
};

TEST_F(ListenerTest, accept)
{
	m_tserver = std::thread([this] () {
		try {
			m_listener.set(m_masterTcp.handle(), FlagRead);
			m_listener.wait();
			m_masterTcp.accept();
			m_masterTcp.close();
		} catch (const std::exception &ex) {
			FAIL() << ex.what();
		}
	});

	std::this_thread::sleep_for(100ms);

	m_tclient = std::thread([this] () {
		m_clientTcp.connect(Ipv4{"127.0.0.1", 16000});
	});
}

TEST_F(ListenerTest, recv)
{
	m_tserver = std::thread([this] () {
		try {
			m_listener.set(m_masterTcp.handle(), FlagRead);
			m_listener.wait();

			auto sc = m_masterTcp.accept();

			ASSERT_EQ("hello", sc.recv(512));
		} catch (const std::exception &ex) {
			FAIL() << ex.what();
		}
	});

	std::this_thread::sleep_for(100ms);

	m_tclient = std::thread([this] () {
		m_clientTcp.connect(Ipv4{"127.0.0.1", 16000});
		m_clientTcp.send("hello");
	});
}

/* --------------------------------------------------------
 * Non-blocking connect
 * -------------------------------------------------------- */

class NonBlockingConnectTest : public testing::Test {
protected:
	SocketTcp<Ipv4> m_server;
	SocketTcp<Ipv4> m_client;

	std::thread m_tserver;
	std::thread m_tclient;

public:
	NonBlockingConnectTest()
	{
		m_client.setBlockMode(false);
	}

	~NonBlockingConnectTest()
	{
		if (m_tserver.joinable())
			m_tserver.join();
		if (m_tclient.joinable())
			m_tclient.join();
	}
};

/* --------------------------------------------------------
 * TCP accept
 * -------------------------------------------------------- */

class TcpAcceptTest : public testing::Test {
protected:
	SocketTcp<Ipv4> m_server;
	SocketTcp<Ipv4> m_client;

	std::thread m_tserver;
	std::thread m_tclient;

public:
	TcpAcceptTest()
	{
		m_server.set(SOL_SOCKET, SO_REUSEADDR, 1);
		m_server.bind(Ipv4{"*", 16000});
		m_server.listen();
	}

	~TcpAcceptTest()
	{
		if (m_tserver.joinable())
			m_tserver.join();
		if (m_tclient.joinable())
			m_tclient.join();
	}
};

/* --------------------------------------------------------
 * TCP recv
 * -------------------------------------------------------- */

class TcpRecvTest : public testing::Test {
protected:
	SocketTcp<Ipv4> m_server;
	SocketTcp<Ipv4> m_client;

	std::thread m_tserver;
	std::thread m_tclient;

public:
	TcpRecvTest()
	{
		m_server.set(SOL_SOCKET, SO_REUSEADDR, 1);
		m_server.bind(Ipv4{"*", 16000});
		m_server.listen();
	}

	~TcpRecvTest()
	{
		if (m_tserver.joinable())
			m_tserver.join();
		if (m_tclient.joinable())
			m_tclient.join();
	}
};

TEST_F(TcpRecvTest, blockingSuccess)
{
	m_tserver = std::thread([this] () {
		auto client = m_server.accept();

		ASSERT_EQ("hello", client.recv(32));
	});

	std::this_thread::sleep_for(100ms);

	m_tclient = std::thread([this] () {
		m_client.connect(Ipv4{"127.0.0.1", 16000});
		m_client.send("hello");
		m_client.close();
	});
}

/* --------------------------------------------------------
 * Socket SSL
 * -------------------------------------------------------- */

class TlsRecvTest : public testing::Test {
protected:
	SocketTls<Ipv4> m_server;
	SocketTls<Ipv4> m_client;

	std::thread m_tserver;
	std::thread m_tclient;

public:
	TlsRecvTest()
		: m_server{Ipv4{}, Tls{Tls::Tlsv1, false, "Socket/test.key", "Socket/test.crt"}}
		, m_client{Ipv4{}, Tls{Tls::Tlsv1, false, "", "Socket/test.crt"}}
	{
		m_server.set(SOL_SOCKET, SO_REUSEADDR, 1);
		m_server.bind(Ipv4{"*", 16000});
		m_server.listen();
	}

	~TlsRecvTest()
	{
		if (m_tserver.joinable())
			m_tserver.join();
		if (m_tclient.joinable())
			m_tclient.join();
	}
};

TEST_F(TlsRecvTest, blockingSuccess)
{
	m_tserver = std::thread([this] () {
		try {
			auto client = m_server.accept();

			ASSERT_EQ("hello", client.recv(32));
		} catch (const std::exception &ex) {
			FAIL() << ex.what();
		}
	});

	std::this_thread::sleep_for(100ms);

	m_tclient = std::thread([this] () {
		try {
			m_client.connect(Ipv4{"127.0.0.1", 16000});
			m_client.send("hello");
			m_client.close();
		} catch (const std::exception &ex) {
			FAIL() << ex.what();
		}
	});
}

int main(int argc, char **argv)
{
	testing::InitGoogleTest(&argc, argv);

	return RUN_ALL_TESTS();
}