diff --git a/contrib/drivers/mysql/mysql_z_unit_feature_hook_test.go b/contrib/drivers/mysql/mysql_z_unit_feature_hook_test.go index b8be0aa706b..bf69d3d4f87 100644 --- a/contrib/drivers/mysql/mysql_z_unit_feature_hook_test.go +++ b/contrib/drivers/mysql/mysql_z_unit_feature_hook_test.go @@ -8,7 +8,6 @@ package mysql_test import ( "context" - "database/sql" "fmt" "testing" @@ -23,19 +22,18 @@ func Test_Model_Hook_Select(t *testing.T) { defer dropTable(table) gtest.C(t, func(t *gtest.T) { - m := db.Model(table).Hook(gdb.HookHandler{ - Select: func(ctx context.Context, in *gdb.HookSelectInput) (result gdb.Result, err error) { - result, err = in.Next(ctx) + m := db.Model(table).Hook( + gdb.AfterSelect(func(ctx context.Context, in *gdb.HookSelectInput, result gdb.Result, err error) (gdb.Result, error) { if err != nil { - return + return result, err } for i, record := range result { record["test"] = gvar.New(100 + record["id"].Int()) result[i] = record } - return - }, - }) + return result, nil + }), + ) all, err := m.Where(`id > 6`).OrderAsc(`id`).All(ctx) t.AssertNil(err) t.Assert(len(all), 4) @@ -52,16 +50,16 @@ func Test_Model_Hook_Insert(t *testing.T) { defer dropTable(table) gtest.C(t, func(t *gtest.T) { - m := db.Model(table).Hook(gdb.HookHandler{ - Insert: func(ctx context.Context, in *gdb.HookInsertInput) (result sql.Result, err error) { + m := db.Model(table).Hook( + gdb.BeforeInsert(func(ctx context.Context, in *gdb.HookInsertInput) error { for i, item := range in.Data { item["passport"] = fmt.Sprintf(`test_port_%d`, item["id"]) item["nickname"] = fmt.Sprintf(`test_name_%d`, item["id"]) in.Data[i] = item } - return in.Next(ctx) - }, - }) + return nil + }), + ) _, err := m.Data(g.Map{ "id": 1, "nickname": "name_1", @@ -80,8 +78,8 @@ func Test_Model_Hook_Update(t *testing.T) { defer dropTable(table) gtest.C(t, func(t *gtest.T) { - m := db.Model(table).Hook(gdb.HookHandler{ - Update: func(ctx context.Context, in *gdb.HookUpdateInput) (result sql.Result, err error) { + m := db.Model(table).Hook( + gdb.BeforeUpdate(func(ctx context.Context, in *gdb.HookUpdateInput) error { switch value := in.Data.(type) { case gdb.List: for i, data := range value { @@ -96,9 +94,9 @@ func Test_Model_Hook_Update(t *testing.T) { value["nickname"] = `name` in.Data = value } - return in.Next(ctx) - }, - }) + return nil + }), + ) _, err := m.Data(g.Map{ "nickname": "name_1", }).WherePri(1).Update(ctx) @@ -117,13 +115,17 @@ func Test_Model_Hook_Delete(t *testing.T) { defer dropTable(table) gtest.C(t, func(t *gtest.T) { - m := db.Model(table).Hook(gdb.HookHandler{ - Delete: func(ctx context.Context, in *gdb.HookDeleteInput) (result sql.Result, err error) { - return db.Model(table).Data(g.Map{ + m := db.Model(table).Hook( + gdb.BeforeDelete(func(ctx context.Context, in *gdb.HookDeleteInput) error { + origCondition := in.Condition + // Make delete a no-op, then execute the intended update. + in.Condition = "1=0" + _, err := in.Model.Data(g.Map{ "nickname": `deleted`, - }).Where(in.Condition).Update(ctx) - }, - }) + }).Where(origCondition).Update(ctx) + return err + }), + ) _, err := m.Where(1).Delete(ctx) t.AssertNil(err) @@ -134,3 +136,109 @@ func Test_Model_Hook_Delete(t *testing.T) { } }) } + +// Test_Model_Hook_Multiple tests multiple hooks execution order +func Test_Model_Hook_Multiple(t *testing.T) { + table := createInitTable() + defer dropTable(table) + + gtest.C(t, func(t *gtest.T) { + var afterCalls []string + m := db.Model(table). + Hook( + gdb.AfterSelect(func(ctx context.Context, in *gdb.HookSelectInput, result gdb.Result, err error) (gdb.Result, error) { + afterCalls = append(afterCalls, "hook1") + if err != nil { + return result, err + } + for i, record := range result { + record["hook1"] = gvar.New("value1") + result[i] = record + } + return result, nil + }), + ). + Hook( + gdb.AfterSelect(func(ctx context.Context, in *gdb.HookSelectInput, result gdb.Result, err error) (gdb.Result, error) { + afterCalls = append(afterCalls, "hook2") + if err != nil { + return result, err + } + for i, record := range result { + record["hook2"] = gvar.New("value2") + result[i] = record + } + return result, nil + }), + ) + + _, err := m.Where("id", 1).One(ctx) + t.AssertNil(err) + + one, err := m.One(ctx) + t.AssertNil(err) + t.Assert(one["hook1"].String(), "value1") + t.Assert(one["hook2"].String(), "value2") + t.Assert(afterCalls, g.Slice{"hook1", "hook2"}) + }) +} + +// Test_Model_Hook_Error_Abort tests hook returning error aborts operation +func Test_Model_Hook_Error_Abort(t *testing.T) { + table := createInitTable() + defer dropTable(table) + + gtest.C(t, func(t *gtest.T) { + m := db.Model(table).Hook( + gdb.BeforeInsert(func(ctx context.Context, in *gdb.HookInsertInput) error { + // Return error to abort insert. + return fmt.Errorf("hook aborted insert") + }), + ) + + _, err := m.Data(g.Map{ + "passport": "test_abort", + "password": "pass", + "nickname": "name", + }).Insert(ctx) + t.AssertNE(err, nil) + t.Assert(err.Error(), "hook aborted insert") + + // Verify record was not inserted + count, err := db.Model(table).Where("passport", "test_abort").Count(ctx) + t.AssertNil(err) + t.Assert(count, 0) + }) +} + +// Test_Model_Hook_Modify_Data tests hook modifying data before insert +func Test_Model_Hook_Modify_Data(t *testing.T) { + table := createTable() + defer dropTable(table) + + gtest.C(t, func(t *gtest.T) { + m := db.Model(table).Hook( + gdb.BeforeInsert(func(ctx context.Context, in *gdb.HookInsertInput) error { + // Modify all data items + for i := range in.Data { + in.Data[i]["password"] = "encrypted_" + fmt.Sprint(in.Data[i]["password"]) + in.Data[i]["nickname"] = "verified_" + fmt.Sprint(in.Data[i]["nickname"]) + } + return nil + }), + ) + + _, err := m.Data(g.Map{ + "passport": "test_user", + "password": "plain123", + "nickname": "john", + }).Insert(ctx) + t.AssertNil(err) + + // Verify data was modified by hook + one, err := db.Model(table).Where("passport", "test_user").One(ctx) + t.AssertNil(err) + t.Assert(one["password"].String(), "encrypted_plain123") + t.Assert(one["nickname"].String(), "verified_john") + }) +} diff --git a/contrib/drivers/mysql/mysql_z_unit_feature_model_sharding_test.go b/contrib/drivers/mysql/mysql_z_unit_feature_model_sharding_test.go index ce5600758ac..5c6035f660c 100644 --- a/contrib/drivers/mysql/mysql_z_unit_feature_model_sharding_test.go +++ b/contrib/drivers/mysql/mysql_z_unit_feature_model_sharding_test.go @@ -8,7 +8,6 @@ package mysql_test import ( "context" - "database/sql" "fmt" "testing" @@ -70,7 +69,7 @@ func dropShardingDatabase(t *gtest.T) { } func Test_Sharding_Basic(t *testing.T) { - return + t.Skip("disabled by default: requires sharding test databases") gtest.C(t, func(t *gtest.T) { var ( tablePrefix = "user_" @@ -144,7 +143,7 @@ func Test_Sharding_Basic(t *testing.T) { // Test_Sharding_Error tests error cases func Test_Sharding_Error(t *testing.T) { - return + t.Skip("disabled by default: requires sharding test databases") gtest.C(t, func(t *gtest.T) { // Create test databases and tables createShardingDatabase(t) @@ -182,7 +181,7 @@ func Test_Sharding_Error(t *testing.T) { // Test_Sharding_Complex tests complex sharding scenarios func Test_Sharding_Complex(t *testing.T) { - return + t.Skip("disabled by default: requires sharding test databases") gtest.C(t, func(t *gtest.T) { // Create test databases and tables createShardingDatabase(t) @@ -251,24 +250,24 @@ func Test_Model_Sharding_Table_Using_Hook(t *testing.T) { createTable(table2) defer dropTable(table2) - shardingModel := db.Model(table1).Hook(gdb.HookHandler{ - Select: func(ctx context.Context, in *gdb.HookSelectInput) (result gdb.Result, err error) { + shardingModel := db.Model(table1).Hook( + gdb.BeforeSelect(func(ctx context.Context, in *gdb.HookSelectInput) error { in.Table = table2 - return in.Next(ctx) - }, - Insert: func(ctx context.Context, in *gdb.HookInsertInput) (result sql.Result, err error) { + return nil + }), + gdb.BeforeInsert(func(ctx context.Context, in *gdb.HookInsertInput) error { in.Table = table2 - return in.Next(ctx) - }, - Update: func(ctx context.Context, in *gdb.HookUpdateInput) (result sql.Result, err error) { + return nil + }), + gdb.BeforeUpdate(func(ctx context.Context, in *gdb.HookUpdateInput) error { in.Table = table2 - return in.Next(ctx) - }, - Delete: func(ctx context.Context, in *gdb.HookDeleteInput) (result sql.Result, err error) { + return nil + }), + gdb.BeforeDelete(func(ctx context.Context, in *gdb.HookDeleteInput) error { in.Table = table2 - return in.Next(ctx) - }, - }) + return nil + }), + ) gtest.C(t, func(t *gtest.T) { r, err := shardingModel.Data(g.Map{ "id": 1, @@ -359,28 +358,28 @@ func Test_Model_Sharding_Schema_Using_Hook(t *testing.T) { createTableWithDb(db2, table) defer dropTableWithDb(db2, table) - shardingModel := db.Model(table).Hook(gdb.HookHandler{ - Select: func(ctx context.Context, in *gdb.HookSelectInput) (result gdb.Result, err error) { + shardingModel := db.Model(table).Hook( + gdb.BeforeSelect(func(ctx context.Context, in *gdb.HookSelectInput) error { in.Table = table in.Schema = db2.GetSchema() - return in.Next(ctx) - }, - Insert: func(ctx context.Context, in *gdb.HookInsertInput) (result sql.Result, err error) { + return nil + }), + gdb.BeforeInsert(func(ctx context.Context, in *gdb.HookInsertInput) error { in.Table = table in.Schema = db2.GetSchema() - return in.Next(ctx) - }, - Update: func(ctx context.Context, in *gdb.HookUpdateInput) (result sql.Result, err error) { + return nil + }), + gdb.BeforeUpdate(func(ctx context.Context, in *gdb.HookUpdateInput) error { in.Table = table in.Schema = db2.GetSchema() - return in.Next(ctx) - }, - Delete: func(ctx context.Context, in *gdb.HookDeleteInput) (result sql.Result, err error) { + return nil + }), + gdb.BeforeDelete(func(ctx context.Context, in *gdb.HookDeleteInput) error { in.Table = table in.Schema = db2.GetSchema() - return in.Next(ctx) - }, - }) + return nil + }), + ) gtest.C(t, func(t *gtest.T) { r, err := shardingModel.Data(g.Map{ "id": 1, diff --git a/contrib/drivers/mysql/mysql_z_unit_feature_scanlist_test.go b/contrib/drivers/mysql/mysql_z_unit_feature_scanlist_test.go index 9fd21fcb5e5..b46a9d465ec 100644 --- a/contrib/drivers/mysql/mysql_z_unit_feature_scanlist_test.go +++ b/contrib/drivers/mysql/mysql_z_unit_feature_scanlist_test.go @@ -1293,135 +1293,6 @@ CREATE TABLE %s ( err = all.ScanList(ctx, &users, "UserScores", "User", "uid:uid") t.AssertNil(err) }) - return - // Result ScanList with pointer elements and pointer attributes. - gtest.C(t, func(t *gtest.T) { - var users []*Entity - // User - all, err := db.Model(tableUser).Where("uid", g.Slice{3, 4}).Order("uid asc").All(ctx) - t.AssertNil(err) - err = all.ScanList(ctx, &users, "User") - t.AssertNil(err) - t.Assert(len(users), 0) - - // Detail - all, err = db.Model(tableUserDetail).Where("uid", gdb.ListItemValues(users, "User", "Uid")).Order("uid asc").All(ctx) - t.AssertNil(err) - err = all.ScanList(ctx, &users, "UserDetail", "User", "Uid:UID") - t.AssertNil(err) - - // Scores - all, err = db.Model(tableUserScores).Where("uid", gdb.ListItemValues(users, "User", "Uid")).Order("id asc").All(ctx) - t.AssertNil(err) - err = all.ScanList(ctx, &users, "UserScores", "User", "Uid:UID") - t.AssertNil(err) - }) - - // Result ScanList with struct elements and struct attributes. - gtest.C(t, func(t *gtest.T) { - type EntityUser struct { - Uid int `json:"uid"` - Name string `json:"name"` - } - type EntityUserDetail struct { - Uid int `json:"uid"` - Address string `json:"address"` - } - type EntityUserScores struct { - Id int `json:"id"` - Uid int `json:"uid"` - Score int `json:"score"` - } - type Entity struct { - User EntityUser - UserDetail EntityUserDetail - UserScores []EntityUserScores - } - var users []Entity - // User - all, err := db.Model(tableUser).Where("uid", g.Slice{3, 4}).Order("uid asc").All(ctx) - t.AssertNil(err) - err = all.ScanList(ctx, &users, "User") - t.AssertNil(err) - - // Detail - all, err = db.Model(tableUserDetail).Where("uid", gdb.ListItemValues(users, "User", "Uid")).Order("uid asc").All(ctx) - t.AssertNil(err) - err = all.ScanList(ctx, &users, "UserDetail", "User", "uid:UId") - t.AssertNil(err) - - // Scores - all, err = db.Model(tableUserScores).Where("uid", gdb.ListItemValues(users, "User", "Uid")).Order("id asc").All(ctx) - t.AssertNil(err) - err = all.ScanList(ctx, &users, "UserScores", "User", "UId:Uid") - t.AssertNil(err) - }) - - // Result ScanList with pointer elements and struct attributes. - gtest.C(t, func(t *gtest.T) { - type EntityUser struct { - Uid int `json:"uid"` - Name string `json:"name"` - } - type EntityUserDetail struct { - Uid int `json:"uid"` - Address string `json:"address"` - } - type EntityUserScores struct { - Id int `json:"id"` - Uid int `json:"uid"` - Score int `json:"score"` - } - type Entity struct { - User EntityUser - UserDetail EntityUserDetail - UserScores []EntityUserScores - } - var users []*Entity - - // User - all, err := db.Model(tableUser).Where("uid", g.Slice{3, 4}).Order("uid asc").All(ctx) - t.AssertNil(err) - err = all.ScanList(ctx, &users, "User") - t.AssertNil(err) - t.Assert(len(users), 0) - // Detail - all, err = db.Model(tableUserDetail).Where("uid", gdb.ListItemValues(users, "User", "Uid")).Order("uid asc").All(ctx) - t.AssertNil(err) - err = all.ScanList(ctx, &users, "UserDetail", "User", "uid:Uid") - t.AssertNil(err) - - // Scores - all, err = db.Model(tableUserScores).Where("uid", gdb.ListItemValues(users, "User", "Uid")).Order("id asc").All(ctx) - t.AssertNil(err) - err = all.ScanList(ctx, &users, "UserScores", "User", "UID:Uid") - t.AssertNil(err) - }) - - // Model ScanList with pointer elements and pointer attributes. - gtest.C(t, func(t *gtest.T) { - var users []*Entity - // User - err := db.Model(tableUser). - Where("uid", g.Slice{3, 4}). - Order("uid asc"). - ScanList(ctx, &users, "User") - t.AssertNil(err) - // Detail - err = db.Model(tableUserDetail). - Where("uid", gdb.ListItemValues(users, "User", "Uid")). - Order("uid asc"). - ScanList(ctx, &users, "UserDetail", "User", "uid:Uid") - t.AssertNil(err) - // Scores - err = db.Model(tableUserScores). - Where("uid", gdb.ListItemValues(users, "User", "Uid")). - Order("id asc"). - ScanList(ctx, &users, "UserScores", "User", "uid:Uid") - t.AssertNil(err) - - t.Assert(len(users), 0) - }) } func Test_Table_Relation_NoneEqualDataSize(t *testing.T) { diff --git a/contrib/drivers/mysql/mysql_z_unit_issue_test.go b/contrib/drivers/mysql/mysql_z_unit_issue_test.go index 8c5b418b85d..3f9f486518d 100644 --- a/contrib/drivers/mysql/mysql_z_unit_issue_test.go +++ b/contrib/drivers/mysql/mysql_z_unit_issue_test.go @@ -1290,11 +1290,10 @@ func Test_Issue3238(t *testing.T) { gtest.C(t, func(t *gtest.T) { for i := 0; i < 100; i++ { - _, err := db.Model(table).Hook(gdb.HookHandler{ - Select: func(ctx context.Context, in *gdb.HookSelectInput) (result gdb.Result, err error) { - result, err = in.Next(ctx) + _, err := db.Model(table).Hook( + gdb.AfterSelect(func(ctx context.Context, in *gdb.HookSelectInput, result gdb.Result, err error) (gdb.Result, error) { if err != nil { - return + return result, err } var wg sync.WaitGroup for _, record := range result { @@ -1308,9 +1307,8 @@ func Test_Issue3238(t *testing.T) { }(record) } wg.Wait() - return - }, - }, + return result, nil + }), ).All(ctx) t.AssertNil(err) } @@ -1428,27 +1426,26 @@ func Test_Issue3626(t *testing.T) { var ( cacheKey = guid.S() - cacheFunc = func(duration time.Duration) gdb.HookHandler { - return gdb.HookHandler{ - Select: func(ctx context.Context, in *gdb.HookSelectInput) (result gdb.Result, err error) { - get, err := db.GetCache().Get(ctx, cacheKey) - if err == nil && !get.IsEmpty() { - err = get.Scan(&result) - if err == nil { - return result, nil - } - } - result, err = in.Next(ctx) - if err != nil { - return nil, err - } - if result == nil || result.Len() < 1 { - result = make(gdb.Result, 0) + cacheFunc = func(duration time.Duration) gdb.HookDescriptor { + return gdb.AfterSelect(func(ctx context.Context, in *gdb.HookSelectInput, result gdb.Result, err error) (gdb.Result, error) { + // In new hook design, DB select always runs before after hook. + // We still keep cache as an override for returning results. + get, getErr := db.GetCache().Get(ctx, cacheKey) + if getErr == nil && !get.IsEmpty() { + var cached gdb.Result + if scanErr := get.Scan(&cached); scanErr == nil { + return cached, nil } - _ = db.GetCache().Set(ctx, cacheKey, result, duration) - return - }, - } + } + if err != nil { + return result, err + } + if result == nil || result.Len() < 1 { + result = make(gdb.Result, 0) + } + _ = db.GetCache().Set(ctx, cacheKey, result, duration) + return result, nil + }) } ) gtest.C(t, func(t *gtest.T) { @@ -1503,20 +1500,15 @@ func Test_Issue3968(t *testing.T) { defer dropTable(table) gtest.C(t, func(t *gtest.T) { - var hook = gdb.HookHandler{ - Select: func(ctx context.Context, in *gdb.HookSelectInput) (result gdb.Result, err error) { - result, err = in.Next(ctx) - if err != nil { - return nil, err - } - if result != nil { - for i, _ := range result { - result[i]["location"] = gvar.New("ny") - } - } - return - }, - } + var hook = gdb.AfterSelect(func(ctx context.Context, in *gdb.HookSelectInput, result gdb.Result, err error) (gdb.Result, error) { + if err != nil { + return nil, err + } + for i := range result { + result[i]["location"] = gvar.New("ny") + } + return result, nil + }) var ( count int result gdb.Result diff --git a/contrib/drivers/mysql/mysql_z_unit_transaction_test.go b/contrib/drivers/mysql/mysql_z_unit_transaction_test.go index 702e91b17e7..5236f2650b6 100644 --- a/contrib/drivers/mysql/mysql_z_unit_transaction_test.go +++ b/contrib/drivers/mysql/mysql_z_unit_transaction_test.go @@ -743,7 +743,6 @@ func Test_Transaction_Panic(t *testing.T) { }) t.AssertNil(err) panic("error") - return nil }) t.AssertNE(err, nil) @@ -854,7 +853,6 @@ func Test_Transaction_Nested_TX_Transaction_UseTX(t *testing.T) { t.AssertNil(err) panic("error") - return err }) t.AssertNE(err, nil) return nil @@ -916,7 +914,6 @@ func Test_Transaction_Nested_TX_Transaction_UseTX(t *testing.T) { t.AssertNil(err) panic("error") - return err }) t.AssertNE(err, nil) return nil @@ -993,7 +990,6 @@ func Test_Transaction_Nested_TX_Transaction_UseDB(t *testing.T) { // panic makes this transaction rollback. panic("error") - return err }) t.AssertNE(err, nil) return nil @@ -1055,7 +1051,6 @@ func Test_Transaction_Nested_TX_Transaction_UseDB(t *testing.T) { // panic makes this transaction rollback. panic("error") - return err }) t.AssertNE(err, nil) return nil diff --git a/database/gdb/gdb_model_delete.go b/database/gdb/gdb_model_delete.go index ba4f8943c4a..2a9c77a0b47 100644 --- a/database/gdb/gdb_model_delete.go +++ b/database/gdb/gdb_model_delete.go @@ -54,11 +54,8 @@ func (m *Model) Delete(ctx context.Context) (result sql.Result, err error) { ctx, "", fieldNameDelete, fieldTypeDelete, ) in := &HookUpdateInput{ - internalParamHookUpdate: internalParamHookUpdate{ - internalParamHook: internalParamHook{ - link: model.getLink(ctx, true), - }, - handler: model.hookHandler.Update, + internalParamHook: internalParamHook{ + link: model.getLink(ctx, true), }, Model: model, Table: model.tables, @@ -67,15 +64,12 @@ func (m *Model) Delete(ctx context.Context) (result sql.Result, err error) { Condition: conditionStr, Args: append([]interface{}{dataValue}, conditionArgs...), } - return in.Next(ctx) + return m.hookHandler.runUpdate(ctx, in) } in := &HookDeleteInput{ - internalParamHookDelete: internalParamHookDelete{ - internalParamHook: internalParamHook{ - link: model.getLink(ctx, true), - }, - handler: model.hookHandler.Delete, + internalParamHook: internalParamHook{ + link: model.getLink(ctx, true), }, Model: model, Table: model.tables, @@ -83,5 +77,5 @@ func (m *Model) Delete(ctx context.Context) (result sql.Result, err error) { Condition: conditionStr, Args: conditionArgs, } - return in.Next(ctx) + return m.hookHandler.runDelete(ctx, in) } diff --git a/database/gdb/gdb_model_hook.go b/database/gdb/gdb_model_hook.go index fe7e6bc96c1..2af1465e539 100644 --- a/database/gdb/gdb_model_hook.go +++ b/database/gdb/gdb_model_hook.go @@ -17,66 +17,113 @@ import ( ) type ( - HookFuncSelect func(ctx context.Context, in *HookSelectInput) (result Result, err error) - HookFuncInsert func(ctx context.Context, in *HookInsertInput) (result sql.Result, err error) - HookFuncUpdate func(ctx context.Context, in *HookUpdateInput) (result sql.Result, err error) - HookFuncDelete func(ctx context.Context, in *HookDeleteInput) (result sql.Result, err error) + HookBeforeSelect func(ctx context.Context, in *HookSelectInput) error + HookAfterSelect func(ctx context.Context, in *HookSelectInput, result Result, err error) (Result, error) + + HookBeforeInsert func(ctx context.Context, in *HookInsertInput) error + HookAfterInsert func(ctx context.Context, in *HookInsertInput, result sql.Result, err error) (sql.Result, error) + + HookBeforeUpdate func(ctx context.Context, in *HookUpdateInput) error + HookAfterUpdate func(ctx context.Context, in *HookUpdateInput, result sql.Result, err error) (sql.Result, error) + + HookBeforeDelete func(ctx context.Context, in *HookDeleteInput) error + HookAfterDelete func(ctx context.Context, in *HookDeleteInput, result sql.Result, err error) (sql.Result, error) ) // HookHandler manages all supported hook functions for Model. type HookHandler struct { - Select HookFuncSelect - Insert HookFuncInsert - Update HookFuncUpdate - Delete HookFuncDelete + selectBefore []HookBeforeSelect + selectAfter []HookAfterSelect + + insertBefore []HookBeforeInsert + insertAfter []HookAfterInsert + + updateBefore []HookBeforeUpdate + updateAfter []HookAfterUpdate + + deleteBefore []HookBeforeDelete + deleteAfter []HookAfterDelete } -// internalParamHook manages all internal parameters for hook operations. -// The `internal` obviously means you cannot access these parameters outside this package. -type internalParamHook struct { - link Link // Connection object from third party sql driver. - handlerCalled bool // Simple mark for custom handler called, in case of recursive calling. - removedWhere bool // Removed mark for condition string that was removed `WHERE` prefix. - originalTableName *gvar.Var // The original table name. - originalSchemaName *gvar.Var // The original schema name. +type HookType int + +const ( + HookTypeSelect HookType = 1 + HookTypeInsert HookType = 2 + HookTypeUpdate HookType = 3 + HookTypeDelete HookType = 4 +) + +type HookStage int + +const ( + HookStageBefore HookStage = 1 + HookStageAfter HookStage = 2 +) + +type HookDescriptor struct { + Type HookType + Stage HookStage + Handler any } -type internalParamHookSelect struct { - internalParamHook - handler HookFuncSelect +func BeforeSelect(handler HookBeforeSelect) HookDescriptor { + return HookDescriptor{Type: HookTypeSelect, Stage: HookStageBefore, Handler: handler} } -type internalParamHookInsert struct { - internalParamHook - handler HookFuncInsert +func AfterSelect(handler HookAfterSelect) HookDescriptor { + return HookDescriptor{Type: HookTypeSelect, Stage: HookStageAfter, Handler: handler} } -type internalParamHookUpdate struct { - internalParamHook - handler HookFuncUpdate +func BeforeInsert(handler HookBeforeInsert) HookDescriptor { + return HookDescriptor{Type: HookTypeInsert, Stage: HookStageBefore, Handler: handler} } -type internalParamHookDelete struct { - internalParamHook - handler HookFuncDelete +func AfterInsert(handler HookAfterInsert) HookDescriptor { + return HookDescriptor{Type: HookTypeInsert, Stage: HookStageAfter, Handler: handler} +} + +func BeforeUpdate(handler HookBeforeUpdate) HookDescriptor { + return HookDescriptor{Type: HookTypeUpdate, Stage: HookStageBefore, Handler: handler} +} + +func AfterUpdate(handler HookAfterUpdate) HookDescriptor { + return HookDescriptor{Type: HookTypeUpdate, Stage: HookStageAfter, Handler: handler} +} + +func BeforeDelete(handler HookBeforeDelete) HookDescriptor { + return HookDescriptor{Type: HookTypeDelete, Stage: HookStageBefore, Handler: handler} +} + +func AfterDelete(handler HookAfterDelete) HookDescriptor { + return HookDescriptor{Type: HookTypeDelete, Stage: HookStageAfter, Handler: handler} +} + +// internalParamHook manages all internal parameters for hook operations. +// The `internal` obviously means you cannot access these parameters outside this package. +type internalParamHook struct { + link Link // Connection object from third party sql driver. + removedWhere bool // Removed mark for condition string that was removed `WHERE` prefix. + originalTableName *gvar.Var // The original table name. + originalSchemaName *gvar.Var // The original schema name. } // HookSelectInput holds the parameters for select hook operation. // Note that, COUNT statement will also be hooked by this feature, // which is usually not be interesting for upper business hook handler. type HookSelectInput struct { - internalParamHookSelect - Model *Model // Current operation Model. - Table string // The table name that to be used. Update this attribute to change target table name. - Schema string // The schema name that to be used. Update this attribute to change target schema name. - Sql string // The sql string that to be committed. - Args []interface{} // The arguments of sql. - SelectType SelectType // The type of this SELECT operation. + internalParamHook + Model *Model // Current operation Model. + Table string // The table name that to be used. Update this attribute to change target table name. + Schema string // The schema name that to be used. Update this attribute to change target schema name. + Sql string // The sql string that to be committed. + Args []any // The arguments of sql. + SelectType SelectType // The type of this SELECT operation. } // HookInsertInput holds the parameters for insert hook operation. type HookInsertInput struct { - internalParamHookInsert + internalParamHook Model *Model // Current operation Model. Table string // The table name that to be used. Update this attribute to change target table name. Schema string // The schema name that to be used. Update this attribute to change target schema name. @@ -86,23 +133,23 @@ type HookInsertInput struct { // HookUpdateInput holds the parameters for update hook operation. type HookUpdateInput struct { - internalParamHookUpdate - Model *Model // Current operation Model. - Table string // The table name that to be used. Update this attribute to change target table name. - Schema string // The schema name that to be used. Update this attribute to change target schema name. - Data interface{} // Data can be type of: map[string]interface{}/string. You can use type assertion on `Data`. - Condition string // The where condition string for updating. - Args []interface{} // The arguments for sql place-holders. + internalParamHook + Model *Model // Current operation Model. + Table string // The table name that to be used. Update this attribute to change target table name. + Schema string // The schema name that to be used. Update this attribute to change target schema name. + Data any // Data can be type of: map[string]any/string. You can use type assertion on `Data`. + Condition string // The where condition string for updating. + Args []any // The arguments for sql place-holders. } // HookDeleteInput holds the parameters for delete hook operation. type HookDeleteInput struct { - internalParamHookDelete - Model *Model // Current operation Model. - Table string // The table name that to be used. Update this attribute to change target table name. - Schema string // The schema name that to be used. Update this attribute to change target schema name. - Condition string // The where condition string for deleting. - Args []interface{} // The arguments for sql place-holders. + internalParamHook + Model *Model // Current operation Model. + Table string // The table name that to be used. Update this attribute to change target table name. + Schema string // The schema name that to be used. Update this attribute to change target schema name. + Condition string // The where condition string for deleting. + Args []any // The arguments for sql place-holders. } const ( @@ -114,181 +161,375 @@ func (h *internalParamHook) IsTransaction() bool { return h.link.IsTransaction() } -// Next calls the next hook handler. -func (h *HookSelectInput) Next(ctx context.Context) (result Result, err error) { - if h.originalTableName.IsNil() { - h.originalTableName = gvar.New(h.Table) +// Hook sets the hook functions for current model. +// Can be used multiple times without overwriting. +func (m *Model) Hook(descriptors ...HookDescriptor) *Model { + model := m + + for _, descriptor := range descriptors { + model.hookHandler = model.hookHandler.append(descriptor) } - if h.originalSchemaName.IsNil() { - h.originalSchemaName = gvar.New(h.Schema) + + return model +} + +func (h HookHandler) append(descriptor HookDescriptor) HookHandler { + switch descriptor.Type { + case HookTypeSelect: + switch descriptor.Stage { + case HookStageBefore: + handler, ok := descriptor.Handler.(HookBeforeSelect) + if !ok { + panic(fmt.Sprintf("invalid hook handler type for select/before: %T", descriptor.Handler)) + } + h.selectBefore = append(h.selectBefore, handler) + case HookStageAfter: + handler, ok := descriptor.Handler.(HookAfterSelect) + if !ok { + panic(fmt.Sprintf("invalid hook handler type for select/after: %T", descriptor.Handler)) + } + h.selectAfter = append(h.selectAfter, handler) + default: + panic(fmt.Sprintf("invalid hook stage: %d", descriptor.Stage)) + } + + case HookTypeInsert: + switch descriptor.Stage { + case HookStageBefore: + handler, ok := descriptor.Handler.(HookBeforeInsert) + if !ok { + panic(fmt.Sprintf("invalid hook handler type for insert/before: %T", descriptor.Handler)) + } + h.insertBefore = append(h.insertBefore, handler) + case HookStageAfter: + handler, ok := descriptor.Handler.(HookAfterInsert) + if !ok { + panic(fmt.Sprintf("invalid hook handler type for insert/after: %T", descriptor.Handler)) + } + h.insertAfter = append(h.insertAfter, handler) + default: + panic(fmt.Sprintf("invalid hook stage: %d", descriptor.Stage)) + } + + case HookTypeUpdate: + switch descriptor.Stage { + case HookStageBefore: + handler, ok := descriptor.Handler.(HookBeforeUpdate) + if !ok { + panic(fmt.Sprintf("invalid hook handler type for update/before: %T", descriptor.Handler)) + } + h.updateBefore = append(h.updateBefore, handler) + case HookStageAfter: + handler, ok := descriptor.Handler.(HookAfterUpdate) + if !ok { + panic(fmt.Sprintf("invalid hook handler type for update/after: %T", descriptor.Handler)) + } + h.updateAfter = append(h.updateAfter, handler) + default: + panic(fmt.Sprintf("invalid hook stage: %d", descriptor.Stage)) + } + + case HookTypeDelete: + switch descriptor.Stage { + case HookStageBefore: + handler, ok := descriptor.Handler.(HookBeforeDelete) + if !ok { + panic(fmt.Sprintf("invalid hook handler type for delete/before: %T", descriptor.Handler)) + } + h.deleteBefore = append(h.deleteBefore, handler) + case HookStageAfter: + handler, ok := descriptor.Handler.(HookAfterDelete) + if !ok { + panic(fmt.Sprintf("invalid hook handler type for delete/after: %T", descriptor.Handler)) + } + h.deleteAfter = append(h.deleteAfter, handler) + default: + panic(fmt.Sprintf("invalid hook stage: %d", descriptor.Stage)) + } + + default: + panic(fmt.Sprintf("invalid hook type: %d", descriptor.Type)) } + return h +} - // Sharding feature. - h.Schema, err = h.Model.getActualSchema(ctx, h.Schema) - if err != nil { +func (h HookHandler) Clone() HookHandler { + return HookHandler{ + selectBefore: append([]HookBeforeSelect(nil), h.selectBefore...), + selectAfter: append([]HookAfterSelect(nil), h.selectAfter...), + insertBefore: append([]HookBeforeInsert(nil), h.insertBefore...), + insertAfter: append([]HookAfterInsert(nil), h.insertAfter...), + updateBefore: append([]HookBeforeUpdate(nil), h.updateBefore...), + updateAfter: append([]HookAfterUpdate(nil), h.updateAfter...), + deleteBefore: append([]HookBeforeDelete(nil), h.deleteBefore...), + deleteAfter: append([]HookAfterDelete(nil), h.deleteAfter...), + } +} + +func (h HookHandler) runSelect(ctx context.Context, in *HookSelectInput) (result Result, err error) { + in.initOriginalNames() + if err = in.applySharding(ctx); err != nil { return nil, err } - h.Table, err = h.Model.getActualTable(ctx, h.Table) - if err != nil { + for _, before := range h.selectBefore { + if err = before(ctx, in); err != nil { + return nil, err + } + } + if result, err = in.doSelect(ctx); err != nil { return nil, err } + for _, after := range h.selectAfter { + result, err = after(ctx, in, result, err) + if err != nil { + return nil, err + } + } + return result, nil +} - // Custom hook handler call. - if h.handler != nil && !h.handlerCalled { - h.handlerCalled = true - return h.handler(ctx, h) +func (h HookHandler) runInsert(ctx context.Context, in *HookInsertInput) (result sql.Result, err error) { + in.initOriginalNames() + if err = in.applySharding(ctx); err != nil { + return nil, err } - var toBeCommittedSql = h.Sql - // Table change. - if h.Table != h.originalTableName.String() { - toBeCommittedSql, err = gregex.ReplaceStringFuncMatch( - `(?i) FROM ([\S]+)`, - toBeCommittedSql, - func(match []string) string { - charL, charR := h.Model.db.GetChars() - return fmt.Sprintf(` FROM %s%s%s`, charL, h.Table, charR) - }, - ) - if err != nil { - return + for _, before := range h.insertBefore { + if err = before(ctx, in); err != nil { + return nil, err } } - // Schema change. - if h.Schema != "" && h.Schema != h.originalSchemaName.String() { - h.link, err = h.Model.db.GetCore().SlaveLink(h.Schema) + if result, err = in.doInsert(ctx); err != nil { + return nil, err + } + for _, after := range h.insertAfter { + result, err = after(ctx, in, result, err) if err != nil { - return + return nil, err } } - return h.Model.db.DoSelect(ctx, h.link, toBeCommittedSql, h.Args...) + return result, nil } -// Next calls the next hook handler. -func (h *HookInsertInput) Next(ctx context.Context) (result sql.Result, err error) { - if h.originalTableName.IsNil() { - h.originalTableName = gvar.New(h.Table) +func (h HookHandler) runUpdate(ctx context.Context, in *HookUpdateInput) (result sql.Result, err error) { + in.initOriginalNames() + if err = in.applySharding(ctx); err != nil { + return nil, err } - if h.originalSchemaName.IsNil() { - h.originalSchemaName = gvar.New(h.Schema) + in.normalizeWhereForHooks() + for _, before := range h.updateBefore { + if err = before(ctx, in); err != nil { + return nil, err + } } - - // Sharding feature. - h.Schema, err = h.Model.getActualSchema(ctx, h.Schema) - if err != nil { + if result, err = in.doUpdate(ctx); err != nil { return nil, err } - h.Table, err = h.Model.getActualTable(ctx, h.Table) - if err != nil { - return nil, err + for _, after := range h.updateAfter { + result, err = after(ctx, in, result, err) + if err != nil { + return nil, err + } } + return result, nil +} - if h.handler != nil && !h.handlerCalled { - h.handlerCalled = true - return h.handler(ctx, h) +func (h HookHandler) runDelete(ctx context.Context, in *HookDeleteInput) (result sql.Result, err error) { + in.initOriginalNames() + if err = in.applySharding(ctx); err != nil { + return nil, err } - - // No need to handle table change. - - // Schema change. - if h.Schema != "" && h.Schema != h.originalSchemaName.String() { - h.link, err = h.Model.db.GetCore().MasterLink(h.Schema) + in.normalizeWhereForHooks() + for _, before := range h.deleteBefore { + if err = before(ctx, in); err != nil { + return nil, err + } + } + if result, err = in.doDelete(ctx); err != nil { + return nil, err + } + for _, after := range h.deleteAfter { + result, err = after(ctx, in, result, err) if err != nil { - return + return nil, err } } - return h.Model.db.DoInsert(ctx, h.link, h.Table, h.Data, h.Option) + return result, nil } -// Next calls the next hook handler. -func (h *HookUpdateInput) Next(ctx context.Context) (result sql.Result, err error) { +func (h *internalParamHook) initOriginalNames(schema, table string) { if h.originalTableName.IsNil() { - h.originalTableName = gvar.New(h.Table) + h.originalTableName = gvar.New(table) } if h.originalSchemaName.IsNil() { - h.originalSchemaName = gvar.New(h.Schema) + h.originalSchemaName = gvar.New(schema) } +} + +func (in *HookSelectInput) initOriginalNames() { + in.internalParamHook.initOriginalNames(in.Schema, in.Table) +} + +func (in *HookInsertInput) initOriginalNames() { + in.internalParamHook.initOriginalNames(in.Schema, in.Table) +} - // Sharding feature. - h.Schema, err = h.Model.getActualSchema(ctx, h.Schema) +func (in *HookUpdateInput) initOriginalNames() { + in.internalParamHook.initOriginalNames(in.Schema, in.Table) +} + +func (in *HookDeleteInput) initOriginalNames() { + in.internalParamHook.initOriginalNames(in.Schema, in.Table) +} + +func (in *HookSelectInput) applySharding(ctx context.Context) (err error) { + in.Schema, err = in.Model.getActualSchema(ctx, in.Schema) if err != nil { - return nil, err + return err } - h.Table, err = h.Model.getActualTable(ctx, h.Table) + in.Table, err = in.Model.getActualTable(ctx, in.Table) if err != nil { - return nil, err + return err } + return nil +} - if h.handler != nil && !h.handlerCalled { - h.handlerCalled = true - if gstr.HasPrefix(h.Condition, whereKeyInCondition) { - h.removedWhere = true - h.Condition = gstr.TrimLeftStr(h.Condition, whereKeyInCondition) - } - return h.handler(ctx, h) +func (in *HookInsertInput) applySharding(ctx context.Context) (err error) { + in.Schema, err = in.Model.getActualSchema(ctx, in.Schema) + if err != nil { + return err } - if h.removedWhere { - h.Condition = whereKeyInCondition + h.Condition + in.Table, err = in.Model.getActualTable(ctx, in.Table) + if err != nil { + return err } + return nil +} - // No need to handle table change. +func (in *HookUpdateInput) applySharding(ctx context.Context) (err error) { + in.Schema, err = in.Model.getActualSchema(ctx, in.Schema) + if err != nil { + return err + } + in.Table, err = in.Model.getActualTable(ctx, in.Table) + if err != nil { + return err + } + return nil +} - // Schema change. - if h.Schema != "" && h.Schema != h.originalSchemaName.String() { - h.link, err = h.Model.db.GetCore().MasterLink(h.Schema) - if err != nil { - return - } +func (in *HookDeleteInput) applySharding(ctx context.Context) (err error) { + in.Schema, err = in.Model.getActualSchema(ctx, in.Schema) + if err != nil { + return err } - return h.Model.db.DoUpdate(ctx, h.link, h.Table, h.Data, h.Condition, h.Args...) + in.Table, err = in.Model.getActualTable(ctx, in.Table) + if err != nil { + return err + } + return nil } -// Next calls the next hook handler. -func (h *HookDeleteInput) Next(ctx context.Context) (result sql.Result, err error) { - if h.originalTableName.IsNil() { - h.originalTableName = gvar.New(h.Table) +func (in *HookUpdateInput) normalizeWhereForHooks() { + if gstr.HasPrefix(in.Condition, whereKeyInCondition) { + in.removedWhere = true + in.Condition = gstr.TrimLeftStr(in.Condition, whereKeyInCondition) } - if h.originalSchemaName.IsNil() { - h.originalSchemaName = gvar.New(h.Schema) +} + +func (in *HookDeleteInput) normalizeWhereForHooks() { + if gstr.HasPrefix(in.Condition, whereKeyInCondition) { + in.removedWhere = true + in.Condition = gstr.TrimLeftStr(in.Condition, whereKeyInCondition) } +} - // Sharding feature. - h.Schema, err = h.Model.getActualSchema(ctx, h.Schema) - if err != nil { - return nil, err +func (in *HookUpdateInput) restoreWhereForCommit() string { + if in.removedWhere { + return whereKeyInCondition + in.Condition } - h.Table, err = h.Model.getActualTable(ctx, h.Table) - if err != nil { - return nil, err + return in.Condition +} + +func (in *HookDeleteInput) restoreWhereForCommit() string { + if in.removedWhere { + return whereKeyInCondition + in.Condition } + return in.Condition +} - if h.handler != nil && !h.handlerCalled { - h.handlerCalled = true - if gstr.HasPrefix(h.Condition, whereKeyInCondition) { - h.removedWhere = true - h.Condition = gstr.TrimLeftStr(h.Condition, whereKeyInCondition) +func (in *HookSelectInput) doSelect(ctx context.Context) (result Result, err error) { + // Table change. + if in.Table != in.originalTableName.String() { + in.Sql, err = gregex.ReplaceStringFuncMatch( + `(?i) FROM ([\S]+)`, + in.Sql, + func(match []string) string { + charL, charR := in.Model.db.GetChars() + return fmt.Sprintf(` FROM %s%s%s`, charL, in.Table, charR) + }, + ) + if err != nil { + return nil, err } - return h.handler(ctx, h) } - if h.removedWhere { - h.Condition = whereKeyInCondition + h.Condition + + // Schema change. + if in.Schema != "" && in.Schema != in.originalSchemaName.String() { + in.link, err = in.Model.db.GetCore().SlaveLink(in.Schema) + if err != nil { + return nil, err + } + in.Model.db.GetCore().schema = in.Schema + defer func() { + in.Model.db.GetCore().schema = in.originalSchemaName.String() + }() } + return in.Model.db.DoSelect(ctx, in.link, in.Sql, in.Args...) +} - // No need to handle table change. +func (in *HookInsertInput) doInsert(ctx context.Context) (result sql.Result, err error) { + // Schema change. + if in.Schema != "" && in.Schema != in.originalSchemaName.String() { + in.link, err = in.Model.db.GetCore().MasterLink(in.Schema) + if err != nil { + return nil, err + } + in.Model.db.GetCore().schema = in.Schema + defer func() { + in.Model.db.GetCore().schema = in.originalSchemaName.String() + }() + } + return in.Model.db.DoInsert(ctx, in.link, in.Table, in.Data, in.Option) +} +func (in *HookUpdateInput) doUpdate(ctx context.Context) (result sql.Result, err error) { + condition := in.restoreWhereForCommit() // Schema change. - if h.Schema != "" && h.Schema != h.originalSchemaName.String() { - h.link, err = h.Model.db.GetCore().MasterLink(h.Schema) + if in.Schema != "" && in.Schema != in.originalSchemaName.String() { + in.link, err = in.Model.db.GetCore().MasterLink(in.Schema) if err != nil { - return + return nil, err } + in.Model.db.GetCore().schema = in.Schema + defer func() { + in.Model.db.GetCore().schema = in.originalSchemaName.String() + }() } - return h.Model.db.DoDelete(ctx, h.link, h.Table, h.Condition, h.Args...) + return in.Model.db.DoUpdate(ctx, in.link, in.Table, in.Data, condition, in.Args...) } -// Hook sets the hook functions for current model. -func (m *Model) Hook(hook HookHandler) *Model { - return m.Handler(func(ctx context.Context, model *Model) *Model { - model.hookHandler = hook - return model - }) +func (in *HookDeleteInput) doDelete(ctx context.Context) (result sql.Result, err error) { + condition := in.restoreWhereForCommit() + // Schema change. + if in.Schema != "" && in.Schema != in.originalSchemaName.String() { + in.link, err = in.Model.db.GetCore().MasterLink(in.Schema) + if err != nil { + return nil, err + } + in.Model.db.GetCore().schema = in.Schema + defer func() { + in.Model.db.GetCore().schema = in.originalSchemaName.String() + }() + } + return in.Model.db.DoDelete(ctx, in.link, in.Table, condition, in.Args...) } diff --git a/database/gdb/gdb_model_insert.go b/database/gdb/gdb_model_insert.go index 745bc4fac21..42c32a4826b 100644 --- a/database/gdb/gdb_model_insert.go +++ b/database/gdb/gdb_model_insert.go @@ -316,11 +316,8 @@ func (m *Model) doInsertWithOption(ctx context.Context, insertOption InsertOptio } in := &HookInsertInput{ - internalParamHookInsert: internalParamHookInsert{ - internalParamHook: internalParamHook{ - link: model.getLink(ctx, true), - }, - handler: model.hookHandler.Insert, + internalParamHook: internalParamHook{ + link: model.getLink(ctx, true), }, Model: model, Table: model.tables, @@ -328,7 +325,7 @@ func (m *Model) doInsertWithOption(ctx context.Context, insertOption InsertOptio Data: list, Option: doInsertOption, } - return in.Next(ctx) + return m.hookHandler.runInsert(ctx, in) } func (m *Model) formatDoInsertOption(insertOption InsertOption, columnNames []string) (option DoInsertOption, err error) { diff --git a/database/gdb/gdb_model_select.go b/database/gdb/gdb_model_select.go index af2c742db3f..e7bb309e01f 100644 --- a/database/gdb/gdb_model_select.go +++ b/database/gdb/gdb_model_select.go @@ -599,11 +599,8 @@ func (m *Model) doGetAllBySql( return } in := &HookSelectInput{ - internalParamHookSelect: internalParamHookSelect{ - internalParamHook: internalParamHook{ - link: m.getLink(ctx, false), - }, - handler: m.hookHandler.Select, + internalParamHook: internalParamHook{ + link: m.getLink(ctx, false), }, Model: m, Table: m.tables, @@ -612,7 +609,7 @@ func (m *Model) doGetAllBySql( Args: m.mergeArguments(args), SelectType: selectType, } - if result, err = in.Next(ctx); err != nil { + if result, err = m.hookHandler.runSelect(ctx, in); err != nil { return } diff --git a/database/gdb/gdb_model_update.go b/database/gdb/gdb_model_update.go index ad8ade20b87..3a313c6aa70 100644 --- a/database/gdb/gdb_model_update.go +++ b/database/gdb/gdb_model_update.go @@ -90,11 +90,8 @@ func (m *Model) Update(ctx context.Context) (result sql.Result, err error) { } in := &HookUpdateInput{ - internalParamHookUpdate: internalParamHookUpdate{ - internalParamHook: internalParamHook{ - link: model.getLink(ctx, true), - }, - handler: model.hookHandler.Update, + internalParamHook: internalParamHook{ + link: model.getLink(ctx, true), }, Model: model, Table: model.tables, @@ -103,7 +100,7 @@ func (m *Model) Update(ctx context.Context) (result sql.Result, err error) { Condition: conditionStr, Args: model.mergeArguments(conditionArgs), } - return in.Next(ctx) + return m.hookHandler.runUpdate(ctx, in) } // UpdateAndGetAffected performs update statement and returns the affected rows number. diff --git a/database/gdb/gdb_model_with.go b/database/gdb/gdb_model_with.go index 9e991b5ce93..41868b65e6c 100644 --- a/database/gdb/gdb_model_with.go +++ b/database/gdb/gdb_model_with.go @@ -161,7 +161,8 @@ func (m *Model) doWithScanStruct(ctx context.Context, pointer any) error { } } // Recursively with feature checks. - model = m.db.With(field.Value).Hook(m.hookHandler) + model = m.db.With(field.Value) + model.hookHandler = m.hookHandler.Clone() if m.withAll { model = model.WithAll() } else { @@ -287,7 +288,8 @@ func (m *Model) doWithScanStructs(ctx context.Context, pointer any) error { } } // Recursively with feature checks. - model = m.db.With(field.Value).Hook(m.hookHandler) + model = m.db.With(field.Value) + model.hookHandler = m.hookHandler.Clone() if m.withAll { model = model.WithAll() } else { diff --git a/database/gdb/gdb_z_mysql_internal_test.go b/database/gdb/gdb_z_mysql_internal_test.go index b8130bdd51f..48d5f3b6e6f 100644 --- a/database/gdb/gdb_z_mysql_internal_test.go +++ b/database/gdb/gdb_z_mysql_internal_test.go @@ -7,6 +7,7 @@ package gdb import ( + "context" "fmt" "testing" @@ -57,6 +58,70 @@ func Test_HookSelect_Regex(t *testing.T) { }) } +func Test_Multiple_Hooks(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + var ( + beforeCalls []string + afterCalls []string + ) + + model := (&Model{}). + // One Hook call registers multiple handlers, in order. + Hook( + BeforeSelect(func(ctx context.Context, in *HookSelectInput) error { + beforeCalls = append(beforeCalls, "before1") + return nil + }), + BeforeSelect(func(ctx context.Context, in *HookSelectInput) error { + beforeCalls = append(beforeCalls, "before2") + return nil + }), + AfterSelect(func(ctx context.Context, in *HookSelectInput, result Result, err error) (Result, error) { + afterCalls = append(afterCalls, "after1") + return result, err + }), + ). + // Chained Hook calls append handlers in chain order. + Hook( + BeforeSelect(func(ctx context.Context, in *HookSelectInput) error { + beforeCalls = append(beforeCalls, "before3") + return nil + }), + AfterSelect(func(ctx context.Context, in *HookSelectInput, result Result, err error) (Result, error) { + afterCalls = append(afterCalls, "after2") + return result, err + }), + ). + Hook( + BeforeSelect(func(ctx context.Context, in *HookSelectInput) error { + beforeCalls = append(beforeCalls, "before4") + return nil + }), + AfterSelect(func(ctx context.Context, in *HookSelectInput, result Result, err error) (Result, error) { + afterCalls = append(afterCalls, "after3") + return result, err + }), + ) + + // Execute registered handlers without touching DB. + in := &HookSelectInput{Model: &Model{}} + for _, h := range model.hookHandler.selectBefore { + t.AssertNil(h(context.Background(), in)) + } + var ( + result Result = nil + err error = nil + ) + for _, h := range model.hookHandler.selectAfter { + result, err = h(context.Background(), in, result, err) + t.AssertNil(err) + } + + t.Assert(beforeCalls, []string{"before1", "before2", "before3", "before4"}) + t.Assert(afterCalls, []string{"after1", "after2", "after3"}) + }) +} + func Test_parseConfigNodeLink_WithType(t *testing.T) { gtest.C(t, func(t *gtest.T) { node := &ConfigNode{