Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions internal/cmd/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,35 @@ func (g *generator) ProcessResult(ctx context.Context, combo config.CombinedSett
// out is specified by the user, not a plugin
absout := filepath.Join(g.dir, out)

// When the Go codegen is configured to emit the models file into a
// separate package directory, route that file to its own absolute path.
// This is the only file allowed to live outside of `out`.
var (
modelsFileName string
modelsAbsout string
modelsAbsfile string
)
if sql.Gen.Go != nil && sql.Gen.Go.OutputModelsPath != "" && sql.Gen.Go.ModelsEmitEnabled() {
modelsFileName = sql.Gen.Go.OutputModelsFileName
if modelsFileName == "" {
modelsFileName = "models.go"
}
modelsAbsout = filepath.Join(g.dir, sql.Gen.Go.OutputModelsPath)
modelsAbsfile = filepath.Join(modelsAbsout, modelsFileName)
}

for n, source := range files {
if modelsFileName != "" && n == modelsFileName {
// Models file routed to a separate package directory.
if strings.Contains(modelsAbsfile, "..") {
return fmt.Errorf("invalid file output path: %s", modelsAbsfile)
}
if !strings.HasPrefix(modelsAbsfile, modelsAbsout) {
return fmt.Errorf("invalid file output path: %s", modelsAbsfile)
}
g.output[modelsAbsfile] = source
continue
}
filename := filepath.Join(g.dir, out, n)
// filepath.Join calls filepath.Clean which should remove all "..", but
// double check to make sure
Expand Down
48 changes: 31 additions & 17 deletions internal/codegen/golang/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@ import (
)

type tmplCtx struct {
Q string
Package string
SQLDriver opts.SQLDriver
Enums []Enum
Structs []Struct
GoQueries []Query
SqlcVersion string
Q string
Package string
ModelsPackage string
SQLDriver opts.SQLDriver
Enums []Enum
Structs []Struct
GoQueries []Query
SqlcVersion string

// TODO: Race conditions
SourceName string
Expand Down Expand Up @@ -120,13 +121,13 @@ func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.Generat

enums := buildEnums(req, options)
structs := buildStructs(req, options)
queries, err := buildQueries(req, options, structs)
queries, err := buildQueries(req, options, enums, structs)
if err != nil {
return nil, err
}

if options.OmitUnusedStructs {
enums, structs = filterUnusedStructs(enums, structs, queries)
enums, structs = filterUnusedStructs(enums, structs, queries, options.ModelsTypeQualifier())
}

if err := validate(options, enums, structs, queries); err != nil {
Expand Down Expand Up @@ -186,6 +187,7 @@ func generate(req *plugin.GenerateRequest, options *opts.Options, enums []Enum,
SQLDriver: parseDriver(options.SqlPackage),
Q: "`",
Package: options.Package,
ModelsPackage: options.ModelsPackage(),
Enums: enums,
Structs: structs,
SqlcVersion: req.SqlcVersion,
Expand Down Expand Up @@ -292,8 +294,10 @@ func generate(req *plugin.GenerateRequest, options *opts.Options, enums []Enum,
if err := execute(dbFileName, "dbFile"); err != nil {
return nil, err
}
if err := execute(modelsFileName, "modelsFile"); err != nil {
return nil, err
if options.ModelsEmitEnabled() {
if err := execute(modelsFileName, "modelsFile"); err != nil {
return nil, err
}
}
if options.EmitInterface {
if err := execute(querierFileName, "interfaceFile"); err != nil {
Expand Down Expand Up @@ -367,25 +371,35 @@ func checkNoTimesForMySQLCopyFrom(queries []Query) error {
return nil
}

func filterUnusedStructs(enums []Enum, structs []Struct, queries []Query) ([]Enum, []Struct) {
func filterUnusedStructs(enums []Enum, structs []Struct, queries []Query, qualifier string) ([]Enum, []Struct) {
keepTypes := make(map[string]struct{})

keep := func(t string) {
keepTypes[t] = struct{}{}
// Also store the bare type name so that lookups against
// bare struct/enum names match even when types have been
// qualified with the models package prefix (e.g. "model.User").
if bare := stripQualifier(t, qualifier); bare != t {
keepTypes[bare] = struct{}{}
}
}

for _, query := range queries {
if !query.Arg.isEmpty() {
keepTypes[query.Arg.Type()] = struct{}{}
keep(query.Arg.Type())
if query.Arg.IsStruct() {
for _, field := range query.Arg.Struct.Fields {
keepTypes[field.Type] = struct{}{}
keep(field.Type)
}
}
}
if query.hasRetType() {
keepTypes[query.Ret.Type()] = struct{}{}
keep(query.Ret.Type())
if query.Ret.IsStruct() {
for _, field := range query.Ret.Struct.Fields {
keepTypes[strings.TrimPrefix(field.Type, "[]")] = struct{}{}
keep(strings.TrimPrefix(field.Type, "[]"))
for _, embedField := range field.EmbedFields {
keepTypes[embedField.Type] = struct{}{}
keep(embedField.Type)
}
}
}
Expand Down
11 changes: 11 additions & 0 deletions internal/codegen/golang/imports.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,17 @@ func buildImports(options *opts.Options, queries []Query, uses func(string) bool
}
}

// Models package import. When models live in a separate Go package and
// any type in this file references a qualified model type, import the
// models package under a fixed `models` alias so query files always
// reference types as `models.User` regardless of how the actual
// package is named.
if options.ModelsAreExternal() {
if uses(options.ModelsTypeQualifier()) {
pkg[ImportSpec{Path: options.OutputModelsImport, ID: opts.ModelsImportAlias}] = struct{}{}
}
}

return std, pkg
}

Expand Down
85 changes: 85 additions & 0 deletions internal/codegen/golang/opts/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ type Options struct {
OutputBatchFileName string `json:"output_batch_file_name,omitempty" yaml:"output_batch_file_name"`
OutputDbFileName string `json:"output_db_file_name,omitempty" yaml:"output_db_file_name"`
OutputModelsFileName string `json:"output_models_file_name,omitempty" yaml:"output_models_file_name"`
OutputModelsPath string `json:"output_models_path,omitempty" yaml:"output_models_path"`
OutputModelsPackage string `json:"output_models_package,omitempty" yaml:"output_models_package"`
OutputModelsImport string `json:"output_models_import,omitempty" yaml:"output_models_import"`
OutputModelsEmit *bool `json:"output_models_emit,omitempty" yaml:"output_models_emit"`
OutputQuerierFileName string `json:"output_querier_file_name,omitempty" yaml:"output_querier_file_name"`
OutputCopyfromFileName string `json:"output_copyfrom_file_name,omitempty" yaml:"output_copyfrom_file_name"`
OutputFilesSuffix string `json:"output_files_suffix,omitempty" yaml:"output_files_suffix"`
Expand Down Expand Up @@ -94,6 +98,17 @@ func parseOpts(req *plugin.GenerateRequest) (*Options, error) {
}
}

// Default the models package name to the base of the models path. When
// the user only configures output_models_emit: false (no path), fall
// back to the base of the import path.
if options.OutputModelsPackage == "" {
if options.OutputModelsPath != "" {
options.OutputModelsPackage = filepath.Base(options.OutputModelsPath)
} else if options.OutputModelsImport != "" {
options.OutputModelsPackage = filepath.Base(options.OutputModelsImport)
}
}

for i := range options.Overrides {
if err := options.Overrides[i].parse(req); err != nil {
return nil, err
Expand Down Expand Up @@ -154,5 +169,75 @@ func ValidateOpts(opts *Options) error {
return fmt.Errorf("invalid options: query parameter limit must not be negative")
}

if err := validateModelsOptions(opts); err != nil {
return err
}

return nil
}

// ModelsEmitEnabled reports whether this codegen block should write the
// models file. Defaults to true when the option is unset.
func (o *Options) ModelsEmitEnabled() bool {
if o.OutputModelsEmit == nil {
return true
}
return *o.OutputModelsEmit
}

// ModelsImportAlias is the fixed Go import alias used for the models
// package in query files. Using a constant alias keeps the type qualifier
// consistent regardless of how the user names the actual package.
const ModelsImportAlias = "models"

// ModelsPackage returns the Go package name to use in the models file
// itself (i.e. the `package X` declaration). When the caller has not
// configured a separate models package, this is the same as Package.
func (o *Options) ModelsPackage() string {
if o.OutputModelsPackage != "" {
return o.OutputModelsPackage
}
return o.Package
}

// ModelsAreExternal reports whether model types live in a different Go
// package than the queries package. When true, query files must import the
// models package and reference types as `models.Type`.
func (o *Options) ModelsAreExternal() bool {
return o.OutputModelsImport != ""
}

// ModelsTypeQualifier returns the prefix to use when referencing a model
// type from a query file ("models."). Empty string when no qualifier is
// needed.
func (o *Options) ModelsTypeQualifier() string {
if o.ModelsAreExternal() {
return ModelsImportAlias + "."
}
return ""
}

func validateModelsOptions(opts *Options) error {
hasAnyModelsOpt := opts.OutputModelsPath != "" ||
opts.OutputModelsPackage != "" ||
opts.OutputModelsImport != "" ||
opts.OutputModelsEmit != nil

if !hasAnyModelsOpt {
return nil
}

if opts.OutputModelsImport == "" {
return fmt.Errorf("invalid options: output_models_import is required when any output_models_* option is set")
}

if opts.ModelsEmitEnabled() && opts.OutputModelsPath == "" {
return fmt.Errorf("invalid options: output_models_path is required when emitting models to a separate package")
}

if opts.ModelsEmitEnabled() && opts.OutputModelsPath == opts.Out {
return fmt.Errorf("invalid options: output_models_path matches out; models would overwrite the queries package")
}

return nil
}
79 changes: 79 additions & 0 deletions internal/codegen/golang/qualify.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package golang

import "strings"

// stripQualifier removes a leading slice/pointer prefix and the given
// `pkg.` qualifier from a Go type expression. When the qualifier is empty
// or absent from the type, the input is returned unchanged.
func stripQualifier(t, qualifier string) string {
if qualifier == "" {
return t
}
prefix := ""
rest := t
for {
if strings.HasPrefix(rest, "[]") {
prefix += "[]"
rest = rest[2:]
continue
}
if strings.HasPrefix(rest, "*") {
prefix += "*"
rest = rest[1:]
continue
}
break
}
if strings.HasPrefix(rest, qualifier) {
return prefix + rest[len(qualifier):]
}
return t
}

// modelTypeSet is the set of Go type names that live in the models file.
type modelTypeSet map[string]struct{}

// buildModelTypeSet returns the set of type names that are declared in
// models.go for the current codegen invocation.
func buildModelTypeSet(enums []Enum, structs []Struct) modelTypeSet {
set := make(modelTypeSet, len(enums)*4+len(structs))
for _, e := range enums {
set[e.Name] = struct{}{}
set["Null"+e.Name] = struct{}{}
}
for _, s := range structs {
if s.IsModel {
set[s.Name] = struct{}{}
}
}
return set
}

// qualifyType prefixes a Go type expression with `qualifier` when the bare
// type name belongs to `models`. Slice and pointer prefixes are preserved.
// When qualifier is empty (i.e. models live in the queries package), the
// input is returned unchanged.
func qualifyType(t string, models modelTypeSet, qualifier string) string {
if qualifier == "" || t == "" || len(models) == 0 {
return t
}
prefix := ""
rest := t
for {
if strings.HasPrefix(rest, "[]") {
prefix += "[]"
rest = rest[2:]
continue
}
if strings.HasPrefix(rest, "*") {
prefix += "*"
rest = rest[1:]
continue
}
break
}
if _, ok := models[rest]; ok {
return prefix + qualifier + rest
}
return t
}
12 changes: 12 additions & 0 deletions internal/codegen/golang/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ type QueryValue struct {
Typ string
SQLDriver opts.SQLDriver

// ModelQualifier prefixes references to model types when the models file
// lives in a different Go package (e.g. "model."). Empty otherwise.
ModelQualifier string

// Column is kept so late in the generation process around to differentiate
// between mysql slices and pg arrays
Column *plugin.Column
Expand Down Expand Up @@ -88,6 +92,14 @@ func (v QueryValue) Type() string {
return v.Typ
}
if v.Struct != nil {
// Model structs (table-derived) live in the models file. When that
// file is generated into a different Go package, references from
// query files must be qualified. Synthetic structs (Params/Row)
// are defined in the same query file as their use, so they stay
// bare.
if v.Struct.IsModel && v.ModelQualifier != "" {
return v.ModelQualifier + v.Struct.Name
}
return v.Struct.Name
}
panic("no type for QueryValue: " + v.Name)
Expand Down
Loading
Loading