X-Git-Url: https://git.saurik.com/apple/security.git/blobdiff_plain/b1ab9ed8d0e0f1c3b66d7daa8fd5564444c56195..420ff9d9379a8d93f2c90f026a797bdea1eb4517:/libsecurity_transform/lib/SecTransformReadTransform.cpp?ds=inline diff --git a/libsecurity_transform/lib/SecTransformReadTransform.cpp b/libsecurity_transform/lib/SecTransformReadTransform.cpp index 902bfb1d..5d998cf3 100644 --- a/libsecurity_transform/lib/SecTransformReadTransform.cpp +++ b/libsecurity_transform/lib/SecTransformReadTransform.cpp @@ -41,15 +41,42 @@ static SecTransformInstanceBlock StreamTransformImplementation(CFStringRef name, } CFArrayRef array = (CFArrayRef) value; - CFReadStreamRef input = (CFReadStreamRef) CFArrayGetValueAtIndex(array, 0); - - // open the stream - if (!CFReadStreamOpen(input)) + CFTypeRef item = (CFTypeRef) CFArrayGetValueAtIndex(array, 0); + + // Ensure that indeed we do have a CFReadStreamRef + if (NULL == item || CFReadStreamGetTypeID() != CFGetTypeID(item)) { - // We didn't open properly. Error out - return (CFTypeRef) CreateSecTransformErrorRef(kSecTransformErrorInvalidInput, "An error occurred while opening the stream."); + return (CFTypeRef) CreateSecTransformErrorRef(kSecTransformErrorInvalidInput, "The input attribute item was nil or not a read stream"); } + // This now is a safe cast + CFReadStreamRef input = (CFReadStreamRef)item; + + // Get the state of the stream + CFStreamStatus streamStatus = CFReadStreamGetStatus(input); + switch (streamStatus) + { + case kCFStreamStatusNotOpen: + { + if (!CFReadStreamOpen(input)) + { + // We didn't open properly. Error out + return (CFTypeRef) CreateSecTransformErrorRef(kSecTransformErrorInvalidInput, "An error occurred while opening the stream."); + } + } + break; + + case kCFStreamStatusError: + { + return (CFTypeRef) CreateSecTransformErrorRef(kSecTransformErrorInvalidInput, "The read stream is in an error state"); + } + break; + + default: + // The assumption is that the stream is ready to go as is. + break; + } + // allocate the read buffer on the heap u_int8_t* buffer = (u_int8_t*) malloc(blockSize);