diff --git a/net/ghttp/ghttp_request_param.go b/net/ghttp/ghttp_request_param.go index 4ff6475df30..adb2f173681 100644 --- a/net/ghttp/ghttp_request_param.go +++ b/net/ghttp/ghttp_request_param.go @@ -236,8 +236,20 @@ func (r *Request) parseBody() { if body := r.GetBody(); len(body) > 0 { // Trim space/new line characters. body = bytes.TrimSpace(body) + if len(body) == 0 { + return + } + contentType := r.Header.Get("Content-Type") + jsonContentType := gstr.ContainsI(contentType, contentTypeJson) + // Preserve GET query/form body compatibility while validating JSON-shaped GET bodies. + strictJsonContentType := jsonContentType && (r.Method != http.MethodGet || body[0] == '{' || body[0] == '[') // JSON format checks. - if body[0] == '{' && body[len(body)-1] == '}' { + if strictJsonContentType { + if err := json.UnmarshalUseNumber(body, &r.bodyMap); err != nil { + r.SetError(gerror.WrapCode(gcode.CodeInvalidParameter, err, "Parse JSON body failed")) + return + } + } else if body[0] == '{' && body[len(body)-1] == '}' { _ = json.UnmarshalUseNumber(body, &r.bodyMap) } // XML format checks. @@ -248,7 +260,7 @@ func (r *Request) parseBody() { r.bodyMap, _ = gxml.DecodeWithoutRoot(body) } // Default parameters decoding. - if contentType := r.Header.Get("Content-Type"); (contentType == "" || !gstr.Contains(contentType, "multipart/")) && r.bodyMap == nil { + if (contentType == "" || !gstr.Contains(contentType, "multipart/")) && r.bodyMap == nil { r.bodyMap, _ = gstr.Parse(r.GetBodyString()) } } diff --git a/net/ghttp/ghttp_request_param_form.go b/net/ghttp/ghttp_request_param_form.go index 68deb0a6736..0c154ac09ce 100644 --- a/net/ghttp/ghttp_request_param_form.go +++ b/net/ghttp/ghttp_request_param_form.go @@ -98,6 +98,9 @@ func (r *Request) GetFormStruct(pointer any, mapping ...map[string]string) error func (r *Request) doGetFormStruct(pointer any, mapping ...map[string]string) (data map[string]any, err error) { r.parseForm() + if err = r.GetError(); err != nil { + return nil, err + } data = r.formMap if data == nil { data = map[string]any{} diff --git a/net/ghttp/ghttp_request_param_query.go b/net/ghttp/ghttp_request_param_query.go index 2ad5ca5f59f..8da16734a4e 100644 --- a/net/ghttp/ghttp_request_param_query.go +++ b/net/ghttp/ghttp_request_param_query.go @@ -142,6 +142,9 @@ func (r *Request) GetQueryStruct(pointer any, mapping ...map[string]string) erro func (r *Request) doGetQueryStruct(pointer any, mapping ...map[string]string) (data map[string]any, err error) { r.parseQuery() data = r.GetQueryMap() + if err = r.GetError(); err != nil { + return nil, err + } if data == nil { data = map[string]any{} } diff --git a/net/ghttp/ghttp_request_param_request.go b/net/ghttp/ghttp_request_param_request.go index 31e3b83d706..e45021d0676 100644 --- a/net/ghttp/ghttp_request_param_request.go +++ b/net/ghttp/ghttp_request_param_request.go @@ -174,6 +174,9 @@ func (r *Request) GetRequestStruct(pointer any, mapping ...map[string]string) er func (r *Request) doGetRequestStruct(pointer any, mapping ...map[string]string) (data map[string]any, err error) { data = r.GetRequestMap() + if err = r.GetError(); err != nil { + return nil, err + } if data == nil { data = map[string]any{} } diff --git a/net/ghttp/ghttp_z_unit_feature_request_param_test.go b/net/ghttp/ghttp_z_unit_feature_request_param_test.go index 9495b6e1932..76c1b058669 100644 --- a/net/ghttp/ghttp_z_unit_feature_request_param_test.go +++ b/net/ghttp/ghttp_z_unit_feature_request_param_test.go @@ -8,11 +8,13 @@ package ghttp_test import ( "context" + "encoding/json" "fmt" "strconv" "testing" "time" + "github.com/gogf/gf/v2/errors/gcode" "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/net/ghttp" "github.com/gogf/gf/v2/test/gtest" @@ -94,6 +96,26 @@ func (c *cUserTagDefault) User(ctx context.Context, req *UserTagDefaultReq) (res return } +type UserInvalidJsonReq struct { + g.Meta `path:"/user-invalid-json" method:"post,get" summary:"user invalid json api"` + Name string +} + +type UserInvalidJsonRes struct { + g.Meta `mime:"application/json"` + Name string +} + +var ( + UserInvalidJson = cUserInvalidJson{} +) + +type cUserInvalidJson struct{} + +func (c *cUserInvalidJson) User(ctx context.Context, req *UserInvalidJsonReq) (res *UserInvalidJsonRes, err error) { + return &UserInvalidJsonRes{Name: req.Name}, nil +} + func Test_ParamsTagDefault(t *testing.T) { s := g.Server(guid.S()) s.Group("/", func(group *ghttp.RouterGroup) { @@ -150,6 +172,49 @@ func Test_ParamsTagDefault(t *testing.T) { }) } +func Test_ParamsInvalidJsonReportsParseError(t *testing.T) { + s := g.Server(guid.S()) + s.Group("/", func(group *ghttp.RouterGroup) { + group.Middleware(ghttp.MiddlewareHandlerResponse) + group.Bind(UserInvalidJson) + }) + s.SetDumpRouterMap(false) + s.Start() + defer s.Shutdown() + + for i := 0; i < 100 && s.GetListenedPort() == 0; i++ { + time.Sleep(10 * time.Millisecond) + } + if s.GetListenedPort() == 0 { + t.Fatal("server did not start listening") + } + + gtest.C(t, func(t *gtest.T) { + prefix := fmt.Sprintf("http://127.0.0.1:%d", s.GetListenedPort()) + reqCtx := context.Background() + client := g.Client() + client.SetPrefix(prefix) + + assertInvalidParameter := func(resp string) { + var payload struct { + Code int `json:"code"` + Data *UserInvalidJsonRes `json:"data"` + } + err := json.Unmarshal([]byte(resp), &payload) + t.AssertNil(err) + t.Assert(payload.Code, gcode.CodeInvalidParameter.Code()) + t.Assert(payload.Data, nil) + } + + for _, body := range []string{`{"name":}`, `[{"name":"john"}]`, `name=john`} { + assertInvalidParameter(client.ContentJson().PostContent(reqCtx, "/user-invalid-json", body)) + } + for _, body := range []string{`{"name":}`, `[{"name":"john"}]`} { + assertInvalidParameter(client.ContentJson().GetContent(reqCtx, "/user-invalid-json", body)) + } + }) +} + func Benchmark_ParamTagIn(b *testing.B) { b.StopTimer()