00001 #include "factory.h"
00002 #include "integer.h"
00003 #include "filters.h"
00004 #include "hex.h"
00005 #include "randpool.h"
00006 #include "files.h"
00007 #include "trunhash.h"
00008 #include <iostream>
00009 #include <memory>
00010
00011 USING_NAMESPACE(CryptoPP)
00012 USING_NAMESPACE(std)
00013
00014 RandomPool & GlobalRNG();
00015 void RegisterFactories();
00016
00017 typedef std::map<std::string, std::string> TestData;
00018
00019 class TestFailure : public Exception
00020 {
00021 public:
00022 TestFailure() : Exception(OTHER_ERROR, "Validation test failed") {}
00023 };
00024
00025 static const TestData *s_currentTestData = NULL;
00026
00027 void OutputTestData(const TestData &v)
00028 {
00029 for (TestData::const_iterator i = v.begin(); i != v.end(); ++i)
00030 {
00031 cerr << i->first << ": " << i->second << endl;
00032 }
00033 }
00034
00035 void SignalTestFailure()
00036 {
00037 OutputTestData(*s_currentTestData);
00038 throw TestFailure();
00039 }
00040
00041 void SignalTestError()
00042 {
00043 OutputTestData(*s_currentTestData);
00044 throw Exception(Exception::OTHER_ERROR, "Unexpected error during validation test");
00045 }
00046
00047 class TestDataNameValuePairs : public NameValuePairs
00048 {
00049 public:
00050 TestDataNameValuePairs(const TestData &data) : m_data(data) {}
00051
00052 virtual bool GetVoidValue(const char *name, const std::type_info &valueType, void *pValue) const
00053 {
00054 TestData::const_iterator i = m_data.find(name);
00055 if (i == m_data.end())
00056 return false;
00057
00058 const std::string &value = i->second;
00059
00060 if (valueType == typeid(int))
00061 *reinterpret_cast<int *>(pValue) = atoi(value.c_str());
00062 else if (valueType == typeid(Integer))
00063 *reinterpret_cast<Integer *>(pValue) = Integer((std::string(value) + "h").c_str());
00064 else
00065 throw ValueTypeMismatch(name, typeid(std::string), valueType);
00066
00067 return true;
00068 }
00069
00070 private:
00071 const TestData &m_data;
00072 };
00073
00074 const std::string & GetRequiredDatum(const TestData &data, const char *name)
00075 {
00076 TestData::const_iterator i = data.find(name);
00077 if (i == data.end())
00078 SignalTestError();
00079 return i->second;
00080 }
00081
00082 void PutDecodedDatumInto(const TestData &data, const char *name, BufferedTransformation &target)
00083 {
00084 std::string s1 = GetRequiredDatum(data, name), s2;
00085
00086 int repeat = 1;
00087 if (s1[0] == 'r')
00088 {
00089 repeat = atoi(s1.c_str()+1);
00090 s1 = s1.substr(s1.find(' ')+1);
00091 }
00092
00093 if (s1[0] == '\"')
00094 s2 = s1.substr(1, s1.find('\"', 1)-1);
00095 else if (s1.substr(0, 2) == "0x")
00096 StringSource(s1.substr(2), true, new HexDecoder(new StringSink(s2)));
00097 else
00098 StringSource(s1, true, new HexDecoder(new StringSink(s2)));
00099
00100 while (repeat--)
00101 target.Put((const byte *)s2.data(), s2.size());
00102 }
00103
00104 std::string GetDecodedDatum(const TestData &data, const char *name)
00105 {
00106 std::string s;
00107 PutDecodedDatumInto(data, name, StringSink(s).Ref());
00108 return s;
00109 }
00110
00111 void TestKeyPairValidAndConsistent(CryptoMaterial &pub, const CryptoMaterial &priv)
00112 {
00113 if (!pub.Validate(GlobalRNG(), 3))
00114 SignalTestFailure();
00115 if (!priv.Validate(GlobalRNG(), 3))
00116 SignalTestFailure();
00117
00118
00119
00120
00121
00122
00123
00124
00125 }
00126
00127 void TestSignatureScheme(TestData &v)
00128 {
00129 std::string name = GetRequiredDatum(v, "Name");
00130 std::string test = GetRequiredDatum(v, "Test");
00131
00132 std::auto_ptr<PK_Signer> signer(ObjectFactoryRegistry<PK_Signer>::Registry().CreateObject(name.c_str()));
00133 std::auto_ptr<PK_Verifier> verifier(ObjectFactoryRegistry<PK_Verifier>::Registry().CreateObject(name.c_str()));
00134
00135 TestDataNameValuePairs pairs(v);
00136 std::string keyFormat = GetRequiredDatum(v, "KeyFormat");
00137
00138 if (keyFormat == "DER")
00139 verifier->AccessMaterial().Load(StringStore(GetDecodedDatum(v, "PublicKey")).Ref());
00140 else if (keyFormat == "Component")
00141 verifier->AccessMaterial().AssignFrom(pairs);
00142
00143 if (test == "Verify" || test == "NotVerify")
00144 {
00145 VerifierFilter verifierFilter(*verifier, NULL, VerifierFilter::SIGNATURE_AT_BEGIN);
00146 PutDecodedDatumInto(v, "Signature", verifierFilter);
00147 PutDecodedDatumInto(v, "Message", verifierFilter);
00148 verifierFilter.MessageEnd();
00149 if (verifierFilter.GetLastResult() == (test == "NotVerify"))
00150 SignalTestFailure();
00151 }
00152 else if (test == "PublicKeyValid")
00153 {
00154 if (!verifier->GetMaterial().Validate(GlobalRNG(), 3))
00155 SignalTestFailure();
00156 }
00157 else
00158 goto privateKeyTests;
00159
00160 return;
00161
00162 privateKeyTests:
00163 if (keyFormat == "DER")
00164 signer->AccessMaterial().Load(StringStore(GetDecodedDatum(v, "PrivateKey")).Ref());
00165 else if (keyFormat == "Component")
00166 signer->AccessMaterial().AssignFrom(pairs);
00167
00168 if (test == "KeyPairValidAndConsistent")
00169 {
00170 TestKeyPairValidAndConsistent(verifier->AccessMaterial(), signer->GetMaterial());
00171 }
00172 else if (test == "Sign")
00173 {
00174 SignerFilter f(GlobalRNG(), *signer, new HexEncoder(new FileSink(cout)));
00175 StringSource ss(GetDecodedDatum(v, "Message"), true, new Redirector(f));
00176 SignalTestFailure();
00177 }
00178 else if (test == "DeterministicSign")
00179 {
00180 SignalTestError();
00181 assert(false);
00182 }
00183 else if (test == "RandomSign")
00184 {
00185 SignalTestError();
00186 assert(false);
00187 }
00188 else if (test == "GenerateKey")
00189 {
00190 SignalTestError();
00191 assert(false);
00192 }
00193 else
00194 {
00195 SignalTestError();
00196 assert(false);
00197 }
00198 }
00199
00200 void TestEncryptionScheme(TestData &v)
00201 {
00202 std::string name = GetRequiredDatum(v, "Name");
00203 std::string test = GetRequiredDatum(v, "Test");
00204
00205 std::auto_ptr<PK_Encryptor> encryptor(ObjectFactoryRegistry<PK_Encryptor>::Registry().CreateObject(name.c_str()));
00206 std::auto_ptr<PK_Decryptor> decryptor(ObjectFactoryRegistry<PK_Decryptor>::Registry().CreateObject(name.c_str()));
00207
00208 std::string keyFormat = GetRequiredDatum(v, "KeyFormat");
00209
00210 if (keyFormat == "DER")
00211 {
00212 decryptor->AccessMaterial().Load(StringStore(GetDecodedDatum(v, "PrivateKey")).Ref());
00213 encryptor->AccessMaterial().Load(StringStore(GetDecodedDatum(v, "PublicKey")).Ref());
00214 }
00215 else if (keyFormat == "Component")
00216 {
00217 TestDataNameValuePairs pairs(v);
00218 decryptor->AccessMaterial().AssignFrom(pairs);
00219 encryptor->AccessMaterial().AssignFrom(pairs);
00220 }
00221
00222 if (test == "DecryptMatch")
00223 {
00224 std::string decrypted, expected = GetDecodedDatum(v, "Plaintext");
00225 StringSource ss(GetDecodedDatum(v, "Ciphertext"), true, new PK_DecryptorFilter(GlobalRNG(), *decryptor, new StringSink(decrypted)));
00226 if (decrypted != expected)
00227 SignalTestFailure();
00228 }
00229 else if (test == "KeyPairValidAndConsistent")
00230 {
00231 TestKeyPairValidAndConsistent(encryptor->AccessMaterial(), decryptor->GetMaterial());
00232 }
00233 else
00234 {
00235 SignalTestError();
00236 assert(false);
00237 }
00238 }
00239
00240 void TestDigestOrMAC(TestData &v, bool testDigest)
00241 {
00242 std::string name = GetRequiredDatum(v, "Name");
00243 std::string test = GetRequiredDatum(v, "Test");
00244
00245 member_ptr<MessageAuthenticationCode> mac;
00246 member_ptr<HashTransformation> hash;
00247 HashTransformation *pHash = NULL;
00248
00249 if (testDigest)
00250 {
00251 hash.reset(ObjectFactoryRegistry<HashTransformation>::Registry().CreateObject(name.c_str()));
00252 pHash = hash.get();
00253 }
00254 else
00255 {
00256 mac.reset(ObjectFactoryRegistry<MessageAuthenticationCode>::Registry().CreateObject(name.c_str()));
00257 pHash = mac.get();
00258 std::string key = GetDecodedDatum(v, "Key");
00259 mac->SetKey((const byte *)key.c_str(), key.size());
00260 }
00261
00262 if (test == "Verify" || test == "VerifyTruncated" || test == "NotVerify")
00263 {
00264 int digestSize = pHash->DigestSize();
00265 if (test == "VerifyTruncated")
00266 digestSize = atoi(GetRequiredDatum(v, "TruncatedSize").c_str());
00267 TruncatedHashModule thash(*pHash, digestSize);
00268 HashVerificationFilter verifierFilter(thash, NULL, HashVerificationFilter::HASH_AT_BEGIN);
00269 PutDecodedDatumInto(v, "Digest", verifierFilter);
00270 PutDecodedDatumInto(v, "Message", verifierFilter);
00271 verifierFilter.MessageEnd();
00272 if (verifierFilter.GetLastResult() == (test == "NotVerify"))
00273 SignalTestFailure();
00274 }
00275 else
00276 {
00277 SignalTestError();
00278 assert(false);
00279 }
00280 }
00281
00282 bool GetField(std::istream &is, std::string &name, std::string &value)
00283 {
00284 name.resize(0);
00285 is >> name;
00286 if (name.empty())
00287 return false;
00288
00289 if (name[name.size()-1] != ':')
00290 SignalTestError();
00291 name.erase(name.size()-1);
00292
00293 while (is.peek() == ' ')
00294 is.ignore(1);
00295
00296
00297 char buffer[128];
00298 value.resize(0);
00299 bool continueLine;
00300
00301 do
00302 {
00303 do
00304 {
00305 is.get(buffer, sizeof(buffer));
00306 value += buffer;
00307 }
00308 while (buffer[0] != 0);
00309 is.clear();
00310 is.ignore();
00311
00312 if (value[value.size()-1] == '\\')
00313 {
00314 value.resize(value.size()-1);
00315 continueLine = true;
00316 }
00317 else
00318 continueLine = false;
00319
00320 std::string::size_type i = value.find('#');
00321 if (i != std::string::npos)
00322 value.erase(i);
00323 }
00324 while (continueLine);
00325
00326 return true;
00327 }
00328
00329 void OutputPair(const NameValuePairs &v, const char *name)
00330 {
00331 Integer x;
00332 bool b = v.GetValue(name, x);
00333 assert(b);
00334 cout << name << ": \\\n ";
00335 x.Encode(HexEncoder(new FileSink(cout), false, 64, "\\\n ").Ref(), x.MinEncodedSize());
00336 cout << endl;
00337 }
00338
00339 void OutputNameValuePairs(const NameValuePairs &v)
00340 {
00341 std::string names = v.GetValueNames();
00342 string::size_type i = 0;
00343 while (i < names.size())
00344 {
00345 string::size_type j = names.find_first_of (';', i);
00346
00347 if (j == string::npos)
00348 return;
00349 else
00350 {
00351 std::string name = names.substr(i, j-i);
00352 if (name.find(':') == string::npos)
00353 OutputPair(v, name.c_str());
00354 }
00355
00356 i = j + 1;
00357 }
00358 }
00359
00360 void TestDataFile(const std::string &filename, unsigned int &totalTests, unsigned int &failedTests)
00361 {
00362 std::ifstream file(filename.c_str());
00363 TestData v;
00364 s_currentTestData = &v;
00365 std::string name, value, lastAlgName;
00366
00367 while (file)
00368 {
00369 while (file.peek() == '#')
00370 file.ignore(INT_MAX, '\n');
00371
00372 if (file.peek() == '\n')
00373 v.clear();
00374
00375 if (!GetField(file, name, value))
00376 break;
00377 v[name] = value;
00378
00379 if (name == "Test")
00380 {
00381 bool failed = true;
00382 std::string algType = GetRequiredDatum(v, "AlgorithmType");
00383
00384 if (lastAlgName != GetRequiredDatum(v, "Name"))
00385 {
00386 lastAlgName = GetRequiredDatum(v, "Name");
00387 cout << "\nTesting " << algType.c_str() << " algorithm " << lastAlgName.c_str() << ".\n";
00388 }
00389
00390 try
00391 {
00392 if (algType == "Signature")
00393 TestSignatureScheme(v);
00394 else if (algType == "AsymmetricCipher")
00395 TestEncryptionScheme(v);
00396 else if (algType == "MessageDigest")
00397 TestDigestOrMAC(v, true);
00398 else if (algType == "MAC")
00399 TestDigestOrMAC(v, false);
00400 else if (algType == "FileList")
00401 TestDataFile(GetRequiredDatum(v, "Test"), totalTests, failedTests);
00402 else
00403 SignalTestError();
00404 failed = false;
00405 }
00406 catch (TestFailure &)
00407 {
00408 cout << "\nTest failed.\n";
00409 }
00410 catch (CryptoPP::Exception &e)
00411 {
00412 cout << "\nCryptoPP::Exception caught: " << e.what() << endl;
00413 }
00414 catch (std::exception &e)
00415 {
00416 cout << "\nstd::exception caught: " << e.what() << endl;
00417 }
00418
00419 if (failed)
00420 {
00421 cout << "Skipping to next test.\n";
00422 failedTests++;
00423 }
00424 else
00425 cout << "." << flush;
00426
00427 totalTests++;
00428 }
00429 }
00430 }
00431
00432 bool RunTestDataFile(const char *filename)
00433 {
00434 RegisterFactories();
00435 unsigned int totalTests = 0, failedTests = 0;
00436 TestDataFile(filename, totalTests, failedTests);
00437 cout << "\nTests complete. Total tests = " << totalTests << ". Failed tests = " << failedTests << ".\n";
00438 if (failedTests != 0)
00439 cout << "SOME TESTS FAILED!\n";
00440 return failedTests == 0;
00441 }