Main Page | Namespace List | Class Hierarchy | Alphabetical List | Class List | File List | Namespace Members | Class Members | File Members

zinflate.cpp

00001 // zinflate.cpp - written and placed in the public domain by Wei Dai
00002 
00003 // This is a complete reimplementation of the DEFLATE decompression algorithm.
00004 // It should not be affected by any security vulnerabilities in the zlib 
00005 // compression library. In particular it is not affected by the double free bug
00006 // (http://www.kb.cert.org/vuls/id/368819).
00007 
00008 #include "pch.h"
00009 #include "zinflate.h"
00010 
00011 NAMESPACE_BEGIN(CryptoPP)
00012 
00013 struct CodeLessThan
00014 {
00015         inline bool operator()(const CryptoPP::HuffmanDecoder::code_t lhs, const CryptoPP::HuffmanDecoder::CodeInfo &rhs)
00016                 {return lhs < rhs.code;}
00017 };
00018 
00019 inline bool LowFirstBitReader::FillBuffer(unsigned int length)
00020 {
00021         while (m_bitsBuffered < length)
00022         {
00023                 byte b;
00024                 if (!m_store.Get(b))
00025                         return false;
00026                 m_buffer |= (unsigned long)b << m_bitsBuffered;
00027                 m_bitsBuffered += 8;
00028         }
00029         assert(m_bitsBuffered <= sizeof(unsigned long)*8);
00030         return true;
00031 }
00032 
00033 inline unsigned long LowFirstBitReader::PeekBits(unsigned int length)
00034 {
00035         bool result = FillBuffer(length);
00036         assert(result);
00037         return m_buffer & (((unsigned long)1 << length) - 1);
00038 }
00039 
00040 inline void LowFirstBitReader::SkipBits(unsigned int length)
00041 {
00042         assert(m_bitsBuffered >= length);
00043         m_buffer >>= length;
00044         m_bitsBuffered -= length;
00045 }
00046 
00047 inline unsigned long LowFirstBitReader::GetBits(unsigned int length)
00048 {
00049         unsigned long result = PeekBits(length);
00050         SkipBits(length);
00051         return result;
00052 }
00053 
00054 inline HuffmanDecoder::code_t HuffmanDecoder::NormalizeCode(HuffmanDecoder::code_t code, unsigned int codeBits)
00055 {
00056         return code << (MAX_CODE_BITS - codeBits);
00057 }
00058 
00059 void HuffmanDecoder::Initialize(const unsigned int *codeBits, unsigned int nCodes)
00060 {
00061         // the Huffman codes are represented in 3 ways in this code:
00062         //
00063         // 1. most significant code bit (i.e. top of code tree) in the least significant bit position
00064         // 2. most significant code bit (i.e. top of code tree) in the most significant bit position
00065         // 3. most significant code bit (i.e. top of code tree) in n-th least significant bit position,
00066         //    where n is the maximum code length for this code tree
00067         //
00068         // (1) is the way the codes come in from the deflate stream
00069         // (2) is used to sort codes so they can be binary searched
00070         // (3) is used in this function to compute codes from code lengths
00071         //
00072         // a code in representation (2) is called "normalized" here
00073         // The BitReverse() function is used to convert between (1) and (2)
00074         // The NormalizeCode() function is used to convert from (3) to (2)
00075 
00076         if (nCodes == 0)
00077                 throw Err("null code");
00078 
00079         m_maxCodeBits = *std::max_element(codeBits, codeBits+nCodes);
00080 
00081         if (m_maxCodeBits > MAX_CODE_BITS)
00082                 throw Err("code length exceeds maximum");
00083 
00084         if (m_maxCodeBits == 0)
00085                 throw Err("null code");
00086 
00087         // count number of codes of each length
00088         SecBlockWithHint<unsigned int, 15+1> blCount(m_maxCodeBits+1);
00089         std::fill(blCount.begin(), blCount.end(), 0);
00090         unsigned int i;
00091         for (i=0; i<nCodes; i++)
00092                 blCount[codeBits[i]]++;
00093 
00094         // compute the starting code of each length
00095         code_t code = 0;
00096         SecBlockWithHint<code_t, 15+1> nextCode(m_maxCodeBits+1);
00097         nextCode[1] = 0;
00098         for (i=2; i<=m_maxCodeBits; i++)
00099         {
00100                 // compute this while checking for overflow: code = (code + blCount[i-1]) << 1
00101                 if (code > code + blCount[i-1])
00102                         throw Err("codes oversubscribed");
00103                 code += blCount[i-1];
00104                 if (code > (code << 1))
00105                         throw Err("codes oversubscribed");
00106                 code <<= 1;
00107                 nextCode[i] = code;
00108         }
00109 
00110         if (code > (1 << m_maxCodeBits) - blCount[m_maxCodeBits])
00111                 throw Err("codes oversubscribed");
00112         else if (m_maxCodeBits != 1 && code < (1 << m_maxCodeBits) - blCount[m_maxCodeBits])
00113                 throw Err("codes incomplete");
00114 
00115         // compute a vector of <code, length, value> triples sorted by code
00116         m_codeToValue.resize(nCodes - blCount[0]);
00117         unsigned int j=0;
00118         for (i=0; i<nCodes; i++) 
00119         {
00120                 unsigned int len = codeBits[i];
00121                 if (len != 0)
00122                 {
00123                         code = NormalizeCode(nextCode[len]++, len);
00124                         m_codeToValue[j].code = code;
00125                         m_codeToValue[j].len = len;
00126                         m_codeToValue[j].value = i;
00127                         j++;
00128                 }
00129         }
00130         std::sort(m_codeToValue.begin(), m_codeToValue.end());
00131 
00132         // initialize the decoding cache
00133         m_cacheBits = STDMIN(9U, m_maxCodeBits);
00134         m_cacheMask = (1 << m_cacheBits) - 1;
00135         m_normalizedCacheMask = NormalizeCode(m_cacheMask, m_cacheBits);
00136         assert(m_normalizedCacheMask == BitReverse(m_cacheMask));
00137 
00138         if (m_cache.size() != 1 << m_cacheBits)
00139                 m_cache.resize(1 << m_cacheBits);
00140 
00141         for (i=0; i<m_cache.size(); i++)
00142                 m_cache[i].type = 0;
00143 }
00144 
00145 void HuffmanDecoder::FillCacheEntry(LookupEntry &entry, code_t normalizedCode) const
00146 {
00147         normalizedCode &= m_normalizedCacheMask;
00148         const CodeInfo &codeInfo = *(std::upper_bound(m_codeToValue.begin(), m_codeToValue.end(), normalizedCode, CodeLessThan())-1);
00149         if (codeInfo.len <= m_cacheBits)
00150         {
00151                 entry.type = 1;
00152                 entry.value = codeInfo.value;
00153                 entry.len = codeInfo.len;
00154         }
00155         else
00156         {
00157                 entry.begin = &codeInfo;
00158                 const CodeInfo *last = & *(std::upper_bound(m_codeToValue.begin(), m_codeToValue.end(), normalizedCode + ~m_normalizedCacheMask, CodeLessThan())-1);
00159                 if (codeInfo.len == last->len)
00160                 {
00161                         entry.type = 2;
00162                         entry.len = codeInfo.len;
00163                 }
00164                 else
00165                 {
00166                         entry.type = 3;
00167                         entry.end = last+1;
00168                 }
00169         }
00170 }
00171 
00172 inline unsigned int HuffmanDecoder::Decode(code_t code, /* out */ value_t &value) const
00173 {
00174         assert(m_codeToValue.size() > 0);
00175         LookupEntry &entry = m_cache[code & m_cacheMask];
00176 
00177         code_t normalizedCode;
00178         if (entry.type != 1)
00179                 normalizedCode = BitReverse(code);
00180 
00181         if (entry.type == 0)
00182                 FillCacheEntry(entry, normalizedCode);
00183 
00184         if (entry.type == 1)
00185         {
00186                 value = entry.value;
00187                 return entry.len;
00188         }
00189         else
00190         {
00191                 const CodeInfo &codeInfo = (entry.type == 2)
00192                         ? entry.begin[(normalizedCode << m_cacheBits) >> (MAX_CODE_BITS - (entry.len - m_cacheBits))]
00193                         : *(std::upper_bound(entry.begin, entry.end, normalizedCode, CodeLessThan())-1);
00194                 value = codeInfo.value;
00195                 return codeInfo.len;
00196         }
00197 }
00198 
00199 bool HuffmanDecoder::Decode(LowFirstBitReader &reader, value_t &value) const
00200 {
00201         reader.FillBuffer(m_maxCodeBits);
00202         unsigned int codeBits = Decode(reader.PeekBuffer(), value);
00203         if (codeBits > reader.BitsBuffered())
00204                 return false;
00205         reader.SkipBits(codeBits);
00206         return true;
00207 }
00208 
00209 // *************************************************************
00210 
00211 Inflator::Inflator(BufferedTransformation *attachment, bool repeat, int propagation)
00212         : AutoSignaling<Filter>(attachment, propagation)
00213         , m_state(PRE_STREAM), m_repeat(repeat)
00214         , m_decodersInitializedWithFixedCodes(false), m_reader(m_inQueue)
00215 {
00216 }
00217 
00218 void Inflator::IsolatedInitialize(const NameValuePairs &parameters)
00219 {
00220         m_state = PRE_STREAM;
00221         parameters.GetValue("Repeat", m_repeat);
00222         m_inQueue.Clear();
00223         m_reader.SkipBits(m_reader.BitsBuffered());
00224 }
00225 
00226 inline void Inflator::OutputByte(byte b)
00227 {
00228         m_window[m_current++] = b;
00229         if (m_current == m_window.size())
00230         {
00231                 ProcessDecompressedData(m_window + m_lastFlush, m_window.size() - m_lastFlush);
00232                 m_lastFlush = 0;
00233                 m_current = 0;
00234         }
00235         if (m_maxDistance < m_window.size())
00236                 m_maxDistance++;
00237 }
00238 
00239 void Inflator::OutputString(const byte *string, unsigned int length)
00240 {
00241         while (length--)
00242                 OutputByte(*string++);
00243 }
00244 
00245 void Inflator::OutputPast(unsigned int length, unsigned int distance)
00246 {
00247         if (distance > m_maxDistance)
00248                 throw BadBlockErr();
00249         unsigned int start;
00250         if (m_current > distance)
00251                 start = m_current - distance;
00252         else
00253                 start = m_current + m_window.size() - distance;
00254 
00255         if (start + length > m_window.size())
00256         {
00257                 for (; start < m_window.size(); start++, length--)
00258                         OutputByte(m_window[start]);
00259                 start = 0;
00260         }
00261 
00262         if (start + length > m_current || m_current + length >= m_window.size())
00263         {
00264                 while (length--)
00265                         OutputByte(m_window[start++]);
00266         }
00267         else
00268         {
00269                 memcpy(m_window + m_current, m_window + start, length);
00270                 m_current += length;
00271                 m_maxDistance = STDMIN((unsigned int)m_window.size(), m_maxDistance + length);
00272         }
00273 }
00274 
00275 unsigned int Inflator::Put2(const byte *inString, unsigned int length, int messageEnd, bool blocking)
00276 {
00277         if (!blocking)
00278                 throw BlockingInputOnly("Inflator");
00279 
00280         LazyPutter lp(m_inQueue, inString, length);
00281         ProcessInput(messageEnd != 0);
00282 
00283         if (messageEnd)
00284                 if (!(m_state == PRE_STREAM || m_state == AFTER_END))
00285                         throw UnexpectedEndErr();
00286 
00287         Output(0, NULL, 0, messageEnd, blocking);
00288         return 0;
00289 }
00290 
00291 bool Inflator::IsolatedFlush(bool hardFlush, bool blocking)
00292 {
00293         if (!blocking)
00294                 throw BlockingInputOnly("Inflator");
00295 
00296         if (hardFlush)
00297                 ProcessInput(true);
00298         FlushOutput();
00299 
00300         return false;
00301 }
00302 
00303 void Inflator::ProcessInput(bool flush)
00304 {
00305         while (true)
00306         {
00307                 if (m_inQueue.IsEmpty())
00308                         return;
00309 
00310                 switch (m_state)
00311                 {
00312                 case PRE_STREAM:
00313                         if (!flush && m_inQueue.CurrentSize() < MaxPrestreamHeaderSize())
00314                                 return;
00315                         ProcessPrestreamHeader();
00316                         m_state = WAIT_HEADER;
00317                         m_maxDistance = 0;
00318                         m_current = 0;
00319                         m_lastFlush = 0;
00320                         m_window.New(1 << GetLog2WindowSize());
00321                         break;
00322                 case WAIT_HEADER:
00323                         {
00324                         // maximum number of bytes before actual compressed data starts
00325                         const unsigned int MAX_HEADER_SIZE = BitsToBytes(3+5+5+4+19*7+286*15+19*15);
00326                         if (m_inQueue.CurrentSize() < (flush ? 1 : MAX_HEADER_SIZE))
00327                                 return;
00328                         DecodeHeader();
00329                         break;
00330                         }
00331                 case DECODING_BODY:
00332                         if (!DecodeBody())
00333                                 return;
00334                         break;
00335                 case POST_STREAM:
00336                         if (!flush && m_inQueue.CurrentSize() < MaxPoststreamTailSize())
00337                                 return;
00338                         ProcessPoststreamTail();
00339                         m_state = m_repeat ? PRE_STREAM : AFTER_END;
00340                         Output(0, NULL, 0, GetAutoSignalPropagation(), true);   // TODO: non-blocking
00341                         break;
00342                 case AFTER_END:
00343                         m_inQueue.TransferTo(*AttachedTransformation());
00344                         return;
00345                 }
00346         }
00347 }
00348 
00349 void Inflator::DecodeHeader()
00350 {
00351         if (!m_reader.FillBuffer(3))
00352                 throw UnexpectedEndErr();
00353         m_eof = m_reader.GetBits(1) != 0;
00354         m_blockType = (byte)m_reader.GetBits(2);
00355         switch (m_blockType)
00356         {
00357         case 0: // stored
00358                 {
00359                 m_reader.SkipBits(m_reader.BitsBuffered() % 8);
00360                 if (!m_reader.FillBuffer(32))
00361                         throw UnexpectedEndErr();
00362                 m_storedLen = (word16)m_reader.GetBits(16);
00363                 word16 nlen = (word16)m_reader.GetBits(16);
00364                 if (nlen != (word16)~m_storedLen)
00365                         throw BadBlockErr();
00366                 break;
00367                 }
00368         case 1: // fixed codes
00369                 if (!m_decodersInitializedWithFixedCodes)
00370                 {
00371                         unsigned int codeLengths[288];
00372                         std::fill(codeLengths + 0, codeLengths + 144, 8);
00373                         std::fill(codeLengths + 144, codeLengths + 256, 9);
00374                         std::fill(codeLengths + 256, codeLengths + 280, 7);
00375                         std::fill(codeLengths + 280, codeLengths + 288, 8);
00376                         m_literalDecoder.Initialize(codeLengths, 288);
00377                         std::fill(codeLengths + 0, codeLengths + 32, 5);
00378                         m_distanceDecoder.Initialize(codeLengths, 32);
00379                         m_decodersInitializedWithFixedCodes = true;
00380                 }
00381                 m_nextDecode = LITERAL;
00382                 break;
00383         case 2: // dynamic codes
00384                 {
00385                 m_decodersInitializedWithFixedCodes = false;
00386                 if (!m_reader.FillBuffer(5+5+4))
00387                         throw UnexpectedEndErr();
00388                 unsigned int hlit = m_reader.GetBits(5);
00389                 unsigned int hdist = m_reader.GetBits(5);
00390                 unsigned int hclen = m_reader.GetBits(4);
00391 
00392                 FixedSizeSecBlock<unsigned int, 286+32> codeLengths;
00393                 unsigned int i;
00394                 static const unsigned int border[] = {    // Order of the bit length code lengths
00395                         16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15};
00396                 std::fill(codeLengths.begin(), codeLengths+19, 0);
00397                 for (i=0; i<hclen+4; i++)
00398                         codeLengths[border[i]] = m_reader.GetBits(3);
00399 
00400                 try
00401                 {
00402                         HuffmanDecoder codeLengthDecoder(codeLengths, 19);
00403                         for (i = 0; i < hlit+257+hdist+1; )
00404                         {
00405                                 unsigned int k, count, repeater;
00406                                 bool result = codeLengthDecoder.Decode(m_reader, k);
00407                                 if (!result)
00408                                         throw UnexpectedEndErr();
00409                                 if (k <= 15)
00410                                 {
00411                                         count = 1;
00412                                         repeater = k;
00413                                 }
00414                                 else switch (k)
00415                                 {
00416                                 case 16:
00417                                         if (!m_reader.FillBuffer(2))
00418                                                 throw UnexpectedEndErr();
00419                                         count = 3 + m_reader.GetBits(2);
00420                                         if (i == 0)
00421                                                 throw BadBlockErr();
00422                                         repeater = codeLengths[i-1];
00423                                         break;
00424                                 case 17:
00425                                         if (!m_reader.FillBuffer(3))
00426                                                 throw UnexpectedEndErr();
00427                                         count = 3 + m_reader.GetBits(3);
00428                                         repeater = 0;
00429                                         break;
00430                                 case 18:
00431                                         if (!m_reader.FillBuffer(7))
00432                                                 throw UnexpectedEndErr();
00433                                         count = 11 + m_reader.GetBits(7);
00434                                         repeater = 0;
00435                                         break;
00436                                 }
00437                                 if (i + count > hlit+257+hdist+1)
00438                                         throw BadBlockErr();
00439                                 std::fill(codeLengths + i, codeLengths + i + count, repeater);
00440                                 i += count;
00441                         }
00442                         m_literalDecoder.Initialize(codeLengths, hlit+257);
00443                         if (hdist == 0 && codeLengths[hlit+257] == 0)
00444                         {
00445                                 if (hlit != 0)  // a single zero distance code length means all literals
00446                                         throw BadBlockErr();
00447                         }
00448                         else
00449                                 m_distanceDecoder.Initialize(codeLengths+hlit+257, hdist+1);
00450                         m_nextDecode = LITERAL;
00451                 }
00452                 catch (HuffmanDecoder::Err &)
00453                 {
00454                         throw BadBlockErr();
00455                 }
00456                 break;
00457                 }
00458         default:
00459                 throw BadBlockErr();    // reserved block type
00460         }
00461         m_state = DECODING_BODY;
00462 }
00463 
00464 bool Inflator::DecodeBody()
00465 {
00466         bool blockEnd = false;
00467         switch (m_blockType)
00468         {
00469         case 0: // stored
00470                 assert(m_reader.BitsBuffered() == 0);
00471                 while (!m_inQueue.IsEmpty() && !blockEnd)
00472                 {
00473                         unsigned int size;
00474                         const byte *block = m_inQueue.Spy(size);
00475                         size = STDMIN(size, (unsigned int)m_storedLen);
00476                         OutputString(block, size);
00477                         m_inQueue.Skip(size);
00478                         m_storedLen -= size;
00479                         if (m_storedLen == 0)
00480                                 blockEnd = true;
00481                 }
00482                 break;
00483         case 1: // fixed codes
00484         case 2: // dynamic codes
00485                 static const unsigned int lengthStarts[] = {
00486                         3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 17, 19, 23, 27, 31,
00487                         35, 43, 51, 59, 67, 83, 99, 115, 131, 163, 195, 227, 258};
00488                 static const unsigned int lengthExtraBits[] = {
00489                         0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2,
00490                         3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 0};
00491                 static const unsigned int distanceStarts[] = {
00492                         1, 2, 3, 4, 5, 7, 9, 13, 17, 25, 33, 49, 65, 97, 129, 193,
00493                         257, 385, 513, 769, 1025, 1537, 2049, 3073, 4097, 6145,
00494                         8193, 12289, 16385, 24577};
00495                 static const unsigned int distanceExtraBits[] = {
00496                         0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6,
00497                         7, 7, 8, 8, 9, 9, 10, 10, 11, 11,
00498                         12, 12, 13, 13};
00499 
00500                 switch (m_nextDecode)
00501                 {
00502                         while (true)
00503                         {
00504                 case LITERAL:
00505                                 if (!m_literalDecoder.Decode(m_reader, m_literal))
00506                                 {
00507                                         m_nextDecode = LITERAL;
00508                                         break;
00509                                 }
00510                                 if (m_literal < 256)
00511                                         OutputByte((byte)m_literal);
00512                                 else if (m_literal == 256)      // end of block
00513                                 {
00514                                         blockEnd = true;
00515                                         break;
00516                                 }
00517                                 else
00518                                 {
00519                                         if (m_literal > 285)
00520                                                 throw BadBlockErr();
00521                                         unsigned int bits;
00522                 case LENGTH_BITS:
00523                                         bits = lengthExtraBits[m_literal-257];
00524                                         if (!m_reader.FillBuffer(bits))
00525                                         {
00526                                                 m_nextDecode = LENGTH_BITS;
00527                                                 break;
00528                                         }
00529                                         m_literal = m_reader.GetBits(bits) + lengthStarts[m_literal-257];
00530                 case DISTANCE:
00531                                         if (!m_distanceDecoder.Decode(m_reader, m_distance))
00532                                         {
00533                                                 m_nextDecode = DISTANCE;
00534                                                 break;
00535                                         }
00536                 case DISTANCE_BITS:
00537                                         bits = distanceExtraBits[m_distance];
00538                                         if (!m_reader.FillBuffer(bits))
00539                                         {
00540                                                 m_nextDecode = DISTANCE_BITS;
00541                                                 break;
00542                                         }
00543                                         m_distance = m_reader.GetBits(bits) + distanceStarts[m_distance];
00544                                         OutputPast(m_literal, m_distance);
00545                                 }
00546                         }
00547                 }
00548         }
00549         if (blockEnd)
00550         {
00551                 if (m_eof)
00552                 {
00553                         FlushOutput();
00554                         m_reader.SkipBits(m_reader.BitsBuffered()%8);
00555                         if (m_reader.BitsBuffered())
00556                         {
00557                                 // undo too much lookahead
00558                                 SecBlockWithHint<byte, 4> buffer(m_reader.BitsBuffered() / 8);
00559                                 for (unsigned int i=0; i<buffer.size(); i++)
00560                                         buffer[i] = (byte)m_reader.GetBits(8);
00561                                 m_inQueue.Unget(buffer, buffer.size());
00562                         }
00563                         m_state = POST_STREAM;
00564                 }
00565                 else
00566                         m_state = WAIT_HEADER;
00567         }
00568         return blockEnd;
00569 }
00570 
00571 void Inflator::FlushOutput()
00572 {
00573         if (m_state != PRE_STREAM)
00574         {
00575                 assert(m_current >= m_lastFlush);
00576                 ProcessDecompressedData(m_window + m_lastFlush, m_current - m_lastFlush);
00577                 m_lastFlush = m_current;
00578         }
00579 }
00580 
00581 NAMESPACE_END

Generated on Mon Apr 19 18:12:34 2004 for Crypto++ by doxygen 1.3.6-20040222