view C++/Socket.cpp @ 296:5806c767aec7

Zip: fix resize if read less than requested Task: #312
author David Demelier <markand@malikania.fr>
date Thu, 13 Nov 2014 21:10:13 +0100
parents 9b3270513f40
children 836903141476 24085fae3162
line wrap: on
line source

/*
 * Socket.cpp -- portable C++ socket wrappers
 *
 * Copyright (c) 2013, David Demelier <markand@malikania.fr>
 *
 * Permission to use, copy, modify, and/or distribute this software for any
 * purpose with or without fee is hereby granted, provided that the above
 * copyright notice and this permission notice appear in all copies.
 *
 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 */

#include <cstring>

#include "Socket.h"
#include "SocketAddress.h"

/* --------------------------------------------------------
 * Socket exceptions
 * -------------------------------------------------------- */

namespace error {

Timeout::Timeout(std::string func)
	: m_error(func + ": Timeout exception")
{
}

const char *Timeout::what() const noexcept
{
	return m_error.c_str();
}

InProgress::InProgress(std::string func)
	: m_error(func + ": Operation in progress")
{
}

const char *InProgress::what() const noexcept
{
	return m_error.c_str();
}

WouldBlock::WouldBlock(std::string func)
	: m_error(func + ": Operation would block")
{
}

const char *WouldBlock::what() const noexcept
{
	return m_error.c_str();
}

Failure::Failure(std::string func, std::string message)
	: m_error(func + ": " + message)
{
}

const char *Failure::what() const noexcept
{
	return m_error.c_str();
}

} // !error

/* --------------------------------------------------------
 * Socket implementation
 * -------------------------------------------------------- */

void Socket::init()
{
#if defined(_WIN32)
	WSADATA wsa;
	WSAStartup(MAKEWORD(2, 2), &wsa);
#endif
}

void Socket::finish()
{
#if defined(_WIN32)
	WSACleanup();
#endif
}

/* --------------------------------------------------------
 * System dependent code
 * -------------------------------------------------------- */

#if defined(_WIN32)

std::string Socket::syserror(int errn)
{
	LPSTR str = nullptr;
	std::string errmsg = "Unknown error";

	FormatMessageA(
		FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM,
		NULL,
		errn,
		MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT),
		(LPSTR)&str, 0, NULL);


	if (str) {
		errmsg = std::string(str);
		LocalFree(str);
	}

	return errmsg;
}

#else

#include <cerrno>

std::string Socket::syserror(int errn)
{
	return strerror(errn);
}

#endif

std::string Socket::syserror()
{
#if defined(_WIN32)
	return syserror(WSAGetLastError());
#else
	return syserror(errno);
#endif
}

/* --------------------------------------------------------
 * Standard clear implementation
 * -------------------------------------------------------- */

class Standard final : public SocketInterface {
public:
	void bind(Socket &s, const SocketAddress &address) override;
	void close(Socket &s) override;
	void connect(Socket &s, const SocketAddress &address) override;
	Socket accept(Socket &s, SocketAddress &info) override;
	void listen(Socket &s, int max) override;
	unsigned recv(Socket &s, void *data, unsigned len) override;
	unsigned recvfrom(Socket &s, void *data, unsigned len, SocketAddress &info) override ;
	unsigned send(Socket &s, const void *data, unsigned len) override;
	unsigned sendto(Socket &s, const void *data, unsigned len, const SocketAddress &info) override;
};

void Standard::bind(Socket &s, const SocketAddress &addr)
{
	auto &sa = addr.address();
	auto addrlen = addr.length();

	if (::bind(s.handle(), (sockaddr *)&sa, addrlen) == SOCKET_ERROR)
		throw error::Failure("bind", Socket::syserror());
}

void Standard::connect(Socket &s, const SocketAddress &addr)
{
	auto &sa = addr.address();
	auto addrlen = addr.length();

	if (::connect(s.handle(), (sockaddr *)&sa, addrlen) == SOCKET_ERROR) {
		/*
		 * Determine if the error comes from a non-blocking connect that cannot be
		 * accomplished yet.
		 */
#if defined(_WIN32)
		if (WSAGetLastError() == WSAEWOULDBLOCK)
			throw error::InProgress("connect");

		throw error::Failure("connect", Socket::syserror());
#else
		if (errno == EINPROGRESS)
			throw error::InProgress("connect");

		throw error::Failure("connect", Socket::syserror());
#endif
	}
}

Socket Standard::accept(Socket &s, SocketAddress &info)
{
	Socket::Handle handle;

	// Store the information
	sockaddr_storage address;
	socklen_t addrlen;

	addrlen = sizeof (sockaddr_storage);
	handle = ::accept(s.handle(), (sockaddr *)&address, &addrlen);

	if (handle == INVALID_SOCKET) {
#if defined(_WIN32)
		if (WSAGetLastError() == WSAEWOULDBLOCK)
			throw error::WouldBlock("accept");

		throw error::Failure("accept", Socket::syserror());
#else
		if (errno == EAGAIN || errno == EWOULDBLOCK)
			throw error::WouldBlock("accept");

		throw error::Failure("accept", Socket::syserror());
#endif
	}

	// Usually accept works only with SOCK_STREAM
	info = SocketAddress(address, addrlen);

	return Socket(handle, std::make_shared<Standard>());
}

