diff --git a/remotefs/hostinfo_test.go b/remotefs/hostinfo_test.go index c9b0c3c3..8b4e281f 100644 --- a/remotefs/hostinfo_test.go +++ b/remotefs/hostinfo_test.go @@ -1,6 +1,8 @@ package remotefs_test import ( + "context" + "encoding/base64" "errors" "io/fs" "testing" @@ -473,22 +475,55 @@ func TestWindowsChownVariantsNotSupported(t *testing.T) { require.ErrorIs(t, f.ChownTreeInt("/tmp", 0, 0), remotefs.ErrNotSupported) } -func TestPosixCreateTemp(t *testing.T) { - t.Run("with prefix", func(t *testing.T) { +func TestPosixHTTPStatus(t *testing.T) { + t.Run("200", func(t *testing.T) { mr := rigtest.NewMockRunner() - mr.AddCommandOutput(rigtest.Equal("echo ${TMPDIR:-/tmp}"), "/tmp") - mr.AddCommandOutput(rigtest.Contains("mktemp"), "/tmp/rig-abc123") + mr.AddCommandOutput(rigtest.Equal("command -v curl"), "/usr/bin/curl") + mr.AddCommandOutput(rigtest.Equal("command -v base64"), "/usr/bin/base64") + resp200 := base64.StdEncoding.EncodeToString([]byte("HTTP/1.1 200 OK\r\n\r\n")) + mr.AddCommandOutput(rigtest.Contains("--http1.1"), resp200) f := remotefs.NewPosixFS(mr) - path, err := f.CreateTemp("", "rig-") + code, err := remotefs.HTTPStatus(context.Background(), f, "http://example.com/health") require.NoError(t, err) - require.Equal(t, "/tmp/rig-abc123", path) - require.NoError(t, mr.Received(rigtest.Contains("mktemp -- /tmp/rig-XXXXXX"))) + require.Equal(t, 200, code) }) - t.Run("failure", func(t *testing.T) { + t.Run("503", func(t *testing.T) { mr := rigtest.NewMockRunner() - mr.AddCommandFailure(rigtest.Contains("mktemp"), errors.New("permission denied")) + mr.AddCommandOutput(rigtest.Equal("command -v curl"), "/usr/bin/curl") + mr.AddCommandOutput(rigtest.Equal("command -v base64"), "/usr/bin/base64") + resp503 := base64.StdEncoding.EncodeToString([]byte("HTTP/1.1 503 Service Unavailable\r\n\r\n")) + mr.AddCommandOutput(rigtest.Contains("--http1.1"), resp503) f := remotefs.NewPosixFS(mr) - _, err := f.CreateTemp("/srv", "rig-") + code, err := remotefs.HTTPStatus(context.Background(), f, "http://example.com/health") + require.NoError(t, err) + require.Equal(t, 503, code) + }) + t.Run("curl unavailable", func(t *testing.T) { + mr := rigtest.NewMockRunner() + mr.AddCommandFailure(rigtest.Equal("command -v curl"), errors.New("not found")) + f := remotefs.NewPosixFS(mr) + _, err := remotefs.HTTPStatus(context.Background(), f, "http://example.com/health") + require.Error(t, err) + }) +} + +func TestWindowsHTTPStatus(t *testing.T) { + t.Run("200", func(t *testing.T) { + mr := rigtest.NewMockRunner() + mr.Windows = true + resp200 := base64.StdEncoding.EncodeToString([]byte("HTTP/1.1 200 OK\r\n\r\n")) + mr.AddCommandOutput(rigtest.HasPrefix("powershell.exe"), resp200) + f := remotefs.NewWindowsFS(mr) + code, err := remotefs.HTTPStatus(context.Background(), f, "http://example.com/health") + require.NoError(t, err) + require.Equal(t, 200, code) + }) + t.Run("failure", func(t *testing.T) { + mr := rigtest.NewMockRunner() + mr.Windows = true + mr.AddCommandFailure(rigtest.HasPrefix("powershell.exe"), errors.New("exit 1")) + f := remotefs.NewWindowsFS(mr) + _, err := remotefs.HTTPStatus(context.Background(), f, "http://example.com/health") require.Error(t, err) }) } @@ -507,3 +542,24 @@ func TestWindowsCreateTemp(t *testing.T) { require.Equal(t, "C:/Windows/Temp/rig-abc123.tmp", path) }) } + +func TestPosixCreateTemp(t *testing.T) { + t.Run("with prefix", func(t *testing.T) { + mr := rigtest.NewMockRunner() + mr.AddCommandOutput(rigtest.Equal("echo ${TMPDIR:-/tmp}"), "/tmp") + mr.AddCommandOutput(rigtest.Contains("mktemp"), "/tmp/rig-abc123") + f := remotefs.NewPosixFS(mr) + path, err := f.CreateTemp("", "rig-") + require.NoError(t, err) + require.Equal(t, "/tmp/rig-abc123", path) + require.NoError(t, mr.Received(rigtest.Contains("mktemp -- /tmp/rig-XXXXXX"))) + }) + t.Run("failure", func(t *testing.T) { + mr := rigtest.NewMockRunner() + mr.AddCommandFailure(rigtest.Contains("mktemp"), errors.New("permission denied")) + f := remotefs.NewPosixFS(mr) + _, err := f.CreateTemp("/srv", "rig-") + require.Error(t, err) + }) +} + diff --git a/remotefs/http.go b/remotefs/http.go new file mode 100644 index 00000000..561a1c72 --- /dev/null +++ b/remotefs/http.go @@ -0,0 +1,99 @@ +package remotefs + +import ( + "bufio" + "bytes" + "context" + "encoding/base64" + "errors" + "fmt" + "net/http" + "net/url" + "strings" +) + +const maxEncodedResponseSize = 32 * 1024 * 1024 // base64-encoded limit (~24 MiB decoded) + +var ( + errNilRequest = errors.New("net/http: nil Request") + errNilRequestURL = errors.New("net/http: nil Request.URL") + errUnsupportedScheme = errors.New("unsupported URL scheme") + errMissingURLHost = errors.New("missing URL host") + errUserinfoInURL = errors.New("URL must not contain userinfo (credentials in URLs leak to remote command line)") + errInvalidURLChars = errors.New("URL contains control characters (CR, LF, or NUL)") + errInvalidHeader = errors.New("invalid header: contains CR, LF, or NUL") + errResponseTooLarge = errors.New("http response exceeds maximum size") +) + +// validateRoundTripURL returns an error if u has an unsupported scheme, no host, userinfo, or control characters. +// Must be called after a nil-URL guard. +func validateRoundTripURL(target *url.URL) error { + if strings.ContainsAny(target.String(), "\r\n\x00") { + return errInvalidURLChars + } + switch strings.ToLower(target.Scheme) { + case "http", "https": + default: + return fmt.Errorf("%w: %q", errUnsupportedScheme, target.Scheme) + } + if target.Host == "" { + return errMissingURLHost + } + if target.User != nil { + return errUserinfoInURL + } + return nil +} + +// HTTPTransport is implemented by remote filesystems that can proxy HTTP requests +// through the remote host. Since RoundTrip matches the http.RoundTripper signature, +// any FS value satisfies http.RoundTripper and can be used directly as http.Client.Transport. +// +// Note: RoundTrip materializes the entire HTTP response on the remote side before +// transferring it to the caller as a base64-encoded string. Depending on the +// implementation, the remote side may buffer the response in memory or write it to a +// temporary file, and the caller also buffers the full response while decoding and +// parsing it. It is not suitable for large response bodies; use DownloadURL for +// downloading large files instead. +type HTTPTransport interface { + DownloadURL(url, dst string) error + RoundTrip(req *http.Request) (*http.Response, error) +} + +// HTTPStatus issues a HEAD request via t and returns the HTTP status code. +func HTTPStatus(ctx context.Context, t http.RoundTripper, url string) (int, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodHead, url, nil) + if err != nil { + return 0, fmt.Errorf("http-status: %w", err) + } + resp, err := t.RoundTrip(req) + if err != nil { + return 0, fmt.Errorf("http-status %s: %w", req.URL.Redacted(), err) + } + _ = resp.Body.Close() + return resp.StatusCode, nil +} + +// parseRawHTTPResponse decodes a base64-encoded raw HTTP/1.1 response and parses it. +// 100 Continue responses are consumed and discarded; all other responses are returned as-is. +func parseRawHTTPResponse(encoded string, req *http.Request) (*http.Response, error) { + cleaned := strings.NewReplacer("\r", "", "\n", "").Replace(strings.TrimSpace(encoded)) + if len(cleaned) > maxEncodedResponseSize { + return nil, fmt.Errorf("%w (%d bytes)", errResponseTooLarge, len(cleaned)) + } + raw, err := base64.StdEncoding.DecodeString(cleaned) + if err != nil { + return nil, fmt.Errorf("decode http response: %w", err) + } + reader := bufio.NewReader(bytes.NewReader(raw)) + for { + resp, err := http.ReadResponse(reader, req) + if err != nil { + return nil, fmt.Errorf("parse http response: %w", err) + } + if resp.StatusCode != http.StatusContinue { + return resp, nil + } + _ = resp.Body.Close() + } +} diff --git a/remotefs/http_test.go b/remotefs/http_test.go new file mode 100644 index 00000000..8c1f06df --- /dev/null +++ b/remotefs/http_test.go @@ -0,0 +1,363 @@ +package remotefs_test + +import ( + "context" + "encoding/base64" + "errors" + "io" + "net/http" + "strings" + "testing" + + "github.com/k0sproject/rig/v2/remotefs" + "github.com/k0sproject/rig/v2/rigtest" + "github.com/stretchr/testify/require" +) + +// decodePSScript extracts and decodes the PowerShell script from a command using the -E encoded flag. +// PS -EncodedCommand encoding: script is UTF-16LE encoded then base64 encoded. For ASCII-only scripts +// each character occupies two bytes (char_byte, 0x00), so decoding strips every other byte. +func decodePSScript(psCmd string) string { + const marker = " -E " + idx := strings.Index(psCmd, marker) + if idx < 0 { + return "" + } + raw, err := base64.StdEncoding.DecodeString(psCmd[idx+len(marker):]) + if err != nil { + return "" + } + var sb strings.Builder + for i := 0; i+1 < len(raw); i += 2 { + sb.WriteByte(raw[i]) + } + return sb.String() +} + +func TestPosixRoundTripGET(t *testing.T) { + rawResp := "HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\n\r\nhello" + encoded := base64.StdEncoding.EncodeToString([]byte(rawResp)) + + mr := rigtest.NewMockRunner() + mr.AddCommandOutput(rigtest.Equal("command -v curl"), "/usr/bin/curl") + mr.AddCommandOutput(rigtest.Equal("command -v base64"), "/usr/bin/base64") + mr.AddCommandOutput(rigtest.Contains("--http1.1"), encoded) + f := remotefs.NewPosixFS(mr) + + req, err := http.NewRequest(http.MethodGet, "http://example.com/", nil) + require.NoError(t, err) + + resp, err := f.RoundTrip(req) + require.NoError(t, err) + defer resp.Body.Close() + + require.Equal(t, 200, resp.StatusCode) + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, "hello", string(body)) +} + +func TestPosixRoundTrip404(t *testing.T) { + rawResp := "HTTP/1.1 404 Not Found\r\n\r\n" + encoded := base64.StdEncoding.EncodeToString([]byte(rawResp)) + + mr := rigtest.NewMockRunner() + mr.AddCommandOutput(rigtest.Equal("command -v curl"), "/usr/bin/curl") + mr.AddCommandOutput(rigtest.Equal("command -v base64"), "/usr/bin/base64") + mr.AddCommandOutput(rigtest.Contains("--http1.1"), encoded) + f := remotefs.NewPosixFS(mr) + + req, err := http.NewRequest(http.MethodGet, "http://example.com/missing", nil) + require.NoError(t, err) + + resp, err := f.RoundTrip(req) + require.NoError(t, err) + defer resp.Body.Close() + + require.Equal(t, 404, resp.StatusCode) +} + +func TestPosixRoundTripWithRequestBody(t *testing.T) { + rawResp := "HTTP/1.1 201 Created\r\n\r\n" + encoded := base64.StdEncoding.EncodeToString([]byte(rawResp)) + + mr := rigtest.NewMockRunner() + mr.AddCommandOutput(rigtest.Equal("command -v curl"), "/usr/bin/curl") + mr.AddCommandOutput(rigtest.Equal("command -v base64"), "/usr/bin/base64") + mr.AddCommandOutput(rigtest.Contains("--http1.1"), encoded) + f := remotefs.NewPosixFS(mr) + + req, err := http.NewRequest(http.MethodPost, "http://example.com/api", strings.NewReader(`{"key":"val"}`)) + require.NoError(t, err) + + resp, err := f.RoundTrip(req) + require.NoError(t, err) + defer resp.Body.Close() + + require.Equal(t, 201, resp.StatusCode) + require.Contains(t, mr.LastCommand(), "--data-binary") + // curl adds Content-Type: application/x-www-form-urlencoded for --data-binary when none is set; + // we suppress it so callers get the same default-free behavior as net/http. + require.Contains(t, mr.LastCommand(), "Content-Type:") +} + +func TestPosixRoundTripBodyWithContentType(t *testing.T) { + rawResp := "HTTP/1.1 200 OK\r\n\r\n" + encoded := base64.StdEncoding.EncodeToString([]byte(rawResp)) + + mr := rigtest.NewMockRunner() + mr.AddCommandOutput(rigtest.Equal("command -v curl"), "/usr/bin/curl") + mr.AddCommandOutput(rigtest.Equal("command -v base64"), "/usr/bin/base64") + mr.AddCommandOutput(rigtest.Contains("--http1.1"), encoded) + f := remotefs.NewPosixFS(mr) + + req, err := http.NewRequest(http.MethodPost, "http://example.com/api", strings.NewReader(`{"key":"val"}`)) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + + resp, err := f.RoundTrip(req) + require.NoError(t, err) + defer resp.Body.Close() + + require.Equal(t, 200, resp.StatusCode) + require.Contains(t, mr.LastCommand(), "Content-Type: application/json") +} + +func TestPosixRoundTripCurlUnavailable(t *testing.T) { + mr := rigtest.NewMockRunner() + mr.AddCommandFailure(rigtest.Equal("command -v curl"), errors.New("not found")) + f := remotefs.NewPosixFS(mr) + + req, err := http.NewRequest(http.MethodGet, "http://example.com/", nil) + require.NoError(t, err) + + _, err = f.RoundTrip(req) + require.Error(t, err) +} + +func TestPosixRoundTripBase64Unavailable(t *testing.T) { + mr := rigtest.NewMockRunner() + mr.AddCommandOutput(rigtest.Equal("command -v curl"), "/usr/bin/curl") + mr.AddCommandFailure(rigtest.Equal("command -v base64"), errors.New("not found")) + f := remotefs.NewPosixFS(mr) + + req, err := http.NewRequest(http.MethodGet, "http://example.com/", nil) + require.NoError(t, err) + + _, err = f.RoundTrip(req) + require.Error(t, err) +} + +func TestHTTPStatusFreeFuncPosix(t *testing.T) { + rawResp := "HTTP/1.1 200 OK\r\n\r\n" + encoded := base64.StdEncoding.EncodeToString([]byte(rawResp)) + + mr := rigtest.NewMockRunner() + mr.AddCommandOutput(rigtest.Equal("command -v curl"), "/usr/bin/curl") + mr.AddCommandOutput(rigtest.Equal("command -v base64"), "/usr/bin/base64") + mr.AddCommandOutput(rigtest.Contains("--http1.1"), encoded) + f := remotefs.NewPosixFS(mr) + + code, err := remotefs.HTTPStatus(context.Background(), f, "http://example.com/health") + require.NoError(t, err) + require.Equal(t, 200, code) + + require.Contains(t, mr.LastCommand(), "-I") +} + +func TestWinRoundTripGET(t *testing.T) { + rawResp := "HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\n\r\nhello" + encoded := base64.StdEncoding.EncodeToString([]byte(rawResp)) + + mr := rigtest.NewMockRunner() + mr.Windows = true + mr.AddCommandOutput(rigtest.HasPrefix("powershell.exe"), encoded) + f := remotefs.NewWindowsFS(mr) + + req, err := http.NewRequest(http.MethodGet, "http://example.com/", nil) + require.NoError(t, err) + + resp, err := f.RoundTrip(req) + require.NoError(t, err) + defer resp.Body.Close() + + require.Equal(t, 200, resp.StatusCode) + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, "hello", string(body)) +} + +func TestWinRoundTripCommandError(t *testing.T) { + mr := rigtest.NewMockRunner() + mr.Windows = true + mr.AddCommandFailure(rigtest.HasPrefix("powershell.exe"), errors.New("exit 1")) + f := remotefs.NewWindowsFS(mr) + + req, err := http.NewRequest(http.MethodGet, "http://example.com/", nil) + require.NoError(t, err) + + _, err = f.RoundTrip(req) + require.Error(t, err) +} + +func TestWinRoundTrip404(t *testing.T) { + rawResp := "HTTP/1.1 404 Not Found\r\n\r\n" + encoded := base64.StdEncoding.EncodeToString([]byte(rawResp)) + + mr := rigtest.NewMockRunner() + mr.Windows = true + mr.AddCommandOutput(rigtest.HasPrefix("powershell.exe"), encoded) + f := remotefs.NewWindowsFS(mr) + + req, err := http.NewRequest(http.MethodGet, "http://example.com/missing", nil) + require.NoError(t, err) + + resp, err := f.RoundTrip(req) + require.NoError(t, err) + defer resp.Body.Close() + + require.Equal(t, 404, resp.StatusCode) +} + +func TestWinRoundTripWithRequestBody(t *testing.T) { + rawResp := "HTTP/1.1 201 Created\r\n\r\n" + encoded := base64.StdEncoding.EncodeToString([]byte(rawResp)) + + mr := rigtest.NewMockRunner() + mr.Windows = true + mr.AddCommandOutput(rigtest.HasPrefix("powershell.exe"), encoded) + f := remotefs.NewWindowsFS(mr) + + req, err := http.NewRequest(http.MethodPost, "http://example.com/api", strings.NewReader(`{"key":"val"}`)) + require.NoError(t, err) + + resp, err := f.RoundTrip(req) + require.NoError(t, err) + defer resp.Body.Close() + + require.Equal(t, 201, resp.StatusCode) + require.Contains(t, decodePSScript(mr.LastCommand()), "OpenStandardInput") +} + +func TestRoundTripURLValidation(t *testing.T) { + tests := []struct { + name string + req func() *http.Request + }{ + {"unsupported scheme", func() *http.Request { + req, _ := http.NewRequest(http.MethodGet, "ftp://example.com/", nil) + return req + }}, + {"missing host", func() *http.Request { + req, _ := http.NewRequest(http.MethodGet, "http://example.com/", nil) + req.URL.Host = "" + return req + }}, + {"userinfo in URL", func() *http.Request { + req, _ := http.NewRequest(http.MethodGet, "http://user:pass@example.com/", nil) + return req + }}, + {"CR in URL host", func() *http.Request { + req, _ := http.NewRequest(http.MethodGet, "http://example.com/", nil) + req.URL.Host = "example.com\r" + return req + }}, + {"LF in URL host", func() *http.Request { + req, _ := http.NewRequest(http.MethodGet, "http://example.com/", nil) + req.URL.Host = "example.com\n" + return req + }}, + {"NUL in URL host", func() *http.Request { + req, _ := http.NewRequest(http.MethodGet, "http://example.com/", nil) + req.URL.Host = "example.com\x00" + return req + }}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Posix: validateRoundTripURL runs before requireHTTPTools, so no tool stubs needed. + mr := rigtest.NewMockRunner() + f := remotefs.NewPosixFS(mr) + _, err := f.RoundTrip(tc.req()) + require.Error(t, err) + }) + } +} + +func TestCurlHeaderSanitization(t *testing.T) { + okResp := base64.StdEncoding.EncodeToString([]byte("HTTP/1.1 200 OK\r\n\r\n")) + + stubTools := func(mr *rigtest.MockRunner) { + mr.AddCommandOutput(rigtest.Equal("command -v curl"), "/usr/bin/curl") + mr.AddCommandOutput(rigtest.Equal("command -v base64"), "/usr/bin/base64") + } + + t.Run("CR in header name rejected", func(t *testing.T) { + mr := rigtest.NewMockRunner() + stubTools(mr) + f := remotefs.NewPosixFS(mr) + req, _ := http.NewRequest(http.MethodGet, "http://example.com/", nil) + req.Header["X-Bad\rName"] = []string{"value"} + _, err := f.RoundTrip(req) + require.Error(t, err) + }) + + t.Run("LF in header value rejected", func(t *testing.T) { + mr := rigtest.NewMockRunner() + stubTools(mr) + f := remotefs.NewPosixFS(mr) + req, _ := http.NewRequest(http.MethodGet, "http://example.com/", nil) + req.Header["X-Custom"] = []string{"val\nue"} + _, err := f.RoundTrip(req) + require.Error(t, err) + }) + + t.Run("NUL in header value rejected", func(t *testing.T) { + mr := rigtest.NewMockRunner() + stubTools(mr) + f := remotefs.NewPosixFS(mr) + req, _ := http.NewRequest(http.MethodGet, "http://example.com/", nil) + req.Header["X-Custom"] = []string{"val\x00ue"} + _, err := f.RoundTrip(req) + require.Error(t, err) + }) + + t.Run("req.Host injected as Host header", func(t *testing.T) { + mr := rigtest.NewMockRunner() + stubTools(mr) + mr.AddCommandOutput(rigtest.Contains("--http1.1"), okResp) + f := remotefs.NewPosixFS(mr) + req, _ := http.NewRequest(http.MethodGet, "http://example.com/", nil) + req.Host = "override.example.com" + resp, err := f.RoundTrip(req) + require.NoError(t, err) + _ = resp.Body.Close() + require.Contains(t, mr.LastCommand(), "Host: override.example.com") + }) + + t.Run("Host in req.Header is skipped", func(t *testing.T) { + mr := rigtest.NewMockRunner() + stubTools(mr) + mr.AddCommandOutput(rigtest.Contains("--http1.1"), okResp) + f := remotefs.NewPosixFS(mr) + req, _ := http.NewRequest(http.MethodGet, "http://example.com/", nil) + req.Header.Set("Host", "should-not-appear.example.com") + resp, err := f.RoundTrip(req) + require.NoError(t, err) + _ = resp.Body.Close() + require.NotContains(t, mr.LastCommand(), "should-not-appear.example.com") + }) + + t.Run("Cookie values joined with semicolon", func(t *testing.T) { + mr := rigtest.NewMockRunner() + stubTools(mr) + mr.AddCommandOutput(rigtest.Contains("--http1.1"), okResp) + f := remotefs.NewPosixFS(mr) + req, _ := http.NewRequest(http.MethodGet, "http://example.com/", nil) + req.Header["Cookie"] = []string{"a=1", "b=2"} + resp, err := f.RoundTrip(req) + require.NoError(t, err) + _ = resp.Body.Close() + require.Contains(t, mr.LastCommand(), "Cookie: a=1; b=2") + }) +} diff --git a/remotefs/posixfs.go b/remotefs/posixfs.go index 2e51ded3..deb0b71b 100644 --- a/remotefs/posixfs.go +++ b/remotefs/posixfs.go @@ -7,10 +7,12 @@ import ( "fmt" "io" "io/fs" + "net/http" "os" "path" "strconv" "strings" + "sync" "time" "github.com/k0sproject/rig/v2/cmd" @@ -24,6 +26,8 @@ var ( _ FS = (*PosixFS)(nil) errInvalid = errors.New("invalid") errNoDownloadTool = errors.New("neither curl nor wget is available on the remote host") + errCurlRequired = errors.New("curl is not available on the remote host") + errBase64Required = errors.New("base64 is not available on the remote host") errGrepFailed = errors.New("grep failed") errTestFailed = errors.New("test failed") statCmdGNU = `env -i PATH="$PATH" LC_ALL=C stat -c '%%#f %%s %%.9Y //%%n//' -- %s 2> /dev/null` @@ -44,6 +48,9 @@ type PosixFS struct { statCmd *string chtimesFn func(name string, atime, mtime int64) error timeTrunc time.Duration + + httpToolsOnce sync.Once + httpToolsErr error } // NewPosixFS returns a fs.FS implementation for a remote filesystem that uses POSIX commands for access. @@ -414,6 +421,135 @@ func (s *PosixFS) DownloadURL(url, dst string) error { return fmt.Errorf("download %s: %w", url, errNoDownloadTool) } +// headerKeyExists reports whether h contains a key matching name, case-insensitively. +// http.Header keys assigned directly to the map may not be canonical, so +// http.Header.Get (which only looks up canonical keys) can miss them. +func headerKeyExists(h http.Header, name string) bool { + for k := range h { + if strings.EqualFold(k, name) { + return true + } + } + return false +} + +// curlHeaderArgs builds the -H flag list for curl from an http.Header map. +// Header["Host"] is always skipped (ignored in net/http client requests); host is injected +// when non-empty. Returns errInvalidHeader if any key or value contains CR, LF, or NUL. +func curlHeaderArgs(header http.Header, host string) ([]string, error) { + args := make([]string, 0, (len(header)+1)*2) + for name, vals := range header { + if strings.EqualFold(name, "Host") { + continue + } + if strings.ContainsAny(name, "\r\n\x00") { + return nil, fmt.Errorf("%w: %q", errInvalidHeader, name) + } + if len(vals) == 0 { + continue + } + if strings.EqualFold(name, "Cookie") { + joined := strings.Join(vals, "; ") + if strings.ContainsAny(joined, "\r\n\x00") { + return nil, fmt.Errorf("%w: %q", errInvalidHeader, name) + } + args = append(args, "-H", name+": "+joined) + } else { + for _, v := range vals { + if strings.ContainsAny(v, "\r\n\x00") { + return nil, fmt.Errorf("%w: %q", errInvalidHeader, name) + } + args = append(args, "-H", name+": "+v) + } + } + } + if host != "" { + if strings.ContainsAny(host, "\r\n\x00") { + return nil, fmt.Errorf("%w: %q", errInvalidHeader, "Host") + } + args = append(args, "-H", "Host: "+host) + } + return args, nil +} + +func (s *PosixFS) requireHTTPTools() error { + s.httpToolsOnce.Do(func() { + if _, err := s.LookPath("curl"); err != nil { + s.httpToolsErr = fmt.Errorf("%w: %w", errCurlRequired, err) + return + } + if _, err := s.LookPath("base64"); err != nil { + s.httpToolsErr = fmt.Errorf("%w: %w", errBase64Required, err) + } + }) + return s.httpToolsErr +} + +// buildCurlArgs constructs the curl arguments and execution options for the request. +// It does not verify that curl or base64 are available; callers are expected to +// perform any required tool checks before invoking it. It returns an error only +// if the request headers contain invalid values. +func buildCurlArgs(req *http.Request) ([]string, []cmd.ExecOption, error) { + args := []string{"curl", "-si", "--http1.1", "--raw"} + if !headerKeyExists(req.Header, "Expect") { + args = append(args, "-H", "Expect:") + } + method := req.Method + if method == "" { + method = http.MethodGet + } + hasBody := req.Body != nil && req.Body != http.NoBody + if method == http.MethodHead && !hasBody { + args = append(args, "-I") + } else { + args = append(args, "-X", method) + } + headerArgs, err := curlHeaderArgs(req.Header, req.Host) + if err != nil { + return nil, nil, err + } + args = append(args, headerArgs...) + var execOpts []cmd.ExecOption + if hasBody { + args = append(args, "--data-binary", "@-") + if !headerKeyExists(req.Header, "Content-Type") { + args = append(args, "-H", "Content-Type:") + } + execOpts = append(execOpts, cmd.Stdin(req.Body)) + } + return args, execOpts, nil +} + +// RoundTrip implements http.RoundTripper by executing the request via curl on the remote host. +func (s *PosixFS) RoundTrip(req *http.Request) (*http.Response, error) { + if req == nil { + return nil, errNilRequest + } + if req.URL == nil { + return nil, errNilRequestURL + } + if req.Body != nil && req.Body != http.NoBody { + defer req.Body.Close() + } + if err := validateRoundTripURL(req.URL); err != nil { + return nil, fmt.Errorf("http round-trip %s: %w", req.URL.Redacted(), err) + } + if err := s.requireHTTPTools(); err != nil { + return nil, fmt.Errorf("http round-trip %s: %w", req.URL.Redacted(), err) + } + args, execOpts, err := buildCurlArgs(req) + if err != nil { + return nil, fmt.Errorf("http round-trip %s: %w", req.URL.Redacted(), err) + } + curlCmd := sh.CommandBuilder(sh.Command(args[0], args[1:]...)).Raw("--").Arg(req.URL.String()) + script := `_t=$(mktemp "${TMPDIR:-/tmp}/rigXXXXXX") && ` + curlCmd.String() + ` > "$_t" && base64 < "$_t"; _e=$?; rm -f "$_t" 2>/dev/null; exit "$_e"` + out, err := s.ExecOutputContext(req.Context(), script, append(execOpts, cmd.HideOutput(), cmd.HideCommand())...) + if err != nil { + return nil, fmt.Errorf("http round-trip %s: %w", req.URL.Redacted(), err) + } + return parseRawHTTPResponse(out, req) +} + // FileContains reports whether the file at path contains the given substring. // Returns a not-exist error if the file does not exist. func (s *PosixFS) FileContains(name, substr string) (bool, error) { diff --git a/remotefs/types.go b/remotefs/types.go index 5974b8e2..056bc33c 100644 --- a/remotefs/types.go +++ b/remotefs/types.go @@ -17,6 +17,7 @@ type FS interface { fs.ReadFileFS fs.ReadDirFS OS + HTTPTransport Opener Sha256summer } @@ -53,7 +54,6 @@ type OS interface { //nolint:interfacebloat // intentionally large interface Truncate(path string, size int64) error Getenv(key string) string Rename(oldpath, newpath string) error - DownloadURL(url, dst string) error FileContains(path, substr string) (bool, error) IsContainer() (bool, error) Hostname() (string, error) diff --git a/remotefs/upload_test.go b/remotefs/upload_test.go index 261b77cb..0bc18157 100644 --- a/remotefs/upload_test.go +++ b/remotefs/upload_test.go @@ -5,6 +5,7 @@ import ( "encoding/hex" "io" "io/fs" + "net/http" "os" "testing" "time" @@ -64,6 +65,7 @@ func (f *uploadFS) Truncate(_ string, _ int64) error { pa func (f *uploadFS) Getenv(_ string) string { panic("not implemented") } func (f *uploadFS) Rename(_, _ string) error { panic("not implemented") } func (f *uploadFS) DownloadURL(_ string, _ string) error { panic("not implemented") } +func (f *uploadFS) RoundTrip(_ *http.Request) (*http.Response, error) { panic("not implemented") } func (f *uploadFS) FileContains(_ string, _ string) (bool, error) { panic("not implemented") } func (f *uploadFS) IsContainer() (bool, error) { panic("not implemented") } func (f *uploadFS) Hostname() (string, error) { panic("not implemented") } diff --git a/remotefs/winfs.go b/remotefs/winfs.go index 6aaaaceb..fdb1273c 100644 --- a/remotefs/winfs.go +++ b/remotefs/winfs.go @@ -9,6 +9,7 @@ import ( "fmt" "io" "io/fs" + "net/http" "os" "strconv" "strings" @@ -437,6 +438,115 @@ try { return nil } +// winHeaderScript builds the PowerShell $params['Headers'] assignment from an http.Header map. +// Header["Host"] is always skipped (ignored in net/http client requests); host is injected when +// non-empty. Returns errInvalidHeader if any key, value, or host contains CR, LF, or NUL. +func winHeaderScript(header http.Header, host string) (string, error) { + lines := make([]string, 0, len(header)+1) + for name, vals := range header { + if strings.EqualFold(name, "Host") { + continue + } + if strings.ContainsAny(name, "\r\n\x00") { + return "", fmt.Errorf("%w: %q", errInvalidHeader, name) + } + if len(vals) == 0 { + continue + } + sep := ", " + if strings.EqualFold(name, "Cookie") { + sep = "; " + } + joined := strings.Join(vals, sep) + if strings.ContainsAny(joined, "\r\n\x00") { + return "", fmt.Errorf("%w: %q", errInvalidHeader, name) + } + lines = append(lines, " "+ps.SingleQuote(name)+"="+ps.SingleQuote(joined)) + } + if host != "" { + if strings.ContainsAny(host, "\r\n\x00") { + return "", fmt.Errorf("%w: host %q", errInvalidHeader, host) + } + lines = append(lines, " "+ps.SingleQuote("Host")+"="+ps.SingleQuote(host)) + } + if len(lines) == 0 { + return "", nil + } + return "$params['Headers']=@{\n" + strings.Join(lines, "\n") + "\n}", nil +} + +// RoundTrip implements http.RoundTripper by executing the request via Invoke-WebRequest on the remote host. +func (s *WinFS) RoundTrip(req *http.Request) (*http.Response, error) { + if req == nil { + return nil, errNilRequest + } + if req.URL == nil { + return nil, errNilRequestURL + } + if req.Body != nil && req.Body != http.NoBody { + defer req.Body.Close() + } + if err := validateRoundTripURL(req.URL); err != nil { + return nil, fmt.Errorf("http round-trip %s: %w", req.URL.Redacted(), err) + } + method := req.Method + if method == "" { + method = http.MethodGet + } + headerScript, err := winHeaderScript(req.Header, req.Host) + if err != nil { + return nil, fmt.Errorf("http round-trip %s: %w", req.URL.Redacted(), err) + } + + var execOpts []cmd.ExecOption + bodyScript := "" + if req.Body != nil && req.Body != http.NoBody { + execOpts = append(execOpts, cmd.Stdin(req.Body)) + bodyScript = `$ms=[IO.MemoryStream]::new() +[Console]::OpenStandardInput().CopyTo($ms) +$params['Body']=$ms.ToArray() +$ms.Dispose()` + } + + script := fmt.Sprintf(`$ProgressPreference='SilentlyContinue' +$ErrorActionPreference='Stop' +function ConvertTo-RawResponse($code,$desc,$hdrs,$body){ + $crlf=[char]13+[char]10 + $head="HTTP/1.1 $code $desc"+$crlf + if($hdrs.PSObject.Properties['AllKeys']){$hdrs.AllKeys|ForEach-Object{$key=$_;if($key -ne 'Transfer-Encoding' -and $key -ne 'Content-Length'){$hdrs.GetValues($key)|ForEach-Object{$head+="${key}: $_"+$crlf}}}}else{$hdrs.GetEnumerator()|ForEach-Object{$key=$_.Key;if($key -ne 'Transfer-Encoding' -and $key -ne 'Content-Length'){@($_.Value)|ForEach-Object{$head+="${key}: $_"+$crlf}}}} + $head+="Content-Length: "+$body.Length+$crlf + $head+=$crlf + $hb=[Text.Encoding]::GetEncoding(28591).GetBytes($head) + $out=[byte[]]::new($hb.Length+$body.Length) + [Buffer]::BlockCopy($hb,0,$out,0,$hb.Length) + [Buffer]::BlockCopy($body,0,$out,$hb.Length,$body.Length) + [Convert]::ToBase64String($out) +} +$params=@{Uri=%s;Method=%s;UseBasicParsing=$true;ErrorAction='Stop';MaximumRedirection=0} +%s +%s +try{ + $r=Invoke-WebRequest @params + $b=if($r.PSObject.Properties['RawContentStream']){$r.RawContentStream.ToArray()}elseif($r.PSObject.Properties['BaseResponse']){$ms2=New-Object System.IO.MemoryStream;$r.BaseResponse.GetResponseStream().CopyTo($ms2);$ms2.ToArray()}else{[byte[]]$r.Content} + ConvertTo-RawResponse ([int]$r.StatusCode) $r.StatusDescription $r.Headers $b +}catch [System.Net.WebException]{ + if($_.Exception.Response -ne $null){ + $er=$_.Exception.Response + $ms=New-Object System.IO.MemoryStream + $er.GetResponseStream().CopyTo($ms) + $eh=@{} + $er.Headers.AllKeys|ForEach-Object{$eh[$_]=$er.Headers[$_]} + ConvertTo-RawResponse ([int]$er.StatusCode) $er.StatusDescription $eh $ms.ToArray() + }else{Write-Error $_.Exception.Message;exit 1} +}catch{Write-Error $_.Exception.Message;exit 1}`, ps.SingleQuote(req.URL.String()), ps.SingleQuote(method), headerScript, bodyScript) + + out, err := s.ExecOutputContext(req.Context(), script, append(execOpts, cmd.PS(), cmd.HideOutput(), cmd.HideCommand())...) + if err != nil { + return nil, fmt.Errorf("http round-trip %s: %w", req.URL.Redacted(), err) + } + return parseRawHTTPResponse(out, req) +} + // FileContains reports whether the file at path contains the given substring. // Returns a not-exist error if the file does not exist. func (s *WinFS) FileContains(name, substr string) (bool, error) {