]> git.saurik.com Git - apple/security.git/blobdiff - SecureTransport/sslHandshakeHello.cpp
Security-163.tar.gz
[apple/security.git] / SecureTransport / sslHandshakeHello.cpp
index 74a7681cd77fa3eaa00daeed4a24905b3475cd6b..9264e42ab65ac6521813271892a178ac15c9b0c9 100644 (file)
@@ -64,6 +64,10 @@ SSLEncodeServerHello(SSLRecord &serverHello, SSLContext *ctx)
        }       
        #endif  /* SSL_IE_NULL_RESUME_BUG */
                
+       /* this was set to a known quantity in SSLProcessClientHello */
+       assert(ctx->negProtocolVersion != SSL_Version_Undetermined);
+       /* should not be here in this case */
+       assert(ctx->negProtocolVersion != SSL_Version_2_0);
        sslLogNegotiateDebug("===SSL3 server: sending version %d_%d",
                ctx->negProtocolVersion >> 8, ctx->negProtocolVersion & 0xff);
        sslLogNegotiateDebug("...sessionIDLen = %d", sessionIDLen);
@@ -112,7 +116,7 @@ SSLEncodeServerHello(SSLRecord &serverHello, SSLContext *ctx)
 OSStatus
 SSLProcessServerHello(SSLBuffer message, SSLContext *ctx)
 {   OSStatus            err;
-    SSLProtocolVersion  protocolVersion;
+    SSLProtocolVersion  protocolVersion, negVersion;
     unsigned int        sessionIDLen;
     UInt8               *p;
     
@@ -126,11 +130,13 @@ SSLProcessServerHello(SSLBuffer message, SSLContext *ctx)
     
     protocolVersion = (SSLProtocolVersion)SSLDecodeInt(p, 2);
     p += 2;
-    if (protocolVersion > ctx->maxProtocolVersion) {
-        return errSSLNegotiation;
+       /* FIXME this should probably send appropriate alerts */
+       err = sslVerifyProtVersion(ctx, protocolVersion, &negVersion);
+       if(err) {
+               return err;
        }
-    ctx->negProtocolVersion = protocolVersion;
-       switch(protocolVersion) {
+    ctx->negProtocolVersion = negVersion;
+       switch(negVersion) {
                case SSL_Version_3_0:
                        ctx->sslTslCalls = &Ssl3Callouts;
                        break;
@@ -141,7 +147,7 @@ SSLProcessServerHello(SSLBuffer message, SSLContext *ctx)
                        return errSSLNegotiation;
        }
     sslLogNegotiateDebug("===SSL3 client: negVersion is %d_%d",
-               (protocolVersion >> 8) & 0xff, protocolVersion & 0xff);
+               (negVersion >> 8) & 0xff, negVersion & 0xff);
     
     memcpy(ctx->serverRandom, p, 32);
     p += 32;
@@ -184,7 +190,7 @@ SSLEncodeClientHello(SSLRecord &clientHello, SSLContext *ctx)
     UInt16          sessionIDLen;
     
     assert(ctx->protocolSide == SSL_ClientSide);
-    
+       
     sessionIDLen = 0;
     if (ctx->resumableSession.data != 0)
     {   if ((err = SSLRetrieveSessionID(ctx->resumableSession,
@@ -196,7 +202,11 @@ SSLEncodeClientHello(SSLRecord &clientHello, SSLContext *ctx)
     
     length = 39 + 2*(ctx->numValidCipherSpecs) + sessionIDLen;
     
-    clientHello.protocolVersion = ctx->maxProtocolVersion;
+       err = sslGetMaxProtVersion(ctx, &clientHello.protocolVersion);
+       if(err) {
+               /* we don't have a protocol enabled */
+               return err;
+       }
     clientHello.contentType = SSL_RecordTypeHandshake;
     if ((err = SSLAllocBuffer(clientHello.contents, length + 4, ctx)) != 0)
         return err;
@@ -204,10 +214,10 @@ SSLEncodeClientHello(SSLRecord &clientHello, SSLContext *ctx)
     p = clientHello.contents.data;
     *p++ = SSL_HdskClientHello;
     p = SSLEncodeInt(p, length, 3);
-    p = SSLEncodeInt(p, ctx->maxProtocolVersion, 2);
+    p = SSLEncodeInt(p, clientHello.protocolVersion, 2);
        sslLogNegotiateDebug("===SSL3 client: proclaiming max protocol "
                "%d_%d capable ONLY",
-               ctx->maxProtocolVersion >> 8, ctx->maxProtocolVersion & 0xff);
+               clientHello.protocolVersion >> 8, clientHello.protocolVersion & 0xff);
    if ((err = SSLEncodeRandom(p, ctx)) != 0)
     {   SSLFreeBuffer(clientHello.contents, ctx);
         return err;
@@ -239,7 +249,7 @@ SSLEncodeClientHello(SSLRecord &clientHello, SSLContext *ctx)
 OSStatus
 SSLProcessClientHello(SSLBuffer message, SSLContext *ctx)
 {   OSStatus            err;
-    SSLProtocolVersion  clientVersion;
+    SSLProtocolVersion  negVersion;
     UInt16              cipherListLen, cipherCount, desiredSpec, cipherSpec;
     UInt8               sessionIDLen, compressionCount;
     UInt8               *charPtr;
@@ -250,12 +260,13 @@ SSLProcessClientHello(SSLBuffer message, SSLContext *ctx)
         return errSSLProtocol;
     }
     charPtr = message.data;
-    clientVersion = (SSLProtocolVersion)SSLDecodeInt(charPtr, 2);
+    ctx->clientReqProtocol = (SSLProtocolVersion)SSLDecodeInt(charPtr, 2);
     charPtr += 2;
-       if(clientVersion > ctx->maxProtocolVersion) {
-               clientVersion = ctx->maxProtocolVersion;
+       err = sslVerifyProtVersion(ctx, ctx->clientReqProtocol, &negVersion);
+       if(err) {
+               return err;
        }
-       switch(clientVersion) {
+       switch(negVersion) {
                case SSL_Version_3_0:
                        ctx->sslTslCalls = &Ssl3Callouts;
                        break;
@@ -265,9 +276,9 @@ SSLProcessClientHello(SSLBuffer message, SSLContext *ctx)
                default:
                        return errSSLNegotiation;
        }
-       ctx->negProtocolVersion = clientVersion;
+       ctx->negProtocolVersion = negVersion;
     sslLogNegotiateDebug("===SSL3 server: negVersion is %d_%d",
-               clientVersion >> 8, clientVersion & 0xff);
+               negVersion >> 8, negVersion & 0xff);
     
     memcpy(ctx->clientRandom, charPtr, SSL_CLIENT_SRVR_RAND_SIZE);
     charPtr += 32;
@@ -276,7 +287,8 @@ SSLProcessClientHello(SSLBuffer message, SSLContext *ctx)
        sslErrorLog("SSLProcessClientHello: msg len error 2\n");
         return errSSLProtocol;
     }
-    if (sessionIDLen > 0 && ctx->peerID.data != 0)
+       /* FIXME peerID is never set on server side.... */
+    if (sessionIDLen > 0 && ctx->peerID.data != 0) 
     {   /* Don't die on error; just treat it as an uncacheable session */
         err = SSLAllocBuffer(ctx->sessionID, sessionIDLen, ctx);
         if (err == 0)