diff --git a/cmd/config.go b/cmd/config.go index 22b9e541..807515ab 100644 --- a/cmd/config.go +++ b/cmd/config.go @@ -150,6 +150,10 @@ func setupFlags(rootCmd *cobra.Command) { rootCmd.PersistentFlags(). StringVar(&customRulesPathVar, customRulesFileFlagName, "", "Path to a custom rules file (JSON or YAML)."+ " Rules should be a list of ruledefine.Rule objects. --rule, --ignore-rule still apply to custom rules") + + rootCmd.PersistentFlags(). + BoolVar(&disableConsoleReportVar, disableConsoleReportFlagName, false, + "disable printing the report to the console. Other log messages are not affected") } func loadRulesFile(path string) ([]*ruledefine.Rule, error) { diff --git a/cmd/main.go b/cmd/main.go index 00bdaf77..b39d1e1e 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -35,17 +35,19 @@ const ( maxSecretSizeFlagName = "max-secret-size" validate = "validate" customRulesFileFlagName = "custom-rules-path" + disableConsoleReportFlagName = "disable-console-report" ) var ( - logLevelVar string - reportPathVar []string - stdoutFormatVar string - customRegexRuleVar []string - ignoreOnExitVar = ignoreOnExitNone - engineConfigVar engine.EngineConfig - validateVar bool - customRulesPathVar string + logLevelVar string + reportPathVar []string + stdoutFormatVar string + customRegexRuleVar []string + ignoreOnExitVar = ignoreOnExitNone + engineConfigVar engine.EngineConfig + validateVar bool + customRulesPathVar string + disableConsoleReportVar bool ) const envPrefix = "2MS" @@ -174,7 +176,7 @@ func postRun(engineInstance engine.IEngine) error { report := engineInstance.GetReport() if report.GetTotalItemsScanned() > 0 { - if zerolog.GlobalLevel() != zerolog.Disabled { + if !disableConsoleReportVar && zerolog.GlobalLevel() != zerolog.Disabled { if err := report.ShowReport(stdoutFormatVar, cfg); err != nil { return err } diff --git a/lib/reporting/report.go b/lib/reporting/report.go index b83cfa98..2b2cea0d 100644 --- a/lib/reporting/report.go +++ b/lib/reporting/report.go @@ -68,12 +68,31 @@ func (r *Report) WriteFile(reportPath []string, cfg *config.Config) error { fileExtension := filepath.Ext(path) format := strings.TrimPrefix(fileExtension, ".") + + // SARIF streams directly to the file to avoid holding the full + // serialized report in memory. + // TODO: apply streaming to all report types + if format == sarifFormat { + err = writeSarifToWriter(file, r, cfg) + if closeErr := file.Close(); closeErr != nil && err == nil { + err = closeErr + } + if err != nil { + return err + } + continue + } + output, err := r.GetOutput(format, cfg) if err != nil { + file.Close() return err } _, err = file.WriteString(output) + if closeErr := file.Close(); closeErr != nil && err == nil { + err = closeErr + } if err != nil { return err } diff --git a/lib/reporting/report_test.go b/lib/reporting/report_test.go index 84857dbd..0081bb67 100644 --- a/lib/reporting/report_test.go +++ b/lib/reporting/report_test.go @@ -2,6 +2,7 @@ package reporting import ( "encoding/json" + "fmt" "os" "path/filepath" "reflect" @@ -438,6 +439,171 @@ func TestGetOutputSarif(t *testing.T) { } } +func TestWriteSarifToWriter(t *testing.T) { + tests := []struct { + name string + arg *Report + want []Runs + }{ + { + name: "streaming_two_results_same_rule", + arg: &Report{ + TotalItemsScanned: 2, + TotalSecretsFound: 2, + Results: map[string][]*secrets.Secret{ + "secret1": {result1}, + "secret3": {result3}, + }, + }, + want: []Runs{ + { + Tool: Tool{ + Driver: Driver{ + Name: "report", + SemanticVersion: "1", + Rules: []*SarifRule{rule1Sarif}, + }, + }, + Results: []Results{result1Sarif, result3Sarif}, + }, + }, + }, + { + name: "streaming_empty_results", + arg: &Report{ + TotalItemsScanned: 0, + TotalSecretsFound: 0, + Results: map[string][]*secrets.Secret{}, + }, + want: []Runs{ + { + Tool: Tool{ + Driver: Driver{ + Name: "report", + SemanticVersion: "1", + }, + }, + Results: []Results{}, + }, + }, + }, + { + name: "streaming_includes_confluence_pageId", + arg: &Report{ + TotalItemsScanned: 1, + TotalSecretsFound: 1, + Results: map[string][]*secrets.Secret{ + "secret1": {result4}, + }, + }, + want: []Runs{ + { + Tool: Tool{ + Driver: Driver{ + Name: "report", + SemanticVersion: "1", + Rules: []*SarifRule{rule4Sarif}, + }, + }, + Results: []Results{result4Sarif}, + }, + }, + }, + } + + cfg := &config.Config{Name: "report", Version: "1"} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var buf strings.Builder + err := writeSarifToWriter(&buf, tt.arg, cfg) + assert.NoError(t, err) + + var gotReport Sarif + err = json.Unmarshal([]byte(buf.String()), &gotReport) + assert.NoError(t, err, "streaming output must be valid JSON") + SortSarifReports(&gotReport, &Sarif{Runs: tt.want}) + assert.Equal(t, tt.want, gotReport.Runs) + }) + } +} + +func TestWriteFileSarifUsesStreaming(t *testing.T) { + report := &Report{ + TotalItemsScanned: 2, + TotalSecretsFound: 2, + Results: map[string][]*secrets.Secret{ + "secret1": {result1}, + "secret2": {result2}, + }, + } + cfg := &config.Config{Name: "report", Version: "1"} + + tempDir := t.TempDir() + path := filepath.Join(tempDir, "report.sarif") + err := report.WriteFile([]string{path}, cfg) + assert.NoError(t, err) + + data, err := os.ReadFile(path) + assert.NoError(t, err) + + var gotReport Sarif + err = json.Unmarshal(data, &gotReport) + assert.NoError(t, err, "written sarif file must be valid JSON") + assert.Len(t, gotReport.Runs, 1) + assert.Len(t, gotReport.Runs[0].Results, 2) +} + +// errWriter is an io.Writer that always fails, used to test error propagation. +type errWriter struct{} + +func (errWriter) Write(_ []byte) (int, error) { + return 0, fmt.Errorf("write error") +} + +func TestWriteFileMultiplePaths(t *testing.T) { + report := &Report{ + TotalItemsScanned: 1, + TotalSecretsFound: 1, + Results: map[string][]*secrets.Secret{ + "secret1": {result1}, + }, + } + cfg := &config.Config{Name: "report", Version: "1"} + + tempDir := t.TempDir() + sarifPath := filepath.Join(tempDir, "out.sarif") + jsonPath := filepath.Join(tempDir, "out.json") + + err := report.WriteFile([]string{sarifPath, jsonPath}, cfg) + assert.NoError(t, err) + + sarifData, err := os.ReadFile(sarifPath) + assert.NoError(t, err) + var gotSarif Sarif + assert.NoError(t, json.Unmarshal(sarifData, &gotSarif), "sarif file must be valid JSON") + assert.Len(t, gotSarif.Runs[0].Results, 1) + + jsonData, err := os.ReadFile(jsonPath) + assert.NoError(t, err) + var gotReport Report + assert.NoError(t, json.Unmarshal(jsonData, &gotReport), "json file must be valid JSON") + assert.Equal(t, 1, gotReport.TotalSecretsFound) +} + +func TestWriteSarifToWriterError(t *testing.T) { + report := &Report{ + TotalItemsScanned: 1, + TotalSecretsFound: 1, + Results: map[string][]*secrets.Secret{ + "secret1": {result1}, + }, + } + cfg := &config.Config{Name: "report", Version: "1"} + + err := writeSarifToWriter(errWriter{}, report, cfg) + assert.Error(t, err) +} + // SortProject Sorts two sarif reports func SortSarifReports(run1, run2 *Sarif) { // Sort Rules diff --git a/lib/reporting/sarif.go b/lib/reporting/sarif.go index c3daabe2..bb233a1f 100644 --- a/lib/reporting/sarif.go +++ b/lib/reporting/sarif.go @@ -1,8 +1,10 @@ package reporting import ( + "bufio" "encoding/json" "fmt" + "io" "strings" "github.com/checkmarx/2ms/v5/lib/config" @@ -24,6 +26,91 @@ func writeSarif(report *Report, cfg *config.Config) (string, error) { return string(sarifReport), nil } +// writeSarifToWriter streams the SARIF report directly to w, marshaling each +// result individually so that only one result's JSON is in memory at a time. +// This eliminates the intermediate Sarif struct, full []byte, and string copies +// that writeSarif produces. +func writeSarifToWriter(w io.Writer, report *Report, cfg *config.Config) error { + bw := bufio.NewWriter(w) + + // Schema and version + bw.WriteString("{\n") + bw.WriteString(" \"$schema\": \"https://schemastore.azurewebsites.net/schemas/json/sarif-2.1.0-rtm.5.json\",\n") + bw.WriteString(" \"version\": \"2.1.0\",\n") + bw.WriteString(" \"runs\": [\n") + bw.WriteString(" {\n") + + // Tool section (small — safe to marshal in one shot) + tool := getTool(report, cfg) + toolJSON, err := json.MarshalIndent(tool, " ", " ") + if err != nil { + return fmt.Errorf("failed to marshal tool: %w", err) + } + bw.WriteString(" \"tool\": ") + bw.Write(toolJSON) + bw.WriteString(",\n") + + // Results — stream one at a time + bw.WriteString(" \"results\": [\n") + + if hasNoResults(report) { + // empty array body — just close it + } else { + first := true + for _, secretsSlice := range report.Results { + for _, secret := range secretsSlice { + if !first { + bw.WriteString(",\n") + } + result := buildSarifResult(secret) + resultJSON, err := json.MarshalIndent(result, " ", " ") + if err != nil { + return fmt.Errorf("failed to marshal result: %w", err) + } + bw.WriteString(" ") + bw.Write(resultJSON) + first = false + } + } + if !first { + bw.WriteString("\n") + } + } + + bw.WriteString(" ]\n") + bw.WriteString(" }\n") + bw.WriteString(" ]\n") + bw.WriteString("}\n") + + return bw.Flush() +} + +// buildSarifResult converts a single Secret into a SARIF Results object. +func buildSarifResult(secret *secrets.Secret) Results { + props := Properties{ + "validationStatus": secret.ValidationStatus, + "cvssScore": secret.CvssScore, + "resultId": secret.ID, + "severity": secret.Severity, + "ruleName": secret.RuleName, + } + + if secret.ExtraDetails != nil { + if pageID, ok := secret.ExtraDetails["confluence.pageId"]; ok { + props["confluence.pageId"] = pageID + } + } + + return Results{ + Message: Message{ + Text: createMessageText(secret.RuleName, secret.Source), + }, + RuleId: secret.RuleID, + Locations: getLocation(secret), + Properties: props, + } +} + func getRuns(report *Report, cfg *config.Config) []Runs { return []Runs{ { diff --git a/main.go b/main.go index 55b1de8f..99ae3241 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,8 @@ package main import ( + "net/http" + _ "net/http/pprof" "os" "os/signal" @@ -14,6 +16,14 @@ func main() { zerolog.SetGlobalLevel(zerolog.InfoLevel) log.Logger = utils.CreateLogger(zerolog.InfoLevel) + // Start pprof server for profiling + go func() { + log.Info().Msg("Starting pprof server on :6060") + if err := http.ListenAndServe(":6060", nil); err != nil { + log.Error().Err(err).Msg("pprof server failed") + } + }() + // this block sets up a go routine to listen for an interrupt signal // which will immediately exit gitleaks stopChan := make(chan os.Signal, 1)