]> git.saurik.com Git - apple/security.git/blob - tlsnke/tlsnketest/dtls_client.c
Security-55471.14.8.tar.gz
[apple/security.git] / tlsnke / tlsnketest / dtls_client.c
1 //
2 // dtls_client.c
3 // tlsnke
4 //
5 // Created by Fabrice Gautier on 2/7/12.
6 // Copyright (c) 2012 Apple, Inc. All rights reserved.
7 //
8
9 /*
10 * dtlsEchoClient.c
11 * Security
12 *
13 * Created by Fabrice Gautier on 1/31/11.
14 * Copyright 2011 Apple, Inc. All rights reserved.
15 *
16 */
17
18 #include <Security/Security.h>
19
20 #include "ssl-utils.h"
21
22 #include <stdlib.h>
23 #include <sys/types.h>
24 #include <sys/socket.h>
25 #include <netinet/in.h>
26 #include <arpa/inet.h>
27 #include <stdio.h>
28 #include <errno.h>
29 #include <unistd.h> /* close() */
30 #include <string.h> /* memset() */
31 #include <fcntl.h>
32 #include <time.h>
33
34 #include "tlssocket.h"
35
36 #define SERVER "10.0.2.1"
37 #define PORT 23232
38 #define BUFLEN 128
39 #define COUNT 10
40
41 #if 0
42 static void dumppacket(const unsigned char *data, unsigned long len)
43 {
44 unsigned long i;
45 for(i=0;i<len;i++)
46 {
47 if((i&0xf)==0) printf("%04lx :",i);
48 printf(" %02x", data[i]);
49 if((i&0xf)==0xf) printf("\n");
50 }
51 printf("\n");
52 }
53 #endif
54
55
56 /* print a '.' every few seconds to keep UI alive while connecting */
57 static time_t lastTime = (time_t)0;
58 #define TIME_INTERVAL 3
59
60 static void sslOutputDot()
61 {
62 time_t thisTime = time(0);
63
64 if((thisTime - lastTime) >= TIME_INTERVAL) {
65 printf("."); fflush(stdout);
66 lastTime = thisTime;
67 }
68 }
69
70 static void printSslErrStr(
71 const char *op,
72 OSStatus err)
73 {
74 printf("*** %s: %ld\n", op, (long)err);
75 }
76
77 /* 2K should be enough for everybody */
78 #define MTU 2048
79
80
81 int dtls_client(const char *hostname, int bypass);
82
83 int dtls_client(const char *hostname, int bypass)
84 {
85 int fd;
86 int tlsfd;
87 struct sockaddr_in sa;
88
89 printf("Running dtls_client test with hostname=%s, bypass=%d\n", hostname, bypass);
90
91 if ((fd=socket(AF_INET, SOCK_DGRAM, 0))==-1) {
92 perror("socket");
93 exit(-1);
94 }
95
96 memset((char *) &sa, 0, sizeof(sa));
97 sa.sin_family = AF_INET;
98 sa.sin_port = htons(PORT);
99 if (inet_aton(hostname, &sa.sin_addr)==0) {
100 fprintf(stderr, "inet_aton() failed\n");
101 exit(1);
102 }
103
104 if(connect(fd, (struct sockaddr *)&sa, sizeof(sa))==-1)
105 {
106 perror("connect");
107 return errno;
108 }
109
110 /* Change to non blocking io */
111 fcntl(fd, F_SETFL, O_NONBLOCK);
112
113 SSLRecordContextRef c=(intptr_t)fd;
114
115
116 OSStatus ortn;
117 SSLContextRef ctx = NULL;
118
119 SSLClientCertificateState certState;
120 SSLCipherSuite negCipher;
121 SSLProtocol negVersion;
122
123 /*
124 * Set up a SecureTransport session.
125 */
126
127 ctx = SSLCreateContextWithRecordFuncs(kCFAllocatorDefault, kSSLClientSide, kSSLDatagramType, &TLSSocket_Funcs);
128 if(!ctx) {
129 printSslErrStr("SSLCreateContextWithRecordFuncs", -1);
130 return -1;
131 }
132
133 printf("Attaching filter\n");
134 ortn = TLSSocket_Attach(fd);
135 if(ortn) {
136 printSslErrStr("TLSSocket_Attach", ortn);
137 return ortn;
138 }
139
140 if(bypass) {
141 tlsfd = open("/dev/tlsnke", O_RDWR);
142 if(tlsfd<0) {
143 perror("opening tlsnke dev");
144 exit(-1);
145 }
146 }
147
148 ortn = SSLSetRecordContext(ctx, c);
149 if(ortn) {
150 printSslErrStr("SSLSetRecordContext", ortn);
151 return ortn;
152 }
153
154 ortn = SSLSetMaxDatagramRecordSize(ctx, 600);
155 if(ortn) {
156 printSslErrStr("SSLSetMaxDatagramRecordSize", ortn);
157 return ortn;
158 }
159
160 /* Lets not verify the cert, which is a random test cert */
161 ortn = SSLSetEnableCertVerify(ctx, false);
162 if(ortn) {
163 printSslErrStr("SSLSetEnableCertVerify", ortn);
164 return ortn;
165 }
166
167 ortn = SSLSetCertificate(ctx, server_chain());
168 if(ortn) {
169 printSslErrStr("SSLSetCertificate", ortn);
170 return ortn;
171 }
172
173 printf("Handshake...\n");
174
175 do {
176 ortn = SSLHandshake(ctx);
177 if(ortn == errSSLWouldBlock) {
178 /* keep UI responsive */
179 sslOutputDot();
180 }
181 } while (ortn == errSSLWouldBlock);
182
183
184 SSLGetClientCertificateState(ctx, &certState);
185 SSLGetNegotiatedCipher(ctx, &negCipher);
186 SSLGetNegotiatedProtocolVersion(ctx, &negVersion);
187
188 int count;
189 size_t len;
190 ssize_t sreadLen, swriteLen;
191 size_t readLen, writeLen;
192
193 char buffer[BUFLEN];
194
195 count = 0;
196 while(count<COUNT) {
197 int timeout = 10000;
198
199 snprintf(buffer, BUFLEN, "Message %d", count);
200 len = strlen(buffer);
201
202 if(bypass) {
203 /* Send data through the side channel, kind of like utun would */
204 swriteLen=write(tlsfd, buffer, len);
205 if(swriteLen<0) {
206 perror("write to tlsfd");
207 break;
208 }
209 writeLen=swriteLen;
210 } else {
211 ortn=SSLWrite(ctx, buffer, len, &writeLen);
212 if(ortn) {
213 printSslErrStr("SSLWrite", ortn);
214 break;
215 }
216 }
217
218 printf("Wrote %lu bytes\n", writeLen);
219
220 count++;
221
222 if(bypass) {
223 do {
224 sreadLen=read(tlsfd, buffer, BUFLEN);
225 } while((sreadLen==-1) && (errno==EAGAIN) && (timeout--));
226 if((sreadLen==-1) && (errno==EAGAIN)) {
227 printf("Read timeout...\n");
228 continue;
229 }
230 if(sreadLen<0) {
231 perror("read from tlsfd");
232 break;
233 }
234 readLen=sreadLen;
235 }
236 else {
237 do {
238 ortn=SSLRead(ctx, buffer, BUFLEN, &readLen);
239 } while((ortn==errSSLWouldBlock) && (timeout--));
240 if(ortn==errSSLWouldBlock) {
241 printf("SSLRead timeout...\n");
242 continue;
243 }
244 if(ortn) {
245 printSslErrStr("SSLRead", ortn);
246 break;
247 }
248 }
249
250 buffer[readLen]=0;
251 printf("Received %lu bytes: %s\n", readLen, buffer);
252
253 }
254
255 SSLClose(ctx);
256
257 SSLDisposeContext(ctx);
258
259 return ortn;
260 }
261