From 10c256a1ad660b89f2681a316eb6e6965c68ca2d Mon Sep 17 00:00:00 2001 From: thelovekesh Date: Tue, 26 May 2026 22:47:50 +0530 Subject: [PATCH 01/11] Fix docs url --- cli/cobra.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cli/cobra.go b/cli/cobra.go index 3a26a09..8ce4827 100644 --- a/cli/cobra.go +++ b/cli/cobra.go @@ -50,7 +50,7 @@ func setupCommonRootCommand(rootCmd *cobra.Command) (*cliflags.ClientOptions, *c rootCmd.PersistentFlags().Lookup("help").Hidden = true rootCmd.Annotations = map[string]string{ - "additionalHelp": "For more help on how to use wpm, see https://docs.wpm.so", + "additionalHelp": "For more help on how to use wpm, see https://wpm.so/docs", } return opts, helpCommand From 67b13ece4816e1c8b836235db6148bd87cda2958 Mon Sep 17 00:00:00 2001 From: thelovekesh Date: Tue, 26 May 2026 22:49:17 +0530 Subject: [PATCH 02/11] Update ctx handling in cli bootstrapping --- cmd/wpm/wpm.go | 99 ++++++++++++-------------------------------------- 1 file changed, 23 insertions(+), 76 deletions(-) diff --git a/cmd/wpm/wpm.go b/cmd/wpm/wpm.go index 7e9aa58..92c54b3 100644 --- a/cmd/wpm/wpm.go +++ b/cmd/wpm/wpm.go @@ -4,10 +4,8 @@ import ( "context" "errors" "fmt" - "io" "os" "os/signal" - "syscall" "time" "github.com/morikuni/aec" @@ -23,62 +21,44 @@ import ( platformsignals "go.wpm.so/cli/cmd/wpm/internal/signals" ) -type errCtxSignalTerminated struct { - signal os.Signal -} - -func (errCtxSignalTerminated) Error() string { - return "" -} +// exitCodeInterrupted is the exit code used when the process is interrupted by a signal (e.g., Ctrl-C). +const exitCodeInterrupted = 130 func main() { - err := wpmMain(context.Background()) - if _, ok := errors.AsType[errCtxSignalTerminated](err); ok { - os.Exit(getExitCode(err)) - } - - if err != nil && !errors.Is(err, context.Canceled) { - if err.Error() != "" { - _, _ = fmt.Fprintln(os.Stderr, err) - } - os.Exit(getExitCode(err)) - } + os.Exit(run()) } -func notifyContext(ctx context.Context, signals ...os.Signal) (context.Context, context.CancelFunc) { - ch := make(chan os.Signal, 1) - signal.Notify(ch, signals...) - - ctxCause, cancel := context.WithCancelCause(ctx) +func run() int { + ctx, stop := signal.NotifyContext(context.Background(), platformsignals.TerminationSignals...) + defer stop() go func() { - select { - case <-ctx.Done(): - signal.Stop(ch) - return - case sig := <-ch: - cancel(errCtxSignalTerminated{ - signal: sig, - }) - signal.Stop(ch) - return - } + <-ctx.Done() + stop() }() - return ctxCause, func() { - signal.Stop(ch) - cancel(nil) + err := wpmMain(ctx) + if err == nil { + return 0 + } + + if errors.Is(err, context.Canceled) { + return exitCodeInterrupted } + + if err.Error() != "" { + _, _ = fmt.Fprintln(os.Stderr, err) + } + + return getExitCode(err) } func wpmMain(ctx context.Context) error { - ctx, cancelNotify := notifyContext(ctx, platformsignals.TerminationSignals...) - defer cancelNotify() - wpmCli, err := command.NewWpmCli() if err != nil { return err } + log.Logger = zerolog.New(zerolog.ConsoleWriter{ Out: wpmCli.Err(), NoColor: !wpmCli.Err().IsColorEnabled(), @@ -96,14 +76,6 @@ func getExitCode(err error) int { return 0 } - if userTerminatedErr, ok := errors.AsType[errCtxSignalTerminated](err); ok { - s, ok := userTerminatedErr.signal.(syscall.Signal) - if !ok { - return 1 - } - return 128 + int(s) - } - if stErr, ok := errors.AsType[cli.StatusError](err); ok && stErr.StatusCode != 0 { return stErr.StatusCode } @@ -189,25 +161,6 @@ func newWpmCommand(wpmCli *command.WpmCli) *cli.TopLevelCommand { return cli.NewTopLevelCommand(cmd, wpmCli, opts, cmd.Flags()) } -// forceExitAfter3TerminationSignals waits for the first termination signal -// to be caught and the context to be marked as done, then registers a new -// signal handler for subsequent signals. It forces the process to exit -// after 3 SIGTERM/SIGINT signals. -func forceExitAfter3TerminationSignals(ctx context.Context, w io.Writer) { - // wait for the first signal to be caught and the context to be marked as done - <-ctx.Done() - // register a new signal handler for subsequent signals - sig := make(chan os.Signal, 2) - signal.Notify(sig, platformsignals.TerminationSignals...) - - // once we have received a total of 3 signals we force exit the cli - for i := 0; i < 2; i++ { - <-sig - } - _, _ = fmt.Fprint(w, "\ngot 3 SIGTERM/SIGINTs, forcefully exiting\n") - os.Exit(1) -} - func setupHelpCommand(helpCmd *cobra.Command) { origRun := helpCmd.Run origRunE := helpCmd.RunE @@ -234,14 +187,8 @@ func runWpm(ctx context.Context, wpmCli *command.WpmCli) error { return err } - // This is a fallback for the case where the command does not exit - // based on context cancellation. - go forceExitAfter3TerminationSignals(ctx, wpmCli.Err()) - // We've parsed global args already, so reset args to those // which remain. cmd.SetArgs(args) - err = cmd.ExecuteContext(ctx) - - return err + return cmd.ExecuteContext(ctx) } From c204861b2a7414b41f53a9deb76ba49634c1fde2 Mon Sep 17 00:00:00 2001 From: thelovekesh Date: Tue, 26 May 2026 23:44:49 +0530 Subject: [PATCH 03/11] Fix context cancellation issues in pkg/api --- cmd/wpm/wpm.go | 2 +- pkg/api/cache.go | 20 ++++++++++++++------ pkg/api/errors.go | 10 ++++++++-- pkg/api/http_client.go | 12 +++++++++++- pkg/api/rest_client.go | 39 +++++++++++++++------------------------ pkg/pm/registry/client.go | 16 ++++++++++++---- 6 files changed, 61 insertions(+), 38 deletions(-) diff --git a/cmd/wpm/wpm.go b/cmd/wpm/wpm.go index 92c54b3..2e80128 100644 --- a/cmd/wpm/wpm.go +++ b/cmd/wpm/wpm.go @@ -42,7 +42,7 @@ func run() int { return 0 } - if errors.Is(err, context.Canceled) { + if ctx.Err() != nil || errors.Is(err, context.Canceled) { return exitCodeInterrupted } diff --git a/pkg/api/cache.go b/pkg/api/cache.go index b55efcf..e43dd7d 100644 --- a/pkg/api/cache.go +++ b/pkg/api/cache.go @@ -1,6 +1,7 @@ package api import ( + "context" "crypto/sha256" "encoding/binary" "encoding/hex" @@ -142,7 +143,7 @@ func (t *Transport) executeRequest(req *http.Request, finalPath string, force bo } if resp.StatusCode == http.StatusOK { - resp.Body = t.write(resp.Body, finalPath, resp.Header) + resp.Body = t.write(req.Context(), resp.Body, finalPath, resp.Header) } return resp, nil @@ -227,7 +228,7 @@ func readCacheEnvelope(f *os.File) (http.Header, int64, int64, error) { return m.Headers, bodyStart, stat.Size() - 4, nil } -func (t *Transport) write(src io.ReadCloser, finalPath string, h http.Header) io.ReadCloser { +func (t *Transport) write(ctx context.Context, src io.ReadCloser, finalPath string, h http.Header) io.ReadCloser { tmpDir := filepath.Join(t.cacheDir, "tmp") if err := os.MkdirAll(tmpDir, 0o750); err != nil { return src @@ -247,6 +248,7 @@ func (t *Transport) write(src io.ReadCloser, finalPath string, h http.Header) io } return &writer{ + ctx: ctx, src: src, dst: f, tmp: f.Name(), @@ -286,6 +288,7 @@ func (*Transport) writeMeta(w io.Writer, h http.Header) error { } type writer struct { + ctx context.Context src io.ReadCloser dst *os.File tmp string @@ -328,22 +331,27 @@ func (w *writer) Close() error { return srcErr } - if err := renameFile(w.tmp, w.final); err != nil { + if err := renameFile(w.ctx, w.tmp, w.final); err != nil { _ = os.Remove(w.tmp) } return srcErr } // renameFile retries on transient Windows errors (file locked by reader, -// AV scanner). Exponential backoff up to ~3.5s. -func renameFile(src, dst string) error { +// AV scanner). Exponential backoff up to ~3.5s, abandoned early on ctx +// cancel so SIGINT mid-retry is honored. +func renameFile(ctx context.Context, src, dst string) error { const maxAttempts = 8 backoff := 25 * time.Millisecond var err error for i := range maxAttempts { if i > 0 { - time.Sleep(backoff) + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(backoff): + } if backoff < time.Second { backoff *= 2 } diff --git a/pkg/api/errors.go b/pkg/api/errors.go index 6b0d525..27dcb13 100644 --- a/pkg/api/errors.go +++ b/pkg/api/errors.go @@ -8,6 +8,10 @@ import ( "strings" ) +// maxErrorBodySize caps how much of an error response we'll read into memory +// before parsing. +const maxErrorBodySize = 1 << 18 // 256 KiB + // HTTPError represents an error response from the wpm API. type HTTPError struct { Message string @@ -23,7 +27,9 @@ func (err *HTTPError) Error() string { return "wpm registry error: " + strings.ToLower(err.Message) } -// HandleHTTPError parses a http.Response into a HTTPError. +// HandleHTTPError parses a http.Response into a HTTPError. The response body +// is not closed here and the caller owns the response and must close it. We do, +// however, fully consume the body so the connection stays reusable. func HandleHTTPError(resp *http.Response) error { httpError := &HTTPError{ Headers: resp.Header, @@ -36,7 +42,7 @@ func HandleHTTPError(resp *http.Response) error { return httpError } - body, err := io.ReadAll(resp.Body) + body, err := io.ReadAll(io.LimitReader(resp.Body, maxErrorBodySize)) if err != nil { httpError.Message = err.Error() return httpError diff --git a/pkg/api/http_client.go b/pkg/api/http_client.go index a1c5bc1..1d09eb6 100644 --- a/pkg/api/http_client.go +++ b/pkg/api/http_client.go @@ -62,6 +62,14 @@ func NewHTTPClient(opts ClientOptions) (*http.Client, error) { } } + // Sweep stale cache tmp files left behind by aborted requests. Runs in + // the background so a slow filesystem doesn't delay the first request; + // it's idempotent across concurrent invocations and safe to outlive the + // process (the goroutine dies with the CLI either way). + if opts.CacheDir != "" { + go func() { _ = CleanupStale(opts.CacheDir) }() + } + transport := &Transport{ Base: &http.Transport{ MaxIdleConns: 100, @@ -210,7 +218,9 @@ func (d decompressingRoundTripper) RoundTrip(req *http.Request) (*http.Response, decoder := zstdDecoderPool.Get().(*zstd.Decoder) if err := decoder.Reset(resp.Body); err != nil { _ = resp.Body.Close() - zstdDecoderPool.Put(decoder) + // Reset left the decoder in an unknown state hence we discard it instead of putting it back in the pool. + // The next request needing a decoder will allocate a new one to replace it, so we don't leak resources by doing this. + decoder.Close() return nil, fmt.Errorf("failed to reset zstd reader: %w", err) } diff --git a/pkg/api/rest_client.go b/pkg/api/rest_client.go index f65d580..70cbed2 100644 --- a/pkg/api/rest_client.go +++ b/pkg/api/rest_client.go @@ -64,7 +64,10 @@ func (c *RESTClient) DoWithContext(ctx context.Context, method, path string, bod if err != nil { return err } - defer func() { _ = resp.Body.Close() }() + defer func() { + drainBody(resp.Body) + _ = resp.Body.Close() + }() if resp.StatusCode < 200 || resp.StatusCode >= 300 { return HandleHTTPError(resp) @@ -111,35 +114,23 @@ func (c *RESTClient) RequestStream(ctx context.Context, method, path string, bod } if resp.StatusCode < 200 || resp.StatusCode >= 300 { - defer func() { _ = resp.Body.Close() }() + defer func() { + drainBody(resp.Body) + _ = resp.Body.Close() + }() return nil, HandleHTTPError(resp) } return resp.Body, nil } -func (c *RESTClient) Do(method, path string, body io.Reader, response any, opts ...RequestOption) error { - return c.DoWithContext(context.Background(), method, path, body, response, opts...) -} - -func (c *RESTClient) Delete(path string, resp any, opts ...RequestOption) error { - return c.Do(http.MethodDelete, path, nil, resp, opts...) -} - -func (c *RESTClient) Get(path string, resp any, opts ...RequestOption) error { - return c.Do(http.MethodGet, path, nil, resp, opts...) -} - -func (c *RESTClient) Patch(path string, body io.Reader, resp any, opts ...RequestOption) error { - return c.Do(http.MethodPatch, path, body, resp, opts...) -} - -func (c *RESTClient) Post(path string, body io.Reader, resp any, opts ...RequestOption) error { - return c.Do(http.MethodPost, path, body, resp, opts...) -} - -func (c *RESTClient) Put(path string, body io.Reader, resp any, opts ...RequestOption) error { - return c.Do(http.MethodPut, path, body, resp, opts...) +// drainBody discards any unread bytes from body so the underlying TCP +// connection can be reused from the idle pool instead of being torn down. +// +// The drain is capped so a hostile or buggy server cannot keep us reading +// indefinitely. +func drainBody(body io.Reader) { + _, _ = io.Copy(io.Discard, io.LimitReader(body, 64<<10)) } func restURL(hostname, pathOrURL string) string { diff --git a/pkg/pm/registry/client.go b/pkg/pm/registry/client.go index e1d425e..9081c7e 100644 --- a/pkg/pm/registry/client.go +++ b/pkg/pm/registry/client.go @@ -62,7 +62,9 @@ func (c *client) PutPackage(ctx context.Context, data *manifest.Package, tarball return err } - return c.restClient.Put( + return c.restClient.DoWithContext( + ctx, + http.MethodPut, "/", tarball, nil, @@ -84,8 +86,11 @@ func (c *client) GetPackageManifest(ctx context.Context, packageName, versionOrT header = api.HeaderCacheRevalidate } - err := c.restClient.Get( + err := c.restClient.DoWithContext( + ctx, + http.MethodGet, "/"+packageName+"/"+versionOrTag, + nil, &pkg, api.WithHeader(header, "true"), // Used by cache round tripper. api.WithHeader(api.HeaderAccept, wpmContentTypeManifestV1), @@ -118,7 +123,7 @@ func (c *client) Whoami(ctx context.Context, token string) (string, error) { opts = append(opts, api.WithHeader(api.HeaderAuthorization, "Bearer "+token)) } - if err := c.restClient.Get("/-/whoami", &response, opts...); err != nil { + if err := c.restClient.DoWithContext(ctx, http.MethodGet, "/-/whoami", nil, &response, opts...); err != nil { return "", err } @@ -129,8 +134,11 @@ func (c *client) Whoami(ctx context.Context, token string) (string, error) { func (c *client) GetKeysJson(ctx context.Context) (signatures.KeysJson, error) { var keys signatures.KeysJson - err := c.restClient.Get( + err := c.restClient.DoWithContext( + ctx, + http.MethodGet, "/keys.json", + nil, &keys, ) if err != nil { From 762df3387f30a1270d8a622b573fba0d7db5bdbc Mon Sep 17 00:00:00 2001 From: thelovekesh Date: Wed, 27 May 2026 00:03:34 +0530 Subject: [PATCH 04/11] Add ctx handling in archive API --- cli/command/publish/publish.go | 6 ++--- pkg/archive/archive.go | 40 ++++++++++++++++++++++++++++------ pkg/pm/installer/installer.go | 10 ++++++--- 3 files changed, 43 insertions(+), 13 deletions(-) diff --git a/cli/command/publish/publish.go b/cli/command/publish/publish.go index b6ecb53..a002c5d 100644 --- a/cli/command/publish/publish.go +++ b/cli/command/publish/publish.go @@ -65,7 +65,7 @@ func NewPublishCommand(wpmCli command.Cli) *cobra.Command { return cmd } -func pack(path string, opts publishOptions, out *output.Output) (*archive.Tarballer, error) { +func pack(ctx context.Context, path string, opts publishOptions, out *output.Output) (*archive.Tarballer, error) { ignorePatterns, err := wpmignore.ReadWpmIgnore(path) if err != nil { return nil, err @@ -78,7 +78,7 @@ func pack(path string, opts publishOptions, out *output.Output) (*archive.Tarbal }, } - tar, err := archive.Tar(path, tarOptions, func(fileInfo os.FileInfo) { + tar, err := archive.Tar(ctx, path, tarOptions, func(fileInfo os.FileInfo) { if opts.verbose { sizeString := units.HumanSize(float64(fileInfo.Size())) sizeString = fmt.Sprintf("%-7s", sizeString) // pad to 7 spaces since size string is capped to 4 numbers @@ -174,7 +174,7 @@ func runPublish(ctx context.Context, wpmCli command.Cli, opts publishOptions) er _ = os.Remove(tempFile.Name()) }() - tarballer, err := pack(cwd, opts, wpmCli.Output()) + tarballer, err := pack(ctx, cwd, opts, wpmCli.Output()) if err != nil { return fmt.Errorf("failed to pack the package into a tarball: %w", err) } diff --git a/pkg/archive/archive.go b/pkg/archive/archive.go index 5dcf3ba..4b182e1 100644 --- a/pkg/archive/archive.go +++ b/pkg/archive/archive.go @@ -5,6 +5,7 @@ import ( "archive/tar" "bufio" "bytes" + "context" "encoding/binary" "errors" "fmt" @@ -340,8 +341,8 @@ func createTarFile(path string, hdr *tar.Header, reader io.Reader, options *TarO // Tar creates an archive from the directory at `path`, only including files whose relative // paths are included in `options.IncludeFiles` (if non-nil) or not in `options.ExcludePatterns`. -func Tar(srcPath string, options *TarOptions, reporterFn func(fs.FileInfo)) (*Tarballer, error) { - tb, err := NewTarballer(srcPath, options, reporterFn) +func Tar(ctx context.Context, srcPath string, options *TarOptions, reporterFn func(fs.FileInfo)) (*Tarballer, error) { + tb, err := NewTarballer(ctx, srcPath, options, reporterFn) if err != nil { return nil, err } @@ -352,6 +353,7 @@ func Tar(srcPath string, options *TarOptions, reporterFn func(fs.FileInfo)) (*Ta // Tarballer is a lower-level interface to TarWithOptions which gives the caller // control over which goroutine the archiving operation executes on. type Tarballer struct { + ctx context.Context srcPath string options *TarOptions pm *patternmatcher.PatternMatcher @@ -365,7 +367,7 @@ type Tarballer struct { // NewTarballer constructs a new tarballer. The arguments are the same as for // TarWithOptions. -func NewTarballer(srcPath string, options *TarOptions, reporterFn func(fs.FileInfo)) (*Tarballer, error) { +func NewTarballer(ctx context.Context, srcPath string, options *TarOptions, reporterFn func(fs.FileInfo)) (*Tarballer, error) { pm, err := patternmatcher.New(options.ExcludePatterns) if err != nil { return nil, err @@ -381,6 +383,7 @@ func NewTarballer(srcPath string, options *TarOptions, reporterFn func(fs.FileIn } return &Tarballer{ + ctx: ctx, // Fix the source path to work with long path names. This is a no-op // on platforms other than Windows. srcPath: addLongPathPrefix(srcPath), @@ -465,6 +468,9 @@ func (t *Tarballer) Do() { walkRoot := filepath.Join(t.srcPath, include) doErr = filepath.WalkDir(walkRoot, func(filePath string, f os.DirEntry, err error) error { + if ctxErr := t.ctx.Err(); ctxErr != nil { + return ctxErr + } if err != nil { return fmt.Errorf("unable to stat file %q: %w", t.srcPath, err) } @@ -578,7 +584,7 @@ func (t *Tarballer) Do() { // Unpack unpacks the decompressedArchive to dest with options. // //nolint:gocyclo // this function is necessarily complex due to the file walking and pattern matching logic -func Unpack(decompressedArchive io.Reader, dest string, options *TarOptions) error { +func Unpack(ctx context.Context, decompressedArchive io.Reader, dest string, options *TarOptions) error { tr := tar.NewReader(decompressedArchive) var dirs []*tar.Header @@ -587,6 +593,10 @@ func Unpack(decompressedArchive io.Reader, dest string, options *TarOptions) err // Iterate through the files in the archive. loop: for { + if err := ctx.Err(); err != nil { + return err + } + hdr, err := tr.Next() if errors.Is(err, io.EOF) { // end of archive @@ -719,6 +729,22 @@ func (t *readTracker) Read(p []byte) (int, error) { return n, err } +// ctxReader wraps an io.Reader to honor ctx cancellation. The downstream +// tar / zstd readers call Read in a tight loop; checking ctx at each Read +// boundary is fine-grained enough to abort a multi-gigabyte extract within +// a single read buffer (typically 32 KiB). +type ctxReader struct { + ctx context.Context + r io.Reader +} + +func (cr *ctxReader) Read(p []byte) (int, error) { + if err := cr.ctx.Err(); err != nil { + return 0, err + } + return cr.r.Read(p) +} + // extractionLimiter wraps the decompressed stream and enforces size and ratio limits. type extractionLimiter struct { decompressedStream io.Reader @@ -758,7 +784,7 @@ func (b *extractionLimiter) Read(p []byte) (int, error) { // Untar reads a stream of bytes from `archive`, parses it as a tar archive, // and unpacks it into the directory at `dest`. -func Untar(tarArchive io.Reader, dest string, options *TarOptions) error { +func Untar(ctx context.Context, tarArchive io.Reader, dest string, options *TarOptions) error { if tarArchive == nil { return errors.New("empty archive") } @@ -771,7 +797,7 @@ func Untar(tarArchive io.Reader, dest string, options *TarOptions) error { options.ExcludePatterns = []string{} } - compressedTracker := &readTracker{reader: tarArchive} + compressedTracker := &readTracker{reader: &ctxReader{ctx: ctx, r: tarArchive}} decompressedArchive, err := DecompressStream(compressedTracker) if err != nil { @@ -784,5 +810,5 @@ func Untar(tarArchive io.Reader, dest string, options *TarOptions) error { decompressedStream: decompressedArchive, } - return Unpack(detector, dest, options) + return Unpack(ctx, detector, dest, options) } diff --git a/pkg/pm/installer/installer.go b/pkg/pm/installer/installer.go index 3d13e71..ad2ae36 100644 --- a/pkg/pm/installer/installer.go +++ b/pkg/pm/installer/installer.go @@ -189,7 +189,7 @@ func (i *Installer) installOrUpdate(ctx context.Context, action Action, targetDi } defer func() { <-i.extractSem }() - extractedPath, tempContainer, err := i.unpackToStaging(stream) + extractedPath, tempContainer, err := i.unpackToStaging(ctx, stream) defer func() { _ = i.removeAll(context.Background(), tempContainer) }() @@ -197,6 +197,10 @@ func (i *Installer) installOrUpdate(ctx context.Context, action Action, targetDi return fmt.Errorf("failed to unpack package: %w", err) } + if err := ctx.Err(); err != nil { + return err + } + if _, err := io.Copy(io.Discard, stream); err != nil { return fmt.Errorf("failed to drain download stream: %w", err) } @@ -210,14 +214,14 @@ func (i *Installer) installOrUpdate(ctx context.Context, action Action, targetDi return i.replaceDir(ctx, extractedPath, targetDir) } -func (i *Installer) unpackToStaging(r io.Reader) (string, string, error) { +func (i *Installer) unpackToStaging(ctx context.Context, r io.Reader) (string, string, error) { rootTemp, err := os.MkdirTemp(i.runDir, "pkg-*") if err != nil { return "", "", fmt.Errorf("failed to create staging directory: %w", err) } opts := &archive.TarOptions{Logger: i.logger} - if err := archive.Untar(r, rootTemp, opts); err != nil { + if err := archive.Untar(ctx, r, rootTemp, opts); err != nil { return "", rootTemp, fmt.Errorf("failed to extract tarball: %w", err) } From e5cbc47c60f0ec96dcd4dddc23da47a99c7431bb Mon Sep 17 00:00:00 2001 From: thelovekesh Date: Wed, 27 May 2026 00:39:49 +0530 Subject: [PATCH 05/11] Update ctx handling in pkg/pm --- pkg/pm/resolution/resolver.go | 14 +++++----- pkg/pm/utils.go | 51 +++++++++++++++++++++++++++++++++++ pkg/pm/wpmjson/wpmjson.go | 2 +- pkg/pm/wpmlock/lockfile.go | 2 +- 4 files changed, 61 insertions(+), 8 deletions(-) diff --git a/pkg/pm/resolution/resolver.go b/pkg/pm/resolution/resolver.go index 5fdf850..03b96a4 100644 --- a/pkg/pm/resolution/resolver.go +++ b/pkg/pm/resolution/resolver.go @@ -56,7 +56,6 @@ type ProgressReporter interface { type fetchResult struct { req dependencyRequest manifest *manifest.Package - err error } func (r *Resolver) Resolve(ctx context.Context, progress ProgressReporter, w io.Writer) (map[string]Node, error) { @@ -70,6 +69,10 @@ func (r *Resolver) Resolve(ctx context.Context, progress ProgressReporter, w io. }() for len(queue) > 0 { + if err := ctx.Err(); err != nil { + return nil, err + } + uniqueRequests := dedupeRequests(queue, resolved) queue = nil // clear queue for next iteration @@ -131,7 +134,10 @@ func (r *Resolver) fetchAll(ctx context.Context, requests map[string]dependencyR g.Go(func() error { manifest, err := r.fetchMetadata(gtx, req.name, req.version) - results <- fetchResult{req: req, manifest: manifest, err: err} + if err != nil { + return fmt.Errorf("failed to fetch metadata for %s@%s required by %s: %w", req.name, req.version, req.requestor, err) + } + results <- fetchResult{req: req, manifest: manifest} return nil }) } @@ -151,10 +157,6 @@ func (r *Resolver) fetchAll(ctx context.Context, requests map[string]dependencyR // applyResult validates and registers a single fetch result into `resolved`, // returning any newly discovered child dependencies to enqueue. func (r *Resolver) applyResult(res fetchResult, resolved map[string]Node) ([]dependencyRequest, error) { - if res.err != nil { - return nil, fmt.Errorf("failed to fetch metadata for %s@%s required by %s: %w", res.req.name, res.req.version, res.req.requestor, res.err) - } - if existing, ok := resolved[res.req.name]; ok { if existing.Version == res.req.version { return nil, nil diff --git a/pkg/pm/utils.go b/pkg/pm/utils.go index af7562f..dc55f61 100644 --- a/pkg/pm/utils.go +++ b/pkg/pm/utils.go @@ -2,9 +2,60 @@ package pm import ( "bytes" + "fmt" + "os" + "path/filepath" + "runtime" "strings" ) +// WriteFileAtomic writes data to path atomically: it writes to a temp file in +// the same directory, fsyncs it, then renames over the target. +func WriteFileAtomic(path string, data []byte, perm os.FileMode) (retErr error) { + dir := filepath.Dir(path) + + tmp, err := os.CreateTemp(dir, "."+filepath.Base(path)+".tmp-*") + if err != nil { + return fmt.Errorf("failed to create temp file: %w", err) + } + tmpName := tmp.Name() + defer func() { + if retErr != nil { + _ = os.Remove(tmpName) + } + }() + + if _, err := tmp.Write(data); err != nil { + _ = tmp.Close() + return err + } + if err := tmp.Sync(); err != nil { + _ = tmp.Close() + return err + } + if err := tmp.Close(); err != nil { + return err + } + if err := os.Chmod(tmpName, perm); err != nil { + return err + } + + if err := os.Rename(tmpName, path); err != nil { + return err + } + + // Sync the directory to guarantee the POSIX rename is flushed to disk. + // Windows handles NTFS journaling automatically and rejects directory syncs. + if runtime.GOOS != "windows" { + if d, err := os.Open(dir); err == nil { //nolint:gosec // we need to open the directory to fsync it + _ = d.Sync() + _ = d.Close() + } + } + + return nil +} + // DetectIndentation scans the first few lines to find the indentation style. // // Defaults to 2 spaces if it can't decide. diff --git a/pkg/pm/wpmjson/wpmjson.go b/pkg/pm/wpmjson/wpmjson.go index cf197d2..8cd92f4 100644 --- a/pkg/pm/wpmjson/wpmjson.go +++ b/pkg/pm/wpmjson/wpmjson.go @@ -181,7 +181,7 @@ func (c *Config) Write(cwd string) error { } // Write with 0644 permissions (rw-r--r--) - if err := os.WriteFile(path, data, 0o644); err != nil { + if err := pm.WriteFileAtomic(path, data, 0o644); err != nil { return fmt.Errorf("failed to write wpm.json to disk: %w", err) } diff --git a/pkg/pm/wpmlock/lockfile.go b/pkg/pm/wpmlock/lockfile.go index 94a793a..0da5c72 100644 --- a/pkg/pm/wpmlock/lockfile.go +++ b/pkg/pm/wpmlock/lockfile.go @@ -96,7 +96,7 @@ func (l *Lockfile) Write(cwd string) error { } // Write with 0644 permissions (rw-r--r--) - if err := os.WriteFile(path, data, 0o644); err != nil { + if err := pm.WriteFileAtomic(path, data, 0o644); err != nil { return fmt.Errorf("failed to write lockfile to disk: %w", err) } From 3e17f655420f0281c346ed2b0ddeed42152a7c12 Mon Sep 17 00:00:00 2001 From: thelovekesh Date: Wed, 27 May 2026 01:35:36 +0530 Subject: [PATCH 06/11] Add pkg/atomicwriter --- pkg/atomicwriter/atomicwriter.go | 56 ++++++++++++++++++++ pkg/atomicwriter/atomicwriter_test.go | 76 +++++++++++++++++++++++++++ 2 files changed, 132 insertions(+) create mode 100644 pkg/atomicwriter/atomicwriter.go create mode 100644 pkg/atomicwriter/atomicwriter_test.go diff --git a/pkg/atomicwriter/atomicwriter.go b/pkg/atomicwriter/atomicwriter.go new file mode 100644 index 0000000..9825a93 --- /dev/null +++ b/pkg/atomicwriter/atomicwriter.go @@ -0,0 +1,56 @@ +package atomicwriter + +import ( + "fmt" + "os" + "path/filepath" + "runtime" +) + +// WriteFile atomically writes data to path with the given permissions. It is a +// drop-in replacement for os.WriteFile that survives an interrupt or power loss +// mid-write. +func WriteFile(path string, data []byte, perm os.FileMode) (retErr error) { + dir := filepath.Dir(path) + + tmp, err := os.CreateTemp(dir, "."+filepath.Base(path)+".tmp-*") + if err != nil { + return fmt.Errorf("failed to create temp file: %w", err) + } + tmpName := tmp.Name() + defer func() { + if retErr != nil { + _ = os.Remove(tmpName) + } + }() + + if _, err := tmp.Write(data); err != nil { + _ = tmp.Close() + return err + } + if err := tmp.Sync(); err != nil { + _ = tmp.Close() + return err + } + if err := tmp.Close(); err != nil { + return err + } + if err := os.Chmod(tmpName, perm); err != nil { + return err + } + + if err := os.Rename(tmpName, path); err != nil { + return err + } + + // Sync the directory to guarantee the POSIX rename is flushed to disk. + // Windows handles NTFS journaling automatically and rejects directory syncs. + if runtime.GOOS != "windows" { + if d, err := os.Open(dir); err == nil { //nolint:gosec // we need to open the directory to fsync it + _ = d.Sync() + _ = d.Close() + } + } + + return nil +} diff --git a/pkg/atomicwriter/atomicwriter_test.go b/pkg/atomicwriter/atomicwriter_test.go new file mode 100644 index 0000000..73e341e --- /dev/null +++ b/pkg/atomicwriter/atomicwriter_test.go @@ -0,0 +1,76 @@ +package atomicwriter + +import ( + "os" + "path/filepath" + "testing" +) + +func TestWriteFileCreatesFile(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "wpm.json") + + if err := WriteFile(path, []byte("hello"), 0o644); err != nil { + t.Fatalf("WriteFile: %v", err) + } + + got, err := os.ReadFile(path) //nolint:gosec // path is built from t.TempDir() + a constant + if err != nil { + t.Fatalf("ReadFile: %v", err) + } + if string(got) != "hello" { + t.Fatalf("content = %q, want %q", got, "hello") + } + + fi, err := os.Stat(path) + if err != nil { + t.Fatalf("Stat: %v", err) + } + if perm := fi.Mode().Perm(); perm != 0o644 { + t.Fatalf("perm = %o, want 0644", perm) + } +} + +func TestWriteFileOverwrites(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "wpm.lock") + + if err := WriteFile(path, []byte("old"), 0o644); err != nil { + t.Fatalf("first WriteFile: %v", err) + } + if err := WriteFile(path, []byte("new content"), 0o644); err != nil { + t.Fatalf("second WriteFile: %v", err) + } + + got, err := os.ReadFile(path) //nolint:gosec // path is built from t.TempDir() + a constant + if err != nil { + t.Fatalf("ReadFile: %v", err) + } + if string(got) != "new content" { + t.Fatalf("content = %q, want %q", got, "new content") + } +} + +// TestWriteFileLeavesNoTempFiles guards the success-path cleanup: the only +// entry left in the directory must be the target file, never a leftover +// ".target.tmp-*" staging file. +func TestWriteFileLeavesNoTempFiles(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "wpm.json") + + if err := WriteFile(path, []byte("data"), 0o644); err != nil { + t.Fatalf("WriteFile: %v", err) + } + + entries, err := os.ReadDir(dir) + if err != nil { + t.Fatalf("ReadDir: %v", err) + } + if len(entries) != 1 || entries[0].Name() != "wpm.json" { + names := make([]string, 0, len(entries)) + for _, e := range entries { + names = append(names, e.Name()) + } + t.Fatalf("directory entries = %v, want [wpm.json]", names) + } +} From 19835a12c2a88da1eb7efd9bc238bfe023880682 Mon Sep 17 00:00:00 2001 From: thelovekesh Date: Wed, 27 May 2026 02:12:23 +0530 Subject: [PATCH 07/11] Remove WriteFileAtomic in favor of atomicwriter.WriteFile --- pkg/pm/utils.go | 51 ------------------------------------------------- 1 file changed, 51 deletions(-) diff --git a/pkg/pm/utils.go b/pkg/pm/utils.go index dc55f61..af7562f 100644 --- a/pkg/pm/utils.go +++ b/pkg/pm/utils.go @@ -2,60 +2,9 @@ package pm import ( "bytes" - "fmt" - "os" - "path/filepath" - "runtime" "strings" ) -// WriteFileAtomic writes data to path atomically: it writes to a temp file in -// the same directory, fsyncs it, then renames over the target. -func WriteFileAtomic(path string, data []byte, perm os.FileMode) (retErr error) { - dir := filepath.Dir(path) - - tmp, err := os.CreateTemp(dir, "."+filepath.Base(path)+".tmp-*") - if err != nil { - return fmt.Errorf("failed to create temp file: %w", err) - } - tmpName := tmp.Name() - defer func() { - if retErr != nil { - _ = os.Remove(tmpName) - } - }() - - if _, err := tmp.Write(data); err != nil { - _ = tmp.Close() - return err - } - if err := tmp.Sync(); err != nil { - _ = tmp.Close() - return err - } - if err := tmp.Close(); err != nil { - return err - } - if err := os.Chmod(tmpName, perm); err != nil { - return err - } - - if err := os.Rename(tmpName, path); err != nil { - return err - } - - // Sync the directory to guarantee the POSIX rename is flushed to disk. - // Windows handles NTFS journaling automatically and rejects directory syncs. - if runtime.GOOS != "windows" { - if d, err := os.Open(dir); err == nil { //nolint:gosec // we need to open the directory to fsync it - _ = d.Sync() - _ = d.Close() - } - } - - return nil -} - // DetectIndentation scans the first few lines to find the indentation style. // // Defaults to 2 spaces if it can't decide. From 7c9ff43566c396d933a739ccd1b62909adf53f60 Mon Sep 17 00:00:00 2001 From: thelovekesh Date: Wed, 27 May 2026 02:12:30 +0530 Subject: [PATCH 08/11] Use `atomicwriter.WriteFile` for atomic file writes --- pkg/pm/wpmjson/wpmjson.go | 3 ++- pkg/pm/wpmlock/lockfile.go | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/pkg/pm/wpmjson/wpmjson.go b/pkg/pm/wpmjson/wpmjson.go index 8cd92f4..22e2fd7 100644 --- a/pkg/pm/wpmjson/wpmjson.go +++ b/pkg/pm/wpmjson/wpmjson.go @@ -6,6 +6,7 @@ import ( "os" "path/filepath" + "go.wpm.so/cli/pkg/atomicwriter" "go.wpm.so/cli/pkg/pm" "go.wpm.so/cli/pkg/pm/wpmjson/types" "go.wpm.so/cli/pkg/pm/wpmjson/validator" @@ -181,7 +182,7 @@ func (c *Config) Write(cwd string) error { } // Write with 0644 permissions (rw-r--r--) - if err := pm.WriteFileAtomic(path, data, 0o644); err != nil { + if err := atomicwriter.WriteFile(path, data, 0o644); err != nil { return fmt.Errorf("failed to write wpm.json to disk: %w", err) } diff --git a/pkg/pm/wpmlock/lockfile.go b/pkg/pm/wpmlock/lockfile.go index 0da5c72..eea1b95 100644 --- a/pkg/pm/wpmlock/lockfile.go +++ b/pkg/pm/wpmlock/lockfile.go @@ -7,6 +7,7 @@ import ( "os" "path/filepath" + "go.wpm.so/cli/pkg/atomicwriter" "go.wpm.so/cli/pkg/pm" "go.wpm.so/cli/pkg/pm/wpmjson/types" "go.wpm.so/cli/pkg/pm/wpmjson/validator" @@ -96,7 +97,7 @@ func (l *Lockfile) Write(cwd string) error { } // Write with 0644 permissions (rw-r--r--) - if err := pm.WriteFileAtomic(path, data, 0o644); err != nil { + if err := atomicwriter.WriteFile(path, data, 0o644); err != nil { return fmt.Errorf("failed to write lockfile to disk: %w", err) } From 85dca65514af76ce440f3bfd4b97d49f801dd24e Mon Sep 17 00:00:00 2001 From: thelovekesh Date: Wed, 27 May 2026 02:12:36 +0530 Subject: [PATCH 09/11] Close buffer on errors --- pkg/archive/archive.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pkg/archive/archive.go b/pkg/archive/archive.go index 4b182e1..234a03a 100644 --- a/pkg/archive/archive.go +++ b/pkg/archive/archive.go @@ -153,16 +153,19 @@ func DecompressStream(archive io.Reader) (io.ReadCloser, error) { buf := newBufferedReader(archive) bs, err := buf.Peek(10) if err != nil && err != io.EOF { + _ = buf.Close() return nil, err } // check if the stream is compressed with zstd if !isZstd(bs) { + _ = buf.Close() return nil, errors.New("unsupported archive format: expected zstd compressed archive") } zstdReader, err := zstd.NewReader(buf, zstd.WithDecoderMaxWindow(zstdMaxWindowSize)) if err != nil { + _ = buf.Close() return nil, err } From ed4dbb0100007aef03250fdfa6e15a1dce02428f Mon Sep 17 00:00:00 2001 From: thelovekesh Date: Wed, 27 May 2026 02:31:08 +0530 Subject: [PATCH 10/11] Fix bypassing os umask in atomicwriter --- pkg/atomicwriter/atomicwriter.go | 41 ++++++++++++++++++++++++++++---- 1 file changed, 36 insertions(+), 5 deletions(-) diff --git a/pkg/atomicwriter/atomicwriter.go b/pkg/atomicwriter/atomicwriter.go index 9825a93..a3212ff 100644 --- a/pkg/atomicwriter/atomicwriter.go +++ b/pkg/atomicwriter/atomicwriter.go @@ -1,6 +1,9 @@ package atomicwriter import ( + "crypto/rand" + "encoding/hex" + "errors" "fmt" "os" "path/filepath" @@ -13,11 +16,10 @@ import ( func WriteFile(path string, data []byte, perm os.FileMode) (retErr error) { dir := filepath.Dir(path) - tmp, err := os.CreateTemp(dir, "."+filepath.Base(path)+".tmp-*") + tmp, tmpName, err := createTemp(dir, filepath.Base(path), perm) if err != nil { return fmt.Errorf("failed to create temp file: %w", err) } - tmpName := tmp.Name() defer func() { if retErr != nil { _ = os.Remove(tmpName) @@ -35,9 +37,6 @@ func WriteFile(path string, data []byte, perm os.FileMode) (retErr error) { if err := tmp.Close(); err != nil { return err } - if err := os.Chmod(tmpName, perm); err != nil { - return err - } if err := os.Rename(tmpName, path); err != nil { return err @@ -54,3 +53,35 @@ func WriteFile(path string, data []byte, perm os.FileMode) (retErr error) { return nil } + +// createTemp opens a new, uniquely-named file in dir with O_CREATE|O_EXCL so it +// never follows a pre-existing symlink or clobbers an existing file. perm is +// passed straight to the open syscall so the kernel masks it by the process +// umask, exactly as os.WriteFile does. +func createTemp(dir, base string, perm os.FileMode) (*os.File, string, error) { + for range 10000 { + suffix, err := randomSuffix() + if err != nil { + return nil, "", err + } + + name := filepath.Join(dir, "."+base+".tmp-"+suffix) + f, err := os.OpenFile(name, os.O_WRONLY|os.O_CREATE|os.O_EXCL, perm) //nolint:gosec // name is confined to dir; O_EXCL blocks clobbering and symlink-following + if err == nil { + return f, name, nil + } + if errors.Is(err, os.ErrExist) { + continue + } + return nil, "", err + } + return nil, "", errors.New("exhausted attempts to create a unique temp file") +} + +func randomSuffix() (string, error) { + var b [8]byte + if _, err := rand.Read(b[:]); err != nil { + return "", err + } + return hex.EncodeToString(b[:]), nil +} From 0dd6e26cd84276dbe403e1ae8c2010fba59807de Mon Sep 17 00:00:00 2001 From: thelovekesh Date: Wed, 27 May 2026 02:31:30 +0530 Subject: [PATCH 11/11] Update test cases to test masked perms --- pkg/atomicwriter/atomicwriter_test.go | 10 +++++-- pkg/atomicwriter/atomicwriter_unix_test.go | 34 ++++++++++++++++++++++ 2 files changed, 42 insertions(+), 2 deletions(-) create mode 100644 pkg/atomicwriter/atomicwriter_unix_test.go diff --git a/pkg/atomicwriter/atomicwriter_test.go b/pkg/atomicwriter/atomicwriter_test.go index 73e341e..3de9c33 100644 --- a/pkg/atomicwriter/atomicwriter_test.go +++ b/pkg/atomicwriter/atomicwriter_test.go @@ -26,8 +26,14 @@ func TestWriteFileCreatesFile(t *testing.T) { if err != nil { t.Fatalf("Stat: %v", err) } - if perm := fi.Mode().Perm(); perm != 0o644 { - t.Fatalf("perm = %o, want 0644", perm) + if !fi.Mode().IsRegular() { + t.Fatalf("mode = %v, want a regular file", fi.Mode()) + } + // Exact permission bits are umask-dependent (see the unix-only + // TestWriteFileRespectsUmask); here we only assert the owner can still + // read+write, which survives any standard umask. + if perm := fi.Mode().Perm(); perm&0o600 != 0o600 { + t.Fatalf("perm = %o, want at least owner rw (0600)", perm) } } diff --git a/pkg/atomicwriter/atomicwriter_unix_test.go b/pkg/atomicwriter/atomicwriter_unix_test.go new file mode 100644 index 0000000..06f4340 --- /dev/null +++ b/pkg/atomicwriter/atomicwriter_unix_test.go @@ -0,0 +1,34 @@ +//go:build unix + +package atomicwriter + +import ( + "os" + "path/filepath" + "syscall" + "testing" +) + +// TestWriteFileRespectsUmask verifies that WriteFile respects the process +// umask when creating the temp file. +// +// syscall.Umask is process-global, so this test must not run in parallel. +func TestWriteFileRespectsUmask(t *testing.T) { + old := syscall.Umask(0o077) + defer syscall.Umask(old) + + dir := t.TempDir() + path := filepath.Join(dir, "wpm.json") + + if err := WriteFile(path, []byte("x"), 0o644); err != nil { + t.Fatalf("WriteFile: %v", err) + } + + fi, err := os.Stat(path) + if err != nil { + t.Fatalf("Stat: %v", err) + } + if got := fi.Mode().Perm(); got != 0o600 { + t.Fatalf("perm = %o, want 0600 (0644 masked by umask 077)", got) + } +}