diff --git a/contrib/drivers/mysql/go.mod b/contrib/drivers/mysql/go.mod index e74d1e9cb07..07e9024ba21 100644 --- a/contrib/drivers/mysql/go.mod +++ b/contrib/drivers/mysql/go.mod @@ -5,6 +5,7 @@ go 1.23.0 require ( github.com/go-sql-driver/mysql v1.7.1 github.com/gogf/gf/v2 v2.10.0 + github.com/shopspring/decimal v1.4.0 ) require ( diff --git a/contrib/drivers/mysql/go.sum b/contrib/drivers/mysql/go.sum index f96db96f24f..ff9460bec25 100644 --- a/contrib/drivers/mysql/go.sum +++ b/contrib/drivers/mysql/go.sum @@ -50,6 +50,8 @@ github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= +github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k= +github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= diff --git a/contrib/drivers/mysql/mysql_z_unit_issue_test.go b/contrib/drivers/mysql/mysql_z_unit_issue_test.go index 72c7185de87..d15caae1ae5 100644 --- a/contrib/drivers/mysql/mysql_z_unit_issue_test.go +++ b/contrib/drivers/mysql/mysql_z_unit_issue_test.go @@ -14,6 +14,8 @@ import ( "testing" "time" + "github.com/shopspring/decimal" + "github.com/gogf/gf/v2/container/gvar" "github.com/gogf/gf/v2/database/gdb" "github.com/gogf/gf/v2/frame/g" @@ -2066,3 +2068,171 @@ func Test_Issue4698(t *testing.T) { t.Assert(len(result), TableSize) }) } + +// https://github.com/gogf/gf/issues/3977 +func Test_Issue3977(t *testing.T) { + table := "issue3977" + array := gstr.SplitAndTrim(gtest.DataContent(`issues`, `3977.sql`), ";") + for _, v := range array { + if _, err := db.Exec(ctx, v); err != nil { + gtest.Error(err) + } + } + defer dropTable(table) + // string *string []string []*string + gtest.C(t, func(t *gtest.T) { + var err error + var username string + err = db.Model(table).Fields("username").Where("id", 1).Scan(&username) + t.Assert(err, nil) + t.Assert(username, "username1") + var username2 *string + err = db.Model(table).Fields("username").Where("id", 1).Scan(&username2) + t.Assert(err, nil) + t.Assert(username2, "username1") + + var usernames []string + err = db.Model(table).Fields("username").Scan(&usernames) + t.AssertNil(err) + t.Assert(usernames, []string{"username1", "username2", "username3"}) + + var usernames2 []*string + err = db.Model(table).Fields("username").Scan(&usernames2) + t.AssertNil(err) + t.Assert(usernames2, []string{"username1", "username2", "username3"}) + + }) + // float64 *float64 + gtest.C(t, func(t *gtest.T) { + var err error + var balance float64 + err = db.Model(table).Fields("balance").Where("id", 1).Scan(&balance) + t.Assert(err, nil) + t.Assert(balance, 1.01) + + var balance2 *float64 + err = db.Model(table).Fields("balance").Where("id", 1).Scan(&balance2) + t.Assert(err, nil) + t.Assert(balance2, 1.01) + + var balances []float64 + err = db.Model(table).Fields("balance").Scan(&balances) + t.Assert(err, nil) + t.Assert(balances, []float64{1.01, 2.02, 3.03}) + + var balances2 []*float64 + err = db.Model(table).Fields("balance").Scan(&balances2) + t.Assert(err, nil) + expectedBalances := []float64{1.01, 2.02, 3.03} + actualBalances := make([]float64, len(balances2)) + for i, v := range balances2 { + if v != nil { + actualBalances[i] = *v + } + } + t.Assert(actualBalances, expectedBalances) + + var totalBalances float64 + err = db.Model(table).FieldSum("balance").WhereIn("id", []int64{1, 2, 3, 4, 5}).Scan(&totalBalances) + t.Assert(err, nil) + t.Assert(totalBalances, 6.06) + + var totalBalances2 *float64 + err = db.Model(table).FieldSum("balance").WhereIn("id", []int64{1, 2, 3, 4, 5}).Scan(&totalBalances2) + t.Assert(err, nil) + t.Assert(totalBalances2, 6.06) + + }) + // int []int + gtest.C(t, func(t *gtest.T) { + var err error + var age int + err = db.Model(table).Fields("age").Where("id", 1).Scan(&age) + t.Assert(err, nil) + t.Assert(age, 18) + + var age2 *int + err = db.Model(table).Fields("age").Where("id", 1).Scan(&age2) + t.Assert(err, nil) + t.Assert(age2, 18) + + var ids []int64 + err = db.Model(table).Fields("id").Where("state", true).Scan(&ids) + t.AssertNil(err) + t.Assert(ids, []int64{1, 2}) + + var id2s []*int64 + err = db.Model(table).Fields("id").Where("state", true).Scan(&id2s) + t.AssertNil(err) + t.Assert(id2s, []int64{1, 2}) + + var total int64 + err = db.Model(table).FieldSum("id").WhereIn("id", []int64{1, 2, 3, 4, 5}).Scan(&total) + t.Assert(err, nil) + t.Assert(total, 6) + + var total2 int64 + err = db.Model(table).FieldSum("id").WhereIn("id", []int64{1, 2, 3, 4, 5}).Scan(&total2) + t.Assert(err, nil) + t.Assert(total2, 6) + + }) + // bool + gtest.C(t, func(t *gtest.T) { + var err error + var state bool + err = db.Model(table).Fields("state").Where("id", 1).Scan(&state) + t.Assert(err, nil) + t.Assert(state, true) + + var state2 *bool + err = db.Model(table).Fields("state").Where("id", 1).Scan(&state2) + t.Assert(err, nil) + t.Assert(state2, true) + + var states []bool + err = db.Model(table).Fields("state").Scan(&states) + t.AssertNil(err) + t.Assert(states, []bool{true, true, false}) + var states2 []*bool + err = db.Model(table).Fields("state").Scan(&states2) + t.AssertNil(err) + t.Assert(states2, []bool{true, true, false}) + }) + // decimal.Decimal + gtest.C(t, func(t *gtest.T) { + var err error + var balance decimal.Decimal + err = db.Model(table).Fields("balance").Where("id", 1).Scan(&balance) + t.Assert(err, nil) + t.Assert(balance, decimal.NewFromFloat(1.01)) + + var balance2 *decimal.Decimal + err = db.Model(table).Fields("balance").Where("id", 1).Scan(&balance2) + t.Assert(err, nil) + t.Assert(balance2, decimal.NewFromFloat(1.01)) + + var totalBalances decimal.Decimal + err = db.Model(table).FieldSum("balance").WhereIn("id", []int64{1, 2, 3, 4, 5}).Scan(&totalBalances) + t.Assert(err, nil) + t.Assert(totalBalances, decimal.NewFromFloat(6.06)) + + var totalBalances2 *decimal.Decimal + err = db.Model(table).FieldSum("balance").WhereIn("id", []int64{1, 2, 3, 4, 5}).Scan(&totalBalances2) + t.Assert(err, nil) + t.Assert(totalBalances2, decimal.NewFromFloat(6.06)) + }) + + gtest.C(t, func(t *gtest.T) { + var err error + var createdAt gtime.Time + err = db.Model(table).Fields("create_at").Where("id", 1).Scan(&createdAt) + t.AssertNil(err) + t.Assert(createdAt, "2020-01-01 00:00:00") + + var createdAt2 *gtime.Time + err = db.Model(table).Fields("create_at").Where("id", 1).Scan(&createdAt2) + t.AssertNil(err) + t.Assert(createdAt2, "2020-01-01 00:00:00") + }) +} diff --git a/contrib/drivers/mysql/testdata/issues/3977.sql b/contrib/drivers/mysql/testdata/issues/3977.sql new file mode 100644 index 00000000000..d08d4d22600 --- /dev/null +++ b/contrib/drivers/mysql/testdata/issues/3977.sql @@ -0,0 +1,15 @@ +DROP TABLE IF EXISTS `issue3977`; +CREATE TABLE `issue3977` ( + `id` bigint NOT NULL, + `username` varchar(255) DEFAULT "", + `balance` decimal(10,2) DEFAULT 0.00, + `state` bool DEFAULT 0, + `age` int DEFAULT 0, + `create_at` datetime(0) DEFAULT NULL, + `update_at` datetime(0) DEFAULT NULL, + PRIMARY KEY (`id`) +) ENGINE=InnoDB; + +INSERT INTO `issue3977` VALUES (1, "username1", 1.01, 1, 18, "2020-01-01 00:00:00", "2020-01-01 00:00:00"); +INSERT INTO `issue3977` VALUES (2, "username2", 2.02, 1, 100, "2020-01-01 00:00:00", "2020-01-01 00:00:00"); +INSERT INTO `issue3977` VALUES (3, "username3", 3.03, 0, 56, "2020-01-01 00:00:00", "2020-01-01 00:00:00"); \ No newline at end of file diff --git a/database/gdb/gdb_model_select.go b/database/gdb/gdb_model_select.go index 19f44cb3409..278f9146bf8 100644 --- a/database/gdb/gdb_model_select.go +++ b/database/gdb/gdb_model_select.go @@ -8,6 +8,7 @@ package gdb import ( "context" + "database/sql" "fmt" "reflect" @@ -302,11 +303,63 @@ func (m *Model) Scan(pointer any, where ...any) error { } switch reflectInfo.OriginKind { case reflect.Slice, reflect.Array: + elemType := reflectInfo.InputType + for elemType.Kind() == reflect.Pointer { + elemType = elemType.Elem() + } + elemType = elemType.Elem() + originalType := elemType + for elemType.Kind() == reflect.Pointer { + originalType = elemType + elemType = elemType.Elem() + } + + if elemType != nil && (elemType.Kind() != reflect.Struct || elemType.Implements(reflect.TypeFor[sql.Scanner]()) || originalType.Implements(reflect.TypeFor[sql.Scanner]())) { + if len(m.fields) == 1 { + args := append([]any{m.fields[0]}, where...) + valueArr, err := m.Array(args...) + if err != nil { + return err + } + return valueArr.Scan(pointer) + } + } return m.doStructs(pointer, where...) case reflect.Struct, reflect.Invalid: + elemType := reflectInfo.InputType + originalType := elemType + for elemType.Kind() == reflect.Pointer { + originalType = elemType + + elemType = elemType.Elem() + } + if elemType != nil && (elemType.Implements(reflect.TypeFor[sql.Scanner]()) || originalType.Implements(reflect.TypeFor[sql.Scanner]())) { + if len(m.fields) == 1 { + args := append([]any{m.fields[0]}, where...) + valueArr, err := m.Value(args...) + if err != nil { + return err + } + return valueArr.Scan(pointer) + } + } return m.doStruct(pointer, where...) - + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64, reflect.Bool, reflect.String: + if len(m.fields) != 1 { + return gerror.NewCode( + gcode.CodeInvalidParameter, + fmt.Sprintf("Scan operation failed: expected 1 field, but got %d", len(m.fields)), + ) + } + args := append([]any{m.fields[0]}, where...) + value, err := m.Value(args...) + if err != nil { + return err + } + return value.Scan(pointer) default: return gerror.NewCode( gcode.CodeInvalidParameter, diff --git a/util/gconv/internal/converter/converter_bool.go b/util/gconv/internal/converter/converter_bool.go index 60bceda06eb..136a7716bef 100644 --- a/util/gconv/internal/converter/converter_bool.go +++ b/util/gconv/internal/converter/converter_bool.go @@ -40,6 +40,11 @@ func (c *Converter) Bool(anyInput any) (bool, error) { } return true, nil default: + if rv, ok := value.(reflect.Value); ok { + if rv.IsValid() && rv.CanInterface() { + return c.Bool(rv.Interface()) + } + } if f, ok := value.(localinterface.IBool); ok { return f.Bool(), nil } diff --git a/util/gconv/internal/converter/converter_scan.go b/util/gconv/internal/converter/converter_scan.go index feb7cadcbe8..f6618c0ef97 100644 --- a/util/gconv/internal/converter/converter_scan.go +++ b/util/gconv/internal/converter/converter_scan.go @@ -185,41 +185,53 @@ func (c *Converter) Scan(srcValue any, dstPointer any, option ...ScanOption) (er ) for i := 0; i < srcLen; i++ { srcElem := srcValueReflectValue.Index(i).Interface() + elem := newSlice.Index(i) + + if elem.Kind() == reflect.Pointer { + if elem.IsNil() { + elem.Set(reflect.New(elem.Type().Elem())) + } + elem = elem.Elem() + } + switch dstElemType.Kind() { case reflect.String: v, err := c.String(srcElem) if err != nil && !scanOption.ContinueOnError { return err } - newSlice.Index(i).SetString(v) + elem.SetString(v) case reflect.Int: v, err := c.Int64(srcElem) if err != nil && !scanOption.ContinueOnError { return err } - newSlice.Index(i).SetInt(v) + elem.SetInt(v) case reflect.Int64: v, err := c.Int64(srcElem) if err != nil && !scanOption.ContinueOnError { return err } - newSlice.Index(i).SetInt(v) + elem.SetInt(v) case reflect.Float64: v, err := c.Float64(srcElem) if err != nil && !scanOption.ContinueOnError { return err } - newSlice.Index(i).SetFloat(v) + elem.SetFloat(v) case reflect.Bool: v, err := c.Bool(srcElem) if err != nil && !scanOption.ContinueOnError { return err } - newSlice.Index(i).SetBool(v) + elem.SetBool(v) default: - err = c.Scan( - srcElem, newSlice.Index(i).Addr().Interface(), option..., - ) + target := newSlice.Index(i) + if target.Kind() == reflect.Pointer { + err = c.Scan(srcElem, target.Interface(), option...) + } else { + err = c.Scan(srcElem, target.Addr().Interface(), option...) + } if err != nil && !scanOption.ContinueOnError { return err }