Main Page | Namespace List | Class Hierarchy | Alphabetical List | Class List | File List | Namespace Members | Class Members | File Members

datatest.cpp

00001 #include "factory.h"
00002 #include "integer.h"
00003 #include "filters.h"
00004 #include "hex.h"
00005 #include "randpool.h"
00006 #include "files.h"
00007 #include "trunhash.h"
00008 #include <iostream>
00009 #include <memory>
00010 
00011 USING_NAMESPACE(CryptoPP)
00012 USING_NAMESPACE(std)
00013 
00014 RandomPool & GlobalRNG();
00015 void RegisterFactories();
00016 
00017 typedef std::map<std::string, std::string> TestData;
00018 
00019 class TestFailure : public Exception
00020 {
00021 public:
00022         TestFailure() : Exception(OTHER_ERROR, "Validation test failed") {}
00023 };
00024 
00025 static const TestData *s_currentTestData = NULL;
00026 
00027 void OutputTestData(const TestData &v)
00028 {
00029         for (TestData::const_iterator i = v.begin(); i != v.end(); ++i)
00030         {
00031                 cerr << i->first << ": " << i->second << endl;
00032         }
00033 }
00034 
00035 void SignalTestFailure()
00036 {
00037         OutputTestData(*s_currentTestData);
00038         throw TestFailure();
00039 }
00040 
00041 void SignalTestError()
00042 {
00043         OutputTestData(*s_currentTestData);
00044         throw Exception(Exception::OTHER_ERROR, "Unexpected error during validation test");
00045 }
00046 
00047 class TestDataNameValuePairs : public NameValuePairs
00048 {
00049 public:
00050         TestDataNameValuePairs(const TestData &data) : m_data(data) {}
00051 
00052         virtual bool GetVoidValue(const char *name, const std::type_info &valueType, void *pValue) const
00053         {
00054                 TestData::const_iterator i = m_data.find(name);
00055                 if (i == m_data.end())
00056                         return false;
00057                 
00058                 const std::string &value = i->second;
00059                 
00060                 if (valueType == typeid(int))
00061                         *reinterpret_cast<int *>(pValue) = atoi(value.c_str());
00062                 else if (valueType == typeid(Integer))
00063                         *reinterpret_cast<Integer *>(pValue) = Integer((std::string(value) + "h").c_str());
00064                 else
00065                         throw ValueTypeMismatch(name, typeid(std::string), valueType);
00066 
00067                 return true;
00068         }
00069 
00070 private:
00071         const TestData &m_data;
00072 };
00073 
00074 const std::string & GetRequiredDatum(const TestData &data, const char *name)
00075 {
00076         TestData::const_iterator i = data.find(name);
00077         if (i == data.end())
00078                 SignalTestError();
00079         return i->second;
00080 }
00081 
00082 void PutDecodedDatumInto(const TestData &data, const char *name, BufferedTransformation &target)
00083 {
00084         std::string s1 = GetRequiredDatum(data, name), s2;
00085 
00086         int repeat = 1;
00087         if (s1[0] == 'r')
00088         {
00089                 repeat = atoi(s1.c_str()+1);
00090                 s1 = s1.substr(s1.find(' ')+1);
00091         }
00092         
00093         if (s1[0] == '\"')
00094                 s2 = s1.substr(1, s1.find('\"', 1)-1);
00095         else if (s1.substr(0, 2) == "0x")
00096                 StringSource(s1.substr(2), true, new HexDecoder(new StringSink(s2)));
00097         else
00098                 StringSource(s1, true, new HexDecoder(new StringSink(s2)));
00099 
00100         while (repeat--)
00101                 target.Put((const byte *)s2.data(), s2.size());
00102 }
00103 
00104 std::string GetDecodedDatum(const TestData &data, const char *name)
00105 {
00106         std::string s;
00107         PutDecodedDatumInto(data, name, StringSink(s).Ref());
00108         return s;
00109 }
00110 
00111 void TestKeyPairValidAndConsistent(CryptoMaterial &pub, const CryptoMaterial &priv)
00112 {
00113         if (!pub.Validate(GlobalRNG(), 3))
00114                 SignalTestFailure();
00115         if (!priv.Validate(GlobalRNG(), 3))
00116                 SignalTestFailure();
00117 
00118 /*      EqualityComparisonFilter comparison;
00119         pub.Save(ChannelSwitch(comparison, "0"));
00120         pub.AssignFrom(priv);
00121         pub.Save(ChannelSwitch(comparison, "1"));
00122         comparison.ChannelMessageSeriesEnd("0");
00123         comparison.ChannelMessageSeriesEnd("1");
00124 */
00125 }
00126 
00127 void TestSignatureScheme(TestData &v)
00128 {
00129         std::string name = GetRequiredDatum(v, "Name");
00130         std::string test = GetRequiredDatum(v, "Test");
00131 
00132         std::auto_ptr<PK_Signer> signer(ObjectFactoryRegistry<PK_Signer>::Registry().CreateObject(name.c_str()));
00133         std::auto_ptr<PK_Verifier> verifier(ObjectFactoryRegistry<PK_Verifier>::Registry().CreateObject(name.c_str()));
00134 
00135         TestDataNameValuePairs pairs(v);
00136         std::string keyFormat = GetRequiredDatum(v, "KeyFormat");
00137 
00138         if (keyFormat == "DER")
00139                 verifier->AccessMaterial().Load(StringStore(GetDecodedDatum(v, "PublicKey")).Ref());
00140         else if (keyFormat == "Component")
00141                 verifier->AccessMaterial().AssignFrom(pairs);
00142 
00143         if (test == "Verify" || test == "NotVerify")
00144         {
00145                 VerifierFilter verifierFilter(*verifier, NULL, VerifierFilter::SIGNATURE_AT_BEGIN);
00146                 PutDecodedDatumInto(v, "Signature", verifierFilter);
00147                 PutDecodedDatumInto(v, "Message", verifierFilter);
00148                 verifierFilter.MessageEnd();
00149                 if (verifierFilter.GetLastResult() == (test == "NotVerify"))
00150                         SignalTestFailure();
00151         }
00152         else if (test == "PublicKeyValid")
00153         {
00154                 if (!verifier->GetMaterial().Validate(GlobalRNG(), 3))
00155                         SignalTestFailure();
00156         }
00157         else
00158                 goto privateKeyTests;
00159 
00160         return;
00161 
00162 privateKeyTests:
00163         if (keyFormat == "DER")
00164                 signer->AccessMaterial().Load(StringStore(GetDecodedDatum(v, "PrivateKey")).Ref());
00165         else if (keyFormat == "Component")
00166                 signer->AccessMaterial().AssignFrom(pairs);
00167         
00168         if (test == "KeyPairValidAndConsistent")
00169         {
00170                 TestKeyPairValidAndConsistent(verifier->AccessMaterial(), signer->GetMaterial());
00171         }
00172         else if (test == "Sign")
00173         {
00174                 SignerFilter f(GlobalRNG(), *signer, new HexEncoder(new FileSink(cout)));
00175                 StringSource ss(GetDecodedDatum(v, "Message"), true, new Redirector(f));
00176                 SignalTestFailure();
00177         }
00178         else if (test == "DeterministicSign")
00179         {
00180                 SignalTestError();
00181                 assert(false);  // TODO: implement
00182         }
00183         else if (test == "RandomSign")
00184         {
00185                 SignalTestError();
00186                 assert(false);  // TODO: implement
00187         }
00188         else if (test == "GenerateKey")
00189         {
00190                 SignalTestError();
00191                 assert(false);
00192         }
00193         else
00194         {
00195                 SignalTestError();
00196                 assert(false);
00197         }
00198 }
00199 
00200 void TestEncryptionScheme(TestData &v)
00201 {
00202         std::string name = GetRequiredDatum(v, "Name");
00203         std::string test = GetRequiredDatum(v, "Test");
00204 
00205         std::auto_ptr<PK_Encryptor> encryptor(ObjectFactoryRegistry<PK_Encryptor>::Registry().CreateObject(name.c_str()));
00206         std::auto_ptr<PK_Decryptor> decryptor(ObjectFactoryRegistry<PK_Decryptor>::Registry().CreateObject(name.c_str()));
00207 
00208         std::string keyFormat = GetRequiredDatum(v, "KeyFormat");
00209 
00210         if (keyFormat == "DER")
00211         {
00212                 decryptor->AccessMaterial().Load(StringStore(GetDecodedDatum(v, "PrivateKey")).Ref());
00213                 encryptor->AccessMaterial().Load(StringStore(GetDecodedDatum(v, "PublicKey")).Ref());
00214         }
00215         else if (keyFormat == "Component")
00216         {
00217                 TestDataNameValuePairs pairs(v);
00218                 decryptor->AccessMaterial().AssignFrom(pairs);
00219                 encryptor->AccessMaterial().AssignFrom(pairs);
00220         }
00221 
00222         if (test == "DecryptMatch")
00223         {
00224                 std::string decrypted, expected = GetDecodedDatum(v, "Plaintext");
00225                 StringSource ss(GetDecodedDatum(v, "Ciphertext"), true, new PK_DecryptorFilter(GlobalRNG(), *decryptor, new StringSink(decrypted)));
00226                 if (decrypted != expected)
00227                         SignalTestFailure();
00228         }
00229         else if (test == "KeyPairValidAndConsistent")
00230         {
00231                 TestKeyPairValidAndConsistent(encryptor->AccessMaterial(), decryptor->GetMaterial());
00232         }
00233         else
00234         {
00235                 SignalTestError();
00236                 assert(false);
00237         }
00238 }
00239 
00240 void TestDigestOrMAC(TestData &v, bool testDigest)
00241 {
00242         std::string name = GetRequiredDatum(v, "Name");
00243         std::string test = GetRequiredDatum(v, "Test");
00244 
00245         member_ptr<MessageAuthenticationCode> mac;
00246         member_ptr<HashTransformation> hash;
00247         HashTransformation *pHash = NULL;
00248 
00249         if (testDigest)
00250         {
00251                 hash.reset(ObjectFactoryRegistry<HashTransformation>::Registry().CreateObject(name.c_str()));
00252                 pHash = hash.get();
00253         }
00254         else
00255         {
00256                 mac.reset(ObjectFactoryRegistry<MessageAuthenticationCode>::Registry().CreateObject(name.c_str()));
00257                 pHash = mac.get();
00258                 std::string key = GetDecodedDatum(v, "Key");
00259                 mac->SetKey((const byte *)key.c_str(), key.size());
00260         }
00261 
00262         if (test == "Verify" || test == "VerifyTruncated" || test == "NotVerify")
00263         {
00264                 int digestSize = pHash->DigestSize();
00265                 if (test == "VerifyTruncated")
00266                         digestSize = atoi(GetRequiredDatum(v, "TruncatedSize").c_str());
00267                 TruncatedHashModule thash(*pHash, digestSize);
00268                 HashVerificationFilter verifierFilter(thash, NULL, HashVerificationFilter::HASH_AT_BEGIN);
00269                 PutDecodedDatumInto(v, "Digest", verifierFilter);
00270                 PutDecodedDatumInto(v, "Message", verifierFilter);
00271                 verifierFilter.MessageEnd();
00272                 if (verifierFilter.GetLastResult() == (test == "NotVerify"))
00273                         SignalTestFailure();
00274         }
00275         else
00276         {
00277                 SignalTestError();
00278                 assert(false);
00279         }
00280 }
00281 
00282 bool GetField(std::istream &is, std::string &name, std::string &value)
00283 {
00284         name.resize(0);         // GCC workaround: 2.95.3 doesn't have clear()
00285         is >> name;
00286         if (name.empty())
00287                 return false;
00288 
00289         if (name[name.size()-1] != ':')
00290                 SignalTestError();
00291         name.erase(name.size()-1);
00292 
00293         while (is.peek() == ' ')
00294                 is.ignore(1);
00295 
00296         // VC60 workaround: getline bug
00297         char buffer[128];
00298         value.resize(0);        // GCC workaround: 2.95.3 doesn't have clear()
00299         bool continueLine;
00300 
00301         do
00302         {
00303                 do
00304                 {
00305                         is.get(buffer, sizeof(buffer));
00306                         value += buffer;
00307                 }
00308                 while (buffer[0] != 0);
00309                 is.clear();
00310                 is.ignore();
00311                 
00312                 if (value[value.size()-1] == '\\')
00313                 {
00314                         value.resize(value.size()-1);
00315                         continueLine = true;
00316                 }
00317                 else
00318                         continueLine = false;
00319 
00320                 std::string::size_type i = value.find('#');
00321                 if (i != std::string::npos)
00322                         value.erase(i);
00323         }
00324         while (continueLine);
00325 
00326         return true;
00327 }
00328 
00329 void OutputPair(const NameValuePairs &v, const char *name)
00330 {
00331         Integer x;
00332         bool b = v.GetValue(name, x);
00333         assert(b);
00334         cout << name << ": \\\n    ";
00335         x.Encode(HexEncoder(new FileSink(cout), false, 64, "\\\n    ").Ref(), x.MinEncodedSize());
00336         cout << endl;
00337 }
00338 
00339 void OutputNameValuePairs(const NameValuePairs &v)
00340 {
00341         std::string names = v.GetValueNames();
00342         string::size_type i = 0;
00343         while (i < names.size())
00344         {
00345                 string::size_type j = names.find_first_of (';', i);
00346 
00347                 if (j == string::npos)
00348                         return;
00349                 else
00350                 {
00351                         std::string name = names.substr(i, j-i);
00352                         if (name.find(':') == string::npos)
00353                                 OutputPair(v, name.c_str());
00354                 }
00355 
00356                 i = j + 1;
00357         }
00358 }
00359 
00360 void TestDataFile(const std::string &filename, unsigned int &totalTests, unsigned int &failedTests)
00361 {
00362         std::ifstream file(filename.c_str());
00363         TestData v;
00364         s_currentTestData = &v;
00365         std::string name, value, lastAlgName;
00366 
00367         while (file)
00368         {
00369                 while (file.peek() == '#')
00370                         file.ignore(INT_MAX, '\n');
00371 
00372                 if (file.peek() == '\n')
00373                         v.clear();
00374 
00375                 if (!GetField(file, name, value))
00376                         break;
00377                 v[name] = value;
00378 
00379                 if (name == "Test")
00380                 {
00381                         bool failed = true;
00382                         std::string algType = GetRequiredDatum(v, "AlgorithmType");
00383 
00384                         if (lastAlgName != GetRequiredDatum(v, "Name"))
00385                         {
00386                                 lastAlgName = GetRequiredDatum(v, "Name");
00387                                 cout << "\nTesting " << algType.c_str() << " algorithm " << lastAlgName.c_str() << ".\n";
00388                         }
00389 
00390                         try
00391                         {
00392                                 if (algType == "Signature")
00393                                         TestSignatureScheme(v);
00394                                 else if (algType == "AsymmetricCipher")
00395                                         TestEncryptionScheme(v);
00396                                 else if (algType == "MessageDigest")
00397                                         TestDigestOrMAC(v, true);
00398                                 else if (algType == "MAC")
00399                                         TestDigestOrMAC(v, false);
00400                                 else if (algType == "FileList")
00401                                         TestDataFile(GetRequiredDatum(v, "Test"), totalTests, failedTests);
00402                                 else
00403                                         SignalTestError();
00404                                 failed = false;
00405                         }
00406                         catch (TestFailure &)
00407                         {
00408                                 cout << "\nTest failed.\n";
00409                         }
00410                         catch (CryptoPP::Exception &e)
00411                         {
00412                                 cout << "\nCryptoPP::Exception caught: " << e.what() << endl;
00413                         }
00414                         catch (std::exception &e)
00415                         {
00416                                 cout << "\nstd::exception caught: " << e.what() << endl;
00417                         }
00418 
00419                         if (failed)
00420                         {
00421                                 cout << "Skipping to next test.\n";
00422                                 failedTests++;
00423                         }
00424                         else
00425                                 cout << "." << flush;
00426 
00427                         totalTests++;
00428                 }
00429         }
00430 }
00431 
00432 bool RunTestDataFile(const char *filename)
00433 {
00434         RegisterFactories();
00435         unsigned int totalTests = 0, failedTests = 0;
00436         TestDataFile(filename, totalTests, failedTests);
00437         cout << "\nTests complete. Total tests = " << totalTests << ". Failed tests = " << failedTests << ".\n";
00438         if (failedTests != 0)
00439                 cout << "SOME TESTS FAILED!\n";
00440         return failedTests == 0;
00441 }

Generated on Mon Apr 19 18:12:29 2004 for Crypto++ by doxygen 1.3.6-20040222