changeset 477:65b567c8de54

Base64: - Add some asserts, - Add new isValid function, - Throw more errors on invalid strings.
author David Demelier <markand@malikania.fr>
date Tue, 10 Nov 2015 13:32:51 +0100
parents 1ff22c1cb32e
children 453f22449b33
files C++/modules/Base64/Base64.cpp C++/modules/Base64/Base64.h C++/tests/Base64/main.cpp
diffstat 3 files changed, 159 insertions(+), 19 deletions(-) [+]
line wrap: on
line diff
--- a/C++/modules/Base64/Base64.cpp	Tue Nov 10 11:08:38 2015 +0100
+++ b/C++/modules/Base64/Base64.cpp	Tue Nov 10 13:32:51 2015 +0100
@@ -16,35 +16,105 @@
  * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
  */
 
+#include <cassert>
 #include <iterator>
 #include <sstream>
+#include <unordered_map>
 
 #include "Base64.h"
 
 namespace base64 {
 
-char lookup(int value) noexcept
+namespace {
+
+const std::string table{"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"};
+
+const std::unordered_map<unsigned char, unsigned int> rtable{
+	{ 'A', 0  },
+	{ 'B', 1  },
+	{ 'C', 2  },
+	{ 'D', 3  },
+	{ 'E', 4  },
+	{ 'F', 5  },
+	{ 'G', 6  },
+	{ 'H', 7  },
+	{ 'I', 8  },
+	{ 'J', 9  },
+	{ 'K', 10 },
+	{ 'L', 11 },
+	{ 'M', 12 },
+	{ 'N', 13 },
+	{ 'O', 14 },
+	{ 'P', 15 },
+	{ 'Q', 16 },
+	{ 'R', 17 },
+	{ 'S', 18 },
+	{ 'T', 19 },
+	{ 'U', 20 },
+	{ 'V', 21 },
+	{ 'W', 22 },
+	{ 'X', 23 },
+	{ 'Y', 24 },
+	{ 'Z', 25 },
+	{ 'a', 26 },
+	{ 'b', 27 },
+	{ 'c', 28 },
+	{ 'd', 29 },
+	{ 'e', 30 },
+	{ 'f', 31 },
+	{ 'g', 32 },
+	{ 'h', 33 },
+	{ 'i', 34 },
+	{ 'j', 35 },
+	{ 'k', 36 },
+	{ 'l', 37 },
+	{ 'm', 38 },
+	{ 'n', 39 },
+	{ 'o', 40 },
+	{ 'p', 41 },
+	{ 'q', 42 },
+	{ 'r', 43 },
+	{ 's', 44 },
+	{ 't', 45 },
+	{ 'u', 46 },
+	{ 'v', 47 },
+	{ 'w', 48 },
+	{ 'x', 49 },
+	{ 'y', 50 },
+	{ 'z', 51 },
+	{ '0', 52 },
+	{ '1', 53 },
+	{ '2', 54 },
+	{ '3', 55 },
+	{ '4', 56 },
+	{ '5', 57 },
+	{ '6', 58 },
+	{ '7', 59 },
+	{ '8', 60 },
+	{ '9', 61 },
+	{ '+', 62 },
+	{ '/', 63 }
+};
+
+} // !namespace
+
+unsigned char lookup(unsigned int value) noexcept
 {
-	static const char table[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
+	assert(value < 64);
 
 	return table[value];
 }
 
-int rlookup(char ch)
+unsigned int rlookup(unsigned char ch)
 {
-	if (ch == '+')
-		return 62;
-	if (ch == '/')
-		return 63;
+	assert(rtable.count(ch) != 0 && ch != '=');
 
-	if (ch >= '0' && ch <= '9')
-		return ch + 4;
-	if (ch >= 'A' && ch <= 'Z')
-		return ch - 65;
-	if (ch >= 'a' && ch <= 'z')
-		return ch - 71;
+	return rtable.at(ch);
+}
 
-	throw std::invalid_argument("not a valid base64 string");
+bool isValid(unsigned char ch) noexcept
+{
+	return ch == '=' || rtable.count(ch);
 }
 
 std::string encode(const std::string &input)
--- a/C++/modules/Base64/Base64.h	Tue Nov 10 11:08:38 2015 +0100
+++ b/C++/modules/Base64/Base64.h	Tue Nov 10 13:32:51 2015 +0100
@@ -32,16 +32,26 @@
 /**
  * Get the base 64 character from the 6-bits value.
  *
+ * @pre value must be valid
  * @param value the value
  */
-char lookup(int value) noexcept;
+unsigned char lookup(unsigned int value) noexcept;
 
 /**
  * Get the integer value from the base 64 character.
  *
+ * @pre ch must be a valid base 64 character but not '='
  * @param ch the base64 character
  */
-int rlookup(char ch);
+unsigned int rlookup(unsigned char ch);
+
+/**
+ * Check if the given character is a valid base 64 character.
+ *
+ * @param char the character
+ * @return true if the character is valid
+ */
+bool isValid(unsigned char) noexcept;
 
 /**
  * Encode the input to the output. Requirements:
@@ -60,8 +70,9 @@
 		char inputbuf[3] = { 0, 0, 0 };
 		int count;
 
-		for (count = 0; count < 3 && input != end; ++count)
+		for (count = 0; count < 3 && input != end; ++count) {
 			inputbuf[count] = *input++;
+		}
 
 		*output++ = lookup(inputbuf[0] >> 2 & 0x3f);
 		*output++ = lookup((inputbuf[0] << 4 & 0x3f) | (inputbuf[1] >> 4 & 0x0f));
@@ -87,11 +98,24 @@
 OutputIt decode(InputIt input, InputIt end, OutputIt output)
 {
 	while (input != end) {
-		char inputbuf[4] = { 0, 0, 0, 0 };
+		char inputbuf[4] = { -1, -1, -1, -1 };
 		int count;
 
 		for (count = 0; count < 4 && input != end; ++count) {
-			inputbuf[count] = (*input == '=') ? -1: rlookup(*input);
+			if (*input != '=') {
+				/* Check if the character is valid and get its value */
+				if (isValid(*input)) {
+					inputbuf[count] = rlookup(*input);
+				} else {
+					throw std::invalid_argument{"invalid base64 string"};
+				}
+			} else {
+				/* A base 64 string cannot start with "=" or "==" */
+				if (count == 0 || count == 1) {
+					throw std::invalid_argument{"invalid or truncated base64 string"};
+				}
+			}
+			
 			input++;
 		}
 
