Never overflow the output buffer in wxBase64Decode().
[wxWidgets.git] / src / common / base64.cpp
1 ///////////////////////////////////////////////////////////////////////////////
2 // Name: src/common/base64.cpp
3 // Purpose: implementation of BASE64 encoding/decoding functions
4 // Author: Charles Reimers, Vadim Zeitlin
5 // Created: 2007-06-18
6 // RCS-ID: $Id$
7 // Licence: wxWindows licence
8 ///////////////////////////////////////////////////////////////////////////////
9
10 #include "wx/wxprec.h"
11
12 #ifdef __BORLANDC__
13 #pragma hdrstop
14 #endif
15
16 #if wxUSE_BASE64
17
18 #include "wx/base64.h"
19
20 size_t
21 wxBase64Encode(char *dst, size_t dstLen, const void *src_, size_t srcLen)
22 {
23 wxCHECK_MSG( src_, wxCONV_FAILED, wxT("NULL input buffer") );
24
25 const unsigned char *src = static_cast<const unsigned char *>(src_);
26
27 static const char b64[] =
28 "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
29
30
31 size_t encLen = 0;
32
33 // encode blocks of 3 bytes into 4 base64 characters
34 for ( ; srcLen >= 3; srcLen -= 3, src += 3 )
35 {
36 encLen += 4;
37 if ( dst )
38 {
39 if ( encLen > dstLen )
40 return wxCONV_FAILED;
41
42 *dst++ = b64[src[0] >> 2];
43 *dst++ = b64[((src[0] & 0x03) << 4) | ((src[1] & 0xf0) >> 4)];
44 *dst++ = b64[((src[1] & 0x0f) << 2) | ((src[2] & 0xc0) >> 6)];
45 *dst++ = b64[src[2] & 0x3f];
46 }
47 }
48
49 // finish with the remaining characters
50 if ( srcLen )
51 {
52 encLen += 4;
53 if ( dst )
54 {
55 if ( encLen > dstLen )
56 return wxCONV_FAILED;
57
58 // we have definitely one and maybe two bytes remaining
59 unsigned char next = srcLen == 2 ? src[1] : 0;
60 *dst++ = b64[src[0] >> 2];
61 *dst++ = b64[((src[0] & 0x03) << 4) | ((next & 0xf0) >> 4)];
62 *dst++ = srcLen == 2 ? b64[((next & 0x0f) << 2)] : '=';
63 *dst = '=';
64 }
65 }
66
67 return encLen;
68 }
69
70 size_t
71 wxBase64Decode(void *dst_, size_t dstLen,
72 const char *src, size_t srcLen,
73 wxBase64DecodeMode mode,
74 size_t *posErr)
75 {
76 wxCHECK_MSG( src, wxCONV_FAILED, wxT("NULL input buffer") );
77
78 unsigned char *dst = static_cast<unsigned char *>(dst_);
79
80 size_t decLen = 0;
81
82 if ( srcLen == wxNO_LEN )
83 srcLen = strlen(src);
84
85 // this table contains the values, in base 64, of all valid characters and
86 // special values WSP or INV for white space and invalid characters
87 // respectively as well as a special PAD value for '='
88 enum
89 {
90 WSP = 200,
91 INV,
92 PAD
93 };
94
95 static const unsigned char decode[256] =
96 {
97 WSP,INV,INV,INV,INV,INV,INV,INV,INV,WSP,WSP,INV,WSP,WSP,INV,INV,
98 INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,
99 WSP,INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,076,INV,INV,INV,077,
100 064,065,066,067,070,071,072,073,074,075,INV,INV,INV,PAD,INV,INV,
101 INV,000,001,002,003,004,005,006,007,010,011,012,013,014,015,016,
102 017,020,021,022,023,024,025,026,027,030,031,INV,INV,INV,INV,INV,
103 INV,032,033,034,035,036,037,040,041,042,043,044,045,046,047,050,
104 051,052,053,054,055,056,057,060,061,062,063,INV,INV,INV,INV,INV,
105 INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,
106 INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,
107 INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,
108 INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,
109 INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,
110 INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,
111 INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,
112 INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,INV,
113 };
114
115 // we decode input by groups of 4 characters but things are complicated by
116 // the fact that there can be whitespace and other junk in it too so keep
117 // record of where exactly we're inside the current quartet in this var
118 int n = 0;
119 unsigned char in[4]; // current quartet
120 bool end = false; // set when we find padding
121 size_t padLen = 0; // length lost to padding
122 const char *p;
123 for ( p = src; srcLen; p++, srcLen-- )
124 {
125 const unsigned char c = decode[static_cast<unsigned char>(*p)];
126 switch ( c )
127 {
128 case WSP:
129 if ( mode == wxBase64DecodeMode_SkipWS )
130 continue;
131 // fall through
132
133 case INV:
134 if ( mode == wxBase64DecodeMode_Relaxed )
135 continue;
136
137 // force the loop to stop and an error to be returned
138 n = -1;
139 srcLen = 1;
140 break;
141
142 case PAD:
143 // set the flag telling us that we're past the end now
144 end = true;
145
146 // there can be either a single '=' at the end of a quartet or
147 // "==" in positions 2 and 3
148 if ( n == 3 )
149 {
150 padLen = 1;
151 in[n++] = '\0';
152 }
153 else if ( (n == 2) && (--srcLen && *++p == '=') )
154 {
155 padLen = 2;
156 in[n++] = '\0';
157 in[n++] = '\0';
158 }
159 else // invalid padding
160 {
161 // force the loop terminate with an error
162 n = -1;
163 srcLen = 1;
164 }
165 break;
166
167 default:
168 if ( end )
169 {
170 // nothing is allowed after the end so provoke error return
171 n = -1;
172 srcLen = 1;
173 break;
174 }
175
176 in[n++] = c;
177 }
178
179 if ( n == 4 )
180 {
181 // got entire block, decode
182 decLen += 3 - padLen;
183 if ( dst )
184 {
185 if ( decLen > dstLen )
186 return wxCONV_FAILED;
187
188 // undo the bit shifting done during encoding
189 *dst++ = in[0] << 2 | in[1] >> 4;
190
191 // be careful to not overwrite the output buffer with NUL pad
192 // bytes
193 if ( padLen != 2 )
194 {
195 *dst++ = in[1] << 4 | in[2] >> 2;
196 if ( !padLen )
197 *dst++ = in[2] << 6 | in[3];
198 }
199 }
200
201 n = 0;
202 }
203 }
204
205 if ( n )
206 {
207 if ( posErr )
208 {
209 // notice that the error was on a previous position as we did one
210 // extra "p++" in the loop line after it
211 *posErr = p - src - 1;
212 }
213
214 return wxCONV_FAILED;
215 }
216
217 return decLen;
218 }
219
220 wxMemoryBuffer wxBase64Decode(const char *src,
221 size_t srcLen,
222 wxBase64DecodeMode mode,
223 size_t *posErr)
224 {
225 wxMemoryBuffer buf;
226 wxCHECK_MSG( src, buf, wxT("NULL input buffer") );
227
228 if ( srcLen == wxNO_LEN )
229 srcLen = strlen(src);
230
231 size_t len = wxBase64DecodedSize(srcLen);
232 len = wxBase64Decode(buf.GetWriteBuf(len), len, src, srcLen, mode, posErr);
233 if ( len == wxCONV_FAILED )
234 len = 0;
235
236 buf.SetDataLen(len);
237
238 return buf;
239 }
240
241 #endif // wxUSE_BASE64