diff --git a/contrib/drivers/mysql/testdata/with_tpl_detail_meta_soft_delete.sql b/contrib/drivers/mysql/testdata/with_tpl_detail_meta_soft_delete.sql new file mode 100644 index 00000000000..4209e74c7b7 --- /dev/null +++ b/contrib/drivers/mysql/testdata/with_tpl_detail_meta_soft_delete.sql @@ -0,0 +1,9 @@ +CREATE TABLE IF NOT EXISTS %s ( + id int(10) unsigned NOT NULL AUTO_INCREMENT, + detail_id int(10) unsigned NOT NULL, + meta_key varchar(50) NOT NULL, + meta_value varchar(100) NOT NULL, + sort_order int(10) unsigned NOT NULL DEFAULT 0, + deleted_at datetime default NULL, + PRIMARY KEY (id) +) ENGINE=InnoDB DEFAULT CHARSET=utf8; diff --git a/contrib/drivers/mysql/testdata/with_tpl_score_details_soft_delete.sql b/contrib/drivers/mysql/testdata/with_tpl_score_details_soft_delete.sql new file mode 100644 index 00000000000..f61319ea6a8 --- /dev/null +++ b/contrib/drivers/mysql/testdata/with_tpl_score_details_soft_delete.sql @@ -0,0 +1,8 @@ +CREATE TABLE IF NOT EXISTS %s ( + id int(10) unsigned NOT NULL AUTO_INCREMENT, + score_id int(10) unsigned NOT NULL, + detail_info varchar(100) NOT NULL, + rank int(10) unsigned NOT NULL DEFAULT 0, + deleted_at datetime default NULL, + PRIMARY KEY (id) +) ENGINE=InnoDB DEFAULT CHARSET=utf8; diff --git a/contrib/drivers/mysql/testdata/with_tpl_user_detail.sql b/contrib/drivers/mysql/testdata/with_tpl_user_detail.sql index 54bbccd090c..56fad63505d 100644 --- a/contrib/drivers/mysql/testdata/with_tpl_user_detail.sql +++ b/contrib/drivers/mysql/testdata/with_tpl_user_detail.sql @@ -1,5 +1,6 @@ CREATE TABLE IF NOT EXISTS %s ( uid int(10) unsigned NOT NULL AUTO_INCREMENT, address varchar(45) NOT NULL, + deleted_at datetime default NULL, PRIMARY KEY (uid) ) ENGINE=InnoDB DEFAULT CHARSET=utf8; \ No newline at end of file diff --git a/contrib/drivers/mysql/testdata/with_tpl_user_scores_soft_delete.sql b/contrib/drivers/mysql/testdata/with_tpl_user_scores_soft_delete.sql new file mode 100644 index 00000000000..c67229dd9f4 --- /dev/null +++ b/contrib/drivers/mysql/testdata/with_tpl_user_scores_soft_delete.sql @@ -0,0 +1,8 @@ +CREATE TABLE IF NOT EXISTS %s ( + id int(10) unsigned NOT NULL AUTO_INCREMENT, + uid int(10) unsigned NOT NULL, + score int(10) unsigned NOT NULL, + priority int(10) unsigned NOT NULL DEFAULT 0, + deleted_at datetime default NULL, + PRIMARY KEY (id) +) ENGINE=InnoDB DEFAULT CHARSET=utf8; diff --git a/contrib/drivers/mysql/testdata/with_tpl_user_soft_delete.sql b/contrib/drivers/mysql/testdata/with_tpl_user_soft_delete.sql new file mode 100644 index 00000000000..6733d79f831 --- /dev/null +++ b/contrib/drivers/mysql/testdata/with_tpl_user_soft_delete.sql @@ -0,0 +1,7 @@ +CREATE TABLE IF NOT EXISTS %s ( + id int(10) unsigned NOT NULL AUTO_INCREMENT, + name varchar(45) NOT NULL, + status int(10) unsigned NOT NULL DEFAULT 1, + deleted_at datetime default NULL, + PRIMARY KEY (id) +) ENGINE=InnoDB DEFAULT CHARSET=utf8; diff --git a/database/gdb/gdb_func.go b/database/gdb/gdb_func.go index 437c981f8fe..809918c6262 100644 --- a/database/gdb/gdb_func.go +++ b/database/gdb/gdb_func.go @@ -66,6 +66,9 @@ const ( OrmTagForWithOrder = "order" OrmTagForWithUnscoped = "unscoped" OrmTagForDo = "do" + OrmTagForChunkName = "chunkName" + OrmTagForChunkSize = "chunkSize" + OrmTagForChunkMinRows = "chunkMinRows" ) var ( diff --git a/database/gdb/gdb_model.go b/database/gdb/gdb_model.go index 39b9fed3580..ee4c48dafba 100644 --- a/database/gdb/gdb_model.go +++ b/database/gdb/gdb_model.go @@ -17,45 +17,56 @@ import ( // Model is core struct implementing the DAO for ORM. type Model struct { - db DB // Underlying DB interface. - tx TX // Underlying TX interface. - rawSql string // rawSql is the raw SQL string which marks a raw SQL based Model not a table based Model. - schema string // Custom database schema. - linkType int // Mark for operation on master or slave. - tablesInit string // Table names when model initialization. - tables string // Operation table names, which can be more than one table names and aliases, like: "user", "user u", "user u, user_detail ud". - fields []any // Operation fields, multiple fields joined using char ','. - fieldsEx []any // Excluded operation fields, it here uses slice instead of string type for quick filtering. - withArray []any // Arguments for With feature. - withAll bool // Enable model association operations on all objects that have "with" tag in the struct. - extraArgs []any // Extra custom arguments for sql, which are prepended to the arguments before sql committed to underlying driver. - whereBuilder *WhereBuilder // Condition builder for where operation. - groupBy string // Used for "group by" statement. - orderBy string // Used for "order by" statement. - having []any // Used for "having..." statement. - start int // Used for "select ... start, limit ..." statement. - limit int // Used for "select ... start, limit ..." statement. - option int // Option for extra operation features. - offset int // Offset statement for some databases grammar. - partition string // Partition table partition name. - data any // Data for operation, which can be type of map/[]map/struct/*struct/string, etc. - batch int // Batch number for batch Insert/Replace/Save operations. - filter bool // Filter data and where key-value pairs according to the fields of the table. - distinct string // Force the query to only return distinct results. - lockInfo string // Lock for update or in shared lock. - cacheEnabled bool // Enable sql result cache feature, which is mainly for indicating cache duration(especially 0) usage. - cacheOption CacheOption // Cache option for query statement. - pageCacheOption []CacheOption // Cache option for paging query statement. - hookHandler HookHandler // Hook functions for model hook feature. - unscoped bool // Disables soft deleting features when select/delete operations. - safe bool // If true, it clones and returns a new model object whenever operation done; or else it changes the attribute of current model. - onDuplicate any // onDuplicate is used for on Upsert clause. - onDuplicateEx any // onDuplicateEx is used for excluding some columns on Upsert clause. - onConflict any // onConflict is used for conflict keys on Upsert clause. - tableAliasMap map[string]string // Table alias to true table name, usually used in join statements. - softTimeOption SoftTimeOption // SoftTimeOption is the option to customize soft time feature for Model. - shardingConfig ShardingConfig // ShardingConfig for database/table sharding feature. - shardingValue any // Sharding value for sharding feature. + db DB // Underlying DB interface. + tx TX // Underlying TX interface. + rawSql string // rawSql is the raw SQL string which marks a raw SQL based Model not a table based Model. + schema string // Custom database schema. + linkType int // Mark for operation on master or slave. + tablesInit string // Table names when model initialization. + tables string // Operation table names, which can be more than one table names and aliases, like: "user", "user u", "user u, user_detail ud". + fields []any // Operation fields, multiple fields joined using char ','. + fieldsEx []any // Excluded operation fields, it here uses slice instead of string type for quick filtering. + withArray []any // Arguments for With feature. + withAll bool // Enable model association operations on all objects that have "with" tag in the struct. + extraArgs []any // Extra custom arguments for sql, which are prepended to the arguments before sql committed to underlying driver. + whereBuilder *WhereBuilder // Condition builder for where operation. + groupBy string // Used for "group by" statement. + orderBy string // Used for "order by" statement. + having []any // Used for "having..." statement. + start int // Used for "select ... start, limit ..." statement. + limit int // Used for "select ... start, limit ..." statement. + option int // Option for extra operation features. + offset int // Offset statement for some databases grammar. + partition string // Partition table partition name. + data any // Data for operation, which can be type of map/[]map/struct/*struct/string, etc. + batch int // Batch number for batch Insert/Replace/Save operations. + filter bool // Filter data and where key-value pairs according to the fields of the table. + distinct string // Force the query to only return distinct results. + lockInfo string // Lock for update or in shared lock. + cacheEnabled bool // Enable sql result cache feature, which is mainly for indicating cache duration(especially 0) usage. + cacheOption CacheOption // Cache option for query statement. + pageCacheOption []CacheOption // Cache option for paging query statement. + hookHandler HookHandler // Hook functions for model hook feature. + unscoped bool // Disables soft deleting features when select/delete operations. + safe bool // If true, it clones and returns a new model object whenever operation done; or else it changes the attribute of current model. + onDuplicate any // onDuplicate is used for on Upsert clause. + onDuplicateEx any // onDuplicateEx is used for excluding some columns on Upsert clause. + onConflict any // onConflict is used for conflict keys on Upsert clause. + tableAliasMap map[string]string // Table alias to true table name, usually used in join statements. + softTimeOption SoftTimeOption // SoftTimeOption is the option to customize soft time feature for Model. + shardingConfig ShardingConfig // ShardingConfig for database/table sharding feature. + shardingValue any // Sharding value for sharding feature. + withOptions map[ChunkName]*WithOption // withOptions is the batch association configuration indexed by chunkName. +} + +// ChunkName is a type alias for chunk group identifier used in WithOptions. +type ChunkName = string + +// WithOption is the configuration for batch association operations. +type WithOption struct { + ChunkName ChunkName // ChunkName is used to match the chunkName in the tag (for grouping fields). + ChunkSize int // ChunkSize is the size of each chunk (0 means no chunking, -1 means use default). + ChunkMinRows int // ChunkMinRows is the minimum number of rows to trigger chunking (0 means always chunk, -1 means use default). } // ModelHandler is a function that handles given Model and returns a new Model that is custom modified. @@ -303,6 +314,13 @@ func (m *Model) Clone() *Model { newModel.having = make([]any, n) copy(newModel.having, m.having) } + if n := len(m.withOptions); n > 0 { + newModel.withOptions = make(map[ChunkName]*WithOption, n) + for k, v := range m.withOptions { + optCopy := *v + newModel.withOptions[k] = &optCopy + } + } return newModel } diff --git a/database/gdb/gdb_model_struct_cache.go b/database/gdb/gdb_model_struct_cache.go new file mode 100644 index 00000000000..caa5a40dad6 --- /dev/null +++ b/database/gdb/gdb_model_struct_cache.go @@ -0,0 +1,85 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package gdb + +import ( + "reflect" + + "github.com/gogf/gf/v2/container/gmap" + "github.com/gogf/gf/v2/os/gstructs" +) + +// Global cache for struct information using gmap.KVMap. +// This cache is shared by both single-struct and batch With operations. +// Uses NewKVMapWithChecker to handle typed nil issue for pointer values. +var ( + structInfoChecker = func(v *modelStructCacheItem) bool { return v == nil } + structInfoCache = gmap.NewKVMapWithChecker[reflect.Type, *modelStructCacheItem](structInfoChecker, true) +) + +// modelStructCacheItem holds cached struct information. +// Only caches static information (field metadata from gstructs). +// Tag parsing is done dynamically to maintain flexibility. +// +// IMPORTANT: The Field.Value in cached fields is a zero-value instance. +// Only use Field.Field (StructField), Field.Type(), Field.Name(), Field.Tag() etc. +// DO NOT use Field.Value.Interface() to get actual values from real instances. +type modelStructCacheItem struct { + fields []gstructs.Field // All fields from gstructs (static metadata) +} + +// buildStructCacheItem creates a modelStructCacheItem from a struct type. +// It extracts field information using gstructs and builds the fields slice. +func buildStructCacheItem(structType reflect.Type) (*modelStructCacheItem, error) { + // Use gstructs to get field information + fieldMap, err := gstructs.FieldMap(gstructs.FieldMapInput{ + Pointer: reflect.New(structType).Interface(), + RecursiveOption: gstructs.RecursiveOptionEmbeddedNoTag, + }) + if err != nil { + return nil, err + } + + // Build cache item with fields slice + info := &modelStructCacheItem{ + fields: make([]gstructs.Field, 0, len(fieldMap)), + } + for _, field := range fieldMap { + info.fields = append(info.fields, field) + } + + return info, nil +} + +// getCachedStructInfo gets or creates cached struct information. +// It uses gmap.KVMap's GetOrSetFuncLock for thread-safe lazy initialization. +// This function is used by both single-struct and batch With operations. +func getCachedStructInfo(structType reflect.Type) (*modelStructCacheItem, error) { + // Use GetOrSetFuncLock to ensure thread-safe lazy initialization + // The function is only executed once per key, even under concurrent access + cached := structInfoCache.GetOrSetFuncLock(structType, func() *modelStructCacheItem { + info, err := buildStructCacheItem(structType) + if err != nil { + return nil // Return nil on error, will not be cached + } + return info + }) + + // Check if the cached value is nil (error case) + if cached == nil { + // Re-execute to get the actual error + return buildStructCacheItem(structType) + } + + return cached, nil +} + +// ClearModelStructCache clears the model struct cache. +// This is usually not needed unless you're dynamically loading many struct types. +func ClearModelStructCache() { + structInfoCache.Clear() +} diff --git a/database/gdb/gdb_model_with.go b/database/gdb/gdb_model_with.go index eaea5300b4d..7f59148925a 100644 --- a/database/gdb/gdb_model_with.go +++ b/database/gdb/gdb_model_with.go @@ -8,6 +8,7 @@ package gdb import ( "database/sql" + "errors" "reflect" "github.com/gogf/gf/v2/errors/gcode" @@ -15,7 +16,7 @@ import ( "github.com/gogf/gf/v2/internal/utils" "github.com/gogf/gf/v2/os/gstructs" "github.com/gogf/gf/v2/text/gstr" - "github.com/gogf/gf/v2/util/gutil" + "github.com/gogf/gf/v2/util/gconv" ) // With creates and returns an ORM model based on metadata of given object. @@ -65,6 +66,28 @@ func (m *Model) WithAll() *Model { return model } +// WithOptions sets the batch association configuration options. +// It matches fields by chunkName and allows runtime override of chunk settings. +// Multiple options can be provided to configure different chunkName groups. +func (m *Model) WithOptions(options ...WithOption) *Model { + model := m.getModel() + if model.withOptions == nil { + model.withOptions = make(map[ChunkName]*WithOption) + } + + for _, opt := range options { + // Skip empty chunkName + if opt.ChunkName == "" { + continue + } + // Store a copy of the option + optCopy := opt + model.withOptions[opt.ChunkName] = &optCopy + } + + return model +} + // doWithScanStruct handles model association operations feature for single struct. func (m *Model) doWithScanStruct(pointer any) error { if len(m.withArray) == 0 && !m.withAll { @@ -103,21 +126,21 @@ func (m *Model) doWithScanStruct(pointer any) error { } for _, field := range currentStructFieldMap { var ( - fieldTypeStr = gstr.TrimAll(field.Type().String(), "*[]") - parsedTagOutput = m.parseWithTagInFieldStruct(field) + fieldTypeStr = gstr.TrimAll(field.Type().String(), "*[]") + withTag = parseWithTag(field) ) - if parsedTagOutput.With == "" { + if withTag.With == "" { continue } // It just handlers "with" type attribute struct, so it ignores other struct types. if !m.withAll && !gstr.InArray(allowedTypeStrArray, fieldTypeStr) { continue } - array := gstr.SplitAndTrim(parsedTagOutput.With, "=") + array := gstr.SplitAndTrim(withTag.With, "=") if len(array) == 1 { // It also supports using only one column name // if both tables associates using the same column name. - array = append(array, parsedTagOutput.With) + array = append(array, withTag.With) } var ( model *Model @@ -137,7 +160,7 @@ func (m *Model) doWithScanStruct(pointer any) error { return gerror.NewCodef( gcode.CodeInvalidParameter, `cannot find the target related value of name "%s" in with tag "%s" for attribute "%s.%s"`, - relatedTargetName, parsedTagOutput.With, reflect.TypeOf(pointer).Elem(), field.Name(), + relatedTargetName, withTag.With, reflect.TypeOf(pointer).Elem(), field.Name(), ) } bindToReflectValue := field.Value @@ -163,13 +186,13 @@ func (m *Model) doWithScanStruct(pointer any) error { } else { model = model.With(m.withArray...) } - if parsedTagOutput.Where != "" { - model = model.Where(parsedTagOutput.Where) + if withTag.Where != "" { + model = model.Where(withTag.Where) } - if parsedTagOutput.Order != "" { - model = model.Order(parsedTagOutput.Order) + if withTag.Order != "" { + model = model.Order(withTag.Order) } - if parsedTagOutput.Unscoped == "true" { + if withTag.Unscoped == "true" { model = model.Unscoped() } // With cache feature. @@ -180,7 +203,7 @@ func (m *Model) doWithScanStruct(pointer any) error { Where(relatedSourceName, relatedTargetValue). Scan(bindToReflectValue) // It ignores sql.ErrNoRows in with feature. - if err != nil && err != sql.ErrNoRows { + if err != nil && !errors.Is(err, sql.ErrNoRows) { return err } } @@ -196,154 +219,80 @@ func (m *Model) doWithScanStructs(pointer any) error { if v, ok := pointer.(reflect.Value); ok { pointer = v.Interface() } - - var ( - err error - allowedTypeStrArray = make([]string, 0) - ) - currentStructFieldMap, err := gstructs.FieldMap(gstructs.FieldMapInput{ - Pointer: pointer, - PriorityTagArray: nil, - RecursiveOption: gstructs.RecursiveOptionEmbeddedNoTag, - }) - if err != nil { - return err - } - // It checks the with array and automatically calls the ScanList to complete association querying. - if !m.withAll { - for _, field := range currentStructFieldMap { - for _, withItem := range m.withArray { - withItemReflectValueType, err := gstructs.StructType(withItem) - if err != nil { - return err - } - var ( - fieldTypeStr = gstr.TrimAll(field.Type().String(), "*[]") - withItemReflectValueTypeStr = gstr.TrimAll(withItemReflectValueType.String(), "*[]") - ) - // It does select operation if the field type is in the specified with type array. - if gstr.Compare(fieldTypeStr, withItemReflectValueTypeStr) == 0 { - allowedTypeStrArray = append(allowedTypeStrArray, fieldTypeStr) - } - } - } - } - - for fieldName, field := range currentStructFieldMap { - var ( - fieldTypeStr = gstr.TrimAll(field.Type().String(), "*[]") - parsedTagOutput = m.parseWithTagInFieldStruct(field) - ) - if parsedTagOutput.With == "" { - continue - } - if !m.withAll && !gstr.InArray(allowedTypeStrArray, fieldTypeStr) { - continue - } - array := gstr.SplitAndTrim(parsedTagOutput.With, "=") - if len(array) == 1 { - // It supports using only one column name - // if both tables associates using the same column name. - array = append(array, parsedTagOutput.With) - } - var ( - model *Model - fieldKeys []string - relatedSourceName = array[0] - relatedTargetName = array[1] - relatedTargetValue any - ) - // Find the value slice of related attribute from `pointer`. - for attributeName := range currentStructFieldMap { - if utils.EqualFoldWithoutChars(attributeName, relatedTargetName) { - relatedTargetValue = ListItemValuesUnique(pointer, attributeName) - break - } - } - if relatedTargetValue == nil { - return gerror.NewCodef( - gcode.CodeInvalidParameter, - `cannot find the related value for attribute name "%s" of with tag "%s"`, - relatedTargetName, parsedTagOutput.With, - ) - } - // If related value is empty, it does nothing but just returns. - if gutil.IsEmpty(relatedTargetValue) { - return nil - } - if structFields, err := gstructs.Fields(gstructs.FieldsInput{ - Pointer: field.Value, - RecursiveOption: gstructs.RecursiveOptionEmbeddedNoTag, - }); err != nil { - return err - } else { - fieldKeys = make([]string, len(structFields)) - for i, field := range structFields { - fieldKeys[i] = field.Name() - } - } - // Recursively with feature checks. - model = m.db.With(field.Value).Hook(m.hookHandler) - if m.withAll { - model = model.WithAll() - } else { - model = model.With(m.withArray...) - } - if parsedTagOutput.Where != "" { - model = model.Where(parsedTagOutput.Where) - } - if parsedTagOutput.Order != "" { - model = model.Order(parsedTagOutput.Order) - } - if parsedTagOutput.Unscoped == "true" { - model = model.Unscoped() - } - // With cache feature. - if m.cacheEnabled && m.cacheOption.Name == "" { - model = model.Cache(m.cacheOption) - } - err = model.Fields(fieldKeys). - Where(relatedSourceName, relatedTargetValue). - ScanList(pointer, fieldName, parsedTagOutput.With) - // It ignores sql.ErrNoRows in with feature. - if err != nil && err != sql.ErrNoRows { - return err - } - } - return nil + return m.doBatchWithScan(pointer) } -type parseWithTagInFieldStructOutput struct { +// withTagConfig holds the basic ORM tag configuration for a relation field. +type withTagConfig struct { With string Where string Order string Unscoped string } -func (m *Model) parseWithTagInFieldStruct(field gstructs.Field) (output parseWithTagInFieldStructOutput) { +// chunkTagConfig holds the chunk-related ORM tag configuration for a relation field. +type chunkTagConfig struct { + ChunkName string + ChunkSize int + ChunkMinRows int + Chunked bool +} + +// parseWithTag parses the basic ORM tag configuration (with, where, order, unscoped) from a struct field. +func parseWithTag(field gstructs.Field) withTagConfig { + ormTag := field.Tag(OrmTagForStruct) + data := parseOrmTagData(ormTag) + return withTagConfig{ + With: data[OrmTagForWith], + Where: data[OrmTagForWithWhere], + Order: data[OrmTagForWithOrder], + Unscoped: data[OrmTagForWithUnscoped], + } +} + +// parseChunkTag parses the chunk-related ORM tag configuration from a struct field. +func parseChunkTag(field gstructs.Field) chunkTagConfig { + ormTag := field.Tag(OrmTagForStruct) + data := parseOrmTagData(ormTag) + + chunkName := data[OrmTagForChunkName] + _, ifChunkSize := data[OrmTagForChunkSize] + _, ifChunkMinRows := data[OrmTagForChunkMinRows] + chunkSize := gconv.Int(data[OrmTagForChunkSize]) + chunkMinRows := gconv.Int(data[OrmTagForChunkMinRows]) + + return chunkTagConfig{ + ChunkName: chunkName, + ChunkSize: chunkSize, + ChunkMinRows: chunkMinRows, + Chunked: ifChunkSize && ifChunkMinRows && chunkSize > 0 && chunkMinRows > 0, + } +} + +// parseOrmTagData parses the raw ORM tag string into a key-value map. +func parseOrmTagData(ormTag string) map[string]string { var ( - ormTag = field.Tag(OrmTagForStruct) - data = make(map[string]string) - array []string - key string + data = make(map[string]string) + array []string + key string ) for _, v := range gstr.SplitAndTrim(ormTag, ",") { + v = gstr.Trim(v) + if v == "" { + continue + } array = gstr.Split(v, ":") if len(array) == 2 { key = array[0] data[key] = gstr.Trim(array[1]) } else { - if key == OrmTagForWithOrder { - // supporting multiple order fields + switch key { + case OrmTagForWithOrder: data[key] += "," + gstr.Trim(v) - } else { + default: data[key] += " " + gstr.Trim(v) } } } - output.With = data[OrmTagForWith] - output.Where = data[OrmTagForWithWhere] - output.Order = data[OrmTagForWithOrder] - output.Unscoped = data[OrmTagForWithUnscoped] - return + return data } diff --git a/database/gdb/gdb_model_with_batch.go b/database/gdb/gdb_model_with_batch.go new file mode 100644 index 00000000000..f861b65c349 --- /dev/null +++ b/database/gdb/gdb_model_with_batch.go @@ -0,0 +1,643 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package gdb + +import ( + "database/sql" + "errors" + "reflect" + + "github.com/gogf/gf/v2/errors/gcode" + "github.com/gogf/gf/v2/errors/gerror" + "github.com/gogf/gf/v2/internal/utils" + "github.com/gogf/gf/v2/os/gstructs" + "github.com/gogf/gf/v2/text/gstr" + "github.com/gogf/gf/v2/util/gconv" +) + +// relationFieldInfo holds information about a relation field in a struct. +// It caches parsed information to avoid repeated reflection and parsing. +type relationFieldInfo struct { + Field gstructs.Field // Field information from gstructs + WithTag withTagConfig // Basic ORM tag configuration (with, where, order, unscoped) + ChunkTag chunkTagConfig // Chunk-related tag configuration + isSliceType bool // Cached: whether the field is a slice type + sourceField string // Cached: source field name in the related table (DB column) + targetField string // Cached: target field name in the current struct +} + +// isSlice returns whether this relation field is a slice type (one-to-many). +func (r *relationFieldInfo) isSlice() bool { + return r.isSliceType +} + +// newRelationFieldInfo creates a new relationFieldInfo with pre-computed cache values. +func newRelationFieldInfo(field gstructs.Field, withTag withTagConfig, chunkTag chunkTagConfig) *relationFieldInfo { + info := &relationFieldInfo{ + Field: field, + WithTag: withTag, + ChunkTag: chunkTag, + } + + // Pre-compute slice type + kind := field.Type().Kind() + if kind == reflect.Pointer { + elem := field.Type().Elem() + if elem != nil { + kind = elem.Kind() + } + } + info.isSliceType = kind == reflect.Slice || kind == reflect.Array + + // Pre-parse relation field pair from "with" tag + // Format: "source_field=target_field" or just "field_name" (same name in both tables) + parts := gstr.SplitAndTrim(withTag.With, "=") + if len(parts) == 1 { + // Same field name in both tables + info.sourceField = parts[0] + info.targetField = parts[0] + } else if len(parts) >= 2 { + // Different field names + info.sourceField = parts[0] + info.targetField = parts[1] + } + + return info +} + +// withScanContext holds the context for recursive batch association operations. +// It tracks visited types to detect circular references using backtracking algorithm. +type withScanContext struct { + model *Model // Parent Model (reuses its db, hook, cache configurations) + visitedTypes map[string]bool // Visited types for circular reference detection (backtracking) + allRelations []*relationFieldInfo // All relation fields (for chunkName group lookup) +} + +// batchQueryResult holds the result of a batch query for a relation field. +type batchQueryResult struct { + FieldName string // Name of the relation field + DataMap map[string]Result // Map from relation key (as string) to query results + Error error // Query error if any +} + +// doBatchWithScan is the entry point for batch association scanning. +// It performs batch recursive scanning for association operations to solve the N+1 problem. +func (m *Model) doBatchWithScan(pointer any) error { + ctx := &withScanContext{ + model: m, + visitedTypes: make(map[string]bool), + } + return ctx.doRecursiveWithScan(pointer) +} + +// doRecursiveWithScan performs recursive batch association operations on the given pointer. +// It collects all relation fields, executes batch queries, maps results, and recursively processes nested relations. +// Circular references are detected using a backtracking algorithm with visitedTypes map. +func (p *withScanContext) doRecursiveWithScan(pointer any) error { + // 1. Get element type and check for circular references + sliceValue := reflect.ValueOf(pointer) + if sliceValue.Kind() != reflect.Pointer { + return gerror.NewCode(gcode.CodeInvalidParameter, "pointer must be a pointer to slice") + } + sliceValue = sliceValue.Elem() + if sliceValue.Kind() != reflect.Slice { + return gerror.NewCode(gcode.CodeInvalidParameter, "pointer must be a pointer to slice") + } + + // Empty slice, nothing to do + if sliceValue.Len() == 0 { + return nil + } + + // Get element type: []*Struct -> Struct + elemType := sliceValue.Type().Elem() + if elemType.Kind() == reflect.Pointer { + elemType = elemType.Elem() + } + + // Check for circular reference using backtracking + // This allows A->B->C->A structure as long as they're not in the same path + typeName := elemType.String() + if p.visitedTypes[typeName] { + // Already visiting this type in the current path, skip to avoid infinite loop + return nil + } + p.visitedTypes[typeName] = true + defer delete(p.visitedTypes, typeName) // Backtrack: remove from visited when returning + + // 2. Collect relation fields + relations, err := p.collectRelations(pointer) + if err != nil { + return err + } + if len(relations) == 0 { + return nil // No relations to process + } + + // Store all relations for chunkName group lookup + p.allRelations = relations + + // 3. Batch query all relation fields (sequential execution, no goroutines) + batchResults := make(map[string]*batchQueryResult) + for _, relation := range relations { + result := p.queryRelation(pointer, relation) + batchResults[relation.Field.Name()] = result + if result.Error != nil { + return result.Error + } + } + + // 4. Map results to struct fields + if err := p.mapResults(pointer, relations, batchResults); err != nil { + return err + } + + // 5. Recursively process next level + for _, relation := range relations { + if err := p.doRecursiveWithScanNext(pointer, relation); err != nil { + return err + } + } + + return nil +} + +// collectRelations collects all relation fields from the struct that should be batch-processed. +// It uses struct cache to avoid repeated reflection operations. +func (p *withScanContext) collectRelations(pointer any) ([]*relationFieldInfo, error) { + // Get slice value + sliceValue := reflect.ValueOf(pointer).Elem() + if sliceValue.Len() == 0 { + return nil, nil + } + + // Get element type + elemType := sliceValue.Type().Elem() + if elemType.Kind() == reflect.Pointer { + elemType = elemType.Elem() + } + + // Get cached struct info (only caches gstructs reflection results) + cached, err := getCachedStructInfo(elemType) + if err != nil { + return nil, err + } + + // Iterate fields and parse tags (tag parsing is done every time, not cached) + var relations []*relationFieldInfo + for _, field := range cached.fields { + withTag := parseWithTag(field) + if withTag.With == "" { + continue // No "with" tag, skip + } + + // Check if this field should be processed + if !p.shouldProcess(field) { + continue + } + + chunkTag := parseChunkTag(field) + + // Create relationFieldInfo (not cached) + relation := newRelationFieldInfo(field, withTag, chunkTag) + relations = append(relations, relation) + } + + return relations, nil +} + +// shouldProcess checks if a field should be processed based on Model configuration. +func (p *withScanContext) shouldProcess(field gstructs.Field) bool { + // WithAll mode: all fields with "with" tag are allowed + if p.model.withAll { + return true + } + + // With mode: check if field type is in withArray + fieldTypeStr := gstr.TrimAll(field.Type().String(), "*[]") + for _, withItem := range p.model.withArray { + withItemType, err := gstructs.StructType(withItem) + if err != nil { + continue + } + withItemTypeStr := gstr.TrimAll(withItemType.String(), "*[]") + if gstr.Compare(fieldTypeStr, withItemTypeStr) == 0 { + return true + } + } + + return false +} + +// getChunkConfig returns the chunk configuration for a relation field. +// It follows the priority: API config (by chunkName) > Tag config > chunkName group config. +// +// Design principle: Chunking is opt-in, not default. +// - If no chunk config is provided, use single batch query (no chunking) +// - Only when explicitly configured (Chunked=true or API override), enable chunking +// +// Value semantics: +// - In Tag: Only when both chunkSize and chunkMinRows are configured and > 0, enable chunking +// - In API: -1 means use tag/group config, 0 means disable chunking, >0 means enable with that value +func (p *withScanContext) getChunkConfig(relation *relationFieldInfo) (chunkSize, chunkMinRows int) { + chunkName := relation.ChunkTag.ChunkName + chunkSize = -1 + chunkMinRows = -1 + + // Priority 1: API configuration (matched by chunkName) + if chunkName != "" && p.model.withOptions != nil { + if config, ok := p.model.withOptions[chunkName]; ok { + // API config found, use it + chunkSize = config.ChunkSize + chunkMinRows = config.ChunkMinRows + + // If both are explicitly set (not -1), use them directly + if chunkSize != -1 && chunkMinRows != -1 { + // chunkSize=0 means disable chunking + if chunkSize == 0 { + return 0, 0 + } + return chunkSize, chunkMinRows + } + + // If only one is set, need to get the other from tag/group + // Continue to priority 2 + } + } + + // Priority 2: Tag configuration (only if Chunked=true, meaning both are configured) + if relation.ChunkTag.Chunked { + // Both chunkSize and chunkMinRows are configured in tag + if chunkSize == -1 { + chunkSize = relation.ChunkTag.ChunkSize + } + if chunkMinRows == -1 { + chunkMinRows = relation.ChunkTag.ChunkMinRows + } + return chunkSize, chunkMinRows + } + + // Priority 3: ChunkName group config (look for other fields with same chunkName) + if chunkName != "" { + for _, rel := range p.allRelations { + if rel.ChunkTag.ChunkName == chunkName && rel != relation && rel.ChunkTag.Chunked { + // Found a field with same chunkName that has chunk config + if chunkSize == -1 { + chunkSize = rel.ChunkTag.ChunkSize + } + if chunkMinRows == -1 { + chunkMinRows = rel.ChunkTag.ChunkMinRows + } + if chunkSize > 0 && chunkMinRows > 0 { + return chunkSize, chunkMinRows + } + } + } + } + + // No chunk config found, disable chunking (use single batch query) + return 0, 0 +} + +// queryRelation executes a batch query for a single relation field. +// It collects all unique relation keys and performs a single WHERE IN query. +func (p *withScanContext) queryRelation(pointer any, relation *relationFieldInfo) *batchQueryResult { + result := &batchQueryResult{ + FieldName: relation.Field.Name(), + DataMap: make(map[string]Result), + } + + // 1. Collect unique relation key values + sliceValue := reflect.ValueOf(pointer).Elem() + if sliceValue.Len() == 0 { + return result + } + + // Get the first item to find field names + firstItem := sliceValue.Index(0) + if firstItem.Kind() == reflect.Pointer { + firstItem = firstItem.Elem() + } + + // Use cached struct info to find field name + cached, err := getCachedStructInfo(firstItem.Type()) + if err != nil { + result.Error = err + return result + } + + // Find the actual field name that matches relation.targetField (case-insensitive) + var actualFieldName string + for _, field := range cached.fields { + if utils.EqualFoldWithoutChars(field.Name(), relation.targetField) { + actualFieldName = field.Name() + break + } + } + + if actualFieldName == "" { + return result + } + + targetValues := ListItemValuesUnique(pointer, actualFieldName) + if len(targetValues) == 0 { + return result // No values to query + } + + // 2. Build query model + // Use the field value to get the correct table name from ORM metadata + fieldValue := relation.Field.Value + if fieldValue.Kind() == reflect.Pointer { + // For pointer types, create a new instance to get metadata + elemType := fieldValue.Type().Elem() + fieldValue = reflect.New(elemType) + } else if fieldValue.Kind() == reflect.Slice { + // For slice types, get the element type + elemType := fieldValue.Type().Elem() + if elemType.Kind() == reflect.Pointer { + elemType = elemType.Elem() + } + fieldValue = reflect.New(elemType) + } + + model := p.model.db.Model(fieldValue.Interface()) + model = model.Hook(p.model.hookHandler) + + // Apply tag conditions + if relation.WithTag.Where != "" { + model = model.Where(relation.WithTag.Where) + } + if relation.WithTag.Order != "" { + model = model.Order(relation.WithTag.Order) + } + if relation.WithTag.Unscoped == "true" { + model = model.Unscoped() + } + + // Apply cache if enabled + if p.model.cacheEnabled && p.model.cacheOption.Name == "" { + model = model.Cache(p.model.cacheOption) + } + + // 3. Get chunk configuration (API > Tag > chunkName group) + chunkSize, chunkMinRows := p.getChunkConfig(relation) + + // Determine if chunking is needed + shouldChunk := chunkSize > 0 && len(targetValues) >= chunkMinRows + + // 4. Execute batch query with WHERE IN (with optional chunking) + var records Result + if shouldChunk { + // Execute chunked queries + for i := 0; i < len(targetValues); i += chunkSize { + end := i + chunkSize + if end > len(targetValues) { + end = len(targetValues) + } + chunkValues := targetValues[i:end] + + // IMPORTANT: Clone the model for each chunk to avoid accumulating WHERE conditions + chunkModel := model.Clone() + chunkRecords, err := chunkModel.Where(relation.sourceField, chunkValues).All() + if err != nil && !errors.Is(err, sql.ErrNoRows) { + result.Error = err + return result + } + records = append(records, chunkRecords...) + } + } else { + // Execute single query + records, result.Error = model.Where(relation.sourceField, targetValues).All() + if result.Error != nil && !errors.Is(result.Error, sql.ErrNoRows) { + return result + } + } + result.Error = nil // Ignore ErrNoRows (relation data may not exist) + + // 5. Build map grouped by relation key (use string as key to avoid type mismatch) + for _, record := range records { + key := gconv.String(record[relation.sourceField].Interface()) + result.DataMap[key] = append(result.DataMap[key], record) + } + + return result +} + +// mapResults maps batch query results to struct fields. +func (p *withScanContext) mapResults( + pointer any, + relations []*relationFieldInfo, + batchResults map[string]*batchQueryResult, +) error { + sliceValue := reflect.ValueOf(pointer).Elem() + if sliceValue.Len() == 0 { + return nil + } + + firstItem := sliceValue.Index(0) + if firstItem.Kind() == reflect.Pointer { + firstItem = firstItem.Elem() + } + cached, err := getCachedStructInfo(firstItem.Type()) + if err != nil { + return err + } + + for i := 0; i < sliceValue.Len(); i++ { + item := sliceValue.Index(i) + if item.Kind() == reflect.Pointer { + item = item.Elem() + } + + for _, relation := range relations { + // Get relation key value from current item - need to use actual field name + var actualTargetFieldName string + for _, field := range cached.fields { + if utils.EqualFoldWithoutChars(field.Name(), relation.targetField) { + actualTargetFieldName = field.Name() + break + } + } + + if actualTargetFieldName == "" { + continue + } + + targetField := item.FieldByName(actualTargetFieldName) + if !targetField.IsValid() { + continue + } + targetValue := targetField.Interface() + targetValueStr := gconv.String(targetValue) + + // Get corresponding query results (use string key to avoid type mismatch) + records := batchResults[relation.Field.Name()].DataMap[targetValueStr] + if len(records) == 0 { + continue // No related data + } + + // Map to field + fieldValue := item.FieldByName(relation.Field.Name()) + if !fieldValue.IsValid() || !fieldValue.CanSet() { + continue + } + + if relation.isSlice() { + // Slice type: map all records (one-to-many) + if err := gconv.Scan(records, fieldValue.Addr().Interface()); err != nil { + return err + } + } else { + // Single type: map only first record (one-to-one) + // For pointer fields, we need to create a new instance first + if fieldValue.Kind() == reflect.Pointer { + // Create new instance of the pointer's element type + elemType := fieldValue.Type().Elem() + newElem := reflect.New(elemType) + if err := gconv.Scan(records[0], newElem.Interface()); err != nil { + return err + } + fieldValue.Set(newElem) + } else { + // For non-pointer fields, scan directly + if err := gconv.Scan(records[0], fieldValue.Addr().Interface()); err != nil { + return err + } + } + } + } + } + + return nil +} + +// doRecursiveWithScanNext recursively processes the next level of relations. +func (p *withScanContext) doRecursiveWithScanNext(pointer any, relation *relationFieldInfo) error { + sliceValue := reflect.ValueOf(pointer).Elem() + + if relation.isSlice() { + // For slice type relations, collect all child slices and merge them into one big slice + // This allows batch processing of all nested records together + + // Get the element type of the slice + var sliceElemType reflect.Type + for i := 0; i < sliceValue.Len(); i++ { + item := sliceValue.Index(i) + if item.Kind() == reflect.Pointer { + item = item.Elem() + } + + fieldValue := item.FieldByName(relation.Field.Name()) + if !fieldValue.IsValid() || fieldValue.IsZero() { + continue + } + + if fieldValue.Kind() == reflect.Pointer { + fieldValue = fieldValue.Elem() + } + + if fieldValue.Kind() == reflect.Slice && fieldValue.Len() > 0 { + sliceElemType = fieldValue.Type().Elem() + break + } + } + + if sliceElemType == nil { + return nil // No valid slice found + } + + // Create a merged slice to hold all child records + // IMPORTANT: We need to keep references to the original slices so that + // modifications to the merged slice will be reflected in the original data + mergedSliceType := reflect.SliceOf(sliceElemType) + mergedSlice := reflect.MakeSlice(mergedSliceType, 0, sliceValue.Len()*10) // Pre-allocate with estimated capacity + + // Collect all child records from all parent records + // We append the actual slice elements (which are pointers), so modifications will be reflected + for i := 0; i < sliceValue.Len(); i++ { + item := sliceValue.Index(i) + if item.Kind() == reflect.Pointer { + item = item.Elem() + } + + fieldValue := item.FieldByName(relation.Field.Name()) + if !fieldValue.IsValid() || fieldValue.IsZero() { + continue + } + + if fieldValue.Kind() == reflect.Pointer { + fieldValue = fieldValue.Elem() + } + + if fieldValue.Kind() == reflect.Slice && fieldValue.Len() > 0 { + // Append all elements from this child slice to the merged slice + // Since the elements are pointers, modifications to them will be reflected in the original slice + for j := 0; j < fieldValue.Len(); j++ { + mergedSlice = reflect.Append(mergedSlice, fieldValue.Index(j)) + } + } + } + + // If we have collected records, recursively process them as one batch + if mergedSlice.Len() > 0 { + // Create a pointer to the merged slice + mergedSlicePtr := reflect.New(mergedSliceType) + mergedSlicePtr.Elem().Set(mergedSlice) + + if err := p.doRecursiveWithScan(mergedSlicePtr.Interface()); err != nil { + return err + } + + // IMPORTANT: Since we're working with pointers, the modifications made by doRecursiveWithScan + // are automatically reflected in the original parent slices. No need to copy back. + } + } else { + // For single type relations, collect all non-nil values into a temporary slice + // Get the element type of the pointer field + fieldType := relation.Field.Type().Type + if fieldType.Kind() != reflect.Pointer { + return nil // Not a pointer field, skip + } + + // Create a slice to hold all non-nil pointer values + sliceType := reflect.SliceOf(fieldType) + tempSlice := reflect.MakeSlice(sliceType, 0, sliceValue.Len()) + + // Collect all non-nil pointer values + for i := 0; i < sliceValue.Len(); i++ { + item := sliceValue.Index(i) + if item.Kind() == reflect.Pointer { + item = item.Elem() + } + + fieldValue := item.FieldByName(relation.Field.Name()) + if !fieldValue.IsValid() || fieldValue.IsZero() { + continue + } + + // Append the pointer value directly (it's already a pointer) + if fieldValue.Kind() == reflect.Pointer && !fieldValue.IsNil() { + tempSlice = reflect.Append(tempSlice, fieldValue) + } + } + + // If we have collected values, recursively process them + if tempSlice.Len() > 0 { + // Create a pointer to the temporary slice + tempSlicePtr := reflect.New(sliceType) + tempSlicePtr.Elem().Set(tempSlice) + + if err := p.doRecursiveWithScan(tempSlicePtr.Interface()); err != nil { + return err + } + + // IMPORTANT: Since we're working with pointers, the modifications are automatically reflected + } + } + + return nil +}