Skip to content
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,13 @@ func SetFileName(name string) {

filename = name

// Validate file extension before creating the file
err := validateConfigFileExtension(filename)
checkErr(err)

file.CreateEmptyIfNotExists(filename)
configureViper(filename)
err = configureViper(filename)
checkErr(err)
}

func SetFileNameForTest(t *testing.T) {
Expand Down
35 changes: 34 additions & 1 deletion internal/config/viper.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ package config

import (
"bytes"
"path/filepath"
"strings"

"github.com/microsoft/go-sqlcmd/internal/localizer"
"github.com/microsoft/go-sqlcmd/internal/pal"
"github.com/spf13/viper"
"gopkg.in/yaml.v2"
Expand Down Expand Up @@ -56,16 +60,45 @@ func GetConfigFileUsed() string {
return viper.ConfigFileUsed()
}

// validateConfigFileExtension checks if the config file has a supported extension.
// It allows .yaml, .yml, and no extension (for default sqlconfig file).
// Returns an error if the extension is not supported.
func validateConfigFileExtension(configFile string) error {
ext := strings.ToLower(filepath.Ext(configFile))

// Allow no extension (for default sqlconfig file)
if ext == "" {
return nil
}

// Allow .yaml and .yml extensions
if ext == ".yaml" || ext == ".yml" {
return nil
}

// Return error for unsupported extensions
return localizer.Errorf(
"Configuration files must use YAML format with .yaml or .yml extension.\n"+
"The file '%s' has an unsupported extension '%s'.",
Comment thread
dlevy-msft-sql marked this conversation as resolved.
Outdated
configFile, ext)
}

// configureViper initializes the Viper library with the given configuration file.
// This function sets the configuration file type to "yaml" and sets the environment variable prefix to "SQLCMD".
// It also sets the configuration file to use to the one provided as an argument to the function.
// This function is intended to be called at the start of the application to configure Viper before any other code uses it.
func configureViper(configFile string) {
func configureViper(configFile string) error {
if configFile == "" {
panic("Must provide configFile")
}

// Validate file extension
if err := validateConfigFileExtension(configFile); err != nil {
return err
}

viper.SetConfigType("yaml")
viper.SetEnvPrefix("SQLCMD")
viper.SetConfigFile(configFile)
return nil
}
94 changes: 94 additions & 0 deletions internal/config/viper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,100 @@ func Test_configureViper(t *testing.T) {
})
}

func Test_validateConfigFileExtension(t *testing.T) {
tests := []struct {
name string
filename string
wantErr bool
}{
{
name: "valid yaml extension",
filename: "config.yaml",
wantErr: false,
},
{
name: "valid yml extension",
filename: "config.yml",
wantErr: false,
},
{
name: "no extension (default sqlconfig)",
filename: "sqlconfig",
wantErr: false,
},
{
name: "no extension with path",
filename: "/home/user/.sqlcmd/sqlconfig",
wantErr: false,
},
{
name: "invalid txt extension",
filename: "config.txt",
wantErr: true,
},
{
name: "invalid json extension",
filename: "config.json",
wantErr: true,
},
{
name: "invalid xml extension",
filename: "config.xml",
wantErr: true,
},
{
name: "uppercase YAML extension",
filename: "config.YAML",
wantErr: false,
},
{
name: "uppercase YML extension",
filename: "config.YML",
wantErr: false,
},
{
name: "mixed case yaml extension",
filename: "config.Yaml",
wantErr: false,
},
}
Comment thread
dlevy-msft-sql marked this conversation as resolved.

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateConfigFileExtension(tt.filename)
if tt.wantErr {
assert.Error(t, err, "Expected error for filename: %s", tt.filename)
assert.Contains(t, err.Error(), "Configuration files must use YAML format")
} else {
assert.NoError(t, err, "Expected no error for filename: %s", tt.filename)
}
})
}
}

func Test_configureViper_withInvalidExtension(t *testing.T) {
err := configureViper("myconfig.txt")
assert.Error(t, err)
assert.Contains(t, err.Error(), "Configuration files must use YAML format")
assert.Contains(t, err.Error(), ".txt")
}

func Test_configureViper_withValidExtensions(t *testing.T) {
testCases := []string{
"config.yaml",
"config.yml",
"sqlconfig",
"/path/to/config.yaml",
}

for _, filename := range testCases {
t.Run(filename, func(t *testing.T) {
err := configureViper(filename)
assert.NoError(t, err, "Expected no error for filename: %s", filename)
})
}
}

func Test_Load(t *testing.T) {
SetFileNameForTest(t)
Clean()
Expand Down
Loading