Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions net/ghttp/ghttp_request_param.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,20 @@ func (r *Request) parseBody() {
if body := r.GetBody(); len(body) > 0 {
// Trim space/new line characters.
body = bytes.TrimSpace(body)
if len(body) == 0 {
return
}
contentType := r.Header.Get("Content-Type")
jsonContentType := gstr.ContainsI(contentType, contentTypeJson)
// Preserve GET query/form body compatibility while validating JSON-shaped GET bodies.
strictJsonContentType := jsonContentType && (r.Method != http.MethodGet || body[0] == '{' || body[0] == '[')
// JSON format checks.
if body[0] == '{' && body[len(body)-1] == '}' {
if strictJsonContentType {
if err := json.UnmarshalUseNumber(body, &r.bodyMap); err != nil {
r.SetError(gerror.WrapCode(gcode.CodeInvalidParameter, err, "Parse JSON body failed"))
return
}
} else if body[0] == '{' && body[len(body)-1] == '}' {
_ = json.UnmarshalUseNumber(body, &r.bodyMap)
}
// XML format checks.
Expand All @@ -248,7 +260,7 @@ func (r *Request) parseBody() {
r.bodyMap, _ = gxml.DecodeWithoutRoot(body)
}
// Default parameters decoding.
if contentType := r.Header.Get("Content-Type"); (contentType == "" || !gstr.Contains(contentType, "multipart/")) && r.bodyMap == nil {
if (contentType == "" || !gstr.Contains(contentType, "multipart/")) && r.bodyMap == nil {
r.bodyMap, _ = gstr.Parse(r.GetBodyString())
}
}
Expand Down
3 changes: 3 additions & 0 deletions net/ghttp/ghttp_request_param_form.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ func (r *Request) GetFormStruct(pointer any, mapping ...map[string]string) error

func (r *Request) doGetFormStruct(pointer any, mapping ...map[string]string) (data map[string]any, err error) {
r.parseForm()
if err = r.GetError(); err != nil {
return nil, err
}
data = r.formMap
if data == nil {
data = map[string]any{}
Expand Down
3 changes: 3 additions & 0 deletions net/ghttp/ghttp_request_param_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,9 @@ func (r *Request) GetQueryStruct(pointer any, mapping ...map[string]string) erro
func (r *Request) doGetQueryStruct(pointer any, mapping ...map[string]string) (data map[string]any, err error) {
r.parseQuery()
data = r.GetQueryMap()
if err = r.GetError(); err != nil {
return nil, err
}
if data == nil {
data = map[string]any{}
}
Expand Down
3 changes: 3 additions & 0 deletions net/ghttp/ghttp_request_param_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,9 @@ func (r *Request) GetRequestStruct(pointer any, mapping ...map[string]string) er

func (r *Request) doGetRequestStruct(pointer any, mapping ...map[string]string) (data map[string]any, err error) {
data = r.GetRequestMap()
if err = r.GetError(); err != nil {
return nil, err
}
if data == nil {
data = map[string]any{}
}
Expand Down
65 changes: 65 additions & 0 deletions net/ghttp/ghttp_z_unit_feature_request_param_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@ package ghttp_test

import (
"context"
"encoding/json"
"fmt"
"strconv"
"testing"
"time"

"github.com/gogf/gf/v2/errors/gcode"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/net/ghttp"
"github.com/gogf/gf/v2/test/gtest"
Expand Down Expand Up @@ -94,6 +96,26 @@ func (c *cUserTagDefault) User(ctx context.Context, req *UserTagDefaultReq) (res
return
}

type UserInvalidJsonReq struct {
g.Meta `path:"/user-invalid-json" method:"post,get" summary:"user invalid json api"`
Name string
}

type UserInvalidJsonRes struct {
g.Meta `mime:"application/json"`
Name string
}

var (
UserInvalidJson = cUserInvalidJson{}
)

type cUserInvalidJson struct{}

func (c *cUserInvalidJson) User(ctx context.Context, req *UserInvalidJsonReq) (res *UserInvalidJsonRes, err error) {
return &UserInvalidJsonRes{Name: req.Name}, nil
}

func Test_ParamsTagDefault(t *testing.T) {
s := g.Server(guid.S())
s.Group("/", func(group *ghttp.RouterGroup) {
Expand Down Expand Up @@ -150,6 +172,49 @@ func Test_ParamsTagDefault(t *testing.T) {
})
}

func Test_ParamsInvalidJsonReportsParseError(t *testing.T) {
s := g.Server(guid.S())
s.Group("/", func(group *ghttp.RouterGroup) {
group.Middleware(ghttp.MiddlewareHandlerResponse)
group.Bind(UserInvalidJson)
})
s.SetDumpRouterMap(false)
s.Start()
defer s.Shutdown()

for i := 0; i < 100 && s.GetListenedPort() == 0; i++ {
time.Sleep(10 * time.Millisecond)
}
if s.GetListenedPort() == 0 {
t.Fatal("server did not start listening")
}

gtest.C(t, func(t *gtest.T) {
prefix := fmt.Sprintf("http://127.0.0.1:%d", s.GetListenedPort())
reqCtx := context.Background()
client := g.Client()
client.SetPrefix(prefix)
Comment thread
puneetdixit200 marked this conversation as resolved.

assertInvalidParameter := func(resp string) {
var payload struct {
Code int `json:"code"`
Data *UserInvalidJsonRes `json:"data"`
}
err := json.Unmarshal([]byte(resp), &payload)
t.AssertNil(err)
t.Assert(payload.Code, gcode.CodeInvalidParameter.Code())
t.Assert(payload.Data, nil)
}

for _, body := range []string{`{"name":}`, `[{"name":"john"}]`, `name=john`} {
assertInvalidParameter(client.ContentJson().PostContent(reqCtx, "/user-invalid-json", body))
}
for _, body := range []string{`{"name":}`, `[{"name":"john"}]`} {
assertInvalidParameter(client.ContentJson().GetContent(reqCtx, "/user-invalid-json", body))
}
})
}

func Benchmark_ParamTagIn(b *testing.B) {
b.StopTimer()

Expand Down
Loading