From 0fd64e71e4db7075d4cc0dc7d28cf4066d9eb81a Mon Sep 17 00:00:00 2001 From: Dmitri Popov Date: Wed, 25 Feb 2026 10:49:51 +0100 Subject: [PATCH 1/6] Extract shared helpers and add env-gated network tests --- network_env_test.go | 257 +++++++++++++++++++++ tftp_test.go | 546 -------------------------------------------- util_test.go | 360 +++++++++++++++++++++++++++++ 3 files changed, 617 insertions(+), 546 deletions(-) create mode 100644 network_env_test.go create mode 100644 util_test.go diff --git a/network_env_test.go b/network_env_test.go new file mode 100644 index 0000000..eaf0e96 --- /dev/null +++ b/network_env_test.go @@ -0,0 +1,257 @@ +package tftp + +import ( + "bytes" + "fmt" + "io" + "math/rand" + "net" + "os" + "sync" + "testing" + "time" +) + +const networkEnvVar = "TFTP_RUN_NETWORK_ENV_TESTS" + +func requireNetworkEnv(t *testing.T) { + t.Helper() + if os.Getenv(networkEnvVar) != "1" { + t.Skipf("set %s=1 to run environment-dependent network tests", networkEnvVar) + } +} + +// TestRequestPacketInfo checks that request packet destination address +// obtained by server using out-of-band socket info is sane. +// It creates server and tries to do transfers using different local interfaces. +// NB: Test ignores transfer errors and validates RequestPacketInfo only +// if transfer is completed successfully. So it checks that LocalIP returns +// correct result if any result is returned, but does not check if result was +// returned at all when it should. +func TestRequestPacketInfo(t *testing.T) { + requireNetworkEnv(t) + + // localIP keeps value received from RequestPacketInfo.LocalIP + // call inside handler. + // If RequestPacketInfo is not supported, value is set to unspecified + // IP address. + var localIP net.IP + var localIPMu sync.Mutex + + s := NewServer( + func(_ string, rf io.ReaderFrom) error { + localIPMu.Lock() + if rpi, ok := rf.(RequestPacketInfo); ok { + localIP = rpi.LocalIP() + } else { + localIP = net.IP{} + } + localIPMu.Unlock() + _, err := rf.ReadFrom(io.LimitReader( + newRandReader(rand.NewSource(42)), 42)) + if err != nil { + t.Logf("sending to client: %v", err) + } + return nil + }, + func(_ string, wt io.WriterTo) error { + localIPMu.Lock() + if rpi, ok := wt.(RequestPacketInfo); ok { + localIP = rpi.LocalIP() + } else { + localIP = net.IP{} + } + localIPMu.Unlock() + _, err := wt.WriteTo(io.Discard) + if err != nil { + t.Logf("receiving from client: %v", err) + } + return nil + }, + ) + + conn, err := net.ListenUDP("udp", &net.UDPAddr{}) + if err != nil { + t.Fatalf("listen UDP: %v", err) + } + + _, port, err := net.SplitHostPort(conn.LocalAddr().String()) + if err != nil { + t.Fatalf("parsing server port: %v", err) + } + + // Start server + errChan := make(chan error, 1) + go func() { + err := s.Serve(conn) + if err != nil { + errChan <- fmt.Errorf("serve: %w", err) + } + }() + defer func() { + s.Shutdown() + select { + case err := <-errChan: + t.Errorf("server error: %v", err) + default: + } + }() + + addrs, err := net.InterfaceAddrs() + if err != nil { + t.Fatalf("listing interface addresses: %v", err) + } + + for _, addr := range addrs { + ip := networkIP(addr.(*net.IPNet)) + if ip == nil { + continue + } + + c, err := NewClient(net.JoinHostPort(ip.String(), port)) + if err != nil { + t.Fatalf("new client: %v", err) + } + + // Skip re-tries to skip non-routable interfaces faster + c.SetRetries(0) + + ot, err := c.Send("a", "octet") + if err != nil { + t.Logf("start sending to %v: %v", ip, err) + continue + } + _, err = ot.ReadFrom(io.LimitReader( + newRandReader(rand.NewSource(42)), 42)) + if err != nil { + t.Logf("sending to %v: %v", ip, err) + continue + } + + // Check that read handler received IP that was used + // to create the client. + localIPMu.Lock() + if localIP != nil && !localIP.IsUnspecified() { // Skip check if no packet info + if !localIP.Equal(ip) { + t.Errorf("sent to: %v, request packet: %v", ip, localIP) + } + } else { + fmt.Printf("Skip %v\n", ip) + } + localIPMu.Unlock() + + it, err := c.Receive("a", "octet") + if err != nil { + t.Logf("start receiving from %v: %v", ip, err) + continue + } + _, err = it.WriteTo(io.Discard) + if err != nil { + t.Logf("receiving from %v: %v", ip, err) + continue + } + + // Check that write handler received IP that was used + // to create the client. + localIPMu.Lock() + if localIP != nil && !localIP.IsUnspecified() { // Skip check if no packet info + if !localIP.Equal(ip) { + t.Errorf("sent to: %v, request packet: %v", ip, localIP) + } + } else { + fmt.Printf("Skip %v\n", ip) + } + localIPMu.Unlock() + + fmt.Printf("Done %v\n", ip) + } +} + +func networkIP(n *net.IPNet) net.IP { + if ip := n.IP.To4(); ip != nil { + return ip + } + if len(n.IP) == net.IPv6len { + return n.IP + } + return nil +} + +func TestSetLocalAddr(t *testing.T) { + requireNetworkEnv(t) + + interfaces, err := net.Interfaces() + if err != nil { + t.Fatalf("failed to get network interfaces: %v", err) + } + + var addrs []net.IP + for _, i := range interfaces { + interfaceAddrs, err := i.Addrs() + if err != nil { + continue + } + for _, addr := range interfaceAddrs { + var ip net.IP + switch v := addr.(type) { + case *net.IPNet: + ip = v.IP + case *net.IPAddr: + ip = v.IP + } + if ip == nil || ip.IsLoopback() || ip.IsLinkLocalUnicast() { + continue + } + addrs = append(addrs, ip) + } + } + + for _, addrA := range addrs { + for _, addrB := range addrs { + t.Run(fmt.Sprintf("%s receives from %s", addrA.String(), addrB.String()), func(t *testing.T) { + var mu sync.Mutex + var remoteAddr net.UDPAddr + s := NewServer(nil, func(filename string, wt io.WriterTo) error { + _, err := wt.WriteTo(io.Discard) + if err != nil { + return err + } + mu.Lock() + defer mu.Unlock() + remoteAddr = wt.(IncomingTransfer).RemoteAddr() + return nil + }) + s.SetTimeout(2 * time.Second) + + conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: addrA, Port: 0}) + if err != nil { + t.Fatalf("listen udp: %v", err) + } + + go s.Serve(conn) + defer s.Shutdown() + + c, err := NewClient(conn.LocalAddr().String()) + if err != nil { + t.Fatalf("creating client: %v", err) + } + c.SetLocalAddr(addrB.String()) + + wt, err := c.Send("testfile", "octet") + if err != nil { + return + } + _, _ = wt.ReadFrom(bytes.NewReader([]byte("test data for cross-interface"))) + + // Compare addrB to remoteAddr in thread-safe manner + mu.Lock() + defer mu.Unlock() + if remoteAddr.IP == nil { + t.Error("remote address was not captured from server") + } else if !remoteAddr.IP.Equal(addrB) { + t.Errorf("remote address mismatch: expected %s, got %s", addrB.String(), remoteAddr.IP.String()) + } + }) + } + } +} diff --git a/tftp_test.go b/tftp_test.go index ed61851..c2b25a2 100644 --- a/tftp_test.go +++ b/tftp_test.go @@ -7,58 +7,11 @@ import ( "io" "math/rand" "net" - "os" - "strconv" - "sync" "testing" "testing/iotest" "time" ) -var localhost = determineLocalhost() - -func determineLocalhost() string { - l, err := net.ListenTCP("tcp", nil) - if err != nil { - panic(fmt.Sprintf("ListenTCP error: %s", err)) - } - _, lport, _ := net.SplitHostPort(l.Addr().String()) - defer l.Close() - - lo := make(chan string) - - go func() { - for { - conn, err := l.Accept() - if err != nil { - break - } - conn.Close() - } - }() - - go func() { - port, _ := strconv.Atoi(lport) - for _, af := range []string{"tcp6", "tcp4"} { - conn, err := net.DialTCP(af, &net.TCPAddr{}, &net.TCPAddr{Port: port}) - if err == nil { - conn.Close() - host, _, _ := net.SplitHostPort(conn.LocalAddr().String()) - lo <- host - return - } - } - panic("could not determine address family") - }() - - return <-lo -} - -func localSystem(c *net.UDPConn) string { - _, port, _ := net.SplitHostPort(c.LocalAddr().String()) - return net.JoinHostPort(localhost, port) -} - func TestPackUnpack(t *testing.T) { v := []string{"test-filename/with-subdir"} testOptsList := []options{ @@ -148,30 +101,6 @@ func Test1810(t *testing.T) { testSendReceive(t, c, 9000+1810) } -type testHook struct { - *sync.Mutex - transfersCompleted int - transfersFailed int -} - -func newTestHook() *testHook { - return &testHook{ - Mutex: &sync.Mutex{}, - } -} - -func (h *testHook) OnSuccess(result TransferStats) { - h.Lock() - defer h.Unlock() - h.transfersCompleted++ -} - -func (h *testHook) OnFailure(result TransferStats, err error) { - h.Lock() - defer h.Unlock() - h.transfersFailed++ -} - func TestHookSuccess(t *testing.T) { s, c := makeTestServer(false) th := newTestHook() @@ -329,48 +258,6 @@ func TestNotFound(t *testing.T) { t.Logf("receiving file that does not exist: %v", err) } -func testSendReceive(t *testing.T, client *Client, length int64) { - filename := fmt.Sprintf("length-%d-bytes", length) - mode := "octet" - writeTransfer, err := client.Send(filename, mode) - if err != nil { - t.Fatalf("requesting write %s: %v", filename, err) - } - r := io.LimitReader(newRandReader(rand.NewSource(42)), length) - n, err := writeTransfer.ReadFrom(r) - if err != nil { - t.Fatalf("%s write error: %v", filename, err) - } - if n != length { - t.Errorf("%s write length mismatch: %d != %d", filename, n, length) - } - readTransfer, err := client.Receive(filename, mode) - if err != nil { - t.Fatalf("requesting read %s: %v", filename, err) - } - if it, ok := readTransfer.(IncomingTransfer); ok { - if n, ok := it.Size(); ok { - fmt.Printf("Transfer size: %d\n", n) - if n != length { - t.Errorf("tsize mismatch: %d vs %d", n, length) - } - } - } - buf := &bytes.Buffer{} - n, err = readTransfer.WriteTo(buf) - if err != nil { - t.Fatalf("%s read error: %v", filename, err) - } - if n != length { - t.Errorf("%s read length mismatch: %d != %d", filename, n, length) - } - bs, _ := io.ReadAll(io.LimitReader( - newRandReader(rand.NewSource(42)), length)) - if !bytes.Equal(bs, buf.Bytes()) { - t.Errorf("\nsent: %x\nrcvd: %x", bs, buf) - } -} - func TestSendTsizeFromSeek(t *testing.T) { // create read-only server s := NewServer(func(filename string, rf io.ReaderFrom) error { @@ -423,39 +310,6 @@ func TestSendTsizeFromSeek(t *testing.T) { r.WriteTo(io.Discard) } -type testBackend struct { - m map[string][]byte - mu sync.Mutex -} - -func makeTestServer(singlePort bool) (*Server, *Client) { - b := &testBackend{} - b.m = make(map[string][]byte) - - // Create server - s := NewServer(b.handleRead, b.handleWrite) - - if singlePort { - s.SetBlockSize(2000) - s.EnableSinglePort() - } - - conn, err := net.ListenUDP("udp", &net.UDPAddr{}) - if err != nil { - panic(err) - } - - go s.Serve(conn) - - // Create client for that server - c, err := NewClient(localSystem(conn)) - if err != nil { - panic(err) - } - - return s, c -} - func TestNoHandlers(t *testing.T) { s := NewServer(nil, nil) @@ -482,151 +336,11 @@ func TestNoHandlers(t *testing.T) { } } -func (b *testBackend) handleWrite(filename string, wt io.WriterTo) error { - b.mu.Lock() - defer b.mu.Unlock() - _, ok := b.m[filename] - if ok { - fmt.Fprintf(os.Stderr, "File %s already exists\n", filename) - return fmt.Errorf("file already exists") - } - if t, ok := wt.(IncomingTransfer); ok { - if n, ok := t.Size(); ok { - fmt.Printf("Transfer size: %d\n", n) - } - } - buf := &bytes.Buffer{} - _, err := wt.WriteTo(buf) - if err != nil { - fmt.Fprintf(os.Stderr, "Can't receive %s: %v\n", filename, err) - return err - } - b.m[filename] = buf.Bytes() - return nil -} - -func (b *testBackend) handleRead(filename string, rf io.ReaderFrom) error { - b.mu.Lock() - defer b.mu.Unlock() - bs, ok := b.m[filename] - if !ok { - fmt.Fprintf(os.Stderr, "File %s not found\n", filename) - return fmt.Errorf("file not found") - } - if t, ok := rf.(OutgoingTransfer); ok { - t.SetSize(int64(len(bs))) - } - _, err := rf.ReadFrom(bytes.NewBuffer(bs)) - if err != nil { - fmt.Fprintf(os.Stderr, "Can't send %s: %v\n", filename, err) - return err - } - return nil -} - -type randReader struct { - src rand.Source - next int64 - i int8 -} - -func newRandReader(src rand.Source) io.Reader { - r := &randReader{ - src: src, - next: src.Int63(), - } - return r -} - -func (r *randReader) Read(p []byte) (n int, err error) { - next, i := r.next, r.i - for n = 0; n < len(p); n++ { - if i == 7 { - next, i = r.src.Int63(), 0 - } - p[n] = byte(next) - next >>= 8 - i++ - } - r.next, r.i = next, i - return -} - -func serverTimeoutSendTest(s *Server, c *Client, t *testing.T) { - s.SetTimeout(time.Second) - s.SetRetries(2) - sec := make(chan error, 1) - s.mu.Lock() - s.readHandler = func(filename string, rf io.ReaderFrom) error { - r := io.LimitReader(newRandReader(rand.NewSource(42)), 80000) - _, err := rf.ReadFrom(r) - sec <- err - return err - } - s.mu.Unlock() - defer s.Shutdown() - filename := "test-server-send-timeout" - mode := "octet" - readTransfer, err := c.Receive(filename, mode) - if err != nil { - t.Fatalf("requesting read %s: %v", filename, err) - } - w := &slowWriter{ - n: 3, - delay: 8 * time.Second, - } - _, _ = readTransfer.WriteTo(w) - servErr := <-sec - netErr, ok := servErr.(net.Error) - if !ok { - t.Fatalf("network error expected: %T", servErr) - } - if !netErr.Timeout() { - t.Fatalf("timout is expected: %v", servErr) - } - -} - func TestServerSendTimeout(t *testing.T) { s, c := makeTestServer(false) serverTimeoutSendTest(s, c, t) } -func serverReceiveTimeoutTest(s *Server, c *Client, t *testing.T) { - s.SetTimeout(time.Second) - s.SetRetries(2) - sec := make(chan error, 1) - s.mu.Lock() - s.writeHandler = func(filename string, wt io.WriterTo) error { - buf := &bytes.Buffer{} - _, err := wt.WriteTo(buf) - sec <- err - return err - } - s.mu.Unlock() - defer s.Shutdown() - filename := "test-server-receive-timeout" - mode := "octet" - writeTransfer, err := c.Send(filename, mode) - if err != nil { - t.Fatalf("requesting write %s: %v", filename, err) - } - r := &slowReader{ - r: io.LimitReader(newRandReader(rand.NewSource(42)), 80000), - n: 3, - delay: 8 * time.Second, - } - _, _ = writeTransfer.ReadFrom(r) - servErr := <-sec - netErr, ok := servErr.(net.Error) - if !ok { - t.Fatalf("network error expected: %T", servErr) - } - if !netErr.Timeout() { - t.Fatalf("timout is expected: %v", servErr) - } -} - func TestServerReceiveTimeout(t *testing.T) { s, c := makeTestServer(false) serverReceiveTimeoutTest(s, c, t) @@ -697,189 +411,6 @@ func TestClientSendTimeout(t *testing.T) { } } -type slowReader struct { - r io.Reader - n int64 - delay time.Duration -} - -func (r *slowReader) Read(p []byte) (n int, err error) { - if r.n > 0 { - r.n-- - return r.r.Read(p) - } - time.Sleep(r.delay) - return r.r.Read(p) -} - -type slowWriter struct { - n int64 - delay time.Duration -} - -func (r *slowWriter) Write(p []byte) (n int, err error) { - if r.n > 0 { - r.n-- - return len(p), nil - } - time.Sleep(r.delay) - return len(p), nil -} - -// TestRequestPacketInfo checks that request packet destination address -// obtained by server using out-of-band socket info is sane. -// It creates server and tries to do transfers using different local interfaces. -// NB: Test ignores transfer errors and validates RequestPacketInfo only -// if transfer is completed successfully. So it checks that LocalIP returns -// correct result if any result is returned, but does not check if result was -// returned at all when it should. -func TestRequestPacketInfo(t *testing.T) { - // localIP keeps value received from RequestPacketInfo.LocalIP - // call inside handler. - // If RequestPacketInfo is not supported, value is set to unspecified - // IP address. - var localIP net.IP - var localIPMu sync.Mutex - - s := NewServer( - func(_ string, rf io.ReaderFrom) error { - localIPMu.Lock() - if rpi, ok := rf.(RequestPacketInfo); ok { - localIP = rpi.LocalIP() - } else { - localIP = net.IP{} - } - localIPMu.Unlock() - _, err := rf.ReadFrom(io.LimitReader( - newRandReader(rand.NewSource(42)), 42)) - if err != nil { - t.Logf("sending to client: %v", err) - } - return nil - }, - func(_ string, wt io.WriterTo) error { - localIPMu.Lock() - if rpi, ok := wt.(RequestPacketInfo); ok { - localIP = rpi.LocalIP() - } else { - localIP = net.IP{} - } - localIPMu.Unlock() - _, err := wt.WriteTo(io.Discard) - if err != nil { - t.Logf("receiving from client: %v", err) - } - return nil - }, - ) - - conn, err := net.ListenUDP("udp", &net.UDPAddr{}) - if err != nil { - t.Fatalf("listen UDP: %v", err) - } - - _, port, err := net.SplitHostPort(conn.LocalAddr().String()) - if err != nil { - t.Fatalf("parsing server port: %v", err) - } - - // Start server - errChan := make(chan error, 1) - go func() { - err := s.Serve(conn) - if err != nil { - errChan <- fmt.Errorf("serve: %w", err) - } - }() - defer func() { - s.Shutdown() - select { - case err := <-errChan: - t.Errorf("server error: %v", err) - default: - } - }() - - addrs, err := net.InterfaceAddrs() - if err != nil { - t.Fatalf("listing interface addresses: %v", err) - } - - for _, addr := range addrs { - ip := networkIP(addr.(*net.IPNet)) - if ip == nil { - continue - } - - c, err := NewClient(net.JoinHostPort(ip.String(), port)) - if err != nil { - t.Fatalf("new client: %v", err) - } - - // Skip re-tries to skip non-routable interfaces faster - c.SetRetries(0) - - ot, err := c.Send("a", "octet") - if err != nil { - t.Logf("start sending to %v: %v", ip, err) - continue - } - _, err = ot.ReadFrom(io.LimitReader( - newRandReader(rand.NewSource(42)), 42)) - if err != nil { - t.Logf("sending to %v: %v", ip, err) - continue - } - - // Check that read handler received IP that was used - // to create the client. - localIPMu.Lock() - if localIP != nil && !localIP.IsUnspecified() { // Skip check if no packet info - if !localIP.Equal(ip) { - t.Errorf("sent to: %v, request packet: %v", ip, localIP) - } - } else { - fmt.Printf("Skip %v\n", ip) - } - localIPMu.Unlock() - - it, err := c.Receive("a", "octet") - if err != nil { - t.Logf("start receiving from %v: %v", ip, err) - continue - } - _, err = it.WriteTo(io.Discard) - if err != nil { - t.Logf("receiving from %v: %v", ip, err) - continue - } - - // Check that write handler received IP that was used - // to create the client. - localIPMu.Lock() - if localIP != nil && !localIP.IsUnspecified() { // Skip check if no packet info - if !localIP.Equal(ip) { - t.Errorf("sent to: %v, request packet: %v", ip, localIP) - } - } else { - fmt.Printf("Skip %v\n", ip) - } - localIPMu.Unlock() - - fmt.Printf("Done %v\n", ip) - } -} - -func networkIP(n *net.IPNet) net.IP { - if ip := n.IP.To4(); ip != nil { - return ip - } - if len(n.IP) == net.IPv6len { - return n.IP - } - return nil -} - // TestFileIOExceptions checks that errors returned by io.Reader or io.Writer used by // the handler are handled correctly. func TestReadWriteErrors(t *testing.T) { @@ -1075,80 +606,3 @@ func testShutdownDuringTransfer(t *testing.T, singlePort bool) { t.Error("client did not finish in time") } } - -func TestSetLocalAddr(t *testing.T) { - interfaces, err := net.Interfaces() - if err != nil { - t.Fatalf("failed to get network interfaces: %v", err) - } - - var addrs []net.IP - for _, i := range interfaces { - interfaceAddrs, err := i.Addrs() - if err != nil { - continue - } - for _, addr := range interfaceAddrs { - var ip net.IP - switch v := addr.(type) { - case *net.IPNet: - ip = v.IP - case *net.IPAddr: - ip = v.IP - } - if ip == nil || ip.IsLoopback() || ip.IsLinkLocalUnicast() { - continue - } - addrs = append(addrs, ip) - } - } - - for _, addrA := range addrs { - for _, addrB := range addrs { - t.Run(fmt.Sprintf("%s receives from %s", addrA.String(), addrB.String()), func(t *testing.T) { - var mu sync.Mutex - var remoteAddr net.UDPAddr - s := NewServer(nil, func(filename string, wt io.WriterTo) error { - _, err := wt.WriteTo(io.Discard) - if err != nil { - return err - } - mu.Lock() - defer mu.Unlock() - remoteAddr = wt.(IncomingTransfer).RemoteAddr() - return nil - }) - s.SetTimeout(2 * time.Second) - - conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: addrA, Port: 0}) - if err != nil { - panic(err) - } - - go s.Serve(conn) - defer s.Shutdown() - - c, err := NewClient(conn.LocalAddr().String()) - if err != nil { - t.Fatalf("creating client: %v", err) - } - c.SetLocalAddr(addrB.String()) - - wt, err := c.Send("testfile", "octet") - if err != nil { - return - } - _, _ = wt.ReadFrom(bytes.NewReader([]byte("test data for cross-interface"))) - - // Compare addrB to remoteAddr in thread-safe manner - mu.Lock() - defer mu.Unlock() - if remoteAddr.IP == nil { - t.Error("remote address was not captured from server") - } else if !remoteAddr.IP.Equal(addrB) { - t.Errorf("remote address mismatch: expected %s, got %s", addrB.String(), remoteAddr.IP.String()) - } - }) - } - } -} diff --git a/util_test.go b/util_test.go new file mode 100644 index 0000000..7c40d10 --- /dev/null +++ b/util_test.go @@ -0,0 +1,360 @@ +package tftp + +import ( + "bytes" + "fmt" + "io" + "math/rand" + "net" + "os" + "strconv" + "sync" + "testing" + "time" +) + +// Shared test fixtures/helpers only. Do not add test cases here. + +type transferMode string + +const ( + modeRegular transferMode = "regular" + modeSinglePort transferMode = "single_port" +) + +func forModes(t *testing.T, fn func(t *testing.T, mode transferMode)) { + t.Helper() + for _, mode := range []transferMode{modeRegular, modeSinglePort} { + mode := mode + t.Run(string(mode), func(t *testing.T) { + fn(t, mode) + }) + } +} + +func newFixture(t *testing.T, mode transferMode) (*Server, *Client) { + t.Helper() + switch mode { + case modeRegular: + return makeTestServer(false) + case modeSinglePort: + return makeTestServer(true) + default: + t.Fatalf("unknown transfer mode: %q", mode) + return nil, nil + } +} + +var localhost = determineLocalhost() + +func determineLocalhost() string { + l, err := net.ListenTCP("tcp", nil) + if err != nil { + panic(fmt.Sprintf("ListenTCP error: %s", err)) + } + _, lport, _ := net.SplitHostPort(l.Addr().String()) + defer l.Close() + + lo := make(chan string) + + go func() { + for { + conn, err := l.Accept() + if err != nil { + break + } + conn.Close() + } + }() + + go func() { + port, _ := strconv.Atoi(lport) + for _, af := range []string{"tcp6", "tcp4"} { + conn, err := net.DialTCP(af, &net.TCPAddr{}, &net.TCPAddr{Port: port}) + if err == nil { + conn.Close() + host, _, _ := net.SplitHostPort(conn.LocalAddr().String()) + lo <- host + return + } + } + panic("could not determine address family") + }() + + return <-lo +} + +func localSystem(c *net.UDPConn) string { + _, port, _ := net.SplitHostPort(c.LocalAddr().String()) + return net.JoinHostPort(localhost, port) +} + +type testHook struct { + *sync.Mutex + transfersCompleted int + transfersFailed int +} + +func newTestHook() *testHook { + return &testHook{ + Mutex: &sync.Mutex{}, + } +} + +func (h *testHook) OnSuccess(result TransferStats) { + h.Lock() + defer h.Unlock() + h.transfersCompleted++ +} + +func (h *testHook) OnFailure(result TransferStats, err error) { + h.Lock() + defer h.Unlock() + h.transfersFailed++ +} + +func testSendReceive(t *testing.T, client *Client, length int64) { + t.Helper() + filename := fmt.Sprintf("length-%d-bytes", length) + mode := "octet" + writeTransfer, err := client.Send(filename, mode) + if err != nil { + t.Fatalf("requesting write %s: %v", filename, err) + } + r := io.LimitReader(newRandReader(rand.NewSource(42)), length) + n, err := writeTransfer.ReadFrom(r) + if err != nil { + t.Fatalf("%s write error: %v", filename, err) + } + if n != length { + t.Errorf("%s write length mismatch: %d != %d", filename, n, length) + } + readTransfer, err := client.Receive(filename, mode) + if err != nil { + t.Fatalf("requesting read %s: %v", filename, err) + } + if it, ok := readTransfer.(IncomingTransfer); ok { + if n, ok := it.Size(); ok { + fmt.Printf("Transfer size: %d\n", n) + if n != length { + t.Errorf("tsize mismatch: %d vs %d", n, length) + } + } + } + buf := &bytes.Buffer{} + n, err = readTransfer.WriteTo(buf) + if err != nil { + t.Fatalf("%s read error: %v", filename, err) + } + if n != length { + t.Errorf("%s read length mismatch: %d != %d", filename, n, length) + } + bs, _ := io.ReadAll(io.LimitReader( + newRandReader(rand.NewSource(42)), length)) + if !bytes.Equal(bs, buf.Bytes()) { + t.Errorf("\nsent: %x\nrcvd: %x", bs, buf) + } +} + +type testBackend struct { + m map[string][]byte + mu sync.Mutex +} + +func makeTestServer(singlePort bool) (*Server, *Client) { + b := &testBackend{} + b.m = make(map[string][]byte) + + // Create server + s := NewServer(b.handleRead, b.handleWrite) + + if singlePort { + s.SetBlockSize(2000) + s.EnableSinglePort() + } + + conn, err := net.ListenUDP("udp", &net.UDPAddr{}) + if err != nil { + panic(err) + } + + go s.Serve(conn) + + // Create client for that server + c, err := NewClient(localSystem(conn)) + if err != nil { + panic(err) + } + + return s, c +} + +func (b *testBackend) handleWrite(filename string, wt io.WriterTo) error { + b.mu.Lock() + defer b.mu.Unlock() + _, ok := b.m[filename] + if ok { + fmt.Fprintf(os.Stderr, "File %s already exists\n", filename) + return fmt.Errorf("file already exists") + } + if t, ok := wt.(IncomingTransfer); ok { + if n, ok := t.Size(); ok { + fmt.Printf("Transfer size: %d\n", n) + } + } + buf := &bytes.Buffer{} + _, err := wt.WriteTo(buf) + if err != nil { + fmt.Fprintf(os.Stderr, "Can't receive %s: %v\n", filename, err) + return err + } + b.m[filename] = buf.Bytes() + return nil +} + +func (b *testBackend) handleRead(filename string, rf io.ReaderFrom) error { + b.mu.Lock() + defer b.mu.Unlock() + bs, ok := b.m[filename] + if !ok { + fmt.Fprintf(os.Stderr, "File %s not found\n", filename) + return fmt.Errorf("file not found") + } + if t, ok := rf.(OutgoingTransfer); ok { + t.SetSize(int64(len(bs))) + } + _, err := rf.ReadFrom(bytes.NewBuffer(bs)) + if err != nil { + fmt.Fprintf(os.Stderr, "Can't send %s: %v\n", filename, err) + return err + } + return nil +} + +type randReader struct { + src rand.Source + next int64 + i int8 +} + +func newRandReader(src rand.Source) io.Reader { + r := &randReader{ + src: src, + next: src.Int63(), + } + return r +} + +func (r *randReader) Read(p []byte) (n int, err error) { + next, i := r.next, r.i + for n = 0; n < len(p); n++ { + if i == 7 { + next, i = r.src.Int63(), 0 + } + p[n] = byte(next) + next >>= 8 + i++ + } + r.next, r.i = next, i + return +} + +func serverTimeoutSendTest(s *Server, c *Client, t *testing.T) { + t.Helper() + s.SetTimeout(time.Second) + s.SetRetries(2) + sec := make(chan error, 1) + s.mu.Lock() + s.readHandler = func(filename string, rf io.ReaderFrom) error { + r := io.LimitReader(newRandReader(rand.NewSource(42)), 80000) + _, err := rf.ReadFrom(r) + sec <- err + return err + } + s.mu.Unlock() + defer s.Shutdown() + filename := "test-server-send-timeout" + mode := "octet" + readTransfer, err := c.Receive(filename, mode) + if err != nil { + t.Fatalf("requesting read %s: %v", filename, err) + } + w := &slowWriter{ + n: 3, + delay: 8 * time.Second, + } + _, _ = readTransfer.WriteTo(w) + servErr := <-sec + netErr, ok := servErr.(net.Error) + if !ok { + t.Fatalf("network error expected: %T", servErr) + } + if !netErr.Timeout() { + t.Fatalf("timout is expected: %v", servErr) + } +} + +func serverReceiveTimeoutTest(s *Server, c *Client, t *testing.T) { + t.Helper() + s.SetTimeout(time.Second) + s.SetRetries(2) + sec := make(chan error, 1) + s.mu.Lock() + s.writeHandler = func(filename string, wt io.WriterTo) error { + buf := &bytes.Buffer{} + _, err := wt.WriteTo(buf) + sec <- err + return err + } + s.mu.Unlock() + defer s.Shutdown() + filename := "test-server-receive-timeout" + mode := "octet" + writeTransfer, err := c.Send(filename, mode) + if err != nil { + t.Fatalf("requesting write %s: %v", filename, err) + } + r := &slowReader{ + r: io.LimitReader(newRandReader(rand.NewSource(42)), 80000), + n: 3, + delay: 8 * time.Second, + } + _, _ = writeTransfer.ReadFrom(r) + servErr := <-sec + netErr, ok := servErr.(net.Error) + if !ok { + t.Fatalf("network error expected: %T", servErr) + } + if !netErr.Timeout() { + t.Fatalf("timout is expected: %v", servErr) + } +} + +type slowReader struct { + r io.Reader + n int64 + delay time.Duration +} + +func (r *slowReader) Read(p []byte) (n int, err error) { + if r.n > 0 { + r.n-- + return r.r.Read(p) + } + time.Sleep(r.delay) + return r.r.Read(p) +} + +type slowWriter struct { + n int64 + delay time.Duration +} + +func (r *slowWriter) Write(p []byte) (n int, err error) { + if r.n > 0 { + r.n-- + return len(p), nil + } + time.Sleep(r.delay) + return len(p), nil +} From 3b81a38738b53efb9effd520ebedd45a5e6c2ff1 Mon Sep 17 00:00:00 2001 From: Dmitri Popov Date: Wed, 25 Feb 2026 11:02:25 +0100 Subject: [PATCH 2/6] Split tests by functionality --- tftp_anticipate_test.go => anticipate_test.go | 0 api_error_test.go | 159 +++++ blocksize_test.go | 28 + hooks_test.go | 55 ++ lifecycle_test.go | 193 ++++++ packet_test.go | 49 ++ tftp_test.go | 608 ------------------ transfer_test.go | 99 +++ tsize_test.go | 69 ++ 9 files changed, 652 insertions(+), 608 deletions(-) rename tftp_anticipate_test.go => anticipate_test.go (100%) create mode 100644 api_error_test.go create mode 100644 blocksize_test.go create mode 100644 hooks_test.go create mode 100644 lifecycle_test.go create mode 100644 packet_test.go delete mode 100644 tftp_test.go create mode 100644 transfer_test.go create mode 100644 tsize_test.go diff --git a/tftp_anticipate_test.go b/anticipate_test.go similarity index 100% rename from tftp_anticipate_test.go rename to anticipate_test.go diff --git a/api_error_test.go b/api_error_test.go new file mode 100644 index 0000000..dc443ed --- /dev/null +++ b/api_error_test.go @@ -0,0 +1,159 @@ +package tftp + +import ( + "bytes" + "errors" + "fmt" + "io" + "math/rand" + "net" + "testing" +) + +func TestDuplicate(t *testing.T) { + s, c := makeTestServer(false) + defer s.Shutdown() + filename := "test-duplicate" + mode := "octet" + bs := []byte("lalala") + sender, err := c.Send(filename, mode) + if err != nil { + t.Fatalf("requesting write: %v", err) + } + buf := bytes.NewBuffer(bs) + _, err = sender.ReadFrom(buf) + if err != nil { + t.Fatalf("send error: %v", err) + } + _, err = c.Send(filename, mode) + if err == nil { + t.Fatalf("file already exists") + } + t.Logf("sending file that already exists: %v", err) +} + +func TestNotFound(t *testing.T) { + s, c := makeTestServer(false) + defer s.Shutdown() + filename := "test-not-exists" + mode := "octet" + _, err := c.Receive(filename, mode) + if err == nil { + t.Fatalf("file not exists: %v", err) + } + t.Logf("receiving file that does not exist: %v", err) +} + +func TestNoHandlers(t *testing.T) { + s := NewServer(nil, nil) + + conn, err := net.ListenUDP("udp", &net.UDPAddr{}) + if err != nil { + panic(err) + } + + go s.Serve(conn) + + c, err := NewClient(localSystem(conn)) + if err != nil { + panic(err) + } + + _, err = c.Send("test", "octet") + if err == nil { + t.Errorf("error expected") + } + + _, err = c.Receive("test", "octet") + if err == nil { + t.Errorf("error expected") + } +} + +// TestFileIOExceptions checks that errors returned by io.Reader or io.Writer used by +// the handler are handled correctly. +func TestReadWriteErrors(t *testing.T) { + s := NewServer( + func(_ string, rf io.ReaderFrom) error { + _, err := rf.ReadFrom(&failingReader{}) // Read operation fails immediately. + if err != errRead { + t.Errorf("want: %v, got: %v", errRead, err) + } + // return no error from handler, client still should receive error + return nil + }, + func(_ string, wt io.WriterTo) error { + _, err := wt.WriteTo(&failingWriter{}) // Write operation fails immediately. + if err != errWrite { + t.Errorf("want: %v, got: %v", errWrite, err) + } + // return no error from handler, client still should receive error + return nil + }, + ) + + conn, err := net.ListenUDP("udp", &net.UDPAddr{}) + if err != nil { + t.Fatalf("listen UDP: %v", err) + } + + _, port, err := net.SplitHostPort(conn.LocalAddr().String()) + if err != nil { + t.Fatalf("parsing server port: %v", err) + } + + // Start server + errChan := make(chan error, 1) + go func() { + err := s.Serve(conn) + if err != nil { + errChan <- fmt.Errorf("running serve: %w", err) + } + }() + defer func() { + s.Shutdown() + select { + case err := <-errChan: + t.Errorf("server error: %v", err) + default: + } + }() + + // Create client + c, err := NewClient(net.JoinHostPort(localhost, port)) + if err != nil { + t.Fatalf("creating new client: %v", err) + } + + ot, err := c.Send("a", "octet") + if err != nil { + t.Errorf("start sending: %v", err) + } + + _, err = ot.ReadFrom(io.LimitReader( + newRandReader(rand.NewSource(42)), 42)) + if err == nil { + t.Errorf("missing write error") + } + + _, err = c.Receive("a", "octet") + if err == nil { + t.Errorf("missing read error") + } +} + +type failingReader struct{} + +var errRead = errors.New("read error") + +func (r *failingReader) Read(_ []byte) (int, error) { + return 0, errRead +} + +type failingWriter struct{} + +var errWrite = errors.New("write error") + +func (r *failingWriter) Write(_ []byte) (int, error) { + return 0, errWrite +} diff --git a/blocksize_test.go b/blocksize_test.go new file mode 100644 index 0000000..15f733f --- /dev/null +++ b/blocksize_test.go @@ -0,0 +1,28 @@ +package tftp + +import "testing" + +func Test900(t *testing.T) { + s, c := makeTestServer(false) + defer s.Shutdown() + for i := 600; i < 4000; i++ { + c.SetBlockSize(i) + s.SetBlockSize(4600 - i) + testSendReceive(t, c, 9000+int64(i)) + } +} + +func Test1810(t *testing.T) { + s, c := makeTestServer(false) + defer s.Shutdown() + c.SetBlockSize(1810) + testSendReceive(t, c, 9000+1810) +} + +func TestNearBlockLength(t *testing.T) { + s, c := makeTestServer(false) + defer s.Shutdown() + for i := 450; i < 520; i++ { + testSendReceive(t, c, int64(i)) + } +} diff --git a/hooks_test.go b/hooks_test.go new file mode 100644 index 0000000..bb0d4d4 --- /dev/null +++ b/hooks_test.go @@ -0,0 +1,55 @@ +package tftp + +import ( + "fmt" + "io" + "math/rand" + "testing" + "time" +) + +func TestHookSuccess(t *testing.T) { + s, c := makeTestServer(false) + th := newTestHook() + s.SetHook(th) + c.SetBlockSize(1810) + length := int64(9000) + filename := fmt.Sprintf("length-%d-bytes-%d", length, time.Now().UnixNano()) + rf, err := c.Send(filename, "octet") + if err != nil { + t.Fatalf("requesting %s write: %v", filename, err) + } + r := io.LimitReader(newRandReader(rand.NewSource(length)), length) + n, err := rf.ReadFrom(r) + if err != nil { + t.Fatalf("sending %s: %v", filename, err) + } + if n != length { + t.Errorf("%s length mismatch: %d != %d", filename, n, length) + } + s.Shutdown() + th.Lock() + defer th.Unlock() + if th.transfersCompleted != 1 { + t.Errorf("unexpected completed transfers count: %d", th.transfersCompleted) + } +} + +func TestHookFailure(t *testing.T) { + s, c := makeTestServer(false) + th := newTestHook() + s.SetHook(th) + filename := "test-not-exists" + mode := "octet" + _, err := c.Receive(filename, mode) + if err == nil { + t.Fatalf("file not exists: %v", err) + } + t.Logf("receiving file that does not exist: %v", err) + s.Shutdown() + th.Lock() + defer th.Unlock() + if th.transfersFailed == 0 { // TODO: there are two failures, not one on Windows? + t.Errorf("unexpected failed transfers count: %d", th.transfersFailed) + } +} diff --git a/lifecycle_test.go b/lifecycle_test.go new file mode 100644 index 0000000..46a25a1 --- /dev/null +++ b/lifecycle_test.go @@ -0,0 +1,193 @@ +package tftp + +import ( + "bytes" + "io" + "math/rand" + "net" + "testing" + "time" +) + +func TestServerSendTimeout(t *testing.T) { + s, c := makeTestServer(false) + serverTimeoutSendTest(s, c, t) +} + +func TestServerReceiveTimeout(t *testing.T) { + s, c := makeTestServer(false) + serverReceiveTimeoutTest(s, c, t) +} + +func TestClientReceiveTimeout(t *testing.T) { + s, c := makeTestServer(false) + c.SetTimeout(time.Second) + c.SetRetries(2) + s.mu.Lock() + s.readHandler = func(filename string, rf io.ReaderFrom) error { + r := &slowReader{ + r: io.LimitReader(newRandReader(rand.NewSource(42)), 80000), + n: 3, + delay: 8 * time.Second, + } + _, err := rf.ReadFrom(r) + return err + } + s.mu.Unlock() + defer s.Shutdown() + filename := "test-client-receive-timeout" + mode := "octet" + readTransfer, err := c.Receive(filename, mode) + if err != nil { + t.Fatalf("requesting read %s: %v", filename, err) + } + buf := &bytes.Buffer{} + _, err = readTransfer.WriteTo(buf) + netErr, ok := err.(net.Error) + if !ok { + t.Fatalf("network error expected: %T", err) + } + if !netErr.Timeout() { + t.Fatalf("timout is expected: %v", err) + } +} + +func TestClientSendTimeout(t *testing.T) { + s, c := makeTestServer(false) + c.SetTimeout(time.Second) + c.SetRetries(2) + s.mu.Lock() + s.writeHandler = func(filename string, wt io.WriterTo) error { + w := &slowWriter{ + n: 3, + delay: 8 * time.Second, + } + _, err := wt.WriteTo(w) + return err + } + s.mu.Unlock() + defer s.Shutdown() + filename := "test-client-send-timeout" + mode := "octet" + writeTransfer, err := c.Send(filename, mode) + if err != nil { + t.Fatalf("requesting write %s: %v", filename, err) + } + r := io.LimitReader(newRandReader(rand.NewSource(42)), 80000) + _, err = writeTransfer.ReadFrom(r) + netErr, ok := err.(net.Error) + if !ok { + t.Fatalf("network error expected: %T", err) + } + if !netErr.Timeout() { + t.Fatalf("timout is expected: %v", err) + } +} + +// countingWriter signals through a channel when a certain number of bytes have been written +type countingWriter struct { + w io.Writer + total int64 + threshold int64 + signal chan struct{} + signaled bool +} + +func (w *countingWriter) Write(p []byte) (n int, err error) { + n, err = w.w.Write(p) + w.total += int64(n) + if !w.signaled && w.total >= w.threshold { + w.signal <- struct{}{} + w.signaled = true + } + return n, err +} + +// TestShutdownDuringTransfer starts a transfer, then shuts down the server mid-transfer. +// Checks that neither server nor client hang and server shuts down cleanly. +func TestShutdownDuringTransfer(t *testing.T) { + for _, singlePort := range []bool{false, true} { + name := "regular" + if singlePort { + name = "single_port" + } + t.Run(name, func(t *testing.T) { + testShutdownDuringTransfer(t, singlePort) + }) + } +} + +func testShutdownDuringTransfer(t *testing.T, singlePort bool) { + s := NewServer(func(_ string, rf io.ReaderFrom) error { + // Simulate a slow reader: send 1MB, but slowly + _, err := rf.ReadFrom(&slowReader{r: bytes.NewReader(make([]byte, 1<<23)), n: 1 << 20, delay: 10 * time.Millisecond}) + return err + }, nil) + + if singlePort { + s.EnableSinglePort() + } + + conn, err := net.ListenUDP("udp", &net.UDPAddr{}) + if err != nil { + t.Fatal(err) + } + + // Start a goroutine to monitor server errors + errChan := make(chan error, 1) + go func() { + errChan <- s.Serve(conn) + }() + + c, err := NewClient(localSystem(conn)) + if err != nil { + t.Fatal(err) + } + + dl := make(chan error, 1) + received := make(chan struct{}, 1) + go func() { + wt, err := c.Receive("file", "octet") + if err != nil { + dl <- err + return + } + // Use custom writer to signal when 100KB is received + counter := &countingWriter{ + w: io.Discard, + threshold: 100 * 1024, // 100KB + signal: received, + } + _, err = wt.WriteTo(counter) + dl <- err + }() + + // Wait for either 100KB to be received or timeout + select { + case <-received: + // Received enough data, proceed with shutdown + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for data transfer to start") + } + s.Shutdown() + + // Server should shut down cleanly + select { + case err := <-errChan: + if err != nil { + t.Errorf("server error: %v", err) + } + case <-time.After(5 * time.Second): + t.Error("server did not shut down in time") + } + + // Client should shutdown cleanly too because server waits for transfers to finish + select { + case err := <-dl: + if err != nil { + t.Errorf("client transfer error: %v", err) + } + case <-time.After(5 * time.Second): + t.Error("client did not finish in time") + } +} diff --git a/packet_test.go b/packet_test.go new file mode 100644 index 0000000..fb8732e --- /dev/null +++ b/packet_test.go @@ -0,0 +1,49 @@ +package tftp + +import "testing" + +func TestPackUnpack(t *testing.T) { + v := []string{"test-filename/with-subdir"} + testOptsList := []options{ + nil, + { + "tsize": "1234", + "blksize": "22", + }, + } + for _, filename := range v { + for _, mode := range []string{"octet", "netascii"} { + for _, opts := range testOptsList { + packUnpack(t, filename, mode, opts) + } + } + } +} + +func packUnpack(t *testing.T, filename, mode string, opts options) { + b := make([]byte, datagramLength) + for _, op := range []uint16{opRRQ, opWRQ} { + n := packRQ(b, op, filename, mode, opts) + f, m, o, err := unpackRQ(b[:n]) + if err != nil { + t.Errorf("%s pack/unpack: %v", filename, err) + } + if f != filename { + t.Errorf("filename mismatch (%s): '%x' vs '%x'", + filename, f, filename) + } + if m != mode { + t.Errorf("mode mismatch (%s): '%x' vs '%x'", + mode, m, mode) + } + for name, value := range opts { + v, ok := o[name] + if !ok { + t.Errorf("missing %s option", name) + } + if v != value { + t.Errorf("option %s mismatch: '%x' vs '%x'", name, v, value) + } + } + } +} diff --git a/tftp_test.go b/tftp_test.go deleted file mode 100644 index c2b25a2..0000000 --- a/tftp_test.go +++ /dev/null @@ -1,608 +0,0 @@ -package tftp - -import ( - "bytes" - "errors" - "fmt" - "io" - "math/rand" - "net" - "testing" - "testing/iotest" - "time" -) - -func TestPackUnpack(t *testing.T) { - v := []string{"test-filename/with-subdir"} - testOptsList := []options{ - nil, - { - "tsize": "1234", - "blksize": "22", - }, - } - for _, filename := range v { - for _, mode := range []string{"octet", "netascii"} { - for _, opts := range testOptsList { - packUnpack(t, filename, mode, opts) - } - } - } -} - -func packUnpack(t *testing.T, filename, mode string, opts options) { - b := make([]byte, datagramLength) - for _, op := range []uint16{opRRQ, opWRQ} { - n := packRQ(b, op, filename, mode, opts) - f, m, o, err := unpackRQ(b[:n]) - if err != nil { - t.Errorf("%s pack/unpack: %v", filename, err) - } - if f != filename { - t.Errorf("filename mismatch (%s): '%x' vs '%x'", - filename, f, filename) - } - if m != mode { - t.Errorf("mode mismatch (%s): '%x' vs '%x'", - mode, m, mode) - } - for name, value := range opts { - v, ok := o[name] - if !ok { - t.Errorf("missing %s option", name) - } - if v != value { - t.Errorf("option %s mismatch: '%x' vs '%x'", name, v, value) - } - } - } -} - -func TestZeroLength(t *testing.T) { - s, c := makeTestServer(false) - defer s.Shutdown() - testSendReceive(t, c, 0) -} - -func Test900(t *testing.T) { - s, c := makeTestServer(false) - defer s.Shutdown() - for i := 600; i < 4000; i++ { - c.SetBlockSize(i) - s.SetBlockSize(4600 - i) - testSendReceive(t, c, 9000+int64(i)) - } -} - -func Test1000(t *testing.T) { - s, c := makeTestServer(false) - defer s.Shutdown() - for i := int64(0); i < 5000; i++ { - filename := fmt.Sprintf("length-%d-bytes-%d", i, time.Now().UnixNano()) - rf, err := c.Send(filename, "octet") - if err != nil { - t.Fatalf("requesting %s write: %v", filename, err) - } - r := io.LimitReader(newRandReader(rand.NewSource(i)), i) - n, err := rf.ReadFrom(r) - if err != nil { - t.Fatalf("sending %s: %v", filename, err) - } - if n != i { - t.Errorf("%s length mismatch: %d != %d", filename, n, i) - } - } -} - -func Test1810(t *testing.T) { - s, c := makeTestServer(false) - defer s.Shutdown() - c.SetBlockSize(1810) - testSendReceive(t, c, 9000+1810) -} - -func TestHookSuccess(t *testing.T) { - s, c := makeTestServer(false) - th := newTestHook() - s.SetHook(th) - c.SetBlockSize(1810) - length := int64(9000) - filename := fmt.Sprintf("length-%d-bytes-%d", length, time.Now().UnixNano()) - rf, err := c.Send(filename, "octet") - if err != nil { - t.Fatalf("requesting %s write: %v", filename, err) - } - r := io.LimitReader(newRandReader(rand.NewSource(length)), length) - n, err := rf.ReadFrom(r) - if err != nil { - t.Fatalf("sending %s: %v", filename, err) - } - if n != length { - t.Errorf("%s length mismatch: %d != %d", filename, n, length) - } - s.Shutdown() - th.Lock() - defer th.Unlock() - if th.transfersCompleted != 1 { - t.Errorf("unexpected completed transfers count: %d", th.transfersCompleted) - } -} - -func TestHookFailure(t *testing.T) { - s, c := makeTestServer(false) - th := newTestHook() - s.SetHook(th) - filename := "test-not-exists" - mode := "octet" - _, err := c.Receive(filename, mode) - if err == nil { - t.Fatalf("file not exists: %v", err) - } - t.Logf("receiving file that does not exist: %v", err) - s.Shutdown() - th.Lock() - defer th.Unlock() - if th.transfersFailed == 0 { // TODO: there are two failures, not one on Windows? - t.Errorf("unexpected failed transfers count: %d", th.transfersFailed) - } -} - -func TestTSize(t *testing.T) { - s, c := makeTestServer(false) - defer s.Shutdown() - c.tsize = true - testSendReceive(t, c, 640) -} - -func TestNearBlockLength(t *testing.T) { - s, c := makeTestServer(false) - defer s.Shutdown() - for i := 450; i < 520; i++ { - testSendReceive(t, c, int64(i)) - } -} - -func TestBlockWrapsAround(t *testing.T) { - s, c := makeTestServer(false) - defer s.Shutdown() - n := 65535 * 512 - for i := n - 2; i < n+2; i++ { - testSendReceive(t, c, int64(i)) - } -} - -func TestRandomLength(t *testing.T) { - s, c := makeTestServer(false) - defer s.Shutdown() - r := rand.New(rand.NewSource(42)) - for i := 0; i < 100; i++ { - testSendReceive(t, c, r.Int63n(100000)) - } -} - -func TestBigFile(t *testing.T) { - s, c := makeTestServer(false) - defer s.Shutdown() - testSendReceive(t, c, 3*1000*1000) -} - -func TestByOneByte(t *testing.T) { - s, c := makeTestServer(false) - defer s.Shutdown() - filename := "test-by-one-byte" - mode := "octet" - const length = 80000 - sender, err := c.Send(filename, mode) - if err != nil { - t.Fatalf("requesting write: %v", err) - } - r := iotest.OneByteReader(io.LimitReader( - newRandReader(rand.NewSource(42)), length)) - n, err := sender.ReadFrom(r) - if err != nil { - t.Fatalf("send error: %v", err) - } - if n != length { - t.Errorf("%s read length mismatch: %d != %d", filename, n, length) - } - readTransfer, err := c.Receive(filename, mode) - if err != nil { - t.Fatalf("requesting read %s: %v", filename, err) - } - buf := &bytes.Buffer{} - n, err = readTransfer.WriteTo(buf) - if err != nil { - t.Fatalf("%s read error: %v", filename, err) - } - if n != length { - t.Errorf("%s read length mismatch: %d != %d", filename, n, length) - } - bs, _ := io.ReadAll(io.LimitReader( - newRandReader(rand.NewSource(42)), length)) - if !bytes.Equal(bs, buf.Bytes()) { - t.Errorf("\nsent: %x\nrcvd: %x", bs, buf) - } -} - -func TestDuplicate(t *testing.T) { - s, c := makeTestServer(false) - defer s.Shutdown() - filename := "test-duplicate" - mode := "octet" - bs := []byte("lalala") - sender, err := c.Send(filename, mode) - if err != nil { - t.Fatalf("requesting write: %v", err) - } - buf := bytes.NewBuffer(bs) - _, err = sender.ReadFrom(buf) - if err != nil { - t.Fatalf("send error: %v", err) - } - _, err = c.Send(filename, mode) - if err == nil { - t.Fatalf("file already exists") - } - t.Logf("sending file that already exists: %v", err) -} - -func TestNotFound(t *testing.T) { - s, c := makeTestServer(false) - defer s.Shutdown() - filename := "test-not-exists" - mode := "octet" - _, err := c.Receive(filename, mode) - if err == nil { - t.Fatalf("file not exists: %v", err) - } - t.Logf("receiving file that does not exist: %v", err) -} - -func TestSendTsizeFromSeek(t *testing.T) { - // create read-only server - s := NewServer(func(filename string, rf io.ReaderFrom) error { - b := make([]byte, 100) - rr := newRandReader(rand.NewSource(42)) - rr.Read(b) - // bytes.Reader implements io.Seek - r := bytes.NewReader(b) - _, err := rf.ReadFrom(r) - if err != nil { - t.Errorf("sending bytes: %v", err) - } - return nil - }, nil) - - conn, err := net.ListenUDP("udp", &net.UDPAddr{}) - if err != nil { - t.Fatalf("listening: %v", err) - } - - go s.Serve(conn) - defer s.Shutdown() - - c, _ := NewClient(localSystem(conn)) - c.RequestTSize(true) - r, _ := c.Receive("f", "octet") - var size int64 - if it, ok := r.(IncomingTransfer); ok { - if n, ok := it.Size(); ok { - size = n - fmt.Printf("Transfer size: %d\n", n) - } - } - - if size != 100 { - t.Errorf("size expected: 100, got %d", size) - } - - r.WriteTo(io.Discard) - - c.RequestTSize(false) - r, _ = c.Receive("f", "octet") - if it, ok := r.(IncomingTransfer); ok { - _, ok := it.Size() - if ok { - t.Errorf("unexpected size received") - } - } - - r.WriteTo(io.Discard) -} - -func TestNoHandlers(t *testing.T) { - s := NewServer(nil, nil) - - conn, err := net.ListenUDP("udp", &net.UDPAddr{}) - if err != nil { - panic(err) - } - - go s.Serve(conn) - - c, err := NewClient(localSystem(conn)) - if err != nil { - panic(err) - } - - _, err = c.Send("test", "octet") - if err == nil { - t.Errorf("error expected") - } - - _, err = c.Receive("test", "octet") - if err == nil { - t.Errorf("error expected") - } -} - -func TestServerSendTimeout(t *testing.T) { - s, c := makeTestServer(false) - serverTimeoutSendTest(s, c, t) -} - -func TestServerReceiveTimeout(t *testing.T) { - s, c := makeTestServer(false) - serverReceiveTimeoutTest(s, c, t) -} - -func TestClientReceiveTimeout(t *testing.T) { - s, c := makeTestServer(false) - c.SetTimeout(time.Second) - c.SetRetries(2) - s.mu.Lock() - s.readHandler = func(filename string, rf io.ReaderFrom) error { - r := &slowReader{ - r: io.LimitReader(newRandReader(rand.NewSource(42)), 80000), - n: 3, - delay: 8 * time.Second, - } - _, err := rf.ReadFrom(r) - return err - } - s.mu.Unlock() - defer s.Shutdown() - filename := "test-client-receive-timeout" - mode := "octet" - readTransfer, err := c.Receive(filename, mode) - if err != nil { - t.Fatalf("requesting read %s: %v", filename, err) - } - buf := &bytes.Buffer{} - _, err = readTransfer.WriteTo(buf) - netErr, ok := err.(net.Error) - if !ok { - t.Fatalf("network error expected: %T", err) - } - if !netErr.Timeout() { - t.Fatalf("timout is expected: %v", err) - } -} - -func TestClientSendTimeout(t *testing.T) { - s, c := makeTestServer(false) - c.SetTimeout(time.Second) - c.SetRetries(2) - s.mu.Lock() - s.writeHandler = func(filename string, wt io.WriterTo) error { - w := &slowWriter{ - n: 3, - delay: 8 * time.Second, - } - _, err := wt.WriteTo(w) - return err - } - s.mu.Unlock() - defer s.Shutdown() - filename := "test-client-send-timeout" - mode := "octet" - writeTransfer, err := c.Send(filename, mode) - if err != nil { - t.Fatalf("requesting write %s: %v", filename, err) - } - r := io.LimitReader(newRandReader(rand.NewSource(42)), 80000) - _, err = writeTransfer.ReadFrom(r) - netErr, ok := err.(net.Error) - if !ok { - t.Fatalf("network error expected: %T", err) - } - if !netErr.Timeout() { - t.Fatalf("timout is expected: %v", err) - } -} - -// TestFileIOExceptions checks that errors returned by io.Reader or io.Writer used by -// the handler are handled correctly. -func TestReadWriteErrors(t *testing.T) { - s := NewServer( - func(_ string, rf io.ReaderFrom) error { - _, err := rf.ReadFrom(&failingReader{}) // Read operation fails immediately. - if err != errRead { - t.Errorf("want: %v, got: %v", errRead, err) - } - // return no error from handler, client still should receive error - return nil - }, - func(_ string, wt io.WriterTo) error { - _, err := wt.WriteTo(&failingWriter{}) // Write operation fails immediately. - if err != errWrite { - t.Errorf("want: %v, got: %v", errWrite, err) - } - // return no error from handler, client still should receive error - return nil - }, - ) - - conn, err := net.ListenUDP("udp", &net.UDPAddr{}) - if err != nil { - t.Fatalf("listen UDP: %v", err) - } - - _, port, err := net.SplitHostPort(conn.LocalAddr().String()) - if err != nil { - t.Fatalf("parsing server port: %v", err) - } - - // Start server - errChan := make(chan error, 1) - go func() { - err := s.Serve(conn) - if err != nil { - errChan <- fmt.Errorf("running serve: %w", err) - } - }() - defer func() { - s.Shutdown() - select { - case err := <-errChan: - t.Errorf("server error: %v", err) - default: - } - }() - - // Create client - c, err := NewClient(net.JoinHostPort(localhost, port)) - if err != nil { - t.Fatalf("creating new client: %v", err) - } - - ot, err := c.Send("a", "octet") - if err != nil { - t.Errorf("start sending: %v", err) - } - - _, err = ot.ReadFrom(io.LimitReader( - newRandReader(rand.NewSource(42)), 42)) - if err == nil { - t.Errorf("missing write error") - } - - _, err = c.Receive("a", "octet") - if err == nil { - t.Errorf("missing read error") - } -} - -type failingReader struct{} - -var errRead = errors.New("read error") - -func (r *failingReader) Read(_ []byte) (int, error) { - return 0, errRead -} - -type failingWriter struct{} - -var errWrite = errors.New("write error") - -func (r *failingWriter) Write(_ []byte) (int, error) { - return 0, errWrite -} - -// countingWriter signals through a channel when a certain number of bytes have been written -type countingWriter struct { - w io.Writer - total int64 - threshold int64 - signal chan struct{} - signaled bool -} - -func (w *countingWriter) Write(p []byte) (n int, err error) { - n, err = w.w.Write(p) - w.total += int64(n) - if !w.signaled && w.total >= w.threshold { - w.signal <- struct{}{} - w.signaled = true - } - return n, err -} - -// TestShutdownDuringTransfer starts a transfer, then shuts down the server mid-transfer. -// Checks that neither server nor client hang and server shuts down cleanly. -func TestShutdownDuringTransfer(t *testing.T) { - for _, singlePort := range []bool{false, true} { - name := "regular" - if singlePort { - name = "single_port" - } - t.Run(name, func(t *testing.T) { - testShutdownDuringTransfer(t, singlePort) - }) - } -} - -func testShutdownDuringTransfer(t *testing.T, singlePort bool) { - s := NewServer(func(_ string, rf io.ReaderFrom) error { - // Simulate a slow reader: send 1MB, but slowly - _, err := rf.ReadFrom(&slowReader{r: bytes.NewReader(make([]byte, 1<<23)), n: 1 << 20, delay: 10 * time.Millisecond}) - return err - }, nil) - - if singlePort { - s.EnableSinglePort() - } - - conn, err := net.ListenUDP("udp", &net.UDPAddr{}) - if err != nil { - t.Fatal(err) - } - - // Start a goroutine to monitor server errors - errChan := make(chan error, 1) - go func() { - errChan <- s.Serve(conn) - }() - - c, err := NewClient(localSystem(conn)) - if err != nil { - t.Fatal(err) - } - - dl := make(chan error, 1) - received := make(chan struct{}, 1) - go func() { - wt, err := c.Receive("file", "octet") - if err != nil { - dl <- err - return - } - // Use custom writer to signal when 100KB is received - counter := &countingWriter{ - w: io.Discard, - threshold: 100 * 1024, // 100KB - signal: received, - } - _, err = wt.WriteTo(counter) - dl <- err - }() - - // Wait for either 100KB to be received or timeout - select { - case <-received: - // Received enough data, proceed with shutdown - case <-time.After(5 * time.Second): - t.Fatal("timeout waiting for data transfer to start") - } - s.Shutdown() - - // Server should shut down cleanly - select { - case err := <-errChan: - if err != nil { - t.Errorf("server error: %v", err) - } - case <-time.After(5 * time.Second): - t.Error("server did not shut down in time") - } - - // Client should shutdown cleanly too because server waits for transfers to finish - select { - case err := <-dl: - if err != nil { - t.Errorf("client transfer error: %v", err) - } - case <-time.After(5 * time.Second): - t.Error("client did not finish in time") - } -} diff --git a/transfer_test.go b/transfer_test.go new file mode 100644 index 0000000..38e9942 --- /dev/null +++ b/transfer_test.go @@ -0,0 +1,99 @@ +package tftp + +import ( + "bytes" + "fmt" + "io" + "math/rand" + "testing" + "testing/iotest" + "time" +) + +func TestZeroLength(t *testing.T) { + s, c := makeTestServer(false) + defer s.Shutdown() + testSendReceive(t, c, 0) +} + +func Test1000(t *testing.T) { + s, c := makeTestServer(false) + defer s.Shutdown() + for i := int64(0); i < 5000; i++ { + filename := fmt.Sprintf("length-%d-bytes-%d", i, time.Now().UnixNano()) + rf, err := c.Send(filename, "octet") + if err != nil { + t.Fatalf("requesting %s write: %v", filename, err) + } + r := io.LimitReader(newRandReader(rand.NewSource(i)), i) + n, err := rf.ReadFrom(r) + if err != nil { + t.Fatalf("sending %s: %v", filename, err) + } + if n != i { + t.Errorf("%s length mismatch: %d != %d", filename, n, i) + } + } +} + +func TestBlockWrapsAround(t *testing.T) { + s, c := makeTestServer(false) + defer s.Shutdown() + n := 65535 * 512 + for i := n - 2; i < n+2; i++ { + testSendReceive(t, c, int64(i)) + } +} + +func TestRandomLength(t *testing.T) { + s, c := makeTestServer(false) + defer s.Shutdown() + r := rand.New(rand.NewSource(42)) + for i := 0; i < 100; i++ { + testSendReceive(t, c, r.Int63n(100000)) + } +} + +func TestBigFile(t *testing.T) { + s, c := makeTestServer(false) + defer s.Shutdown() + testSendReceive(t, c, 3*1000*1000) +} + +func TestByOneByte(t *testing.T) { + s, c := makeTestServer(false) + defer s.Shutdown() + filename := "test-by-one-byte" + mode := "octet" + const length = 80000 + sender, err := c.Send(filename, mode) + if err != nil { + t.Fatalf("requesting write: %v", err) + } + r := iotest.OneByteReader(io.LimitReader( + newRandReader(rand.NewSource(42)), length)) + n, err := sender.ReadFrom(r) + if err != nil { + t.Fatalf("send error: %v", err) + } + if n != length { + t.Errorf("%s read length mismatch: %d != %d", filename, n, length) + } + readTransfer, err := c.Receive(filename, mode) + if err != nil { + t.Fatalf("requesting read %s: %v", filename, err) + } + buf := &bytes.Buffer{} + n, err = readTransfer.WriteTo(buf) + if err != nil { + t.Fatalf("%s read error: %v", filename, err) + } + if n != length { + t.Errorf("%s read length mismatch: %d != %d", filename, n, length) + } + bs, _ := io.ReadAll(io.LimitReader( + newRandReader(rand.NewSource(42)), length)) + if !bytes.Equal(bs, buf.Bytes()) { + t.Errorf("\nsent: %x\nrcvd: %x", bs, buf) + } +} diff --git a/tsize_test.go b/tsize_test.go new file mode 100644 index 0000000..6061bf2 --- /dev/null +++ b/tsize_test.go @@ -0,0 +1,69 @@ +package tftp + +import ( + "bytes" + "fmt" + "io" + "math/rand" + "net" + "testing" +) + +func TestTSize(t *testing.T) { + s, c := makeTestServer(false) + defer s.Shutdown() + c.RequestTSize(true) + testSendReceive(t, c, 640) +} + +func TestSendTsizeFromSeek(t *testing.T) { + // create read-only server + s := NewServer(func(filename string, rf io.ReaderFrom) error { + b := make([]byte, 100) + rr := newRandReader(rand.NewSource(42)) + rr.Read(b) + // bytes.Reader implements io.Seek + r := bytes.NewReader(b) + _, err := rf.ReadFrom(r) + if err != nil { + t.Errorf("sending bytes: %v", err) + } + return nil + }, nil) + + conn, err := net.ListenUDP("udp", &net.UDPAddr{}) + if err != nil { + t.Fatalf("listening: %v", err) + } + + go s.Serve(conn) + defer s.Shutdown() + + c, _ := NewClient(localSystem(conn)) + c.RequestTSize(true) + r, _ := c.Receive("f", "octet") + var size int64 + if it, ok := r.(IncomingTransfer); ok { + if n, ok := it.Size(); ok { + size = n + fmt.Printf("Transfer size: %d\n", n) + } + } + + if size != 100 { + t.Errorf("size expected: 100, got %d", size) + } + + r.WriteTo(io.Discard) + + c.RequestTSize(false) + r, _ = c.Receive("f", "octet") + if it, ok := r.(IncomingTransfer); ok { + _, ok := it.Size() + if ok { + t.Errorf("unexpected size received") + } + } + + r.WriteTo(io.Discard) +} From 3ee90f8e972380da1cb47efe5be68a75742b1bcc Mon Sep 17 00:00:00 2001 From: Dmitri Popov Date: Wed, 25 Feb 2026 11:12:44 +0100 Subject: [PATCH 3/6] Share regular and single-port coverage for transfer and timeouts --- lifecycle_test.go | 12 ++++++++---- single_port_test.go | 33 --------------------------------- transfer_test.go | 29 ++++++++++++++++++++++++++--- 3 files changed, 34 insertions(+), 40 deletions(-) diff --git a/lifecycle_test.go b/lifecycle_test.go index 46a25a1..82cab76 100644 --- a/lifecycle_test.go +++ b/lifecycle_test.go @@ -10,13 +10,17 @@ import ( ) func TestServerSendTimeout(t *testing.T) { - s, c := makeTestServer(false) - serverTimeoutSendTest(s, c, t) + forModes(t, func(t *testing.T, mode transferMode) { + s, c := newFixture(t, mode) + serverTimeoutSendTest(s, c, t) + }) } func TestServerReceiveTimeout(t *testing.T) { - s, c := makeTestServer(false) - serverReceiveTimeoutTest(s, c, t) + forModes(t, func(t *testing.T, mode transferMode) { + s, c := newFixture(t, mode) + serverReceiveTimeoutTest(s, c, t) + }) } func TestClientReceiveTimeout(t *testing.T) { diff --git a/single_port_test.go b/single_port_test.go index d9f7fb5..51328de 100644 --- a/single_port_test.go +++ b/single_port_test.go @@ -6,39 +6,6 @@ import ( "time" ) -func TestZeroLengthSinglePort(t *testing.T) { - s, c := makeTestServer(true) - defer s.Shutdown() - testSendReceive(t, c, 0) -} - -func TestSendReceiveSinglePort(t *testing.T) { - s, c := makeTestServer(true) - defer s.Shutdown() - for i := 600; i < 1000; i++ { - testSendReceive(t, c, 5000+int64(i)) - } -} - -func TestSendReceiveSinglePortWithBlockSize(t *testing.T) { - s, c := makeTestServer(true) - defer s.Shutdown() - for i := 600; i < 1000; i++ { - c.blksize = i - testSendReceive(t, c, 5000+int64(i)) - } -} - -func TestServerSendTimeoutSinglePort(t *testing.T) { - s, c := makeTestServer(true) - serverTimeoutSendTest(s, c, t) -} - -func TestServerReceiveTimeoutSinglePort(t *testing.T) { - s, c := makeTestServer(true) - serverReceiveTimeoutTest(s, c, t) -} - func TestSinglePortShutdownReturns(t *testing.T) { s := NewServer(nil, nil) s.EnableSinglePort() diff --git a/transfer_test.go b/transfer_test.go index 38e9942..f519b88 100644 --- a/transfer_test.go +++ b/transfer_test.go @@ -11,9 +11,32 @@ import ( ) func TestZeroLength(t *testing.T) { - s, c := makeTestServer(false) - defer s.Shutdown() - testSendReceive(t, c, 0) + forModes(t, func(t *testing.T, mode transferMode) { + s, c := newFixture(t, mode) + defer s.Shutdown() + testSendReceive(t, c, 0) + }) +} + +func TestSendReceiveRange(t *testing.T) { + forModes(t, func(t *testing.T, mode transferMode) { + s, c := newFixture(t, mode) + defer s.Shutdown() + for i := 600; i < 1000; i++ { + testSendReceive(t, c, 5000+int64(i)) + } + }) +} + +func TestSendReceiveWithBlockSizeRange(t *testing.T) { + forModes(t, func(t *testing.T, mode transferMode) { + s, c := newFixture(t, mode) + defer s.Shutdown() + for i := 600; i < 1000; i++ { + c.SetBlockSize(i) + testSendReceive(t, c, 5000+int64(i)) + } + }) } func Test1000(t *testing.T) { From 3eb928d94cfdfbd49e855d0fee63727d85071789 Mon Sep 17 00:00:00 2001 From: Dmitri Popov Date: Wed, 25 Feb 2026 11:15:07 +0100 Subject: [PATCH 4/6] Merge blocksize negotiation tests into blocksize file --- blocksize_negotiation_test.go | 188 ---------------------------------- blocksize_test.go | 187 ++++++++++++++++++++++++++++++++- 2 files changed, 186 insertions(+), 189 deletions(-) delete mode 100644 blocksize_negotiation_test.go diff --git a/blocksize_negotiation_test.go b/blocksize_negotiation_test.go deleted file mode 100644 index 67fe9bd..0000000 --- a/blocksize_negotiation_test.go +++ /dev/null @@ -1,188 +0,0 @@ -package tftp - -import ( - "bytes" - "encoding/binary" - "fmt" - "io" - "net" - "strconv" - "testing" - "time" -) - -func TestBlockSizeNegotiation_ClampsByPathLimit(t *testing.T) { - got := negotiateBlockSizeForTest(t, 65432, 1472, true) - if got != "1472" { - t.Fatalf("unexpected negotiated blksize: got %q, want %q", got, "1472") - } -} - -func TestBlockSizeNegotiation_PreservedWhenPathAllows(t *testing.T) { - got := negotiateBlockSizeForTest(t, 65432, 65508, true) - if got != "65432" { - t.Fatalf("unexpected negotiated blksize: got %q, want %q", got, "65432") - } -} - -func TestBlockSizeNegotiation_DisabledHonorsClientBlockSize(t *testing.T) { - got := negotiateBlockSizeForTest(t, 65432, 1472, false) - if got != "65432" { - t.Fatalf("unexpected negotiated blksize: got %q, want %q", got, "65432") - } -} - -func negotiateBlockSizeForTest(t *testing.T, requestedBlockSize int, maxBlockLen int, smartBlock bool) string { - t.Helper() - - clientConn, err := net.ListenUDP("udp4", &net.UDPAddr{ - IP: net.IPv4(127, 0, 0, 1), - Port: 0, - }) - if err != nil { - t.Fatalf("listen udp: %v", err) - } - defer clientConn.Close() - - s := NewServer(func(_ string, rf io.ReaderFrom) error { - _, err := rf.ReadFrom(bytes.NewReader([]byte{1})) - return err - }, nil) - s.SetBlockSizeNegotiation(smartBlock) - s.SetTimeout(100 * time.Millisecond) - s.SetRetries(1) - t.Cleanup(s.Shutdown) - - request := make([]byte, datagramLength) - reqLen := packRQ( - request, - opRRQ, - "blocksize-negotiation.bin", - "octet", - options{"blksize": strconv.Itoa(requestedBlockSize)}, - ) - remoteAddr := clientConn.LocalAddr().(*net.UDPAddr) - if err := s.handlePacket(net.IPv4(127, 0, 0, 1), remoteAddr, request, reqLen, maxBlockLen, nil); err != nil { - t.Fatalf("handle packet: %v", err) - } - - packet := make([]byte, 70*1024) - _ = clientConn.SetReadDeadline(time.Now().Add(2 * time.Second)) - n, serverAddr, err := clientConn.ReadFromUDP(packet) - if err != nil { - t.Fatalf("read OACK: %v", err) - } - p, err := parsePacket(packet[:n]) - if err != nil { - t.Fatalf("parse OACK: %v", err) - } - oack, ok := p.(pOACK) - if !ok { - t.Fatalf("expected OACK, got %T", p) - } - opts, err := unpackOACK(oack) - if err != nil { - t.Fatalf("unpack OACK: %v", err) - } - got := opts["blksize"] - - ack := make([]byte, 4) - binary.BigEndian.PutUint16(ack[0:2], opACK) - if _, err := clientConn.WriteToUDP(ack, serverAddr); err != nil { - t.Fatalf("write ACK(0): %v", err) - } - - _ = clientConn.SetReadDeadline(time.Now().Add(2 * time.Second)) - n, serverAddr, err = clientConn.ReadFromUDP(packet) - if err != nil { - t.Fatalf("read DATA(1): %v", err) - } - p, err = parsePacket(packet[:n]) - if err != nil { - t.Fatalf("parse DATA(1): %v", err) - } - data, ok := p.(pDATA) - if !ok { - t.Fatalf("expected DATA, got %T", p) - } - if block := data.block(); block != 1 { - t.Fatalf("unexpected DATA block: got %d, want 1", block) - } - - binary.BigEndian.PutUint16(ack[2:4], 1) - if _, err := clientConn.WriteToUDP(ack, serverAddr); err != nil { - t.Fatalf("write ACK(1): %v", err) - } - - done := make(chan struct{}) - go func() { - s.wg.Wait() - close(done) - }() - select { - case <-done: - case <-time.After(2 * time.Second): - t.Fatal("transfer did not finish") - } - - return got -} - -func BenchmarkBlockSizeNegotiation(b *testing.B) { - const ( - payloadSize = 64 << 20 // 64 MiB - clientBlkSize = 65432 - ) - - payload := bytes.Repeat([]byte{0x5a}, payloadSize) - - bench := func(b *testing.B, serverMaxBlock int) { - s := NewServer(func(_ string, rf io.ReaderFrom) error { - _, err := rf.ReadFrom(bytes.NewReader(payload)) - return err - }, nil) - if serverMaxBlock > 0 { - s.SetBlockSize(serverMaxBlock) - } - - conn, err := net.ListenUDP("udp", &net.UDPAddr{}) - if err != nil { - b.Fatalf("listen udp: %v", err) - } - defer conn.Close() - - go func() { _ = s.Serve(conn) }() - b.Cleanup(s.Shutdown) - - c, err := NewClient(localSystem(conn)) - if err != nil { - b.Fatalf("new client: %v", err) - } - c.SetBlockSize(clientBlkSize) - - b.ReportAllocs() - b.SetBytes(payloadSize) - b.ResetTimer() - - for i := 0; i < b.N; i++ { - wt, err := c.Receive(fmt.Sprintf("blocksize-bench-%d", i), "octet") - if err != nil { - b.Fatalf("receive: %v", err) - } - n, err := wt.WriteTo(io.Discard) - if err != nil { - b.Fatalf("write to discard: %v", err) - } - if n != payloadSize { - b.Fatalf("size mismatch: got %d want %d", n, payloadSize) - } - } - } - - b.Run("LargeBlock_65432", func(b *testing.B) { - bench(b, 65432) - }) - b.Run("ClampedBlock_1472", func(b *testing.B) { - bench(b, 1472) - }) -} diff --git a/blocksize_test.go b/blocksize_test.go index 15f733f..f4e294c 100644 --- a/blocksize_test.go +++ b/blocksize_test.go @@ -1,6 +1,15 @@ package tftp -import "testing" +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "net" + "strconv" + "testing" + "time" +) func Test900(t *testing.T) { s, c := makeTestServer(false) @@ -26,3 +35,179 @@ func TestNearBlockLength(t *testing.T) { testSendReceive(t, c, int64(i)) } } + +func TestBlockSizeNegotiation_ClampsByPathLimit(t *testing.T) { + got := negotiateBlockSizeForTest(t, 65432, 1472, true) + if got != "1472" { + t.Fatalf("unexpected negotiated blksize: got %q, want %q", got, "1472") + } +} + +func TestBlockSizeNegotiation_PreservedWhenPathAllows(t *testing.T) { + got := negotiateBlockSizeForTest(t, 65432, 65508, true) + if got != "65432" { + t.Fatalf("unexpected negotiated blksize: got %q, want %q", got, "65432") + } +} + +func TestBlockSizeNegotiation_DisabledHonorsClientBlockSize(t *testing.T) { + got := negotiateBlockSizeForTest(t, 65432, 1472, false) + if got != "65432" { + t.Fatalf("unexpected negotiated blksize: got %q, want %q", got, "65432") + } +} + +func negotiateBlockSizeForTest(t *testing.T, requestedBlockSize int, maxBlockLen int, smartBlock bool) string { + t.Helper() + + clientConn, err := net.ListenUDP("udp4", &net.UDPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: 0, + }) + if err != nil { + t.Fatalf("listen udp: %v", err) + } + defer clientConn.Close() + + s := NewServer(func(_ string, rf io.ReaderFrom) error { + _, err := rf.ReadFrom(bytes.NewReader([]byte{1})) + return err + }, nil) + s.SetBlockSizeNegotiation(smartBlock) + s.SetTimeout(100 * time.Millisecond) + s.SetRetries(1) + t.Cleanup(s.Shutdown) + + request := make([]byte, datagramLength) + reqLen := packRQ( + request, + opRRQ, + "blocksize-negotiation.bin", + "octet", + options{"blksize": strconv.Itoa(requestedBlockSize)}, + ) + remoteAddr := clientConn.LocalAddr().(*net.UDPAddr) + if err := s.handlePacket(net.IPv4(127, 0, 0, 1), remoteAddr, request, reqLen, maxBlockLen, nil); err != nil { + t.Fatalf("handle packet: %v", err) + } + + packet := make([]byte, 70*1024) + _ = clientConn.SetReadDeadline(time.Now().Add(2 * time.Second)) + n, serverAddr, err := clientConn.ReadFromUDP(packet) + if err != nil { + t.Fatalf("read OACK: %v", err) + } + p, err := parsePacket(packet[:n]) + if err != nil { + t.Fatalf("parse OACK: %v", err) + } + oack, ok := p.(pOACK) + if !ok { + t.Fatalf("expected OACK, got %T", p) + } + opts, err := unpackOACK(oack) + if err != nil { + t.Fatalf("unpack OACK: %v", err) + } + got := opts["blksize"] + + ack := make([]byte, 4) + binary.BigEndian.PutUint16(ack[0:2], opACK) + if _, err := clientConn.WriteToUDP(ack, serverAddr); err != nil { + t.Fatalf("write ACK(0): %v", err) + } + + _ = clientConn.SetReadDeadline(time.Now().Add(2 * time.Second)) + n, serverAddr, err = clientConn.ReadFromUDP(packet) + if err != nil { + t.Fatalf("read DATA(1): %v", err) + } + p, err = parsePacket(packet[:n]) + if err != nil { + t.Fatalf("parse DATA(1): %v", err) + } + data, ok := p.(pDATA) + if !ok { + t.Fatalf("expected DATA, got %T", p) + } + if block := data.block(); block != 1 { + t.Fatalf("unexpected DATA block: got %d, want 1", block) + } + + binary.BigEndian.PutUint16(ack[2:4], 1) + if _, err := clientConn.WriteToUDP(ack, serverAddr); err != nil { + t.Fatalf("write ACK(1): %v", err) + } + + done := make(chan struct{}) + go func() { + s.wg.Wait() + close(done) + }() + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("transfer did not finish") + } + + return got +} + +func BenchmarkBlockSizeNegotiation(b *testing.B) { + const ( + payloadSize = 64 << 20 // 64 MiB + clientBlkSize = 65432 + ) + + payload := bytes.Repeat([]byte{0x5a}, payloadSize) + + bench := func(b *testing.B, serverMaxBlock int) { + s := NewServer(func(_ string, rf io.ReaderFrom) error { + _, err := rf.ReadFrom(bytes.NewReader(payload)) + return err + }, nil) + if serverMaxBlock > 0 { + s.SetBlockSize(serverMaxBlock) + } + + conn, err := net.ListenUDP("udp", &net.UDPAddr{}) + if err != nil { + b.Fatalf("listen udp: %v", err) + } + defer conn.Close() + + go func() { _ = s.Serve(conn) }() + b.Cleanup(s.Shutdown) + + c, err := NewClient(localSystem(conn)) + if err != nil { + b.Fatalf("new client: %v", err) + } + c.SetBlockSize(clientBlkSize) + + b.ReportAllocs() + b.SetBytes(payloadSize) + b.ResetTimer() + + for i := 0; i < b.N; i++ { + wt, err := c.Receive(fmt.Sprintf("blocksize-bench-%d", i), "octet") + if err != nil { + b.Fatalf("receive: %v", err) + } + n, err := wt.WriteTo(io.Discard) + if err != nil { + b.Fatalf("write to discard: %v", err) + } + if n != payloadSize { + b.Fatalf("size mismatch: got %d want %d", n, payloadSize) + } + } + } + + b.Run("LargeBlock_65432", func(b *testing.B) { + bench(b, 65432) + }) + b.Run("ClampedBlock_1472", func(b *testing.B) { + bench(b, 1472) + }) +} From 0867707de2b6ae9173b60a03453f88358eb03ba3 Mon Sep 17 00:00:00 2001 From: Dmitri Popov Date: Thu, 26 Feb 2026 01:29:54 +0100 Subject: [PATCH 5/6] Replace panic paths with test failures in setup helpers --- anticipate_test.go | 13 ++++---- api_error_test.go | 17 ++++------ blocksize_test.go | 8 ++--- hooks_test.go | 4 +-- lifecycle_test.go | 6 ++-- transfer_test.go | 10 +++--- tsize_test.go | 17 +++++++--- util_test.go | 77 +++++++++++++++++++++++++++++----------------- 8 files changed, 88 insertions(+), 64 deletions(-) diff --git a/anticipate_test.go b/anticipate_test.go index 0a3480d..f97d9ec 100644 --- a/anticipate_test.go +++ b/anticipate_test.go @@ -7,7 +7,7 @@ import ( // derived from Test900 func TestAnticipateWindow900(t *testing.T) { - s, c := makeTestServerAnticipateWindow() + s, c := makeTestServerAnticipateWindow(t) defer s.Shutdown() for i := 600; i < 4000; i++ { c.blksize = i @@ -17,7 +17,7 @@ func TestAnticipateWindow900(t *testing.T) { // TestAnticipateHookSuccess verifies that OnSuccess hook is called on transfer completion when SetAnticipate is used func TestAnticipateHookSuccess(t *testing.T) { - s, c := makeTestServerAnticipateWindow() + s, c := makeTestServerAnticipateWindow(t) th := newTestHook() s.SetHook(th) testSendReceive(t, c, 154242) @@ -30,7 +30,8 @@ func TestAnticipateHookSuccess(t *testing.T) { } // derived from makeTestServer -func makeTestServerAnticipateWindow() (*Server, *Client) { +func makeTestServerAnticipateWindow(t *testing.T) (*Server, *Client) { + t.Helper() b := &testBackend{} b.m = make(map[string][]byte) @@ -40,15 +41,15 @@ func makeTestServerAnticipateWindow() (*Server, *Client) { conn, err := net.ListenUDP("udp", &net.UDPAddr{}) if err != nil { - panic(err) + t.Fatalf("listen udp: %v", err) } go s.Serve(conn) // Create client for that server - c, err := NewClient(localSystem(conn)) + c, err := NewClient(localSystem(t, conn)) if err != nil { - panic(err) + t.Fatalf("new client: %v", err) } return s, c diff --git a/api_error_test.go b/api_error_test.go index dc443ed..3b76aed 100644 --- a/api_error_test.go +++ b/api_error_test.go @@ -11,7 +11,7 @@ import ( ) func TestDuplicate(t *testing.T) { - s, c := makeTestServer(false) + s, c := makeTestServer(t, false) defer s.Shutdown() filename := "test-duplicate" mode := "octet" @@ -33,7 +33,7 @@ func TestDuplicate(t *testing.T) { } func TestNotFound(t *testing.T) { - s, c := makeTestServer(false) + s, c := makeTestServer(t, false) defer s.Shutdown() filename := "test-not-exists" mode := "octet" @@ -49,14 +49,14 @@ func TestNoHandlers(t *testing.T) { conn, err := net.ListenUDP("udp", &net.UDPAddr{}) if err != nil { - panic(err) + t.Fatalf("listen udp: %v", err) } go s.Serve(conn) - c, err := NewClient(localSystem(conn)) + c, err := NewClient(localSystem(t, conn)) if err != nil { - panic(err) + t.Fatalf("new client: %v", err) } _, err = c.Send("test", "octet") @@ -97,11 +97,6 @@ func TestReadWriteErrors(t *testing.T) { t.Fatalf("listen UDP: %v", err) } - _, port, err := net.SplitHostPort(conn.LocalAddr().String()) - if err != nil { - t.Fatalf("parsing server port: %v", err) - } - // Start server errChan := make(chan error, 1) go func() { @@ -120,7 +115,7 @@ func TestReadWriteErrors(t *testing.T) { }() // Create client - c, err := NewClient(net.JoinHostPort(localhost, port)) + c, err := NewClient(localSystem(t, conn)) if err != nil { t.Fatalf("creating new client: %v", err) } diff --git a/blocksize_test.go b/blocksize_test.go index f4e294c..40a3566 100644 --- a/blocksize_test.go +++ b/blocksize_test.go @@ -12,7 +12,7 @@ import ( ) func Test900(t *testing.T) { - s, c := makeTestServer(false) + s, c := makeTestServer(t, false) defer s.Shutdown() for i := 600; i < 4000; i++ { c.SetBlockSize(i) @@ -22,14 +22,14 @@ func Test900(t *testing.T) { } func Test1810(t *testing.T) { - s, c := makeTestServer(false) + s, c := makeTestServer(t, false) defer s.Shutdown() c.SetBlockSize(1810) testSendReceive(t, c, 9000+1810) } func TestNearBlockLength(t *testing.T) { - s, c := makeTestServer(false) + s, c := makeTestServer(t, false) defer s.Shutdown() for i := 450; i < 520; i++ { testSendReceive(t, c, int64(i)) @@ -179,7 +179,7 @@ func BenchmarkBlockSizeNegotiation(b *testing.B) { go func() { _ = s.Serve(conn) }() b.Cleanup(s.Shutdown) - c, err := NewClient(localSystem(conn)) + c, err := NewClient(localSystem(b, conn)) if err != nil { b.Fatalf("new client: %v", err) } diff --git a/hooks_test.go b/hooks_test.go index bb0d4d4..72807d8 100644 --- a/hooks_test.go +++ b/hooks_test.go @@ -9,7 +9,7 @@ import ( ) func TestHookSuccess(t *testing.T) { - s, c := makeTestServer(false) + s, c := makeTestServer(t, false) th := newTestHook() s.SetHook(th) c.SetBlockSize(1810) @@ -36,7 +36,7 @@ func TestHookSuccess(t *testing.T) { } func TestHookFailure(t *testing.T) { - s, c := makeTestServer(false) + s, c := makeTestServer(t, false) th := newTestHook() s.SetHook(th) filename := "test-not-exists" diff --git a/lifecycle_test.go b/lifecycle_test.go index 82cab76..648ee27 100644 --- a/lifecycle_test.go +++ b/lifecycle_test.go @@ -24,7 +24,7 @@ func TestServerReceiveTimeout(t *testing.T) { } func TestClientReceiveTimeout(t *testing.T) { - s, c := makeTestServer(false) + s, c := makeTestServer(t, false) c.SetTimeout(time.Second) c.SetRetries(2) s.mu.Lock() @@ -57,7 +57,7 @@ func TestClientReceiveTimeout(t *testing.T) { } func TestClientSendTimeout(t *testing.T) { - s, c := makeTestServer(false) + s, c := makeTestServer(t, false) c.SetTimeout(time.Second) c.SetRetries(2) s.mu.Lock() @@ -143,7 +143,7 @@ func testShutdownDuringTransfer(t *testing.T, singlePort bool) { errChan <- s.Serve(conn) }() - c, err := NewClient(localSystem(conn)) + c, err := NewClient(localSystem(t, conn)) if err != nil { t.Fatal(err) } diff --git a/transfer_test.go b/transfer_test.go index f519b88..0e548d0 100644 --- a/transfer_test.go +++ b/transfer_test.go @@ -40,7 +40,7 @@ func TestSendReceiveWithBlockSizeRange(t *testing.T) { } func Test1000(t *testing.T) { - s, c := makeTestServer(false) + s, c := makeTestServer(t, false) defer s.Shutdown() for i := int64(0); i < 5000; i++ { filename := fmt.Sprintf("length-%d-bytes-%d", i, time.Now().UnixNano()) @@ -60,7 +60,7 @@ func Test1000(t *testing.T) { } func TestBlockWrapsAround(t *testing.T) { - s, c := makeTestServer(false) + s, c := makeTestServer(t, false) defer s.Shutdown() n := 65535 * 512 for i := n - 2; i < n+2; i++ { @@ -69,7 +69,7 @@ func TestBlockWrapsAround(t *testing.T) { } func TestRandomLength(t *testing.T) { - s, c := makeTestServer(false) + s, c := makeTestServer(t, false) defer s.Shutdown() r := rand.New(rand.NewSource(42)) for i := 0; i < 100; i++ { @@ -78,13 +78,13 @@ func TestRandomLength(t *testing.T) { } func TestBigFile(t *testing.T) { - s, c := makeTestServer(false) + s, c := makeTestServer(t, false) defer s.Shutdown() testSendReceive(t, c, 3*1000*1000) } func TestByOneByte(t *testing.T) { - s, c := makeTestServer(false) + s, c := makeTestServer(t, false) defer s.Shutdown() filename := "test-by-one-byte" mode := "octet" diff --git a/tsize_test.go b/tsize_test.go index 6061bf2..c659e01 100644 --- a/tsize_test.go +++ b/tsize_test.go @@ -10,7 +10,7 @@ import ( ) func TestTSize(t *testing.T) { - s, c := makeTestServer(false) + s, c := makeTestServer(t, false) defer s.Shutdown() c.RequestTSize(true) testSendReceive(t, c, 640) @@ -39,9 +39,15 @@ func TestSendTsizeFromSeek(t *testing.T) { go s.Serve(conn) defer s.Shutdown() - c, _ := NewClient(localSystem(conn)) + c, err := NewClient(localSystem(t, conn)) + if err != nil { + t.Fatalf("new client: %v", err) + } c.RequestTSize(true) - r, _ := c.Receive("f", "octet") + r, err := c.Receive("f", "octet") + if err != nil { + t.Fatalf("receive: %v", err) + } var size int64 if it, ok := r.(IncomingTransfer); ok { if n, ok := it.Size(); ok { @@ -57,7 +63,10 @@ func TestSendTsizeFromSeek(t *testing.T) { r.WriteTo(io.Discard) c.RequestTSize(false) - r, _ = c.Receive("f", "octet") + r, err = c.Receive("f", "octet") + if err != nil { + t.Fatalf("receive without tsize: %v", err) + } if it, ok := r.(IncomingTransfer); ok { _, ok := it.Size() if ok { diff --git a/util_test.go b/util_test.go index 7c40d10..7940421 100644 --- a/util_test.go +++ b/util_test.go @@ -7,7 +7,6 @@ import ( "math/rand" "net" "os" - "strconv" "sync" "testing" "time" @@ -36,27 +35,33 @@ func newFixture(t *testing.T, mode transferMode) (*Server, *Client) { t.Helper() switch mode { case modeRegular: - return makeTestServer(false) + return makeTestServer(t, false) case modeSinglePort: - return makeTestServer(true) + return makeTestServer(t, true) default: t.Fatalf("unknown transfer mode: %q", mode) return nil, nil } } -var localhost = determineLocalhost() +var ( + localhostOnce sync.Once + localhostAddr string + localhostErr error +) -func determineLocalhost() string { +func determineLocalhost() (string, error) { l, err := net.ListenTCP("tcp", nil) if err != nil { - panic(fmt.Sprintf("ListenTCP error: %s", err)) + return "", fmt.Errorf("listen tcp: %w", err) + } + _, lport, err := net.SplitHostPort(l.Addr().String()) + if err != nil { + l.Close() + return "", fmt.Errorf("split host port: %w", err) } - _, lport, _ := net.SplitHostPort(l.Addr().String()) defer l.Close() - lo := make(chan string) - go func() { for { conn, err := l.Accept() @@ -67,26 +72,39 @@ func determineLocalhost() string { } }() - go func() { - port, _ := strconv.Atoi(lport) - for _, af := range []string{"tcp6", "tcp4"} { - conn, err := net.DialTCP(af, &net.TCPAddr{}, &net.TCPAddr{Port: port}) - if err == nil { - conn.Close() - host, _, _ := net.SplitHostPort(conn.LocalAddr().String()) - lo <- host - return - } + for _, af := range []string{"tcp6", "tcp4"} { + conn, err := net.Dial(af, net.JoinHostPort("", lport)) + if err != nil { + continue } - panic("could not determine address family") - }() + host, _, splitErr := net.SplitHostPort(conn.LocalAddr().String()) + conn.Close() + if splitErr == nil { + return host, nil + } + } + + return "", fmt.Errorf("could not determine localhost address family") +} - return <-lo +func resolveLocalhost() (string, error) { + localhostOnce.Do(func() { + localhostAddr, localhostErr = determineLocalhost() + }) + return localhostAddr, localhostErr } -func localSystem(c *net.UDPConn) string { - _, port, _ := net.SplitHostPort(c.LocalAddr().String()) - return net.JoinHostPort(localhost, port) +func localSystem(tb testing.TB, c *net.UDPConn) string { + tb.Helper() + host, err := resolveLocalhost() + if err != nil { + tb.Fatalf("resolve localhost: %v", err) + } + _, port, err := net.SplitHostPort(c.LocalAddr().String()) + if err != nil { + tb.Fatalf("split listener address: %v", err) + } + return net.JoinHostPort(host, port) } type testHook struct { @@ -161,7 +179,8 @@ type testBackend struct { mu sync.Mutex } -func makeTestServer(singlePort bool) (*Server, *Client) { +func makeTestServer(t *testing.T, singlePort bool) (*Server, *Client) { + t.Helper() b := &testBackend{} b.m = make(map[string][]byte) @@ -175,15 +194,15 @@ func makeTestServer(singlePort bool) (*Server, *Client) { conn, err := net.ListenUDP("udp", &net.UDPAddr{}) if err != nil { - panic(err) + t.Fatalf("listen udp: %v", err) } go s.Serve(conn) // Create client for that server - c, err := NewClient(localSystem(conn)) + c, err := NewClient(localSystem(t, conn)) if err != nil { - panic(err) + t.Fatalf("new client: %v", err) } return s, c From f7e5969317db6817dda620a99e9f1e94fe7a0e9d Mon Sep 17 00:00:00 2001 From: Dmitri Popov Date: Thu, 26 Feb 2026 07:58:01 +0100 Subject: [PATCH 6/6] Expand regular and single-port coverage --- anticipate_test.go | 2 +- api_error_test.go | 221 ++++++++++++++++++++++++--------------------- blocksize_test.go | 43 +++++---- hooks_test.go | 84 +++++++++-------- lifecycle_test.go | 134 ++++++++++++++------------- transfer_test.go | 142 +++++++++++++++-------------- tsize_test.go | 112 ++++++++++++----------- 7 files changed, 392 insertions(+), 346 deletions(-) diff --git a/anticipate_test.go b/anticipate_test.go index f97d9ec..39b9ba4 100644 --- a/anticipate_test.go +++ b/anticipate_test.go @@ -10,7 +10,7 @@ func TestAnticipateWindow900(t *testing.T) { s, c := makeTestServerAnticipateWindow(t) defer s.Shutdown() for i := 600; i < 4000; i++ { - c.blksize = i + c.SetBlockSize(i) testSendReceive(t, c, 9000+int64(i)) } } diff --git a/api_error_test.go b/api_error_test.go index 3b76aed..ed020a3 100644 --- a/api_error_test.go +++ b/api_error_test.go @@ -11,130 +11,145 @@ import ( ) func TestDuplicate(t *testing.T) { - s, c := makeTestServer(t, false) - defer s.Shutdown() - filename := "test-duplicate" - mode := "octet" - bs := []byte("lalala") - sender, err := c.Send(filename, mode) - if err != nil { - t.Fatalf("requesting write: %v", err) - } - buf := bytes.NewBuffer(bs) - _, err = sender.ReadFrom(buf) - if err != nil { - t.Fatalf("send error: %v", err) - } - _, err = c.Send(filename, mode) - if err == nil { - t.Fatalf("file already exists") - } - t.Logf("sending file that already exists: %v", err) + forModes(t, func(t *testing.T, mode transferMode) { + s, c := newFixture(t, mode) + defer s.Shutdown() + filename := "test-duplicate" + xferMode := "octet" + bs := []byte("lalala") + sender, err := c.Send(filename, xferMode) + if err != nil { + t.Fatalf("requesting write: %v", err) + } + buf := bytes.NewBuffer(bs) + _, err = sender.ReadFrom(buf) + if err != nil { + t.Fatalf("send error: %v", err) + } + _, err = c.Send(filename, xferMode) + if err == nil { + t.Fatalf("file already exists") + } + t.Logf("sending file that already exists: %v", err) + }) } func TestNotFound(t *testing.T) { - s, c := makeTestServer(t, false) - defer s.Shutdown() - filename := "test-not-exists" - mode := "octet" - _, err := c.Receive(filename, mode) - if err == nil { - t.Fatalf("file not exists: %v", err) - } - t.Logf("receiving file that does not exist: %v", err) + forModes(t, func(t *testing.T, mode transferMode) { + s, c := newFixture(t, mode) + defer s.Shutdown() + filename := "test-not-exists" + xferMode := "octet" + _, err := c.Receive(filename, xferMode) + if err == nil { + t.Fatalf("file not exists: %v", err) + } + t.Logf("receiving file that does not exist: %v", err) + }) } func TestNoHandlers(t *testing.T) { - s := NewServer(nil, nil) + forModes(t, func(t *testing.T, mode transferMode) { + s := NewServer(nil, nil) + if mode == modeSinglePort { + s.EnableSinglePort() + } - conn, err := net.ListenUDP("udp", &net.UDPAddr{}) - if err != nil { - t.Fatalf("listen udp: %v", err) - } + conn, err := net.ListenUDP("udp", &net.UDPAddr{}) + if err != nil { + t.Fatalf("listen udp: %v", err) + } - go s.Serve(conn) + go s.Serve(conn) + defer s.Shutdown() - c, err := NewClient(localSystem(t, conn)) - if err != nil { - t.Fatalf("new client: %v", err) - } + c, err := NewClient(localSystem(t, conn)) + if err != nil { + t.Fatalf("new client: %v", err) + } - _, err = c.Send("test", "octet") - if err == nil { - t.Errorf("error expected") - } + _, err = c.Send("test", "octet") + if err == nil { + t.Errorf("error expected") + } - _, err = c.Receive("test", "octet") - if err == nil { - t.Errorf("error expected") - } + _, err = c.Receive("test", "octet") + if err == nil { + t.Errorf("error expected") + } + }) } // TestFileIOExceptions checks that errors returned by io.Reader or io.Writer used by // the handler are handled correctly. func TestReadWriteErrors(t *testing.T) { - s := NewServer( - func(_ string, rf io.ReaderFrom) error { - _, err := rf.ReadFrom(&failingReader{}) // Read operation fails immediately. - if err != errRead { - t.Errorf("want: %v, got: %v", errRead, err) + forModes(t, func(t *testing.T, mode transferMode) { + s := NewServer( + func(_ string, rf io.ReaderFrom) error { + _, err := rf.ReadFrom(&failingReader{}) // Read operation fails immediately. + if err != errRead { + t.Errorf("want: %v, got: %v", errRead, err) + } + // return no error from handler, client still should receive error + return nil + }, + func(_ string, wt io.WriterTo) error { + _, err := wt.WriteTo(&failingWriter{}) // Write operation fails immediately. + if err != errWrite { + t.Errorf("want: %v, got: %v", errWrite, err) + } + // return no error from handler, client still should receive error + return nil + }, + ) + if mode == modeSinglePort { + s.EnableSinglePort() + } + + conn, err := net.ListenUDP("udp", &net.UDPAddr{}) + if err != nil { + t.Fatalf("listen UDP: %v", err) + } + + // Start server + errChan := make(chan error, 1) + go func() { + err := s.Serve(conn) + if err != nil { + errChan <- fmt.Errorf("running serve: %w", err) } - // return no error from handler, client still should receive error - return nil - }, - func(_ string, wt io.WriterTo) error { - _, err := wt.WriteTo(&failingWriter{}) // Write operation fails immediately. - if err != errWrite { - t.Errorf("want: %v, got: %v", errWrite, err) + }() + defer func() { + s.Shutdown() + select { + case err := <-errChan: + t.Errorf("server error: %v", err) + default: } - // return no error from handler, client still should receive error - return nil - }, - ) - - conn, err := net.ListenUDP("udp", &net.UDPAddr{}) - if err != nil { - t.Fatalf("listen UDP: %v", err) - } - - // Start server - errChan := make(chan error, 1) - go func() { - err := s.Serve(conn) + }() + + // Create client + c, err := NewClient(localSystem(t, conn)) if err != nil { - errChan <- fmt.Errorf("running serve: %w", err) + t.Fatalf("creating new client: %v", err) } - }() - defer func() { - s.Shutdown() - select { - case err := <-errChan: - t.Errorf("server error: %v", err) - default: + + ot, err := c.Send("a", "octet") + if err != nil { + t.Errorf("start sending: %v", err) + } + + _, err = ot.ReadFrom(io.LimitReader( + newRandReader(rand.NewSource(42)), 42)) + if err == nil { + t.Errorf("missing write error") + } + + _, err = c.Receive("a", "octet") + if err == nil { + t.Errorf("missing read error") } - }() - - // Create client - c, err := NewClient(localSystem(t, conn)) - if err != nil { - t.Fatalf("creating new client: %v", err) - } - - ot, err := c.Send("a", "octet") - if err != nil { - t.Errorf("start sending: %v", err) - } - - _, err = ot.ReadFrom(io.LimitReader( - newRandReader(rand.NewSource(42)), 42)) - if err == nil { - t.Errorf("missing write error") - } - - _, err = c.Receive("a", "octet") - if err == nil { - t.Errorf("missing read error") - } + }) } type failingReader struct{} diff --git a/blocksize_test.go b/blocksize_test.go index 40a3566..cae741b 100644 --- a/blocksize_test.go +++ b/blocksize_test.go @@ -12,28 +12,39 @@ import ( ) func Test900(t *testing.T) { - s, c := makeTestServer(t, false) - defer s.Shutdown() - for i := 600; i < 4000; i++ { - c.SetBlockSize(i) - s.SetBlockSize(4600 - i) - testSendReceive(t, c, 9000+int64(i)) - } + forModes(t, func(t *testing.T, mode transferMode) { + s, c := newFixture(t, mode) + defer s.Shutdown() + for i := 600; i < 4000; i++ { + c.SetBlockSize(i) + // In single-port mode the server read loop continuously reads maxBlockLen, + // so mutating block size at runtime races with Serve under -race. + // Keep runtime server mutation only in regular mode. + if mode == modeRegular { + s.SetBlockSize(4600 - i) + } + testSendReceive(t, c, 9000+int64(i)) + } + }) } func Test1810(t *testing.T) { - s, c := makeTestServer(t, false) - defer s.Shutdown() - c.SetBlockSize(1810) - testSendReceive(t, c, 9000+1810) + forModes(t, func(t *testing.T, mode transferMode) { + s, c := newFixture(t, mode) + defer s.Shutdown() + c.SetBlockSize(1810) + testSendReceive(t, c, 9000+1810) + }) } func TestNearBlockLength(t *testing.T) { - s, c := makeTestServer(t, false) - defer s.Shutdown() - for i := 450; i < 520; i++ { - testSendReceive(t, c, int64(i)) - } + forModes(t, func(t *testing.T, mode transferMode) { + s, c := newFixture(t, mode) + defer s.Shutdown() + for i := 450; i < 520; i++ { + testSendReceive(t, c, int64(i)) + } + }) } func TestBlockSizeNegotiation_ClampsByPathLimit(t *testing.T) { diff --git a/hooks_test.go b/hooks_test.go index 72807d8..2a87fcf 100644 --- a/hooks_test.go +++ b/hooks_test.go @@ -9,47 +9,51 @@ import ( ) func TestHookSuccess(t *testing.T) { - s, c := makeTestServer(t, false) - th := newTestHook() - s.SetHook(th) - c.SetBlockSize(1810) - length := int64(9000) - filename := fmt.Sprintf("length-%d-bytes-%d", length, time.Now().UnixNano()) - rf, err := c.Send(filename, "octet") - if err != nil { - t.Fatalf("requesting %s write: %v", filename, err) - } - r := io.LimitReader(newRandReader(rand.NewSource(length)), length) - n, err := rf.ReadFrom(r) - if err != nil { - t.Fatalf("sending %s: %v", filename, err) - } - if n != length { - t.Errorf("%s length mismatch: %d != %d", filename, n, length) - } - s.Shutdown() - th.Lock() - defer th.Unlock() - if th.transfersCompleted != 1 { - t.Errorf("unexpected completed transfers count: %d", th.transfersCompleted) - } + forModes(t, func(t *testing.T, mode transferMode) { + s, c := newFixture(t, mode) + th := newTestHook() + s.SetHook(th) + c.SetBlockSize(1810) + length := int64(9000) + filename := fmt.Sprintf("length-%d-bytes-%d", length, time.Now().UnixNano()) + rf, err := c.Send(filename, "octet") + if err != nil { + t.Fatalf("requesting %s write: %v", filename, err) + } + r := io.LimitReader(newRandReader(rand.NewSource(length)), length) + n, err := rf.ReadFrom(r) + if err != nil { + t.Fatalf("sending %s: %v", filename, err) + } + if n != length { + t.Errorf("%s length mismatch: %d != %d", filename, n, length) + } + s.Shutdown() + th.Lock() + defer th.Unlock() + if th.transfersCompleted != 1 { + t.Errorf("unexpected completed transfers count: %d", th.transfersCompleted) + } + }) } func TestHookFailure(t *testing.T) { - s, c := makeTestServer(t, false) - th := newTestHook() - s.SetHook(th) - filename := "test-not-exists" - mode := "octet" - _, err := c.Receive(filename, mode) - if err == nil { - t.Fatalf("file not exists: %v", err) - } - t.Logf("receiving file that does not exist: %v", err) - s.Shutdown() - th.Lock() - defer th.Unlock() - if th.transfersFailed == 0 { // TODO: there are two failures, not one on Windows? - t.Errorf("unexpected failed transfers count: %d", th.transfersFailed) - } + forModes(t, func(t *testing.T, mode transferMode) { + s, c := newFixture(t, mode) + th := newTestHook() + s.SetHook(th) + filename := "test-not-exists" + xferMode := "octet" + _, err := c.Receive(filename, xferMode) + if err == nil { + t.Fatalf("file not exists: %v", err) + } + t.Logf("receiving file that does not exist: %v", err) + s.Shutdown() + th.Lock() + defer th.Unlock() + if th.transfersFailed == 0 { // TODO: there are two failures, not one on Windows? + t.Errorf("unexpected failed transfers count: %d", th.transfersFailed) + } + }) } diff --git a/lifecycle_test.go b/lifecycle_test.go index 648ee27..928d945 100644 --- a/lifecycle_test.go +++ b/lifecycle_test.go @@ -24,68 +24,72 @@ func TestServerReceiveTimeout(t *testing.T) { } func TestClientReceiveTimeout(t *testing.T) { - s, c := makeTestServer(t, false) - c.SetTimeout(time.Second) - c.SetRetries(2) - s.mu.Lock() - s.readHandler = func(filename string, rf io.ReaderFrom) error { - r := &slowReader{ - r: io.LimitReader(newRandReader(rand.NewSource(42)), 80000), - n: 3, - delay: 8 * time.Second, + forModes(t, func(t *testing.T, mode transferMode) { + s, c := newFixture(t, mode) + c.SetTimeout(time.Second) + c.SetRetries(2) + s.mu.Lock() + s.readHandler = func(filename string, rf io.ReaderFrom) error { + r := &slowReader{ + r: io.LimitReader(newRandReader(rand.NewSource(42)), 80000), + n: 3, + delay: 8 * time.Second, + } + _, err := rf.ReadFrom(r) + return err } - _, err := rf.ReadFrom(r) - return err - } - s.mu.Unlock() - defer s.Shutdown() - filename := "test-client-receive-timeout" - mode := "octet" - readTransfer, err := c.Receive(filename, mode) - if err != nil { - t.Fatalf("requesting read %s: %v", filename, err) - } - buf := &bytes.Buffer{} - _, err = readTransfer.WriteTo(buf) - netErr, ok := err.(net.Error) - if !ok { - t.Fatalf("network error expected: %T", err) - } - if !netErr.Timeout() { - t.Fatalf("timout is expected: %v", err) - } + s.mu.Unlock() + defer s.Shutdown() + filename := "test-client-receive-timeout" + xferMode := "octet" + readTransfer, err := c.Receive(filename, xferMode) + if err != nil { + t.Fatalf("requesting read %s: %v", filename, err) + } + buf := &bytes.Buffer{} + _, err = readTransfer.WriteTo(buf) + netErr, ok := err.(net.Error) + if !ok { + t.Fatalf("network error expected: %T", err) + } + if !netErr.Timeout() { + t.Fatalf("timout is expected: %v", err) + } + }) } func TestClientSendTimeout(t *testing.T) { - s, c := makeTestServer(t, false) - c.SetTimeout(time.Second) - c.SetRetries(2) - s.mu.Lock() - s.writeHandler = func(filename string, wt io.WriterTo) error { - w := &slowWriter{ - n: 3, - delay: 8 * time.Second, + forModes(t, func(t *testing.T, mode transferMode) { + s, c := newFixture(t, mode) + c.SetTimeout(time.Second) + c.SetRetries(2) + s.mu.Lock() + s.writeHandler = func(filename string, wt io.WriterTo) error { + w := &slowWriter{ + n: 3, + delay: 8 * time.Second, + } + _, err := wt.WriteTo(w) + return err } - _, err := wt.WriteTo(w) - return err - } - s.mu.Unlock() - defer s.Shutdown() - filename := "test-client-send-timeout" - mode := "octet" - writeTransfer, err := c.Send(filename, mode) - if err != nil { - t.Fatalf("requesting write %s: %v", filename, err) - } - r := io.LimitReader(newRandReader(rand.NewSource(42)), 80000) - _, err = writeTransfer.ReadFrom(r) - netErr, ok := err.(net.Error) - if !ok { - t.Fatalf("network error expected: %T", err) - } - if !netErr.Timeout() { - t.Fatalf("timout is expected: %v", err) - } + s.mu.Unlock() + defer s.Shutdown() + filename := "test-client-send-timeout" + xferMode := "octet" + writeTransfer, err := c.Send(filename, xferMode) + if err != nil { + t.Fatalf("requesting write %s: %v", filename, err) + } + r := io.LimitReader(newRandReader(rand.NewSource(42)), 80000) + _, err = writeTransfer.ReadFrom(r) + netErr, ok := err.(net.Error) + if !ok { + t.Fatalf("network error expected: %T", err) + } + if !netErr.Timeout() { + t.Fatalf("timout is expected: %v", err) + } + }) } // countingWriter signals through a channel when a certain number of bytes have been written @@ -110,25 +114,19 @@ func (w *countingWriter) Write(p []byte) (n int, err error) { // TestShutdownDuringTransfer starts a transfer, then shuts down the server mid-transfer. // Checks that neither server nor client hang and server shuts down cleanly. func TestShutdownDuringTransfer(t *testing.T) { - for _, singlePort := range []bool{false, true} { - name := "regular" - if singlePort { - name = "single_port" - } - t.Run(name, func(t *testing.T) { - testShutdownDuringTransfer(t, singlePort) - }) - } + forModes(t, func(t *testing.T, mode transferMode) { + testShutdownDuringTransfer(t, mode) + }) } -func testShutdownDuringTransfer(t *testing.T, singlePort bool) { +func testShutdownDuringTransfer(t *testing.T, mode transferMode) { s := NewServer(func(_ string, rf io.ReaderFrom) error { // Simulate a slow reader: send 1MB, but slowly _, err := rf.ReadFrom(&slowReader{r: bytes.NewReader(make([]byte, 1<<23)), n: 1 << 20, delay: 10 * time.Millisecond}) return err }, nil) - if singlePort { + if mode == modeSinglePort { s.EnableSinglePort() } diff --git a/transfer_test.go b/transfer_test.go index 0e548d0..15915e7 100644 --- a/transfer_test.go +++ b/transfer_test.go @@ -40,83 +40,93 @@ func TestSendReceiveWithBlockSizeRange(t *testing.T) { } func Test1000(t *testing.T) { - s, c := makeTestServer(t, false) - defer s.Shutdown() - for i := int64(0); i < 5000; i++ { - filename := fmt.Sprintf("length-%d-bytes-%d", i, time.Now().UnixNano()) - rf, err := c.Send(filename, "octet") - if err != nil { - t.Fatalf("requesting %s write: %v", filename, err) - } - r := io.LimitReader(newRandReader(rand.NewSource(i)), i) - n, err := rf.ReadFrom(r) - if err != nil { - t.Fatalf("sending %s: %v", filename, err) - } - if n != i { - t.Errorf("%s length mismatch: %d != %d", filename, n, i) + forModes(t, func(t *testing.T, mode transferMode) { + s, c := newFixture(t, mode) + defer s.Shutdown() + for i := int64(0); i < 5000; i++ { + filename := fmt.Sprintf("length-%d-bytes-%d", i, time.Now().UnixNano()) + rf, err := c.Send(filename, "octet") + if err != nil { + t.Fatalf("requesting %s write: %v", filename, err) + } + r := io.LimitReader(newRandReader(rand.NewSource(i)), i) + n, err := rf.ReadFrom(r) + if err != nil { + t.Fatalf("sending %s: %v", filename, err) + } + if n != i { + t.Errorf("%s length mismatch: %d != %d", filename, n, i) + } } - } + }) } func TestBlockWrapsAround(t *testing.T) { - s, c := makeTestServer(t, false) - defer s.Shutdown() - n := 65535 * 512 - for i := n - 2; i < n+2; i++ { - testSendReceive(t, c, int64(i)) - } + forModes(t, func(t *testing.T, mode transferMode) { + s, c := newFixture(t, mode) + defer s.Shutdown() + n := 65535 * 512 + for i := n - 2; i < n+2; i++ { + testSendReceive(t, c, int64(i)) + } + }) } func TestRandomLength(t *testing.T) { - s, c := makeTestServer(t, false) - defer s.Shutdown() - r := rand.New(rand.NewSource(42)) - for i := 0; i < 100; i++ { - testSendReceive(t, c, r.Int63n(100000)) - } + forModes(t, func(t *testing.T, mode transferMode) { + s, c := newFixture(t, mode) + defer s.Shutdown() + r := rand.New(rand.NewSource(42)) + for i := 0; i < 100; i++ { + testSendReceive(t, c, r.Int63n(100000)) + } + }) } func TestBigFile(t *testing.T) { - s, c := makeTestServer(t, false) - defer s.Shutdown() - testSendReceive(t, c, 3*1000*1000) + forModes(t, func(t *testing.T, mode transferMode) { + s, c := newFixture(t, mode) + defer s.Shutdown() + testSendReceive(t, c, 3*1000*1000) + }) } func TestByOneByte(t *testing.T) { - s, c := makeTestServer(t, false) - defer s.Shutdown() - filename := "test-by-one-byte" - mode := "octet" - const length = 80000 - sender, err := c.Send(filename, mode) - if err != nil { - t.Fatalf("requesting write: %v", err) - } - r := iotest.OneByteReader(io.LimitReader( - newRandReader(rand.NewSource(42)), length)) - n, err := sender.ReadFrom(r) - if err != nil { - t.Fatalf("send error: %v", err) - } - if n != length { - t.Errorf("%s read length mismatch: %d != %d", filename, n, length) - } - readTransfer, err := c.Receive(filename, mode) - if err != nil { - t.Fatalf("requesting read %s: %v", filename, err) - } - buf := &bytes.Buffer{} - n, err = readTransfer.WriteTo(buf) - if err != nil { - t.Fatalf("%s read error: %v", filename, err) - } - if n != length { - t.Errorf("%s read length mismatch: %d != %d", filename, n, length) - } - bs, _ := io.ReadAll(io.LimitReader( - newRandReader(rand.NewSource(42)), length)) - if !bytes.Equal(bs, buf.Bytes()) { - t.Errorf("\nsent: %x\nrcvd: %x", bs, buf) - } + forModes(t, func(t *testing.T, mode transferMode) { + s, c := newFixture(t, mode) + defer s.Shutdown() + filename := "test-by-one-byte" + xferMode := "octet" + const length = 80000 + sender, err := c.Send(filename, xferMode) + if err != nil { + t.Fatalf("requesting write: %v", err) + } + r := iotest.OneByteReader(io.LimitReader( + newRandReader(rand.NewSource(42)), length)) + n, err := sender.ReadFrom(r) + if err != nil { + t.Fatalf("send error: %v", err) + } + if n != length { + t.Errorf("%s read length mismatch: %d != %d", filename, n, length) + } + readTransfer, err := c.Receive(filename, xferMode) + if err != nil { + t.Fatalf("requesting read %s: %v", filename, err) + } + buf := &bytes.Buffer{} + n, err = readTransfer.WriteTo(buf) + if err != nil { + t.Fatalf("%s read error: %v", filename, err) + } + if n != length { + t.Errorf("%s read length mismatch: %d != %d", filename, n, length) + } + bs, _ := io.ReadAll(io.LimitReader( + newRandReader(rand.NewSource(42)), length)) + if !bytes.Equal(bs, buf.Bytes()) { + t.Errorf("\nsent: %x\nrcvd: %x", bs, buf) + } + }) } diff --git a/tsize_test.go b/tsize_test.go index c659e01..cb29f50 100644 --- a/tsize_test.go +++ b/tsize_test.go @@ -10,69 +10,77 @@ import ( ) func TestTSize(t *testing.T) { - s, c := makeTestServer(t, false) - defer s.Shutdown() - c.RequestTSize(true) - testSendReceive(t, c, 640) + forModes(t, func(t *testing.T, mode transferMode) { + s, c := newFixture(t, mode) + defer s.Shutdown() + c.RequestTSize(true) + testSendReceive(t, c, 640) + }) } func TestSendTsizeFromSeek(t *testing.T) { - // create read-only server - s := NewServer(func(filename string, rf io.ReaderFrom) error { - b := make([]byte, 100) - rr := newRandReader(rand.NewSource(42)) - rr.Read(b) - // bytes.Reader implements io.Seek - r := bytes.NewReader(b) - _, err := rf.ReadFrom(r) - if err != nil { - t.Errorf("sending bytes: %v", err) + forModes(t, func(t *testing.T, mode transferMode) { + // create read-only server + s := NewServer(func(filename string, rf io.ReaderFrom) error { + b := make([]byte, 100) + rr := newRandReader(rand.NewSource(42)) + rr.Read(b) + // bytes.Reader implements io.Seek + r := bytes.NewReader(b) + _, err := rf.ReadFrom(r) + if err != nil { + t.Errorf("sending bytes: %v", err) + } + return nil + }, nil) + if mode == modeSinglePort { + s.SetBlockSize(2000) + s.EnableSinglePort() } - return nil - }, nil) - conn, err := net.ListenUDP("udp", &net.UDPAddr{}) - if err != nil { - t.Fatalf("listening: %v", err) - } + conn, err := net.ListenUDP("udp", &net.UDPAddr{}) + if err != nil { + t.Fatalf("listening: %v", err) + } - go s.Serve(conn) - defer s.Shutdown() + go s.Serve(conn) + defer s.Shutdown() - c, err := NewClient(localSystem(t, conn)) - if err != nil { - t.Fatalf("new client: %v", err) - } - c.RequestTSize(true) - r, err := c.Receive("f", "octet") - if err != nil { - t.Fatalf("receive: %v", err) - } - var size int64 - if it, ok := r.(IncomingTransfer); ok { - if n, ok := it.Size(); ok { - size = n - fmt.Printf("Transfer size: %d\n", n) + c, err := NewClient(localSystem(t, conn)) + if err != nil { + t.Fatalf("new client: %v", err) + } + c.RequestTSize(true) + r, err := c.Receive("f", "octet") + if err != nil { + t.Fatalf("receive: %v", err) + } + var size int64 + if it, ok := r.(IncomingTransfer); ok { + if n, ok := it.Size(); ok { + size = n + fmt.Printf("Transfer size: %d\n", n) + } } - } - if size != 100 { - t.Errorf("size expected: 100, got %d", size) - } + if size != 100 { + t.Errorf("size expected: 100, got %d", size) + } - r.WriteTo(io.Discard) + r.WriteTo(io.Discard) - c.RequestTSize(false) - r, err = c.Receive("f", "octet") - if err != nil { - t.Fatalf("receive without tsize: %v", err) - } - if it, ok := r.(IncomingTransfer); ok { - _, ok := it.Size() - if ok { - t.Errorf("unexpected size received") + c.RequestTSize(false) + r, err = c.Receive("f", "octet") + if err != nil { + t.Fatalf("receive without tsize: %v", err) + } + if it, ok := r.(IncomingTransfer); ok { + _, ok := it.Size() + if ok { + t.Errorf("unexpected size received") + } } - } - r.WriteTo(io.Discard) + r.WriteTo(io.Discard) + }) }