Signed trailers for signature v4 (#16484)

This commit is contained in:
Klaus Post 2023-05-05 19:53:12 -07:00 committed by GitHub
parent 2f44dac14f
commit 76913a9fd5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 919 additions and 282 deletions

View File

@ -28,6 +28,7 @@ import (
"strings"
"github.com/Azure/azure-storage-blob-go/azblob"
"github.com/minio/minio/internal/ioutil"
"google.golang.org/api/googleapi"
"github.com/minio/madmin-go/v2"
@ -199,6 +200,7 @@ const (
ErrInvalidTagDirective
ErrPolicyAlreadyAttached
ErrPolicyNotAttached
ErrExcessData
// Add new error codes here.
// SSE-S3/SSE-KMS related API errors
@ -527,6 +529,11 @@ var errorCodes = errorCodeMap{
Description: "Your proposed upload exceeds the maximum allowed object size.",
HTTPStatusCode: http.StatusBadRequest,
},
ErrExcessData: {
Code: "ExcessData",
Description: "More data provided than indicated content length",
HTTPStatusCode: http.StatusBadRequest,
},
ErrPolicyTooLarge: {
Code: "PolicyTooLarge",
Description: "Policy exceeds the maximum allowed document size.",
@ -2099,6 +2106,8 @@ func toAPIErrorCode(ctx context.Context, err error) (apiErr APIErrorCode) {
apiErr = ErrMalformedXML
case errInvalidMaxParts:
apiErr = ErrInvalidMaxParts
case ioutil.ErrOverread:
apiErr = ErrExcessData
}
// Compression errors

File diff suppressed because one or more lines are too long

View File

@ -88,6 +88,18 @@ func isRequestSignStreamingV4(r *http.Request) bool {
r.Method == http.MethodPut
}
// Verify if the request has AWS Streaming Signature Version '4'. This is only valid for 'PUT' operation.
func isRequestSignStreamingTrailerV4(r *http.Request) bool {
return r.Header.Get(xhttp.AmzContentSha256) == streamingContentSHA256Trailer &&
r.Method == http.MethodPut
}
// Verify if the request has AWS Streaming Signature Version '4', with unsigned content and trailer.
func isRequestUnsignedTrailerV4(r *http.Request) bool {
return r.Header.Get(xhttp.AmzContentSha256) == unsignedPayloadTrailer &&
r.Method == http.MethodPut && strings.Contains(r.Header.Get(xhttp.ContentEncoding), streamingContentEncoding)
}
// Authorization type.
//
//go:generate stringer -type=authType -trimprefix=authType $GOFILE
@ -105,10 +117,12 @@ const (
authTypeSignedV2
authTypeJWT
authTypeSTS
authTypeStreamingSignedTrailer
authTypeStreamingUnsignedTrailer
)
// Get request authentication type.
func getRequestAuthType(r *http.Request) authType {
func getRequestAuthType(r *http.Request) (at authType) {
if r.URL != nil {
var err error
r.Form, err = url.ParseQuery(r.URL.RawQuery)
@ -123,6 +137,10 @@ func getRequestAuthType(r *http.Request) authType {
return authTypePresignedV2
} else if isRequestSignStreamingV4(r) {
return authTypeStreamingSigned
} else if isRequestSignStreamingTrailerV4(r) {
return authTypeStreamingSignedTrailer
} else if isRequestUnsignedTrailerV4(r) {
return authTypeStreamingUnsignedTrailer
} else if isRequestSignatureV4(r) {
return authTypeSigned
} else if isRequestPresignedSignatureV4(r) {
@ -560,13 +578,15 @@ func isReqAuthenticated(ctx context.Context, r *http.Request, region string, sty
// List of all support S3 auth types.
var supportedS3AuthTypes = map[authType]struct{}{
authTypeAnonymous: {},
authTypePresigned: {},
authTypePresignedV2: {},
authTypeSigned: {},
authTypeSignedV2: {},
authTypePostPolicy: {},
authTypeStreamingSigned: {},
authTypeAnonymous: {},
authTypePresigned: {},
authTypePresignedV2: {},
authTypeSigned: {},
authTypeSignedV2: {},
authTypePostPolicy: {},
authTypeStreamingSigned: {},
authTypeStreamingSignedTrailer: {},
authTypeStreamingUnsignedTrailer: {},
}
// Validate if the authType is valid and supported.
@ -582,7 +602,8 @@ func setAuthHandler(h http.Handler) http.Handler {
tc, ok := r.Context().Value(mcontext.ContextTraceKey).(*mcontext.TraceCtxt)
aType := getRequestAuthType(r)
if aType == authTypeSigned || aType == authTypeSignedV2 || aType == authTypeStreamingSigned {
switch aType {
case authTypeSigned, authTypeSignedV2, authTypeStreamingSigned, authTypeStreamingSignedTrailer:
// Verify if date headers are set, if not reject the request
amzDate, errCode := parseAmzDateHeader(r)
if errCode != ErrNone {
@ -613,10 +634,16 @@ func setAuthHandler(h http.Handler) http.Handler {
atomic.AddUint64(&globalHTTPStats.rejectedRequestsTime, 1)
return
}
}
if isSupportedS3AuthType(aType) || aType == authTypeJWT || aType == authTypeSTS {
h.ServeHTTP(w, r)
return
case authTypeJWT, authTypeSTS:
h.ServeHTTP(w, r)
return
default:
if isSupportedS3AuthType(aType) {
h.ServeHTTP(w, r)
return
}
}
if ok {
@ -710,7 +737,7 @@ func isPutActionAllowed(ctx context.Context, atype authType, bucketName, objectN
return ErrSignatureVersionNotSupported
case authTypeSignedV2, authTypePresignedV2:
cred, owner, s3Err = getReqAccessKeyV2(r)
case authTypeStreamingSigned, authTypePresigned, authTypeSigned:
case authTypeStreamingSigned, authTypePresigned, authTypeSigned, authTypeStreamingSignedTrailer, authTypeStreamingUnsignedTrailer:
cred, owner, s3Err = getReqAccessKeyV4(r, region, serviceS3)
}
if s3Err != ErrNone {

View File

@ -18,11 +18,13 @@ func _() {
_ = x[authTypeSignedV2-7]
_ = x[authTypeJWT-8]
_ = x[authTypeSTS-9]
_ = x[authTypeStreamingSignedTrailer-10]
_ = x[authTypeStreamingUnsignedTrailer-11]
}
const _authType_name = "UnknownAnonymousPresignedPresignedV2PostPolicyStreamingSignedSignedSignedV2JWTSTS"
const _authType_name = "UnknownAnonymousPresignedPresignedV2PostPolicyStreamingSignedSignedSignedV2JWTSTSStreamingSignedTrailerStreamingUnsignedTrailer"
var _authType_index = [...]uint8{0, 7, 16, 25, 36, 46, 61, 67, 75, 78, 81}
var _authType_index = [...]uint8{0, 7, 16, 25, 36, 46, 61, 67, 75, 78, 81, 103, 127}
func (i authType) String() string {
if i < 0 || i >= authType(len(_authType_index)-1) {

View File

@ -591,7 +591,7 @@ func (er erasureObjects) PutObjectPart(ctx context.Context, bucket, object, uplo
return pi, InvalidArgument{
Bucket: bucket,
Object: fi.Name,
Err: fmt.Errorf("checksum missing, want %s, got %s", cs, r.ContentCRCType().String()),
Err: fmt.Errorf("checksum missing, want %q, got %q", cs, r.ContentCRCType().String()),
}
}
}
@ -707,6 +707,7 @@ func (er erasureObjects) PutObjectPart(ctx context.Context, bucket, object, uplo
Index: index,
Checksums: r.ContentCRC(),
}
fi.Parts = []ObjectPartInfo{partInfo}
partFI, err := fi.MarshalMsg(nil)
if err != nil {

View File

@ -29,6 +29,7 @@ import (
"github.com/dustin/go-humanize"
"github.com/minio/minio/internal/config/storageclass"
"github.com/minio/minio/internal/hash"
"github.com/minio/minio/internal/ioutil"
)
// Wrapper for calling NewMultipartUpload tests for both Erasure multiple disks and single node setup.
@ -277,7 +278,7 @@ func testObjectAPIPutObjectPart(obj ObjectLayer, instanceType string, t TestErrH
// Input with size less than the size of actual data inside the reader.
{
bucketName: bucket, objName: object, uploadID: uploadID, PartID: 1, inputReaderData: "abcd", inputMd5: "900150983cd24fb0d6963f7d28e17f73", intputDataSize: int64(len("abcd") - 1),
expectedError: hash.BadDigest{ExpectedMD5: "900150983cd24fb0d6963f7d28e17f73", CalculatedMD5: "900150983cd24fb0d6963f7d28e17f72"},
expectedError: ioutil.ErrOverread,
},
// Test case - 16-19.

View File

@ -29,6 +29,7 @@ import (
"github.com/dustin/go-humanize"
"github.com/minio/minio/internal/hash"
"github.com/minio/minio/internal/ioutil"
)
func md5Header(data []byte) map[string]string {
@ -123,7 +124,7 @@ func testObjectAPIPutObject(obj ObjectLayer, instanceType string, t TestErrHandl
9: {
bucketName: bucket, objName: object, inputData: []byte("abcd"),
inputMeta: map[string]string{"etag": "900150983cd24fb0d6963f7d28e17f73"}, intputDataSize: int64(len("abcd") - 1),
expectedError: hash.BadDigest{ExpectedMD5: "900150983cd24fb0d6963f7d28e17f73", CalculatedMD5: "900150983cd24fb0d6963f7d28e17f72"},
expectedError: ioutil.ErrOverread,
},
// Validating for success cases.
@ -162,9 +163,9 @@ func testObjectAPIPutObject(obj ObjectLayer, instanceType string, t TestErrHandl
},
// data with size different from the actual number of bytes available in the reader
26: {bucketName: bucket, objName: object, inputData: data, intputDataSize: int64(len(data) - 1), expectedMd5: getMD5Hash(data[:len(data)-1])},
26: {bucketName: bucket, objName: object, inputData: data, intputDataSize: int64(len(data) - 1), expectedMd5: getMD5Hash(data[:len(data)-1]), expectedError: ioutil.ErrOverread},
27: {bucketName: bucket, objName: object, inputData: nilBytes, intputDataSize: int64(len(nilBytes) + 1), expectedMd5: getMD5Hash(nilBytes), expectedError: IncompleteBody{Bucket: bucket, Object: object}},
28: {bucketName: bucket, objName: object, inputData: fiveMBBytes, expectedMd5: getMD5Hash(fiveMBBytes)},
28: {bucketName: bucket, objName: object, inputData: fiveMBBytes, expectedMd5: getMD5Hash(fiveMBBytes), expectedError: ioutil.ErrOverread},
// valid data with X-Amz-Meta- meta
29: {bucketName: bucket, objName: object, inputData: data, inputMeta: map[string]string{"X-Amz-Meta-AppID": "a42"}, intputDataSize: int64(len(data)), expectedMd5: getMD5Hash(data)},
@ -173,7 +174,7 @@ func testObjectAPIPutObject(obj ObjectLayer, instanceType string, t TestErrHandl
30: {bucketName: bucket, objName: "emptydir/", inputData: []byte{}, expectedMd5: getMD5Hash([]byte{})},
// Put an object inside the empty directory
31: {bucketName: bucket, objName: "emptydir/" + object, inputData: data, intputDataSize: int64(len(data)), expectedMd5: getMD5Hash(data)},
// Put the empty object with a trailing slash again (refer to Test case 31), this needs to succeed
// Put the empty object with a trailing slash again (refer to Test case 30), this needs to succeed
32: {bucketName: bucket, objName: "emptydir/", inputData: []byte{}, expectedMd5: getMD5Hash([]byte{})},
// With invalid crc32.
@ -187,23 +188,23 @@ func testObjectAPIPutObject(obj ObjectLayer, instanceType string, t TestErrHandl
in := mustGetPutObjReader(t, bytes.NewReader(testCase.inputData), testCase.intputDataSize, testCase.inputMeta["etag"], testCase.inputSHA256)
objInfo, actualErr := obj.PutObject(context.Background(), testCase.bucketName, testCase.objName, in, ObjectOptions{UserDefined: testCase.inputMeta})
if actualErr != nil && testCase.expectedError == nil {
t.Errorf("Test %d: %s: Expected to pass, but failed with: error %s.", i+1, instanceType, actualErr.Error())
t.Errorf("Test %d: %s: Expected to pass, but failed with: error %s.", i, instanceType, actualErr.Error())
continue
}
if actualErr == nil && testCase.expectedError != nil {
t.Errorf("Test %d: %s: Expected to fail with error \"%s\", but passed instead.", i+1, instanceType, testCase.expectedError.Error())
t.Errorf("Test %d: %s: Expected to fail with error \"%s\", but passed instead.", i, instanceType, testCase.expectedError.Error())
continue
}
// Failed as expected, but does it fail for the expected reason.
if actualErr != nil && actualErr != testCase.expectedError {
t.Errorf("Test %d: %s: Expected to fail with error \"%v\", but instead failed with error \"%v\" instead.", i+1, instanceType, testCase.expectedError, actualErr)
t.Errorf("Test %d: %s: Expected to fail with error \"%v\", but instead failed with error \"%v\" instead.", i, instanceType, testCase.expectedError, actualErr)
continue
}
// Test passes as expected, but the output values are verified for correctness here.
if actualErr == nil {
// Asserting whether the md5 output is correct.
if expectedMD5, ok := testCase.inputMeta["etag"]; ok && expectedMD5 != objInfo.ETag {
t.Errorf("Test %d: %s: Calculated Md5 different from the actual one %s.", i+1, instanceType, objInfo.ETag)
t.Errorf("Test %d: %s: Calculated Md5 different from the actual one %s.", i, instanceType, objInfo.ETag)
continue
}
}

View File

@ -1615,7 +1615,9 @@ func (api objectAPIHandlers) PutObjectHandler(w http.ResponseWriter, r *http.Req
// if Content-Length is unknown/missing, deny the request
size := r.ContentLength
rAuthType := getRequestAuthType(r)
if rAuthType == authTypeStreamingSigned {
switch rAuthType {
// Check signature types that must have content length
case authTypeStreamingSigned, authTypeStreamingSignedTrailer, authTypeStreamingUnsignedTrailer:
if sizeStr, ok := r.Header[xhttp.AmzDecodedContentLength]; ok {
if sizeStr[0] == "" {
writeErrorResponse(ctx, w, errorCodes.ToAPIErr(ErrMissingContentLength), r.URL)
@ -1669,9 +1671,16 @@ func (api objectAPIHandlers) PutObjectHandler(w http.ResponseWriter, r *http.Req
}
switch rAuthType {
case authTypeStreamingSigned:
case authTypeStreamingSigned, authTypeStreamingSignedTrailer:
// Initialize stream signature verifier.
reader, s3Err = newSignV4ChunkedReader(r)
reader, s3Err = newSignV4ChunkedReader(r, rAuthType == authTypeStreamingSignedTrailer)
if s3Err != ErrNone {
writeErrorResponse(ctx, w, errorCodes.ToAPIErr(s3Err), r.URL)
return
}
case authTypeStreamingUnsignedTrailer:
// Initialize stream chunked reader with optional trailers.
reader, s3Err = newUnsignedV4ChunkedReader(r, true)
if s3Err != ErrNone {
writeErrorResponse(ctx, w, errorCodes.ToAPIErr(s3Err), r.URL)
return
@ -1903,7 +1912,6 @@ func (api objectAPIHandlers) PutObjectHandler(w http.ResponseWriter, r *http.Req
}
setPutObjHeaders(w, objInfo, false)
writeSuccessResponseHeadersOnly(w)
// Notify object created event.
evt := eventArgs{
@ -1921,6 +1929,10 @@ func (api objectAPIHandlers) PutObjectHandler(w http.ResponseWriter, r *http.Req
sendEvent(evt)
}
// Do not send checksums in events to avoid leaks.
hash.TransferChecksumHeader(w, r)
writeSuccessResponseHeadersOnly(w)
// Remove the transitioned object whose object version is being overwritten.
if !globalTierConfigMgr.Empty() {
// Schedule object for immediate transition if eligible.
@ -1928,8 +1940,6 @@ func (api objectAPIHandlers) PutObjectHandler(w http.ResponseWriter, r *http.Req
enqueueTransitionImmediate(objInfo)
logger.LogIf(ctx, os.Sweep())
}
// Do not send checksums in events to avoid leaks.
hash.TransferChecksumHeader(w, r)
}
// PutObjectExtractHandler - PUT Object extract is an extended API
@ -1983,7 +1993,7 @@ func (api objectAPIHandlers) PutObjectExtractHandler(w http.ResponseWriter, r *h
// if Content-Length is unknown/missing, deny the request
size := r.ContentLength
rAuthType := getRequestAuthType(r)
if rAuthType == authTypeStreamingSigned {
if rAuthType == authTypeStreamingSigned || rAuthType == authTypeStreamingSignedTrailer {
if sizeStr, ok := r.Header[xhttp.AmzDecodedContentLength]; ok {
if sizeStr[0] == "" {
writeErrorResponse(ctx, w, errorCodes.ToAPIErr(ErrMissingContentLength), r.URL)
@ -2023,9 +2033,9 @@ func (api objectAPIHandlers) PutObjectExtractHandler(w http.ResponseWriter, r *h
}
switch rAuthType {
case authTypeStreamingSigned:
case authTypeStreamingSigned, authTypeStreamingSignedTrailer:
// Initialize stream signature verifier.
reader, s3Err = newSignV4ChunkedReader(r)
reader, s3Err = newSignV4ChunkedReader(r, rAuthType == authTypeStreamingSignedTrailer)
if s3Err != ErrNone {
writeErrorResponse(ctx, w, errorCodes.ToAPIErr(s3Err), r.URL)
return

View File

@ -1100,6 +1100,7 @@ func testAPIPutObjectStreamSigV4Handler(obj ObjectLayer, instanceType, bucketNam
},
// Test case - 7
// Chunk with malformed encoding.
// Causes signature mismatch.
{
bucketName: bucketName,
objectName: objectName,
@ -1107,7 +1108,7 @@ func testAPIPutObjectStreamSigV4Handler(obj ObjectLayer, instanceType, bucketNam
dataLen: 1024,
chunkSize: 1024,
expectedContent: []byte{},
expectedRespStatus: http.StatusBadRequest,
expectedRespStatus: http.StatusForbidden,
accessKey: credentials.AccessKey,
secretKey: credentials.SecretKey,
shouldPass: false,

View File

@ -590,7 +590,9 @@ func (api objectAPIHandlers) PutObjectPartHandler(w http.ResponseWriter, r *http
rAuthType := getRequestAuthType(r)
// For auth type streaming signature, we need to gather a different content length.
if rAuthType == authTypeStreamingSigned {
switch rAuthType {
// Check signature types that must have content length
case authTypeStreamingSigned, authTypeStreamingSignedTrailer, authTypeStreamingUnsignedTrailer:
if sizeStr, ok := r.Header[xhttp.AmzDecodedContentLength]; ok {
if sizeStr[0] == "" {
writeErrorResponse(ctx, w, errorCodes.ToAPIErr(ErrMissingContentLength), r.URL)
@ -603,6 +605,7 @@ func (api objectAPIHandlers) PutObjectPartHandler(w http.ResponseWriter, r *http
}
}
}
if size == -1 {
writeErrorResponse(ctx, w, errorCodes.ToAPIErr(ErrMissingContentLength), r.URL)
return
@ -641,9 +644,16 @@ func (api objectAPIHandlers) PutObjectPartHandler(w http.ResponseWriter, r *http
}
switch rAuthType {
case authTypeStreamingSigned:
case authTypeStreamingSigned, authTypeStreamingSignedTrailer:
// Initialize stream signature verifier.
reader, s3Error = newSignV4ChunkedReader(r)
reader, s3Error = newSignV4ChunkedReader(r, rAuthType == authTypeStreamingSignedTrailer)
if s3Error != ErrNone {
writeErrorResponse(ctx, w, errorCodes.ToAPIErr(s3Error), r.URL)
return
}
case authTypeStreamingUnsignedTrailer:
// Initialize stream signature verifier.
reader, s3Error = newUnsignedV4ChunkedReader(r, true)
if s3Error != ErrNone {
writeErrorResponse(ctx, w, errorCodes.ToAPIErr(s3Error), r.URL)
return
@ -689,7 +699,6 @@ func (api objectAPIHandlers) PutObjectPartHandler(w http.ResponseWriter, r *http
// Read compression metadata preserved in the init multipart for the decision.
_, isCompressed := mi.UserDefined[ReservedMetadataPrefix+"compression"]
var idxCb func() []byte
if isCompressed {
actualReader, err := hash.NewReader(reader, size, md5hex, sha256hex, actualSize)
@ -718,6 +727,7 @@ func (api objectAPIHandlers) PutObjectPartHandler(w http.ResponseWriter, r *http
writeErrorResponse(ctx, w, toAPIError(ctx, err), r.URL)
return
}
if err := hashReader.AddChecksum(r, size < 0); err != nil {
writeErrorResponse(ctx, w, errorCodes.ToAPIErr(ErrInvalidChecksum), r.URL)
return

View File

@ -1,4 +1,4 @@
// Copyright (c) 2015-2021 MinIO, Inc.
// Copyright (c) 2015-2023 MinIO, Inc.
//
// This file is part of MinIO Object Storage stack
//
@ -37,6 +37,10 @@ import (
// client did not calculate sha256 of the payload.
const unsignedPayload = "UNSIGNED-PAYLOAD"
// http Header "x-amz-content-sha256" == "STREAMING-UNSIGNED-PAYLOAD-TRAILER" indicates that the
// client did not calculate sha256 of the payload and there is a trailer.
const unsignedPayloadTrailer = "STREAMING-UNSIGNED-PAYLOAD-TRAILER"
// skipContentSha256Cksum returns true if caller needs to skip
// payload checksum, false if not.
func skipContentSha256Cksum(r *http.Request) bool {
@ -62,7 +66,7 @@ func skipContentSha256Cksum(r *http.Request) bool {
// If x-amz-content-sha256 is set and the value is not
// 'UNSIGNED-PAYLOAD' we should validate the content sha256.
switch v[0] {
case unsignedPayload:
case unsignedPayload, unsignedPayloadTrailer:
return true
case emptySHA256:
// some broken clients set empty-sha256
@ -70,12 +74,11 @@ func skipContentSha256Cksum(r *http.Request) bool {
// we should skip such clients and allow
// blindly such insecure clients only if
// S3 strict compatibility is disabled.
if r.ContentLength > 0 && !globalCLIContext.StrictS3Compat {
// We return true only in situations when
// deployment has asked MinIO to allow for
// such broken clients and content-length > 0.
return true
}
// We return true only in situations when
// deployment has asked MinIO to allow for
// such broken clients and content-length > 0.
return r.ContentLength > 0 && !globalCLIContext.StrictS3Compat
}
return false
}

View File

@ -24,9 +24,11 @@ import (
"bytes"
"encoding/hex"
"errors"
"fmt"
"hash"
"io"
"net/http"
"strings"
"time"
"github.com/dustin/go-humanize"
@ -37,24 +39,53 @@ import (
// Streaming AWS Signature Version '4' constants.
const (
emptySHA256 = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
streamingContentSHA256 = "STREAMING-AWS4-HMAC-SHA256-PAYLOAD"
signV4ChunkedAlgorithm = "AWS4-HMAC-SHA256-PAYLOAD"
streamingContentEncoding = "aws-chunked"
emptySHA256 = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
streamingContentSHA256 = "STREAMING-AWS4-HMAC-SHA256-PAYLOAD"
streamingContentSHA256Trailer = "STREAMING-AWS4-HMAC-SHA256-PAYLOAD-TRAILER"
signV4ChunkedAlgorithm = "AWS4-HMAC-SHA256-PAYLOAD"
signV4ChunkedAlgorithmTrailer = "AWS4-HMAC-SHA256-TRAILER"
streamingContentEncoding = "aws-chunked"
awsTrailerHeader = "X-Amz-Trailer"
trailerKVSeparator = ":"
)
// getChunkSignature - get chunk signature.
func getChunkSignature(cred auth.Credentials, seedSignature string, region string, date time.Time, hashedChunk string) string {
// Does not update anything in cr.
func (cr *s3ChunkedReader) getChunkSignature() string {
hashedChunk := hex.EncodeToString(cr.chunkSHA256Writer.Sum(nil))
// Calculate string to sign.
stringToSign := signV4ChunkedAlgorithm + "\n" +
date.Format(iso8601Format) + "\n" +
getScope(date, region) + "\n" +
seedSignature + "\n" +
alg := signV4ChunkedAlgorithm + "\n"
stringToSign := alg +
cr.seedDate.Format(iso8601Format) + "\n" +
getScope(cr.seedDate, cr.region) + "\n" +
cr.seedSignature + "\n" +
emptySHA256 + "\n" +
hashedChunk
// Get hmac signing key.
signingKey := getSigningKey(cred.SecretKey, date, region, serviceS3)
signingKey := getSigningKey(cr.cred.SecretKey, cr.seedDate, cr.region, serviceS3)
// Calculate signature.
newSignature := getSignature(signingKey, stringToSign)
return newSignature
}
// getTrailerChunkSignature - get trailer chunk signature.
func (cr *s3ChunkedReader) getTrailerChunkSignature() string {
hashedChunk := hex.EncodeToString(cr.chunkSHA256Writer.Sum(nil))
// Calculate string to sign.
alg := signV4ChunkedAlgorithmTrailer + "\n"
stringToSign := alg +
cr.seedDate.Format(iso8601Format) + "\n" +
getScope(cr.seedDate, cr.region) + "\n" +
cr.seedSignature + "\n" +
hashedChunk
// Get hmac signing key.
signingKey := getSigningKey(cr.cred.SecretKey, cr.seedDate, cr.region, serviceS3)
// Calculate signature.
newSignature := getSignature(signingKey, stringToSign)
@ -67,7 +98,7 @@ func getChunkSignature(cred auth.Credentials, seedSignature string, region strin
//
// returns signature, error otherwise if the signature mismatches or any other
// error while parsing and validating.
func calculateSeedSignature(r *http.Request) (cred auth.Credentials, signature string, region string, date time.Time, errCode APIErrorCode) {
func calculateSeedSignature(r *http.Request, trailers bool) (cred auth.Credentials, signature string, region string, date time.Time, errCode APIErrorCode) {
// Copy request.
req := *r
@ -82,6 +113,9 @@ func calculateSeedSignature(r *http.Request) (cred auth.Credentials, signature s
// Payload streaming.
payload := streamingContentSHA256
if trailers {
payload = streamingContentSHA256Trailer
}
// Payload for STREAMING signature should be 'STREAMING-AWS4-HMAC-SHA256-PAYLOAD'
if payload != req.Header.Get(xhttp.AmzContentSha256) {
@ -158,13 +192,24 @@ var errChunkTooBig = errors.New("chunk too big: choose chunk size <= 16MiB")
//
// NewChunkedReader is not needed by normal applications. The http package
// automatically decodes chunking when reading response bodies.
func newSignV4ChunkedReader(req *http.Request) (io.ReadCloser, APIErrorCode) {
cred, seedSignature, region, seedDate, errCode := calculateSeedSignature(req)
func newSignV4ChunkedReader(req *http.Request, trailer bool) (io.ReadCloser, APIErrorCode) {
cred, seedSignature, region, seedDate, errCode := calculateSeedSignature(req, trailer)
if errCode != ErrNone {
return nil, errCode
}
if trailer {
// Discard anything unsigned.
req.Trailer = make(http.Header)
trailers := req.Header.Values(awsTrailerHeader)
for _, key := range trailers {
req.Trailer.Add(key, "")
}
} else {
req.Trailer = nil
}
return &s3ChunkedReader{
trailers: req.Trailer,
reader: bufio.NewReader(req.Body),
cred: cred,
seedSignature: seedSignature,
@ -172,6 +217,7 @@ func newSignV4ChunkedReader(req *http.Request) (io.ReadCloser, APIErrorCode) {
region: region,
chunkSHA256Writer: sha256.New(),
buffer: make([]byte, 64*1024),
debug: false,
}, ErrNone
}
@ -183,11 +229,13 @@ type s3ChunkedReader struct {
seedSignature string
seedDate time.Time
region string
trailers http.Header
chunkSHA256Writer hash.Hash // Calculates sha256 of chunk data.
buffer []byte
offset int
err error
debug bool // Print details on failure. Add your own if more are needed.
}
func (cr *s3ChunkedReader) Close() (err error) {
@ -214,6 +262,19 @@ const maxChunkSize = 16 << 20 // 16 MiB
// Read - implements `io.Reader`, which transparently decodes
// the incoming AWS Signature V4 streaming signature.
func (cr *s3ChunkedReader) Read(buf []byte) (n int, err error) {
if cr.err != nil {
if cr.debug {
fmt.Printf("s3ChunkedReader: Returning err: %v (%T)\n", cr.err, cr.err)
}
return 0, cr.err
}
defer func() {
if err != nil && err != io.EOF {
if cr.debug {
fmt.Println("Read err:", err)
}
}
}()
// First, if there is any unread data, copy it to the client
// provided buffer.
if cr.offset > 0 {
@ -319,8 +380,43 @@ func (cr *s3ChunkedReader) Read(buf []byte) (n int, err error) {
cr.err = err
return n, cr.err
}
// Once we have read the entire chunk successfully, we verify
// that the received signature matches our computed signature.
cr.chunkSHA256Writer.Write(cr.buffer)
newSignature := cr.getChunkSignature()
if !compareSignatureV4(string(signature[16:]), newSignature) {
cr.err = errSignatureMismatch
return n, cr.err
}
cr.seedSignature = newSignature
cr.chunkSHA256Writer.Reset()
// If the chunk size is zero we return io.EOF. As specified by AWS,
// only the last chunk is zero-sized.
if len(cr.buffer) == 0 {
if cr.debug {
fmt.Println("EOF. Reading Trailers:", cr.trailers)
}
if cr.trailers != nil {
err = cr.readTrailers()
if cr.debug {
fmt.Println("trailers returned:", err, "now:", cr.trailers)
}
if err != nil {
cr.err = err
return 0, err
}
}
cr.err = io.EOF
return n, cr.err
}
b, err = cr.reader.ReadByte()
if b != '\r' || err != nil {
if cr.debug {
fmt.Printf("want %q, got %q\n", "\r", string(b))
}
cr.err = errMalformedEncoding
return n, cr.err
}
@ -333,33 +429,133 @@ func (cr *s3ChunkedReader) Read(buf []byte) (n int, err error) {
return n, cr.err
}
if b != '\n' {
if cr.debug {
fmt.Printf("want %q, got %q\n", "\r", string(b))
}
cr.err = errMalformedEncoding
return n, cr.err
}
// Once we have read the entire chunk successfully, we verify
// that the received signature matches our computed signature.
cr.chunkSHA256Writer.Write(cr.buffer)
newSignature := getChunkSignature(cr.cred, cr.seedSignature, cr.region, cr.seedDate, hex.EncodeToString(cr.chunkSHA256Writer.Sum(nil)))
if !compareSignatureV4(string(signature[16:]), newSignature) {
cr.err = errSignatureMismatch
return n, cr.err
}
cr.seedSignature = newSignature
cr.chunkSHA256Writer.Reset()
// If the chunk size is zero we return io.EOF. As specified by AWS,
// only the last chunk is zero-sized.
if size == 0 {
cr.err = io.EOF
return n, cr.err
}
cr.offset = copy(buf, cr.buffer)
n += cr.offset
return n, err
}
// readTrailers will read all trailers and populate cr.trailers with actual values.
func (cr *s3ChunkedReader) readTrailers() error {
var valueBuffer bytes.Buffer
// Read value
for {
v, err := cr.reader.ReadByte()
if err != nil {
if err == io.EOF {
return io.ErrUnexpectedEOF
}
}
if v != '\r' {
valueBuffer.WriteByte(v)
continue
}
// End of buffer, do not add to value.
v, err = cr.reader.ReadByte()
if err != nil {
if err == io.EOF {
return io.ErrUnexpectedEOF
}
}
if v != '\n' {
return errMalformedEncoding
}
break
}
// Read signature
var signatureBuffer bytes.Buffer
for {
v, err := cr.reader.ReadByte()
if err != nil {
if err == io.EOF {
return io.ErrUnexpectedEOF
}
}
if v != '\r' {
signatureBuffer.WriteByte(v)
continue
}
var tmp [3]byte
_, err = io.ReadFull(cr.reader, tmp[:])
if err != nil {
if err == io.EOF {
return io.ErrUnexpectedEOF
}
}
if string(tmp[:]) != "\n\r\n" {
if cr.debug {
fmt.Printf("signature, want %q, got %q", "\n\r\n", string(tmp[:]))
}
return errMalformedEncoding
}
// No need to write final newlines to buffer.
break
}
// Verify signature.
sig := signatureBuffer.Bytes()
if !bytes.HasPrefix(sig, []byte("x-amz-trailer-signature:")) {
if cr.debug {
fmt.Printf("prefix, want prefix %q, got %q", "x-amz-trailer-signature:", string(sig))
}
return errMalformedEncoding
}
sig = sig[len("x-amz-trailer-signature:"):]
sig = bytes.TrimSpace(sig)
cr.chunkSHA256Writer.Write(valueBuffer.Bytes())
wantSig := cr.getTrailerChunkSignature()
if !compareSignatureV4(string(sig), wantSig) {
if cr.debug {
fmt.Printf("signature, want: %q, got %q\nSignature buffer: %q\n", wantSig, string(sig), string(valueBuffer.Bytes()))
}
return errSignatureMismatch
}
// Parse trailers.
wantTrailers := make(map[string]struct{}, len(cr.trailers))
for k := range cr.trailers {
wantTrailers[strings.ToLower(k)] = struct{}{}
}
input := bufio.NewScanner(bytes.NewReader(valueBuffer.Bytes()))
for input.Scan() {
line := strings.TrimSpace(input.Text())
if line == "" {
continue
}
// Find first separator.
idx := strings.IndexByte(line, trailerKVSeparator[0])
if idx <= 0 || idx >= len(line) {
if cr.debug {
fmt.Printf("index, ':' not found in %q\n", line)
}
return errMalformedEncoding
}
key := line[:idx]
value := line[idx+1:]
if _, ok := wantTrailers[key]; !ok {
if cr.debug {
fmt.Printf("%q not found in %q\n", key, cr.trailers)
}
return errMalformedEncoding
}
cr.trailers.Set(key, value)
delete(wantTrailers, key)
}
// Check if we got all we want.
if len(wantTrailers) > 0 {
return io.ErrUnexpectedEOF
}
return nil
}
// readCRLF - check if reader only has '\r\n' CRLF character.
// returns malformed encoding if it doesn't.
func readCRLF(reader io.Reader) error {

View File

@ -0,0 +1,257 @@
// Copyright (c) 2015-2023 MinIO, Inc.
//
// This file is part of MinIO Object Storage stack
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package cmd
import (
"bufio"
"bytes"
"fmt"
"io"
"net/http"
"strings"
)
// newUnsignedV4ChunkedReader returns a new s3UnsignedChunkedReader that translates the data read from r
// out of HTTP "chunked" format before returning it.
// The s3ChunkedReader returns io.EOF when the final 0-length chunk is read.
func newUnsignedV4ChunkedReader(req *http.Request, trailer bool) (io.ReadCloser, APIErrorCode) {
if trailer {
// Discard anything unsigned.
req.Trailer = make(http.Header)
trailers := req.Header.Values(awsTrailerHeader)
for _, key := range trailers {
req.Trailer.Add(key, "")
}
} else {
req.Trailer = nil
}
return &s3UnsignedChunkedReader{
trailers: req.Trailer,
reader: bufio.NewReader(req.Body),
buffer: make([]byte, 64*1024),
}, ErrNone
}
// Represents the overall state that is required for decoding a
// AWS Signature V4 chunked reader.
type s3UnsignedChunkedReader struct {
reader *bufio.Reader
trailers http.Header
buffer []byte
offset int
err error
debug bool
}
func (cr *s3UnsignedChunkedReader) Close() (err error) {
return cr.err
}
// Read - implements `io.Reader`, which transparently decodes
// the incoming AWS Signature V4 streaming signature.
func (cr *s3UnsignedChunkedReader) Read(buf []byte) (n int, err error) {
// First, if there is any unread data, copy it to the client
// provided buffer.
if cr.offset > 0 {
n = copy(buf, cr.buffer[cr.offset:])
if n == len(buf) {
cr.offset += n
return n, nil
}
cr.offset = 0
buf = buf[n:]
}
// mustRead reads from input and compares against provided slice.
mustRead := func(b ...byte) error {
for _, want := range b {
got, err := cr.reader.ReadByte()
if err == io.EOF {
return io.ErrUnexpectedEOF
}
if got != want {
if cr.debug {
fmt.Printf("mustread: want: %q got: %q\n", string(want), string(got))
}
return errMalformedEncoding
}
if err != nil {
return err
}
}
return nil
}
var size int
for {
b, err := cr.reader.ReadByte()
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
if err != nil {
cr.err = err
return n, cr.err
}
if b == '\r' { // \r\n denotes end of size.
err := mustRead('\n')
if err != nil {
cr.err = err
return n, cr.err
}
break
}
// Manually deserialize the size since AWS specified
// the chunk size to be of variable width. In particular,
// a size of 16 is encoded as `10` while a size of 64 KB
// is `10000`.
switch {
case b >= '0' && b <= '9':
size = size<<4 | int(b-'0')
case b >= 'a' && b <= 'f':
size = size<<4 | int(b-('a'-10))
case b >= 'A' && b <= 'F':
size = size<<4 | int(b-('A'-10))
default:
if cr.debug {
fmt.Printf("err size: %v\n", string(b))
}
cr.err = errMalformedEncoding
return n, cr.err
}
if size > maxChunkSize {
cr.err = errChunkTooBig
return n, cr.err
}
}
if cap(cr.buffer) < size {
cr.buffer = make([]byte, size)
} else {
cr.buffer = cr.buffer[:size]
}
// Now, we read the payload.
_, err = io.ReadFull(cr.reader, cr.buffer)
if err == io.EOF && size != 0 {
err = io.ErrUnexpectedEOF
}
if err != nil && err != io.EOF {
cr.err = err
return n, cr.err
}
// If the chunk size is zero we return io.EOF. As specified by AWS,
// only the last chunk is zero-sized.
if len(cr.buffer) == 0 {
if cr.debug {
fmt.Println("EOF")
}
if cr.trailers != nil {
err = cr.readTrailers()
if cr.debug {
fmt.Println("trailer returned:", err)
}
if err != nil {
cr.err = err
return 0, err
}
}
cr.err = io.EOF
return n, cr.err
}
// read final terminator.
err = mustRead('\r', '\n')
if err != nil && err != io.EOF {
cr.err = err
return n, cr.err
}
cr.offset = copy(buf, cr.buffer)
n += cr.offset
return n, err
}
// readTrailers will read all trailers and populate cr.trailers with actual values.
func (cr *s3UnsignedChunkedReader) readTrailers() error {
var valueBuffer bytes.Buffer
// Read value
for {
v, err := cr.reader.ReadByte()
if err != nil {
if err == io.EOF {
return io.ErrUnexpectedEOF
}
}
if v != '\r' {
valueBuffer.WriteByte(v)
continue
}
// Must end with \r\n\r\n
var tmp [3]byte
_, err = io.ReadFull(cr.reader, tmp[:])
if err != nil {
if err == io.EOF {
return io.ErrUnexpectedEOF
}
}
if !bytes.Equal(tmp[:], []byte{'\n', '\r', '\n'}) {
if cr.debug {
fmt.Printf("got %q, want %q\n", string(tmp[:]), "\n\r\n")
}
return errMalformedEncoding
}
break
}
// Parse trailers.
wantTrailers := make(map[string]struct{}, len(cr.trailers))
for k := range cr.trailers {
wantTrailers[strings.ToLower(k)] = struct{}{}
}
input := bufio.NewScanner(bytes.NewReader(valueBuffer.Bytes()))
for input.Scan() {
line := strings.TrimSpace(input.Text())
if line == "" {
continue
}
// Find first separator.
idx := strings.IndexByte(line, trailerKVSeparator[0])
if idx <= 0 || idx >= len(line) {
if cr.debug {
fmt.Printf("Could not find separator, got %q\n", line)
}
return errMalformedEncoding
}
key := strings.ToLower(line[:idx])
value := line[idx+1:]
if _, ok := wantTrailers[key]; !ok {
if cr.debug {
fmt.Printf("Unknown key %q - expected on of %v\n", key, cr.trailers)
}
return errMalformedEncoding
}
cr.trailers.Set(key, value)
delete(wantTrailers, key)
}
// Check if we got all we want.
if len(wantTrailers) > 0 {
return io.ErrUnexpectedEOF
}
return nil
}

View File

@ -303,8 +303,8 @@ func (c Checksum) Valid() bool {
if c.Type == ChecksumInvalid {
return false
}
if len(c.Encoded) == 0 || c.Type.Is(ChecksumTrailing) {
return c.Type.Is(ChecksumNone) || c.Type.Is(ChecksumTrailing)
if len(c.Encoded) == 0 || c.Type.Trailing() {
return c.Type.Is(ChecksumNone) || c.Type.Trailing()
}
raw := c.Raw
return c.Type.RawByteLen() == len(raw)
@ -339,10 +339,21 @@ func (c *Checksum) AsMap() map[string]string {
}
// TransferChecksumHeader will transfer any checksum value that has been checked.
// If checksum was trailing, they must have been added to r.Trailer.
func TransferChecksumHeader(w http.ResponseWriter, r *http.Request) {
t, s := getContentChecksum(r)
if !t.IsSet() || t.Is(ChecksumTrailing) {
// TODO: Add trailing when we can read it.
c, err := GetContentChecksum(r)
if err != nil || c == nil {
return
}
t, s := c.Type, c.Encoded
if !c.Type.IsSet() {
return
}
if c.Type.Is(ChecksumTrailing) {
val := r.Trailer.Get(t.Key())
if val != "" {
w.Header().Set(t.Key(), val)
}
return
}
w.Header().Set(t.Key(), s)
@ -365,6 +376,32 @@ func AddChecksumHeader(w http.ResponseWriter, c map[string]string) {
// Returns ErrInvalidChecksum if so.
// Returns nil, nil if no checksum.
func GetContentChecksum(r *http.Request) (*Checksum, error) {
if trailing := r.Header.Values(xhttp.AmzTrailer); len(trailing) > 0 {
var res *Checksum
for _, header := range trailing {
var duplicates bool
switch {
case strings.EqualFold(header, ChecksumCRC32C.Key()):
duplicates = res != nil
res = NewChecksumWithType(ChecksumCRC32C|ChecksumTrailing, "")
case strings.EqualFold(header, ChecksumCRC32.Key()):
duplicates = res != nil
res = NewChecksumWithType(ChecksumCRC32|ChecksumTrailing, "")
case strings.EqualFold(header, ChecksumSHA256.Key()):
duplicates = res != nil
res = NewChecksumWithType(ChecksumSHA256|ChecksumTrailing, "")
case strings.EqualFold(header, ChecksumSHA1.Key()):
duplicates = res != nil
res = NewChecksumWithType(ChecksumSHA1|ChecksumTrailing, "")
}
if duplicates {
return nil, ErrInvalidChecksum
}
}
if res != nil {
return res, nil
}
}
t, s := getContentChecksum(r)
if t == ChecksumNone {
if s == "" {
@ -389,11 +426,6 @@ func getContentChecksum(r *http.Request) (t ChecksumType, s string) {
if t.IsSet() {
hdr := t.Key()
if s = r.Header.Get(hdr); s == "" {
if strings.EqualFold(r.Header.Get(xhttp.AmzTrailer), hdr) {
t |= ChecksumTrailing
} else {
t = ChecksumInvalid
}
return ChecksumNone, ""
}
}
@ -409,6 +441,7 @@ func getContentChecksum(r *http.Request) (t ChecksumType, s string) {
t = c
s = got
}
return
}
}
checkType(ChecksumCRC32)

View File

@ -28,6 +28,7 @@ import (
"github.com/minio/minio/internal/etag"
"github.com/minio/minio/internal/hash/sha256"
"github.com/minio/minio/internal/ioutil"
)
// A Reader wraps an io.Reader and computes the MD5 checksum
@ -51,6 +52,8 @@ type Reader struct {
contentHash Checksum
contentHasher hash.Hash
trailer http.Header
sha256 hash.Hash
}
@ -107,7 +110,7 @@ func NewReader(src io.Reader, size int64, md5Hex, sha256Hex string, actualSize i
r.checksum = MD5
r.contentSHA256 = SHA256
if r.size < 0 && size >= 0 {
r.src = etag.Wrap(io.LimitReader(r.src, size), r.src)
r.src = etag.Wrap(ioutil.HardLimitReader(r.src, size), r.src)
r.size = size
}
if r.actualSize <= 0 && actualSize >= 0 {
@ -117,7 +120,7 @@ func NewReader(src io.Reader, size int64, md5Hex, sha256Hex string, actualSize i
}
if size >= 0 {
r := io.LimitReader(src, size)
r := ioutil.HardLimitReader(src, size)
if _, ok := src.(etag.Tagger); !ok {
src = etag.NewReader(r, MD5)
} else {
@ -155,10 +158,14 @@ func (r *Reader) AddChecksum(req *http.Request, ignoreValue bool) error {
return nil
}
r.contentHash = *cs
if cs.Type.Trailing() || ignoreValue {
// Ignore until we have trailing headers.
if cs.Type.Trailing() {
r.trailer = req.Trailer
}
if ignoreValue {
// Do not validate, but allow for transfer
return nil
}
r.contentHasher = cs.Type.Hasher()
if r.contentHasher == nil {
return ErrInvalidChecksum
@ -186,6 +193,14 @@ func (r *Reader) Read(p []byte) (int, error) {
}
}
if r.contentHasher != nil {
if r.contentHash.Type.Trailing() {
var err error
r.contentHash.Encoded = r.trailer.Get(r.contentHash.Type.Key())
r.contentHash.Raw, err = base64.StdEncoding.DecodeString(r.contentHash.Encoded)
if err != nil || len(r.contentHash.Raw) == 0 {
return 0, ChecksumMismatch{Got: r.contentHash.Encoded}
}
}
if sum := r.contentHasher.Sum(nil); !bytes.Equal(r.contentHash.Raw, sum) {
err := ChecksumMismatch{
Want: r.contentHash.Encoded,
@ -276,6 +291,9 @@ func (r *Reader) ContentCRC() map[string]string {
if r.contentHash.Type == ChecksumNone || !r.contentHash.Valid() {
return nil
}
if r.contentHash.Type.Trailing() {
return map[string]string{r.contentHash.Type.String(): r.trailer.Get(r.contentHash.Type.Key())}
}
return map[string]string{r.contentHash.Type.String(): r.contentHash.Encoded}
}

View File

@ -23,6 +23,8 @@ import (
"fmt"
"io"
"testing"
"github.com/minio/minio/internal/ioutil"
)
// Tests functions like Size(), MD5*(), SHA256*()
@ -79,7 +81,7 @@ func TestHashReaderVerification(t *testing.T) {
md5hex, sha256hex string
err error
}{
{
0: {
desc: "Success, no checksum verification provided.",
src: bytes.NewReader([]byte("abcd")),
size: 4,
@ -124,7 +126,7 @@ func TestHashReaderVerification(t *testing.T) {
CalculatedSHA256: "88d4266fd4e6338d13b845fcf289579d209c897823b9217da3e161936f031589",
},
},
{
5: {
desc: "Correct sha256, nested",
src: mustReader(t, bytes.NewReader([]byte("abcd")), 4, "", "", 4),
size: 4,
@ -137,13 +139,15 @@ func TestHashReaderVerification(t *testing.T) {
size: 4,
actualSize: -1,
sha256hex: "88d4266fd4e6338d13b845fcf289579d209c897823b9217da3e161936f031589",
err: ioutil.ErrOverread,
},
{
7: {
desc: "Correct sha256, nested, truncated, swapped",
src: mustReader(t, bytes.NewReader([]byte("abcd-more-stuff-to-be ignored")), 4, "", "", -1),
size: 4,
actualSize: -1,
sha256hex: "88d4266fd4e6338d13b845fcf289579d209c897823b9217da3e161936f031589",
err: ioutil.ErrOverread,
},
{
desc: "Incorrect MD5, nested",
@ -162,6 +166,7 @@ func TestHashReaderVerification(t *testing.T) {
size: 4,
actualSize: 4,
sha256hex: "88d4266fd4e6338d13b845fcf289579d209c897823b9217da3e161936f031589",
err: ioutil.ErrOverread,
},
{
desc: "Correct MD5, nested",
@ -177,6 +182,7 @@ func TestHashReaderVerification(t *testing.T) {
actualSize: 4,
sha256hex: "",
md5hex: "e2fc714c4727ee9395f324cd2e7f331f",
err: ioutil.ErrOverread,
},
{
desc: "Correct MD5, nested, truncated",
@ -184,6 +190,7 @@ func TestHashReaderVerification(t *testing.T) {
size: 4,
actualSize: 4,
md5hex: "e2fc714c4727ee9395f324cd2e7f331f",
err: ioutil.ErrOverread,
},
}
for i, testCase := range testCases {
@ -194,6 +201,10 @@ func TestHashReaderVerification(t *testing.T) {
}
_, err = io.Copy(io.Discard, r)
if err != nil {
if testCase.err == nil {
t.Errorf("Test %q; got unexpected error: %v", testCase.desc, err)
return
}
if err.Error() != testCase.err.Error() {
t.Errorf("Test %q: Expected error %s, got error %s", testCase.desc, testCase.err, err)
}

View File

@ -0,0 +1,56 @@
// Copyright (c) 2015-2023 MinIO, Inc.
//
// This file is part of MinIO Object Storage stack
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
// Package ioutil implements some I/O utility functions which are not covered
// by the standard library.
package ioutil
import (
"errors"
"io"
)
// ErrOverread is returned to the reader when the hard limit of HardLimitReader is exceeded.
var ErrOverread = errors.New("input provided more bytes than specified")
// HardLimitReader returns a Reader that reads from r
// but returns an error if the source provides more data than allowed.
// This means the source *will* be overread unless EOF is returned prior.
// The underlying implementation is a *HardLimitedReader.
// This will ensure that at most n bytes are returned and EOF is reached.
func HardLimitReader(r io.Reader, n int64) io.Reader { return &HardLimitedReader{r, n} }
// A HardLimitedReader reads from R but limits the amount of
// data returned to just N bytes. Each call to Read
// updates N to reflect the new amount remaining.
// Read returns EOF when N <= 0 or when the underlying R returns EOF.
type HardLimitedReader struct {
R io.Reader // underlying reader
N int64 // max bytes remaining
}
func (l *HardLimitedReader) Read(p []byte) (n int, err error) {
if l.N < 0 {
return 0, ErrOverread
}
n, err = l.R.Read(p)
l.N -= int64(n)
if l.N < 0 {
return 0, ErrOverread
}
return
}