]> git.saurik.com Git - apple/security.git/blob - OSX/libsecurity_ssl/regressions/SecureTransportTests/STLegacyTests+sessionstate.m
Security-59754.80.3.tar.gz
[apple/security.git] / OSX / libsecurity_ssl / regressions / SecureTransportTests / STLegacyTests+sessionstate.m
1 #include <stdbool.h>
2 #include <pthread.h>
3 #include <fcntl.h>
4 #include <sys/mman.h>
5 #include <unistd.h>
6
7 #include <CoreFoundation/CoreFoundation.h>
8
9 #include <AssertMacros.h>
10 #include <Security/SecureTransportPriv.h> /* SSLSetOption */
11 #include <Security/SecureTransport.h>
12 #include <Security/SecPolicy.h>
13 #include <Security/SecTrust.h>
14 #include <Security/SecIdentity.h>
15 #include <Security/SecIdentityPriv.h>
16 #include <Security/SecCertificatePriv.h>
17 #include <Security/SecKeyPriv.h>
18 #include <Security/SecItem.h>
19 #include <Security/SecRandom.h>
20
21 #include <string.h>
22 #include <sys/types.h>
23 #include <sys/socket.h>
24 #include <errno.h>
25 #include <stdlib.h>
26 #include <mach/mach_time.h>
27
28 #if TARGET_OS_IPHONE
29 #include <Security/SecRSAKey.h>
30 #endif
31
32 #include "ssl-utils.h"
33
34 #include <tls_stream_parser.h>
35 #include <tls_handshake.h>
36 #include <tls_record.h>
37
38 #include <sys/queue.h>
39
40 #import "STLegacyTests.h"
41
42 #pragma clang diagnostic push
43 #pragma clang diagnostic ignored "-Wdeprecated-declarations"
44
45 @implementation STLegacyTests (sessionstate)
46
47 #define test_printf(x...)
48
49 #include <corecrypto/ccrng.h>
50 #define CCRNGSTATE ccrng(NULL)
51
52 struct RecQueueItem {
53 STAILQ_ENTRY(RecQueueItem) next; /* link to next queued entry or NULL */
54 tls_buffer record;
55 size_t offset; /* byte reads from this one */
56 };
57
58 typedef struct {
59 SSLContextRef st;
60 tls_stream_parser_t parser;
61 tls_record_t record;
62 tls_handshake_t hdsk;
63 STAILQ_HEAD(, RecQueueItem) rec_queue; // coretls server queue packet in this queue
64 int ready_count;
65 } ssl_test_handle;
66
67
68 static
69 int tls_buffer_alloc(tls_buffer *buf, size_t length)
70 {
71 buf->data = malloc(length);
72 if(!buf->data) return -ENOMEM;
73 buf->length = length;
74 return 0;
75 }
76
77 static
78 int tls_buffer_free(tls_buffer *buf)
79 {
80 free(buf->data);
81 buf->data = NULL;
82 buf->length = 0;
83 return 0;
84 }
85
86 #pragma mark -
87 #pragma mark SecureTransport support
88
89 #if SECTRANS_VERBOSE_DEBUG
90 static void hexdump(const char *s, const uint8_t *bytes, size_t len) {
91 size_t ix;
92 printf("socket %s(%p, %lu)\n", s, bytes, len);
93 for (ix = 0; ix < len; ++ix) {
94 if (!(ix % 16))
95 printf("\n");
96 printf("%02X ", bytes[ix]);
97 }
98 printf("\n");
99 }
100 #else
101 #define hexdump(string, bytes, len)
102 #endif
103
104 static OSStatus SocketWrite(SSLConnectionRef h, const void *data, size_t *length)
105 {
106 ssl_test_handle *handle =(ssl_test_handle *)h;
107
108 size_t len = *length;
109 uint8_t *ptr = (uint8_t *)data;
110
111 tls_buffer buffer;
112 buffer.data = ptr;
113 buffer.length = len;
114 return tls_stream_parser_parse(handle->parser, buffer);
115 }
116
117 static OSStatus SocketRead(SSLConnectionRef h, void *data, size_t *length)
118 {
119 ssl_test_handle *handle =(ssl_test_handle *)h;
120
121 test_printf("%s: %p requesting len=%zd\n", __FUNCTION__, h, *length);
122
123 struct RecQueueItem *item = STAILQ_FIRST(&handle->rec_queue);
124
125 if(item == NULL) {
126 test_printf("%s: %p no data available\n", __FUNCTION__, h);
127 return errSSLWouldBlock;
128 }
129
130 size_t avail = item->record.length - item->offset;
131
132 test_printf("%s: %p %zd bytes available in %p\n", __FUNCTION__, h, avail, item);
133
134 if (avail > *length) {
135 memcpy(data, item->record.data+item->offset, *length);
136 item->offset += *length;
137 } else {
138 memcpy(data, item->record.data+item->offset, avail);
139 *length = avail;
140 STAILQ_REMOVE_HEAD(&handle->rec_queue, next);
141 tls_buffer_free(&item->record);
142 free(item);
143 }
144
145 test_printf("%s: %p %zd bytes read\n", __FUNCTION__, h, *length);
146
147
148 return 0;
149 }
150
151 static int process(tls_stream_parser_ctx_t ctx, tls_buffer record)
152 {
153 ssl_test_handle *h = (ssl_test_handle *)ctx;
154 tls_buffer decrypted;
155 uint8_t ct;
156 int err;
157
158 test_printf("%s: %p processing %zd bytes\n", __FUNCTION__, ctx, record.length);
159
160
161 decrypted.length = tls_record_decrypted_size(h->record, record.length);
162 decrypted.data = malloc(decrypted.length);
163
164 require_action(decrypted.data, errOut, err=-ENOMEM);
165 require_noerr((err=tls_record_decrypt(h->record, record, &decrypted, &ct)), errOut);
166
167 test_printf("%s: %p decrypted %zd bytes, ct=%d\n", __FUNCTION__, ctx, decrypted.length, ct);
168
169 err=tls_handshake_process(h->hdsk, decrypted, ct);
170
171 test_printf("%s: %p processed, err=%d\n", __FUNCTION__, ctx, err);
172
173 errOut:
174 free(decrypted.data);
175 return err;
176 }
177
178 static int
179 tls_handshake_write_callback(tls_handshake_ctx_t ctx, const tls_buffer data, uint8_t content_type)
180 {
181 int err = 0;
182 ssl_test_handle *handle = (ssl_test_handle *)ctx;
183
184 test_printf("%s: %p writing data ct=%d, len=%zd\n", __FUNCTION__, ctx, content_type, data.length);
185
186 struct RecQueueItem *item = malloc(sizeof(struct RecQueueItem));
187 require_action(item, errOut, err=-ENOMEM);
188
189 err=tls_buffer_alloc(&item->record, tls_record_encrypted_size(handle->record, content_type, data.length));
190 require_noerr(err, errOut);
191
192 err=tls_record_encrypt(handle->record, data, content_type, &item->record);
193 require_noerr(err, errOut);
194
195 item->offset = 0;
196
197 test_printf("%s: %p queing %zd encrypted bytes, item=%p\n", __FUNCTION__, ctx, item->record.length, item);
198
199 STAILQ_INSERT_TAIL(&handle->rec_queue, item, next);
200
201 return 0;
202
203 errOut:
204 if(item) {
205 tls_buffer_free(&item->record);
206 free(item);
207 }
208 return err;
209 }
210
211
212 static int
213 tls_handshake_message_callback(tls_handshake_ctx_t ctx, tls_handshake_message_t event)
214 {
215 ssl_test_handle __unused *handle = (ssl_test_handle *)ctx;
216
217 test_printf("%s: %p event = %d\n", __FUNCTION__, handle, event);
218
219 int err = 0;
220
221 return err;
222 }
223
224
225
226 static uint8_t appdata[] = "appdata";
227
228 tls_buffer appdata_buffer = {
229 .data = appdata,
230 .length = sizeof(appdata),
231 };
232
233
234 static void
235 tls_handshake_ready_callback(tls_handshake_ctx_t ctx, bool write, bool ready)
236 {
237 ssl_test_handle *handle = (ssl_test_handle *)ctx;
238
239 test_printf("%s: %p %s ready=%d\n", __FUNCTION__, handle, write?"write":"read", ready);
240
241 if(ready) {
242 if(write) {
243 if(handle->ready_count == 0) {
244 tls_handshake_request_renegotiation(handle->hdsk);
245 } else {
246 tls_handshake_write_callback(ctx, appdata_buffer, tls_record_type_AppData);
247 }
248 handle->ready_count++;;
249 }
250 }
251 }
252
253 static int
254 tls_handshake_set_retransmit_timer_callback(tls_handshake_ctx_t ctx, int attempt)
255 {
256 ssl_test_handle __unused *handle = (ssl_test_handle *)ctx;
257
258 test_printf("%s: %p attempt = %d\n", __FUNCTION__, handle, attempt);
259
260 return -1;
261 }
262
263 static
264 int mySSLRecordInitPendingCiphersFunc(tls_handshake_ctx_t ref,
265 uint16_t selectedCipher,
266 bool server,
267 tls_buffer key)
268 {
269 ssl_test_handle *handle = (ssl_test_handle *)ref;
270
271 test_printf("%s: %p, cipher=%04x, server=%d\n", __FUNCTION__, ref, selectedCipher, server);
272 return tls_record_init_pending_ciphers(handle->record, selectedCipher, server, key);
273 }
274
275 static
276 int mySSLRecordAdvanceWriteCipherFunc(tls_handshake_ctx_t ref)
277 {
278 ssl_test_handle *handle = (ssl_test_handle *)ref;
279 test_printf("%s: %p\n", __FUNCTION__, ref);
280 return tls_record_advance_write_cipher(handle->record);
281 }
282
283 static
284 int mySSLRecordRollbackWriteCipherFunc(tls_handshake_ctx_t ref)
285 {
286 ssl_test_handle *handle = (ssl_test_handle *)ref;
287 test_printf("%s: %p\n", __FUNCTION__, ref);
288 return tls_record_rollback_write_cipher(handle->record);
289 }
290
291 static
292 int mySSLRecordAdvanceReadCipherFunc(tls_handshake_ctx_t ref)
293 {
294 ssl_test_handle *handle = (ssl_test_handle *)ref;
295 test_printf("%s: %p\n", __FUNCTION__, ref);
296 return tls_record_advance_read_cipher(handle->record);
297 }
298
299 static
300 int mySSLRecordSetProtocolVersionFunc(tls_handshake_ctx_t ref,
301 tls_protocol_version protocolVersion)
302 {
303 ssl_test_handle *handle = (ssl_test_handle *)ref;
304 test_printf("%s: %p, version=%04x\n", __FUNCTION__, ref, protocolVersion);
305 return tls_record_set_protocol_version(handle->record, protocolVersion);
306 }
307
308
309 /* TLS callbacks */
310 tls_handshake_callbacks_t myTLS_handshake_callbacks = {
311 .write = tls_handshake_write_callback,
312 .message = tls_handshake_message_callback,
313 .ready = tls_handshake_ready_callback,
314 .set_retransmit_timer = tls_handshake_set_retransmit_timer_callback,
315 .init_pending_cipher = mySSLRecordInitPendingCiphersFunc,
316 .advance_write_cipher = mySSLRecordAdvanceWriteCipherFunc,
317 .rollback_write_cipher = mySSLRecordRollbackWriteCipherFunc,
318 .advance_read_cipher = mySSLRecordAdvanceReadCipherFunc,
319 .set_protocol_version = mySSLRecordSetProtocolVersionFunc,
320 };
321
322
323 static void
324 ssl_test_handle_destroy(ssl_test_handle *handle)
325 {
326 if (handle) {
327 if(handle->parser) tls_stream_parser_destroy(handle->parser);
328 if(handle->record) tls_record_destroy(handle->record);
329 if(handle->hdsk) tls_handshake_destroy(handle->hdsk);
330 if(handle->st) CFRelease(handle->st);
331 free(handle);
332 }
333 }
334
335 static uint16_t ciphers[] = {
336 TLS_PSK_WITH_AES_128_CBC_SHA,
337 };
338 static int nciphers = sizeof(ciphers)/sizeof(ciphers[0]);
339
340 static SSLCipherSuite ciphersuites[] = {
341 TLS_PSK_WITH_AES_128_CBC_SHA,
342 };
343 static int nciphersuites = sizeof(ciphersuites)/sizeof(ciphersuites[0]);
344
345
346
347 static uint8_t shared_secret[] = "secret";
348
349 tls_buffer psk_secret = {
350 .data = shared_secret,
351 .length = sizeof(shared_secret),
352 };
353
354 static ssl_test_handle *
355 ssl_test_handle_create(bool server)
356 {
357 ssl_test_handle *handle = calloc(1, sizeof(ssl_test_handle));
358 SSLContextRef ctx = SSLCreateContext(kCFAllocatorDefault, server?kSSLServerSide:kSSLClientSide, kSSLStreamType);
359
360 require(handle, out);
361 require(ctx, out);
362
363 require_noerr(SSLSetIOFuncs(ctx, (SSLReadFunc)SocketRead, (SSLWriteFunc)SocketWrite), out);
364 require_noerr(SSLSetConnection(ctx, (SSLConnectionRef)handle), out);
365 require_noerr(SSLSetSessionOption(ctx, kSSLSessionOptionBreakOnServerAuth, true), out);
366 require_noerr(SSLSetEnabledCiphers(ctx, ciphersuites, nciphersuites), out);
367 require_noerr(SSLSetPSKSharedSecret(ctx, shared_secret, sizeof(shared_secret)), out);
368
369 handle->st = ctx;
370 handle->parser = tls_stream_parser_create(handle, process);
371 handle->record = tls_record_create(false, CCRNGSTATE);
372 handle->hdsk = tls_handshake_create(false, true); // server.
373
374 require_noerr(tls_handshake_set_ciphersuites(handle->hdsk, ciphers, nciphers), out);
375 require_noerr(tls_handshake_set_callbacks(handle->hdsk, &myTLS_handshake_callbacks, handle), out);
376 require_noerr(tls_handshake_set_psk_secret(handle->hdsk, &psk_secret), out);
377 require_noerr(tls_handshake_set_renegotiation(handle->hdsk, true), out);
378
379 // Initialize the record queue
380 STAILQ_INIT(&handle->rec_queue);
381
382 return handle;
383
384 out:
385 if (handle) free(handle);
386 if (ctx) CFRelease(ctx);
387 return NULL;
388 }
389
390 -(void)testSessionState
391 {
392 OSStatus ortn;
393
394 ssl_test_handle *client;
395 SSLSessionState state;
396 Boolean option;
397
398 client = ssl_test_handle_create(false);
399
400 require_action(client, out, ortn = -1);
401
402 ortn = SSLGetSessionState(client->st, &state);
403 require_noerr(ortn, out);
404 XCTAssertEqual(state, kSSLIdle, "State should be Idle");
405 ortn = SSLGetSessionOption(client->st, kSSLSessionOptionBreakOnServerAuth, &option);
406 require_noerr(ortn, out);
407 XCTAssertEqual(option, true, "Session should break on Server auth");
408
409 do {
410 ortn = SSLHandshake(client->st);
411 test_printf("SSLHandshake returned err=%d\n", (int)ortn);
412
413 require_noerr(SSLGetSessionState(client->st, &state), out);
414
415 if (ortn == errSSLPeerAuthCompleted || ortn == errSSLWouldBlock) {
416 require_action(state == kSSLHandshake, out, ortn = -1);
417 }
418
419 } while(ortn == errSSLWouldBlock ||
420 ortn == errSSLPeerAuthCompleted);
421
422
423 XCTAssertEqual(ortn, 0, "Unexpected SSLHandshake exit code");
424 XCTAssertEqual(state, kSSLConnected, "State should be Connected");
425
426 uint8_t buffer[128];
427 size_t available = 0;
428 size_t avail = 0;
429
430 test_printf("Initial handshake done\n");
431
432 do {
433 XCTAssertEqual(errSecSuccess, SSLGetBufferedReadSize(client->st, &avail));
434 ortn = SSLRead(client->st, buffer, sizeof(buffer), &available);
435 test_printf("SSLRead returned err=%d, avail=%zd\n", (int)ortn, available);
436 require_noerr(SSLGetSessionState(client->st, &state), out);
437
438 if (ortn == errSSLPeerAuthCompleted) {
439 require_action(state == kSSLHandshake, out, ortn = -1);
440 }
441 } while (available == 0);
442
443 XCTAssertEqual(ortn, 0, "Unexpected SSLRead exit code");
444 XCTAssertEqual(state, kSSLConnected, "State should be Connected");
445
446
447 out:
448 XCTAssertEqual(ortn, 0, "Final result is non zero");
449 ssl_test_handle_destroy(client);
450
451 }
452
453 @end
454
455 #pragma clang diagnostic pop