void Standard::listen(Socket &s, int max)
{
	if (::listen(s.handle(), max) == SOCKET_ERROR)
		throw error::Failure("listen", Socket::syserror());
}

unsigned Standard::recv(Socket &s, void *data, unsigned dataLen)
{
	int nbread;

	nbread = ::recv(s.handle(), (Socket::Arg)data, dataLen, 0);
	if (nbread == SOCKET_ERROR) {
#if defined(_WIN32)
		if (WSAGetLastError() == WSAEWOULDBLOCK)
			throw error::WouldBlock("recv");

		throw error::Failure("recv", Socket::syserror());
#else
		if (errno == EAGAIN || errno == EWOULDBLOCK)
			throw error::WouldBlock("recv");

		throw error::Failure("recv", Socket::syserror());
#endif
	}

	return (unsigned)nbread;
}

unsigned Standard::send(Socket &s, const void *data, unsigned dataLen)
{
	int nbsent;

	nbsent = ::send(s.handle(), (Socket::ConstArg)data, dataLen, 0);
	if (nbsent == SOCKET_ERROR) {
#if defined(_WIN32)
		if (WSAGetLastError() == WSAEWOULDBLOCK)
			throw error::WouldBlock("send");

		throw error::Failure("send", Socket::syserror());
#else
		if (errno == EAGAIN || errno == EWOULDBLOCK)
			throw error::WouldBlock("send");

		throw error::Failure("send", Socket::syserror());
#endif
	}

	return (unsigned)nbsent;
}

unsigned Standard::recvfrom(Socket &s, void *data, unsigned dataLen, SocketAddress &info)
{
	int nbread;

	// Store information
	sockaddr_storage address;
	socklen_t addrlen;

	addrlen = sizeof (struct sockaddr_storage);
	nbread = ::recvfrom(s.handle(), (Socket::Arg)data, dataLen, 0, (sockaddr *)&address, &addrlen);

	info = SocketAddress(address, addrlen);

	if (nbread == SOCKET_ERROR) {
#if defined(_WIN32)
		if (WSAGetLastError() == WSAEWOULDBLOCK)
			throw error::WouldBlock("recvfrom");

		throw error::Failure("recvfrom", Socket::syserror());
#else
		if (errno == EAGAIN || errno == EWOULDBLOCK)
			throw error::WouldBlock("recvfrom");

		throw error::Failure("recvfrom", Socket::syserror());
#endif
	}

	return (unsigned)nbread;
}

unsigned Standard::sendto(Socket &s, const void *data, unsigned dataLen, const SocketAddress &info)
{
	int nbsent;

	nbsent = ::sendto(s.handle(), (Socket::ConstArg)data, dataLen, 0, (const sockaddr *)&info.address(), info.length());
	if (nbsent == SOCKET_ERROR) {
#if defined(_WIN32)
		if (WSAGetLastError() == WSAEWOULDBLOCK)
			throw error::WouldBlock("sendto");

		throw error::Failure("sendto", Socket::syserror());
#else
		if (errno == EAGAIN || errno == EWOULDBLOCK)
			throw error::WouldBlock("sendto");

		throw error::Failure("sendto", Socket::syserror());
#endif
	}

	return (unsigned)nbsent;
}

void Standard::close(Socket &s)
{
	(void)closesocket(s.handle());
}

/* --------------------------------------------------------
 * Socket code
 * -------------------------------------------------------- */

Socket::Socket()
	: m_interface(std::make_shared<Standard>())
{
}

Socket::Socket(int domain, int type, int protocol)
	: Socket()
{
	m_handle = socket(domain, type, protocol);

	if (m_handle == INVALID_SOCKET)
		throw error::Failure("socket", syserror());
}

Socket::Socket(Handle handle, std::shared_ptr<SocketInterface> iface)
	: m_interface(std::move(iface))
	, m_handle(handle)
{
}

Socket::Handle Socket::handle() const
{
	return m_handle;
}

void Socket::blockMode(bool block)
{
#if defined(O_NONBLOCK) && !defined(_WIN32)
	int flags;

	if ((flags = fcntl(m_handle, F_GETFL, 0)) == -1)
		flags = 0;

	if (block)
		flags &= ~(O_NONBLOCK);
	else
		flags |= O_NONBLOCK;

	if (fcntl(m_handle, F_SETFL, flags) == -1)
		throw error::Failure("blockMode", Socket::syserror());
#else
	unsigned long flags = (block) ? 0 : 1;

	if (ioctlsocket(m_handle, FIONBIO, &flags) == SOCKET_ERROR)
		throw error::Failure("blockMode", Socket::syserror());
#endif
}

Socket Socket::accept()
{
	SocketAddress dummy;

	return m_interface->accept(*this, dummy);
}

unsigned Socket::send(const std::string &message)
{
	return Socket::send(message.c_str(), message.length());
}

unsigned Socket::recvfrom(void *data, unsigned dataLen)
{
	SocketAddress dummy;

	return m_interface->recvfrom(*this, data, dataLen, dummy);
}

unsigned Socket::sendto(const std::string &message, const SocketAddress &info)
{
	return sendto(message.c_str(), message.length(), info);
}

bool operator==(const Socket &s1, const Socket &s2)
{
	return s1.handle() == s2.handle();
}

bool operator<(const Socket &s1, const Socket &s2)
{
	return s1.handle() < s2.handle();
}