--- a/C++/tests/Base64/main.cpp	Tue Nov 10 11:08:38 2015 +0100
+++ b/C++/tests/Base64/main.cpp	Tue Nov 10 13:32:51 2015 +0100
@@ -174,6 +174,52 @@
 
 	ASSERT_EQ("hello", base64::decode("aGVsbG8="));
 	ASSERT_EQ("this is a long sentence", base64::decode("dGhpcyBpcyBhIGxvbmcgc2VudGVuY2U="));
+	ASSERT_EQ("Welcome to our server dude", base64::decode("V2VsY29tZSB0byBvdXIgc2VydmVyIGR1ZGU="));
+}
+
+TEST(Failure, truncated)
+{
+	try {
+		base64::decode("YW=");
+
+		FAIL() << "exception expected";
+	} catch (...) {}
+}
+
+TEST(Failure, invalid)
+{
+	try {
+		base64::decode("?!");
+
+		FAIL() << "exception expected";
+	} catch (...) {}
+}
+
+TEST(Failure, wrong1)
+{
+	try {
+		base64::decode("=ABC");
+
+		FAIL() << "exception expected";
+	} catch (...) {}
+}
+
+TEST(Failure, wrong2)
+{
+	try {
+		base64::decode("A=BC");
+
+		FAIL() << "exception expected";
+	} catch (...) {}
+}
+
+TEST(Failure, wrong3)
+{
+	try {
+		base64::decode("==BC");
+
+		FAIL() << "exception expected";
+	} catch (...) {}
 }
 
 int main(int argc, char **argv)