Skip to content
2 changes: 1 addition & 1 deletion cli/cobra.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func setupCommonRootCommand(rootCmd *cobra.Command) (*cliflags.ClientOptions, *c
rootCmd.PersistentFlags().Lookup("help").Hidden = true

rootCmd.Annotations = map[string]string{
"additionalHelp": "For more help on how to use wpm, see https://docs.wpm.so",
"additionalHelp": "For more help on how to use wpm, see https://wpm.so/docs",
}

return opts, helpCommand
Expand Down
6 changes: 3 additions & 3 deletions cli/command/publish/publish.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func NewPublishCommand(wpmCli command.Cli) *cobra.Command {
return cmd
}

func pack(path string, opts publishOptions, out *output.Output) (*archive.Tarballer, error) {
func pack(ctx context.Context, path string, opts publishOptions, out *output.Output) (*archive.Tarballer, error) {
ignorePatterns, err := wpmignore.ReadWpmIgnore(path)
if err != nil {
return nil, err
Expand All @@ -78,7 +78,7 @@ func pack(path string, opts publishOptions, out *output.Output) (*archive.Tarbal
},
}

tar, err := archive.Tar(path, tarOptions, func(fileInfo os.FileInfo) {
tar, err := archive.Tar(ctx, path, tarOptions, func(fileInfo os.FileInfo) {
if opts.verbose {
sizeString := units.HumanSize(float64(fileInfo.Size()))
sizeString = fmt.Sprintf("%-7s", sizeString) // pad to 7 spaces since size string is capped to 4 numbers
Expand Down Expand Up @@ -174,7 +174,7 @@ func runPublish(ctx context.Context, wpmCli command.Cli, opts publishOptions) er
_ = os.Remove(tempFile.Name())
}()

tarballer, err := pack(cwd, opts, wpmCli.Output())
tarballer, err := pack(ctx, cwd, opts, wpmCli.Output())
if err != nil {
return fmt.Errorf("failed to pack the package into a tarball: %w", err)
}
Expand Down
99 changes: 23 additions & 76 deletions cmd/wpm/wpm.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@ import (
"context"
"errors"
"fmt"
"io"
"os"
"os/signal"
"syscall"
"time"

"github.com/morikuni/aec"
Expand All @@ -23,62 +21,44 @@ import (
platformsignals "go.wpm.so/cli/cmd/wpm/internal/signals"
)

type errCtxSignalTerminated struct {
signal os.Signal
}

func (errCtxSignalTerminated) Error() string {
return ""
}
// exitCodeInterrupted is the exit code used when the process is interrupted by a signal (e.g., Ctrl-C).
const exitCodeInterrupted = 130

func main() {
err := wpmMain(context.Background())
if _, ok := errors.AsType[errCtxSignalTerminated](err); ok {
os.Exit(getExitCode(err))
}

if err != nil && !errors.Is(err, context.Canceled) {
if err.Error() != "" {
_, _ = fmt.Fprintln(os.Stderr, err)
}
os.Exit(getExitCode(err))
}
os.Exit(run())
}

func notifyContext(ctx context.Context, signals ...os.Signal) (context.Context, context.CancelFunc) {
ch := make(chan os.Signal, 1)
signal.Notify(ch, signals...)

ctxCause, cancel := context.WithCancelCause(ctx)
func run() int {
ctx, stop := signal.NotifyContext(context.Background(), platformsignals.TerminationSignals...)
defer stop()

go func() {
select {
case <-ctx.Done():
signal.Stop(ch)
return
case sig := <-ch:
cancel(errCtxSignalTerminated{
signal: sig,
})
signal.Stop(ch)
return
}
<-ctx.Done()
stop()
}()

return ctxCause, func() {
signal.Stop(ch)
cancel(nil)
err := wpmMain(ctx)
if err == nil {
return 0
}

if ctx.Err() != nil || errors.Is(err, context.Canceled) {
return exitCodeInterrupted
}

if err.Error() != "" {
_, _ = fmt.Fprintln(os.Stderr, err)
}

return getExitCode(err)
}

func wpmMain(ctx context.Context) error {
ctx, cancelNotify := notifyContext(ctx, platformsignals.TerminationSignals...)
defer cancelNotify()

wpmCli, err := command.NewWpmCli()
if err != nil {
return err
}

log.Logger = zerolog.New(zerolog.ConsoleWriter{
Out: wpmCli.Err(),
NoColor: !wpmCli.Err().IsColorEnabled(),
Expand All @@ -96,14 +76,6 @@ func getExitCode(err error) int {
return 0
}

if userTerminatedErr, ok := errors.AsType[errCtxSignalTerminated](err); ok {
s, ok := userTerminatedErr.signal.(syscall.Signal)
if !ok {
return 1
}
return 128 + int(s)
}

if stErr, ok := errors.AsType[cli.StatusError](err); ok && stErr.StatusCode != 0 {
return stErr.StatusCode
}
Expand Down Expand Up @@ -189,25 +161,6 @@ func newWpmCommand(wpmCli *command.WpmCli) *cli.TopLevelCommand {
return cli.NewTopLevelCommand(cmd, wpmCli, opts, cmd.Flags())
}

// forceExitAfter3TerminationSignals waits for the first termination signal
// to be caught and the context to be marked as done, then registers a new
// signal handler for subsequent signals. It forces the process to exit
// after 3 SIGTERM/SIGINT signals.
func forceExitAfter3TerminationSignals(ctx context.Context, w io.Writer) {
// wait for the first signal to be caught and the context to be marked as done
<-ctx.Done()
// register a new signal handler for subsequent signals
sig := make(chan os.Signal, 2)
signal.Notify(sig, platformsignals.TerminationSignals...)

// once we have received a total of 3 signals we force exit the cli
for i := 0; i < 2; i++ {
<-sig
}
_, _ = fmt.Fprint(w, "\ngot 3 SIGTERM/SIGINTs, forcefully exiting\n")
os.Exit(1)
}

func setupHelpCommand(helpCmd *cobra.Command) {
origRun := helpCmd.Run
origRunE := helpCmd.RunE
Expand All @@ -234,14 +187,8 @@ func runWpm(ctx context.Context, wpmCli *command.WpmCli) error {
return err
}

// This is a fallback for the case where the command does not exit
// based on context cancellation.
go forceExitAfter3TerminationSignals(ctx, wpmCli.Err())

// We've parsed global args already, so reset args to those
// which remain.
cmd.SetArgs(args)
err = cmd.ExecuteContext(ctx)

return err
return cmd.ExecuteContext(ctx)
}
20 changes: 14 additions & 6 deletions pkg/api/cache.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package api

import (
"context"
"crypto/sha256"
"encoding/binary"
"encoding/hex"
Expand Down Expand Up @@ -142,7 +143,7 @@ func (t *Transport) executeRequest(req *http.Request, finalPath string, force bo
}

if resp.StatusCode == http.StatusOK {
resp.Body = t.write(resp.Body, finalPath, resp.Header)
resp.Body = t.write(req.Context(), resp.Body, finalPath, resp.Header)
}

return resp, nil
Expand Down Expand Up @@ -227,7 +228,7 @@ func readCacheEnvelope(f *os.File) (http.Header, int64, int64, error) {
return m.Headers, bodyStart, stat.Size() - 4, nil
}

func (t *Transport) write(src io.ReadCloser, finalPath string, h http.Header) io.ReadCloser {
func (t *Transport) write(ctx context.Context, src io.ReadCloser, finalPath string, h http.Header) io.ReadCloser {
tmpDir := filepath.Join(t.cacheDir, "tmp")
if err := os.MkdirAll(tmpDir, 0o750); err != nil {
return src
Expand All @@ -247,6 +248,7 @@ func (t *Transport) write(src io.ReadCloser, finalPath string, h http.Header) io
}

return &writer{
ctx: ctx,
src: src,
dst: f,
tmp: f.Name(),
Expand Down Expand Up @@ -286,6 +288,7 @@ func (*Transport) writeMeta(w io.Writer, h http.Header) error {
}

type writer struct {
ctx context.Context
src io.ReadCloser
dst *os.File
tmp string
Expand Down Expand Up @@ -328,22 +331,27 @@ func (w *writer) Close() error {
return srcErr
}

if err := renameFile(w.tmp, w.final); err != nil {
if err := renameFile(w.ctx, w.tmp, w.final); err != nil {
_ = os.Remove(w.tmp)
}
return srcErr
}

// renameFile retries on transient Windows errors (file locked by reader,
// AV scanner). Exponential backoff up to ~3.5s.
func renameFile(src, dst string) error {
// AV scanner). Exponential backoff up to ~3.5s, abandoned early on ctx
// cancel so SIGINT mid-retry is honored.
func renameFile(ctx context.Context, src, dst string) error {
const maxAttempts = 8
backoff := 25 * time.Millisecond

var err error
for i := range maxAttempts {
if i > 0 {
time.Sleep(backoff)
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(backoff):
}
Comment thread
thelovekesh marked this conversation as resolved.
if backoff < time.Second {
backoff *= 2
}
Expand Down
10 changes: 8 additions & 2 deletions pkg/api/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ import (
"strings"
)

// maxErrorBodySize caps how much of an error response we'll read into memory
// before parsing.
const maxErrorBodySize = 1 << 18 // 256 KiB

// HTTPError represents an error response from the wpm API.
type HTTPError struct {
Message string
Expand All @@ -23,7 +27,9 @@ func (err *HTTPError) Error() string {
return "wpm registry error: " + strings.ToLower(err.Message)
}

// HandleHTTPError parses a http.Response into a HTTPError.
// HandleHTTPError parses a http.Response into a HTTPError. The response body
// is not closed here and the caller owns the response and must close it. We do,
// however, fully consume the body so the connection stays reusable.
func HandleHTTPError(resp *http.Response) error {
httpError := &HTTPError{
Headers: resp.Header,
Expand All @@ -36,7 +42,7 @@ func HandleHTTPError(resp *http.Response) error {
return httpError
}

body, err := io.ReadAll(resp.Body)
body, err := io.ReadAll(io.LimitReader(resp.Body, maxErrorBodySize))
if err != nil {
httpError.Message = err.Error()
return httpError
Expand Down
12 changes: 11 additions & 1 deletion pkg/api/http_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,14 @@ func NewHTTPClient(opts ClientOptions) (*http.Client, error) {
}
}

// Sweep stale cache tmp files left behind by aborted requests. Runs in
// the background so a slow filesystem doesn't delay the first request;
// it's idempotent across concurrent invocations and safe to outlive the
// process (the goroutine dies with the CLI either way).
if opts.CacheDir != "" {
go func() { _ = CleanupStale(opts.CacheDir) }()
}

transport := &Transport{
Base: &http.Transport{
MaxIdleConns: 100,
Expand Down Expand Up @@ -210,7 +218,9 @@ func (d decompressingRoundTripper) RoundTrip(req *http.Request) (*http.Response,
decoder := zstdDecoderPool.Get().(*zstd.Decoder)
if err := decoder.Reset(resp.Body); err != nil {
_ = resp.Body.Close()
zstdDecoderPool.Put(decoder)
// Reset left the decoder in an unknown state hence we discard it instead of putting it back in the pool.
// The next request needing a decoder will allocate a new one to replace it, so we don't leak resources by doing this.
decoder.Close()
return nil, fmt.Errorf("failed to reset zstd reader: %w", err)
}

Expand Down
39 changes: 15 additions & 24 deletions pkg/api/rest_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,10 @@ func (c *RESTClient) DoWithContext(ctx context.Context, method, path string, bod
if err != nil {
return err
}
defer func() { _ = resp.Body.Close() }()
defer func() {
drainBody(resp.Body)
_ = resp.Body.Close()
}()

if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return HandleHTTPError(resp)
Expand Down Expand Up @@ -111,35 +114,23 @@ func (c *RESTClient) RequestStream(ctx context.Context, method, path string, bod
}

if resp.StatusCode < 200 || resp.StatusCode >= 300 {
defer func() { _ = resp.Body.Close() }()
defer func() {
drainBody(resp.Body)
_ = resp.Body.Close()
}()
return nil, HandleHTTPError(resp)
}

return resp.Body, nil
}

func (c *RESTClient) Do(method, path string, body io.Reader, response any, opts ...RequestOption) error {
return c.DoWithContext(context.Background(), method, path, body, response, opts...)
}

func (c *RESTClient) Delete(path string, resp any, opts ...RequestOption) error {
return c.Do(http.MethodDelete, path, nil, resp, opts...)
}

func (c *RESTClient) Get(path string, resp any, opts ...RequestOption) error {
return c.Do(http.MethodGet, path, nil, resp, opts...)
}

func (c *RESTClient) Patch(path string, body io.Reader, resp any, opts ...RequestOption) error {
return c.Do(http.MethodPatch, path, body, resp, opts...)
}

func (c *RESTClient) Post(path string, body io.Reader, resp any, opts ...RequestOption) error {
return c.Do(http.MethodPost, path, body, resp, opts...)
}

func (c *RESTClient) Put(path string, body io.Reader, resp any, opts ...RequestOption) error {
return c.Do(http.MethodPut, path, body, resp, opts...)
// drainBody discards any unread bytes from body so the underlying TCP
// connection can be reused from the idle pool instead of being torn down.
//
// The drain is capped so a hostile or buggy server cannot keep us reading
// indefinitely.
func drainBody(body io.Reader) {
_, _ = io.Copy(io.Discard, io.LimitReader(body, 64<<10))
}

func restURL(hostname, pathOrURL string) string {
Expand Down
Loading