diff --git a/tftp_anticipate_test.go b/anticipate_test.go similarity index 76% rename from tftp_anticipate_test.go rename to anticipate_test.go index 0a3480d..39b9ba4 100644 --- a/tftp_anticipate_test.go +++ b/anticipate_test.go @@ -7,17 +7,17 @@ import ( // derived from Test900 func TestAnticipateWindow900(t *testing.T) { - s, c := makeTestServerAnticipateWindow() + s, c := makeTestServerAnticipateWindow(t) defer s.Shutdown() for i := 600; i < 4000; i++ { - c.blksize = i + c.SetBlockSize(i) testSendReceive(t, c, 9000+int64(i)) } } // TestAnticipateHookSuccess verifies that OnSuccess hook is called on transfer completion when SetAnticipate is used func TestAnticipateHookSuccess(t *testing.T) { - s, c := makeTestServerAnticipateWindow() + s, c := makeTestServerAnticipateWindow(t) th := newTestHook() s.SetHook(th) testSendReceive(t, c, 154242) @@ -30,7 +30,8 @@ func TestAnticipateHookSuccess(t *testing.T) { } // derived from makeTestServer -func makeTestServerAnticipateWindow() (*Server, *Client) { +func makeTestServerAnticipateWindow(t *testing.T) (*Server, *Client) { + t.Helper() b := &testBackend{} b.m = make(map[string][]byte) @@ -40,15 +41,15 @@ func makeTestServerAnticipateWindow() (*Server, *Client) { conn, err := net.ListenUDP("udp", &net.UDPAddr{}) if err != nil { - panic(err) + t.Fatalf("listen udp: %v", err) } go s.Serve(conn) // Create client for that server - c, err := NewClient(localSystem(conn)) + c, err := NewClient(localSystem(t, conn)) if err != nil { - panic(err) + t.Fatalf("new client: %v", err) } return s, c diff --git a/api_error_test.go b/api_error_test.go new file mode 100644 index 0000000..ed020a3 --- /dev/null +++ b/api_error_test.go @@ -0,0 +1,169 @@ +package tftp + +import ( + "bytes" + "errors" + "fmt" + "io" + "math/rand" + "net" + "testing" +) + +func TestDuplicate(t *testing.T) { + forModes(t, func(t *testing.T, mode transferMode) { + s, c := newFixture(t, mode) + defer s.Shutdown() + filename := "test-duplicate" + xferMode := "octet" + bs := []byte("lalala") + sender, err := c.Send(filename, xferMode) + if err != nil { + t.Fatalf("requesting write: %v", err) + } + buf := bytes.NewBuffer(bs) + _, err = sender.ReadFrom(buf) + if err != nil { + t.Fatalf("send error: %v", err) + } + _, err = c.Send(filename, xferMode) + if err == nil { + t.Fatalf("file already exists") + } + t.Logf("sending file that already exists: %v", err) + }) +} + +func TestNotFound(t *testing.T) { + forModes(t, func(t *testing.T, mode transferMode) { + s, c := newFixture(t, mode) + defer s.Shutdown() + filename := "test-not-exists" + xferMode := "octet" + _, err := c.Receive(filename, xferMode) + if err == nil { + t.Fatalf("file not exists: %v", err) + } + t.Logf("receiving file that does not exist: %v", err) + }) +} + +func TestNoHandlers(t *testing.T) { + forModes(t, func(t *testing.T, mode transferMode) { + s := NewServer(nil, nil) + if mode == modeSinglePort { + s.EnableSinglePort() + } + + conn, err := net.ListenUDP("udp", &net.UDPAddr{}) + if err != nil { + t.Fatalf("listen udp: %v", err) + } + + go s.Serve(conn) + defer s.Shutdown() + + c, err := NewClient(localSystem(t, conn)) + if err != nil { + t.Fatalf("new client: %v", err) + } + + _, err = c.Send("test", "octet") + if err == nil { + t.Errorf("error expected") + } + + _, err = c.Receive("test", "octet") + if err == nil { + t.Errorf("error expected") + } + }) +} + +// TestFileIOExceptions checks that errors returned by io.Reader or io.Writer used by +// the handler are handled correctly. +func TestReadWriteErrors(t *testing.T) { + forModes(t, func(t *testing.T, mode transferMode) { + s := NewServer( + func(_ string, rf io.ReaderFrom) error { + _, err := rf.ReadFrom(&failingReader{}) // Read operation fails immediately. + if err != errRead { + t.Errorf("want: %v, got: %v", errRead, err) + } + // return no error from handler, client still should receive error + return nil + }, + func(_ string, wt io.WriterTo) error { + _, err := wt.WriteTo(&failingWriter{}) // Write operation fails immediately. + if err != errWrite { + t.Errorf("want: %v, got: %v", errWrite, err) + } + // return no error from handler, client still should receive error + return nil + }, + ) + if mode == modeSinglePort { + s.EnableSinglePort() + } + + conn, err := net.ListenUDP("udp", &net.UDPAddr{}) + if err != nil { + t.Fatalf("listen UDP: %v", err) + } + + // Start server + errChan := make(chan error, 1) + go func() { + err := s.Serve(conn) + if err != nil { + errChan <- fmt.Errorf("running serve: %w", err) + } + }() + defer func() { + s.Shutdown() + select { + case err := <-errChan: + t.Errorf("server error: %v", err) + default: + } + }() + + // Create client + c, err := NewClient(localSystem(t, conn)) + if err != nil { + t.Fatalf("creating new client: %v", err) + } + + ot, err := c.Send("a", "octet") + if err != nil { + t.Errorf("start sending: %v", err) + } + + _, err = ot.ReadFrom(io.LimitReader( + newRandReader(rand.NewSource(42)), 42)) + if err == nil { + t.Errorf("missing write error") + } + + _, err = c.Receive("a", "octet") + if err == nil { + t.Errorf("missing read error") + } + }) +} + +type failingReader struct{} + +var errRead = errors.New("read error") + +func (r *failingReader) Read(_ []byte) (int, error) { + return 0, errRead +} + +type failingWriter struct{} + +var errWrite = errors.New("write error") + +func (r *failingWriter) Write(_ []byte) (int, error) { + return 0, errWrite +} diff --git a/blocksize_negotiation_test.go b/blocksize_test.go similarity index 82% rename from blocksize_negotiation_test.go rename to blocksize_test.go index 67fe9bd..cae741b 100644 --- a/blocksize_negotiation_test.go +++ b/blocksize_test.go @@ -11,6 +11,42 @@ import ( "time" ) +func Test900(t *testing.T) { + forModes(t, func(t *testing.T, mode transferMode) { + s, c := newFixture(t, mode) + defer s.Shutdown() + for i := 600; i < 4000; i++ { + c.SetBlockSize(i) + // In single-port mode the server read loop continuously reads maxBlockLen, + // so mutating block size at runtime races with Serve under -race. + // Keep runtime server mutation only in regular mode. + if mode == modeRegular { + s.SetBlockSize(4600 - i) + } + testSendReceive(t, c, 9000+int64(i)) + } + }) +} + +func Test1810(t *testing.T) { + forModes(t, func(t *testing.T, mode transferMode) { + s, c := newFixture(t, mode) + defer s.Shutdown() + c.SetBlockSize(1810) + testSendReceive(t, c, 9000+1810) + }) +} + +func TestNearBlockLength(t *testing.T) { + forModes(t, func(t *testing.T, mode transferMode) { + s, c := newFixture(t, mode) + defer s.Shutdown() + for i := 450; i < 520; i++ { + testSendReceive(t, c, int64(i)) + } + }) +} + func TestBlockSizeNegotiation_ClampsByPathLimit(t *testing.T) { got := negotiateBlockSizeForTest(t, 65432, 1472, true) if got != "1472" { @@ -154,7 +190,7 @@ func BenchmarkBlockSizeNegotiation(b *testing.B) { go func() { _ = s.Serve(conn) }() b.Cleanup(s.Shutdown) - c, err := NewClient(localSystem(conn)) + c, err := NewClient(localSystem(b, conn)) if err != nil { b.Fatalf("new client: %v", err) } diff --git a/hooks_test.go b/hooks_test.go new file mode 100644 index 0000000..2a87fcf --- /dev/null +++ b/hooks_test.go @@ -0,0 +1,59 @@ +package tftp + +import ( + "fmt" + "io" + "math/rand" + "testing" + "time" +) + +func TestHookSuccess(t *testing.T) { + forModes(t, func(t *testing.T, mode transferMode) { + s, c := newFixture(t, mode) + th := newTestHook() + s.SetHook(th) + c.SetBlockSize(1810) + length := int64(9000) + filename := fmt.Sprintf("length-%d-bytes-%d", length, time.Now().UnixNano()) + rf, err := c.Send(filename, "octet") + if err != nil { + t.Fatalf("requesting %s write: %v", filename, err) + } + r := io.LimitReader(newRandReader(rand.NewSource(length)), length) + n, err := rf.ReadFrom(r) + if err != nil { + t.Fatalf("sending %s: %v", filename, err) + } + if n != length { + t.Errorf("%s length mismatch: %d != %d", filename, n, length) + } + s.Shutdown() + th.Lock() + defer th.Unlock() + if th.transfersCompleted != 1 { + t.Errorf("unexpected completed transfers count: %d", th.transfersCompleted) + } + }) +} + +func TestHookFailure(t *testing.T) { + forModes(t, func(t *testing.T, mode transferMode) { + s, c := newFixture(t, mode) + th := newTestHook() + s.SetHook(th) + filename := "test-not-exists" + xferMode := "octet" + _, err := c.Receive(filename, xferMode) + if err == nil { + t.Fatalf("file not exists: %v", err) + } + t.Logf("receiving file that does not exist: %v", err) + s.Shutdown() + th.Lock() + defer th.Unlock() + if th.transfersFailed == 0 { // TODO: there are two failures, not one on Windows? + t.Errorf("unexpected failed transfers count: %d", th.transfersFailed) + } + }) +} diff --git a/lifecycle_test.go b/lifecycle_test.go new file mode 100644 index 0000000..928d945 --- /dev/null +++ b/lifecycle_test.go @@ -0,0 +1,195 @@ +package tftp + +import ( + "bytes" + "io" + "math/rand" + "net" + "testing" + "time" +) + +func TestServerSendTimeout(t *testing.T) { + forModes(t, func(t *testing.T, mode transferMode) { + s, c := newFixture(t, mode) + serverTimeoutSendTest(s, c, t) + }) +} + +func TestServerReceiveTimeout(t *testing.T) { + forModes(t, func(t *testing.T, mode transferMode) { + s, c := newFixture(t, mode) + serverReceiveTimeoutTest(s, c, t) + }) +} + +func TestClientReceiveTimeout(t *testing.T) { + forModes(t, func(t *testing.T, mode transferMode) { + s, c := newFixture(t, mode) + c.SetTimeout(time.Second) + c.SetRetries(2) + s.mu.Lock() + s.readHandler = func(filename string, rf io.ReaderFrom) error { + r := &slowReader{ + r: io.LimitReader(newRandReader(rand.NewSource(42)), 80000), + n: 3, + delay: 8 * time.Second, + } + _, err := rf.ReadFrom(r) + return err + } + s.mu.Unlock() + defer s.Shutdown() + filename := "test-client-receive-timeout" + xferMode := "octet" + readTransfer, err := c.Receive(filename, xferMode) + if err != nil { + t.Fatalf("requesting read %s: %v", filename, err) + } + buf := &bytes.Buffer{} + _, err = readTransfer.WriteTo(buf) + netErr, ok := err.(net.Error) + if !ok { + t.Fatalf("network error expected: %T", err) + } + if !netErr.Timeout() { + t.Fatalf("timout is expected: %v", err) + } + }) +} + +func TestClientSendTimeout(t *testing.T) { + forModes(t, func(t *testing.T, mode transferMode) { + s, c := newFixture(t, mode) + c.SetTimeout(time.Second) + c.SetRetries(2) + s.mu.Lock() + s.writeHandler = func(filename string, wt io.WriterTo) error { + w := &slowWriter{ + n: 3, + delay: 8 * time.Second, + } + _, err := wt.WriteTo(w) + return err + } + s.mu.Unlock() + defer s.Shutdown() + filename := "test-client-send-timeout" + xferMode := "octet" + writeTransfer, err := c.Send(filename, xferMode) + if err != nil { + t.Fatalf("requesting write %s: %v", filename, err) + } + r := io.LimitReader(newRandReader(rand.NewSource(42)), 80000) + _, err = writeTransfer.ReadFrom(r) + netErr, ok := err.(net.Error) + if !ok { + t.Fatalf("network error expected: %T", err) + } + if !netErr.Timeout() { + t.Fatalf("timout is expected: %v", err) + } + }) +} + +// countingWriter signals through a channel when a certain number of bytes have been written +type countingWriter struct { + w io.Writer + total int64 + threshold int64 + signal chan struct{} + signaled bool +} + +func (w *countingWriter) Write(p []byte) (n int, err error) { + n, err = w.w.Write(p) + w.total += int64(n) + if !w.signaled && w.total >= w.threshold { + w.signal <- struct{}{} + w.signaled = true + } + return n, err +} + +// TestShutdownDuringTransfer starts a transfer, then shuts down the server mid-transfer. +// Checks that neither server nor client hang and server shuts down cleanly. +func TestShutdownDuringTransfer(t *testing.T) { + forModes(t, func(t *testing.T, mode transferMode) { + testShutdownDuringTransfer(t, mode) + }) +} + +func testShutdownDuringTransfer(t *testing.T, mode transferMode) { + s := NewServer(func(_ string, rf io.ReaderFrom) error { + // Simulate a slow reader: send 1MB, but slowly + _, err := rf.ReadFrom(&slowReader{r: bytes.NewReader(make([]byte, 1<<23)), n: 1 << 20, delay: 10 * time.Millisecond}) + return err + }, nil) + + if mode == modeSinglePort { + s.EnableSinglePort() + } + + conn, err := net.ListenUDP("udp", &net.UDPAddr{}) + if err != nil { + t.Fatal(err) + } + + // Start a goroutine to monitor server errors + errChan := make(chan error, 1) + go func() { + errChan <- s.Serve(conn) + }() + + c, err := NewClient(localSystem(t, conn)) + if err != nil { + t.Fatal(err) + } + + dl := make(chan error, 1) + received := make(chan struct{}, 1) + go func() { + wt, err := c.Receive("file", "octet") + if err != nil { + dl <- err + return + } + // Use custom writer to signal when 100KB is received + counter := &countingWriter{ + w: io.Discard, + threshold: 100 * 1024, // 100KB + signal: received, + } + _, err = wt.WriteTo(counter) + dl <- err + }() + + // Wait for either 100KB to be received or timeout + select { + case <-received: + // Received enough data, proceed with shutdown + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for data transfer to start") + } + s.Shutdown() + + // Server should shut down cleanly + select { + case err := <-errChan: + if err != nil { + t.Errorf("server error: %v", err) + } + case <-time.After(5 * time.Second): + t.Error("server did not shut down in time") + } + + // Client should shutdown cleanly too because server waits for transfers to finish + select { + case err := <-dl: + if err != nil { + t.Errorf("client transfer error: %v", err) + } + case <-time.After(5 * time.Second): + t.Error("client did not finish in time") + } +} diff --git a/network_env_test.go b/network_env_test.go new file mode 100644 index 0000000..eaf0e96 --- /dev/null +++ b/network_env_test.go @@ -0,0 +1,257 @@ +package tftp + +import ( + "bytes" + "fmt" + "io" + "math/rand" + "net" + "os" + "sync" + "testing" + "time" +) + +const networkEnvVar = "TFTP_RUN_NETWORK_ENV_TESTS" + +func requireNetworkEnv(t *testing.T) { + t.Helper() + if os.Getenv(networkEnvVar) != "1" { + t.Skipf("set %s=1 to run environment-dependent network tests", networkEnvVar) + } +} + +// TestRequestPacketInfo checks that request packet destination address +// obtained by server using out-of-band socket info is sane. +// It creates server and tries to do transfers using different local interfaces. +// NB: Test ignores transfer errors and validates RequestPacketInfo only +// if transfer is completed successfully. So it checks that LocalIP returns +// correct result if any result is returned, but does not check if result was +// returned at all when it should. +func TestRequestPacketInfo(t *testing.T) { + requireNetworkEnv(t) + + // localIP keeps value received from RequestPacketInfo.LocalIP + // call inside handler. + // If RequestPacketInfo is not supported, value is set to unspecified + // IP address. + var localIP net.IP + var localIPMu sync.Mutex + + s := NewServer( + func(_ string, rf io.ReaderFrom) error { + localIPMu.Lock() + if rpi, ok := rf.(RequestPacketInfo); ok { + localIP = rpi.LocalIP() + } else { + localIP = net.IP{} + } + localIPMu.Unlock() + _, err := rf.ReadFrom(io.LimitReader( + newRandReader(rand.NewSource(42)), 42)) + if err != nil { + t.Logf("sending to client: %v", err) + } + return nil + }, + func(_ string, wt io.WriterTo) error { + localIPMu.Lock() + if rpi, ok := wt.(RequestPacketInfo); ok { + localIP = rpi.LocalIP() + } else { + localIP = net.IP{} + } + localIPMu.Unlock() + _, err := wt.WriteTo(io.Discard) + if err != nil { + t.Logf("receiving from client: %v", err) + } + return nil + }, + ) + + conn, err := net.ListenUDP("udp", &net.UDPAddr{}) + if err != nil { + t.Fatalf("listen UDP: %v", err) + } + + _, port, err := net.SplitHostPort(conn.LocalAddr().String()) + if err != nil { + t.Fatalf("parsing server port: %v", err) + } + + // Start server + errChan := make(chan error, 1) + go func() { + err := s.Serve(conn) + if err != nil { + errChan <- fmt.Errorf("serve: %w", err) + } + }() + defer func() { + s.Shutdown() + select { + case err := <-errChan: + t.Errorf("server error: %v", err) + default: + } + }() + + addrs, err := net.InterfaceAddrs() + if err != nil { + t.Fatalf("listing interface addresses: %v", err) + } + + for _, addr := range addrs { + ip := networkIP(addr.(*net.IPNet)) + if ip == nil { + continue + } + + c, err := NewClient(net.JoinHostPort(ip.String(), port)) + if err != nil { + t.Fatalf("new client: %v", err) + } + + // Skip re-tries to skip non-routable interfaces faster + c.SetRetries(0) + + ot, err := c.Send("a", "octet") + if err != nil { + t.Logf("start sending to %v: %v", ip, err) + continue + } + _, err = ot.ReadFrom(io.LimitReader( + newRandReader(rand.NewSource(42)), 42)) + if err != nil { + t.Logf("sending to %v: %v", ip, err) + continue + } + + // Check that read handler received IP that was used + // to create the client. + localIPMu.Lock() + if localIP != nil && !localIP.IsUnspecified() { // Skip check if no packet info + if !localIP.Equal(ip) { + t.Errorf("sent to: %v, request packet: %v", ip, localIP) + } + } else { + fmt.Printf("Skip %v\n", ip) + } + localIPMu.Unlock() + + it, err := c.Receive("a", "octet") + if err != nil { + t.Logf("start receiving from %v: %v", ip, err) + continue + } + _, err = it.WriteTo(io.Discard) + if err != nil { + t.Logf("receiving from %v: %v", ip, err) + continue + } + + // Check that write handler received IP that was used + // to create the client. + localIPMu.Lock() + if localIP != nil && !localIP.IsUnspecified() { // Skip check if no packet info + if !localIP.Equal(ip) { + t.Errorf("sent to: %v, request packet: %v", ip, localIP) + } + } else { + fmt.Printf("Skip %v\n", ip) + } + localIPMu.Unlock() + + fmt.Printf("Done %v\n", ip) + } +} + +func networkIP(n *net.IPNet) net.IP { + if ip := n.IP.To4(); ip != nil { + return ip + } + if len(n.IP) == net.IPv6len { + return n.IP + } + return nil +} + +func TestSetLocalAddr(t *testing.T) { + requireNetworkEnv(t) + + interfaces, err := net.Interfaces() + if err != nil { + t.Fatalf("failed to get network interfaces: %v", err) + } + + var addrs []net.IP + for _, i := range interfaces { + interfaceAddrs, err := i.Addrs() + if err != nil { + continue + } + for _, addr := range interfaceAddrs { + var ip net.IP + switch v := addr.(type) { + case *net.IPNet: + ip = v.IP + case *net.IPAddr: + ip = v.IP + } + if ip == nil || ip.IsLoopback() || ip.IsLinkLocalUnicast() { + continue + } + addrs = append(addrs, ip) + } + } + + for _, addrA := range addrs { + for _, addrB := range addrs { + t.Run(fmt.Sprintf("%s receives from %s", addrA.String(), addrB.String()), func(t *testing.T) { + var mu sync.Mutex + var remoteAddr net.UDPAddr + s := NewServer(nil, func(filename string, wt io.WriterTo) error { + _, err := wt.WriteTo(io.Discard) + if err != nil { + return err + } + mu.Lock() + defer mu.Unlock() + remoteAddr = wt.(IncomingTransfer).RemoteAddr() + return nil + }) + s.SetTimeout(2 * time.Second) + + conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: addrA, Port: 0}) + if err != nil { + t.Fatalf("listen udp: %v", err) + } + + go s.Serve(conn) + defer s.Shutdown() + + c, err := NewClient(conn.LocalAddr().String()) + if err != nil { + t.Fatalf("creating client: %v", err) + } + c.SetLocalAddr(addrB.String()) + + wt, err := c.Send("testfile", "octet") + if err != nil { + return + } + _, _ = wt.ReadFrom(bytes.NewReader([]byte("test data for cross-interface"))) + + // Compare addrB to remoteAddr in thread-safe manner + mu.Lock() + defer mu.Unlock() + if remoteAddr.IP == nil { + t.Error("remote address was not captured from server") + } else if !remoteAddr.IP.Equal(addrB) { + t.Errorf("remote address mismatch: expected %s, got %s", addrB.String(), remoteAddr.IP.String()) + } + }) + } + } +} diff --git a/packet_test.go b/packet_test.go new file mode 100644 index 0000000..fb8732e --- /dev/null +++ b/packet_test.go @@ -0,0 +1,49 @@ +package tftp + +import "testing" + +func TestPackUnpack(t *testing.T) { + v := []string{"test-filename/with-subdir"} + testOptsList := []options{ + nil, + { + "tsize": "1234", + "blksize": "22", + }, + } + for _, filename := range v { + for _, mode := range []string{"octet", "netascii"} { + for _, opts := range testOptsList { + packUnpack(t, filename, mode, opts) + } + } + } +} + +func packUnpack(t *testing.T, filename, mode string, opts options) { + b := make([]byte, datagramLength) + for _, op := range []uint16{opRRQ, opWRQ} { + n := packRQ(b, op, filename, mode, opts) + f, m, o, err := unpackRQ(b[:n]) + if err != nil { + t.Errorf("%s pack/unpack: %v", filename, err) + } + if f != filename { + t.Errorf("filename mismatch (%s): '%x' vs '%x'", + filename, f, filename) + } + if m != mode { + t.Errorf("mode mismatch (%s): '%x' vs '%x'", + mode, m, mode) + } + for name, value := range opts { + v, ok := o[name] + if !ok { + t.Errorf("missing %s option", name) + } + if v != value { + t.Errorf("option %s mismatch: '%x' vs '%x'", name, v, value) + } + } + } +} diff --git a/single_port_test.go b/single_port_test.go index d9f7fb5..51328de 100644 --- a/single_port_test.go +++ b/single_port_test.go @@ -6,39 +6,6 @@ import ( "time" ) -func TestZeroLengthSinglePort(t *testing.T) { - s, c := makeTestServer(true) - defer s.Shutdown() - testSendReceive(t, c, 0) -} - -func TestSendReceiveSinglePort(t *testing.T) { - s, c := makeTestServer(true) - defer s.Shutdown() - for i := 600; i < 1000; i++ { - testSendReceive(t, c, 5000+int64(i)) - } -} - -func TestSendReceiveSinglePortWithBlockSize(t *testing.T) { - s, c := makeTestServer(true) - defer s.Shutdown() - for i := 600; i < 1000; i++ { - c.blksize = i - testSendReceive(t, c, 5000+int64(i)) - } -} - -func TestServerSendTimeoutSinglePort(t *testing.T) { - s, c := makeTestServer(true) - serverTimeoutSendTest(s, c, t) -} - -func TestServerReceiveTimeoutSinglePort(t *testing.T) { - s, c := makeTestServer(true) - serverReceiveTimeoutTest(s, c, t) -} - func TestSinglePortShutdownReturns(t *testing.T) { s := NewServer(nil, nil) s.EnableSinglePort() diff --git a/tftp_test.go b/tftp_test.go deleted file mode 100644 index ed61851..0000000 --- a/tftp_test.go +++ /dev/null @@ -1,1154 +0,0 @@ -package tftp - -import ( - "bytes" - "errors" - "fmt" - "io" - "math/rand" - "net" - "os" - "strconv" - "sync" - "testing" - "testing/iotest" - "time" -) - -var localhost = determineLocalhost() - -func determineLocalhost() string { - l, err := net.ListenTCP("tcp", nil) - if err != nil { - panic(fmt.Sprintf("ListenTCP error: %s", err)) - } - _, lport, _ := net.SplitHostPort(l.Addr().String()) - defer l.Close() - - lo := make(chan string) - - go func() { - for { - conn, err := l.Accept() - if err != nil { - break - } - conn.Close() - } - }() - - go func() { - port, _ := strconv.Atoi(lport) - for _, af := range []string{"tcp6", "tcp4"} { - conn, err := net.DialTCP(af, &net.TCPAddr{}, &net.TCPAddr{Port: port}) - if err == nil { - conn.Close() - host, _, _ := net.SplitHostPort(conn.LocalAddr().String()) - lo <- host - return - } - } - panic("could not determine address family") - }() - - return <-lo -} - -func localSystem(c *net.UDPConn) string { - _, port, _ := net.SplitHostPort(c.LocalAddr().String()) - return net.JoinHostPort(localhost, port) -} - -func TestPackUnpack(t *testing.T) { - v := []string{"test-filename/with-subdir"} - testOptsList := []options{ - nil, - { - "tsize": "1234", - "blksize": "22", - }, - } - for _, filename := range v { - for _, mode := range []string{"octet", "netascii"} { - for _, opts := range testOptsList { - packUnpack(t, filename, mode, opts) - } - } - } -} - -func packUnpack(t *testing.T, filename, mode string, opts options) { - b := make([]byte, datagramLength) - for _, op := range []uint16{opRRQ, opWRQ} { - n := packRQ(b, op, filename, mode, opts) - f, m, o, err := unpackRQ(b[:n]) - if err != nil { - t.Errorf("%s pack/unpack: %v", filename, err) - } - if f != filename { - t.Errorf("filename mismatch (%s): '%x' vs '%x'", - filename, f, filename) - } - if m != mode { - t.Errorf("mode mismatch (%s): '%x' vs '%x'", - mode, m, mode) - } - for name, value := range opts { - v, ok := o[name] - if !ok { - t.Errorf("missing %s option", name) - } - if v != value { - t.Errorf("option %s mismatch: '%x' vs '%x'", name, v, value) - } - } - } -} - -func TestZeroLength(t *testing.T) { - s, c := makeTestServer(false) - defer s.Shutdown() - testSendReceive(t, c, 0) -} - -func Test900(t *testing.T) { - s, c := makeTestServer(false) - defer s.Shutdown() - for i := 600; i < 4000; i++ { - c.SetBlockSize(i) - s.SetBlockSize(4600 - i) - testSendReceive(t, c, 9000+int64(i)) - } -} - -func Test1000(t *testing.T) { - s, c := makeTestServer(false) - defer s.Shutdown() - for i := int64(0); i < 5000; i++ { - filename := fmt.Sprintf("length-%d-bytes-%d", i, time.Now().UnixNano()) - rf, err := c.Send(filename, "octet") - if err != nil { - t.Fatalf("requesting %s write: %v", filename, err) - } - r := io.LimitReader(newRandReader(rand.NewSource(i)), i) - n, err := rf.ReadFrom(r) - if err != nil { - t.Fatalf("sending %s: %v", filename, err) - } - if n != i { - t.Errorf("%s length mismatch: %d != %d", filename, n, i) - } - } -} - -func Test1810(t *testing.T) { - s, c := makeTestServer(false) - defer s.Shutdown() - c.SetBlockSize(1810) - testSendReceive(t, c, 9000+1810) -} - -type testHook struct { - *sync.Mutex - transfersCompleted int - transfersFailed int -} - -func newTestHook() *testHook { - return &testHook{ - Mutex: &sync.Mutex{}, - } -} - -func (h *testHook) OnSuccess(result TransferStats) { - h.Lock() - defer h.Unlock() - h.transfersCompleted++ -} - -func (h *testHook) OnFailure(result TransferStats, err error) { - h.Lock() - defer h.Unlock() - h.transfersFailed++ -} - -func TestHookSuccess(t *testing.T) { - s, c := makeTestServer(false) - th := newTestHook() - s.SetHook(th) - c.SetBlockSize(1810) - length := int64(9000) - filename := fmt.Sprintf("length-%d-bytes-%d", length, time.Now().UnixNano()) - rf, err := c.Send(filename, "octet") - if err != nil { - t.Fatalf("requesting %s write: %v", filename, err) - } - r := io.LimitReader(newRandReader(rand.NewSource(length)), length) - n, err := rf.ReadFrom(r) - if err != nil { - t.Fatalf("sending %s: %v", filename, err) - } - if n != length { - t.Errorf("%s length mismatch: %d != %d", filename, n, length) - } - s.Shutdown() - th.Lock() - defer th.Unlock() - if th.transfersCompleted != 1 { - t.Errorf("unexpected completed transfers count: %d", th.transfersCompleted) - } -} - -func TestHookFailure(t *testing.T) { - s, c := makeTestServer(false) - th := newTestHook() - s.SetHook(th) - filename := "test-not-exists" - mode := "octet" - _, err := c.Receive(filename, mode) - if err == nil { - t.Fatalf("file not exists: %v", err) - } - t.Logf("receiving file that does not exist: %v", err) - s.Shutdown() - th.Lock() - defer th.Unlock() - if th.transfersFailed == 0 { // TODO: there are two failures, not one on Windows? - t.Errorf("unexpected failed transfers count: %d", th.transfersFailed) - } -} - -func TestTSize(t *testing.T) { - s, c := makeTestServer(false) - defer s.Shutdown() - c.tsize = true - testSendReceive(t, c, 640) -} - -func TestNearBlockLength(t *testing.T) { - s, c := makeTestServer(false) - defer s.Shutdown() - for i := 450; i < 520; i++ { - testSendReceive(t, c, int64(i)) - } -} - -func TestBlockWrapsAround(t *testing.T) { - s, c := makeTestServer(false) - defer s.Shutdown() - n := 65535 * 512 - for i := n - 2; i < n+2; i++ { - testSendReceive(t, c, int64(i)) - } -} - -func TestRandomLength(t *testing.T) { - s, c := makeTestServer(false) - defer s.Shutdown() - r := rand.New(rand.NewSource(42)) - for i := 0; i < 100; i++ { - testSendReceive(t, c, r.Int63n(100000)) - } -} - -func TestBigFile(t *testing.T) { - s, c := makeTestServer(false) - defer s.Shutdown() - testSendReceive(t, c, 3*1000*1000) -} - -func TestByOneByte(t *testing.T) { - s, c := makeTestServer(false) - defer s.Shutdown() - filename := "test-by-one-byte" - mode := "octet" - const length = 80000 - sender, err := c.Send(filename, mode) - if err != nil { - t.Fatalf("requesting write: %v", err) - } - r := iotest.OneByteReader(io.LimitReader( - newRandReader(rand.NewSource(42)), length)) - n, err := sender.ReadFrom(r) - if err != nil { - t.Fatalf("send error: %v", err) - } - if n != length { - t.Errorf("%s read length mismatch: %d != %d", filename, n, length) - } - readTransfer, err := c.Receive(filename, mode) - if err != nil { - t.Fatalf("requesting read %s: %v", filename, err) - } - buf := &bytes.Buffer{} - n, err = readTransfer.WriteTo(buf) - if err != nil { - t.Fatalf("%s read error: %v", filename, err) - } - if n != length { - t.Errorf("%s read length mismatch: %d != %d", filename, n, length) - } - bs, _ := io.ReadAll(io.LimitReader( - newRandReader(rand.NewSource(42)), length)) - if !bytes.Equal(bs, buf.Bytes()) { - t.Errorf("\nsent: %x\nrcvd: %x", bs, buf) - } -} - -func TestDuplicate(t *testing.T) { - s, c := makeTestServer(false) - defer s.Shutdown() - filename := "test-duplicate" - mode := "octet" - bs := []byte("lalala") - sender, err := c.Send(filename, mode) - if err != nil { - t.Fatalf("requesting write: %v", err) - } - buf := bytes.NewBuffer(bs) - _, err = sender.ReadFrom(buf) - if err != nil { - t.Fatalf("send error: %v", err) - } - _, err = c.Send(filename, mode) - if err == nil { - t.Fatalf("file already exists") - } - t.Logf("sending file that already exists: %v", err) -} - -func TestNotFound(t *testing.T) { - s, c := makeTestServer(false) - defer s.Shutdown() - filename := "test-not-exists" - mode := "octet" - _, err := c.Receive(filename, mode) - if err == nil { - t.Fatalf("file not exists: %v", err) - } - t.Logf("receiving file that does not exist: %v", err) -} - -func testSendReceive(t *testing.T, client *Client, length int64) { - filename := fmt.Sprintf("length-%d-bytes", length) - mode := "octet" - writeTransfer, err := client.Send(filename, mode) - if err != nil { - t.Fatalf("requesting write %s: %v", filename, err) - } - r := io.LimitReader(newRandReader(rand.NewSource(42)), length) - n, err := writeTransfer.ReadFrom(r) - if err != nil { - t.Fatalf("%s write error: %v", filename, err) - } - if n != length { - t.Errorf("%s write length mismatch: %d != %d", filename, n, length) - } - readTransfer, err := client.Receive(filename, mode) - if err != nil { - t.Fatalf("requesting read %s: %v", filename, err) - } - if it, ok := readTransfer.(IncomingTransfer); ok { - if n, ok := it.Size(); ok { - fmt.Printf("Transfer size: %d\n", n) - if n != length { - t.Errorf("tsize mismatch: %d vs %d", n, length) - } - } - } - buf := &bytes.Buffer{} - n, err = readTransfer.WriteTo(buf) - if err != nil { - t.Fatalf("%s read error: %v", filename, err) - } - if n != length { - t.Errorf("%s read length mismatch: %d != %d", filename, n, length) - } - bs, _ := io.ReadAll(io.LimitReader( - newRandReader(rand.NewSource(42)), length)) - if !bytes.Equal(bs, buf.Bytes()) { - t.Errorf("\nsent: %x\nrcvd: %x", bs, buf) - } -} - -func TestSendTsizeFromSeek(t *testing.T) { - // create read-only server - s := NewServer(func(filename string, rf io.ReaderFrom) error { - b := make([]byte, 100) - rr := newRandReader(rand.NewSource(42)) - rr.Read(b) - // bytes.Reader implements io.Seek - r := bytes.NewReader(b) - _, err := rf.ReadFrom(r) - if err != nil { - t.Errorf("sending bytes: %v", err) - } - return nil - }, nil) - - conn, err := net.ListenUDP("udp", &net.UDPAddr{}) - if err != nil { - t.Fatalf("listening: %v", err) - } - - go s.Serve(conn) - defer s.Shutdown() - - c, _ := NewClient(localSystem(conn)) - c.RequestTSize(true) - r, _ := c.Receive("f", "octet") - var size int64 - if it, ok := r.(IncomingTransfer); ok { - if n, ok := it.Size(); ok { - size = n - fmt.Printf("Transfer size: %d\n", n) - } - } - - if size != 100 { - t.Errorf("size expected: 100, got %d", size) - } - - r.WriteTo(io.Discard) - - c.RequestTSize(false) - r, _ = c.Receive("f", "octet") - if it, ok := r.(IncomingTransfer); ok { - _, ok := it.Size() - if ok { - t.Errorf("unexpected size received") - } - } - - r.WriteTo(io.Discard) -} - -type testBackend struct { - m map[string][]byte - mu sync.Mutex -} - -func makeTestServer(singlePort bool) (*Server, *Client) { - b := &testBackend{} - b.m = make(map[string][]byte) - - // Create server - s := NewServer(b.handleRead, b.handleWrite) - - if singlePort { - s.SetBlockSize(2000) - s.EnableSinglePort() - } - - conn, err := net.ListenUDP("udp", &net.UDPAddr{}) - if err != nil { - panic(err) - } - - go s.Serve(conn) - - // Create client for that server - c, err := NewClient(localSystem(conn)) - if err != nil { - panic(err) - } - - return s, c -} - -func TestNoHandlers(t *testing.T) { - s := NewServer(nil, nil) - - conn, err := net.ListenUDP("udp", &net.UDPAddr{}) - if err != nil { - panic(err) - } - - go s.Serve(conn) - - c, err := NewClient(localSystem(conn)) - if err != nil { - panic(err) - } - - _, err = c.Send("test", "octet") - if err == nil { - t.Errorf("error expected") - } - - _, err = c.Receive("test", "octet") - if err == nil { - t.Errorf("error expected") - } -} - -func (b *testBackend) handleWrite(filename string, wt io.WriterTo) error { - b.mu.Lock() - defer b.mu.Unlock() - _, ok := b.m[filename] - if ok { - fmt.Fprintf(os.Stderr, "File %s already exists\n", filename) - return fmt.Errorf("file already exists") - } - if t, ok := wt.(IncomingTransfer); ok { - if n, ok := t.Size(); ok { - fmt.Printf("Transfer size: %d\n", n) - } - } - buf := &bytes.Buffer{} - _, err := wt.WriteTo(buf) - if err != nil { - fmt.Fprintf(os.Stderr, "Can't receive %s: %v\n", filename, err) - return err - } - b.m[filename] = buf.Bytes() - return nil -} - -func (b *testBackend) handleRead(filename string, rf io.ReaderFrom) error { - b.mu.Lock() - defer b.mu.Unlock() - bs, ok := b.m[filename] - if !ok { - fmt.Fprintf(os.Stderr, "File %s not found\n", filename) - return fmt.Errorf("file not found") - } - if t, ok := rf.(OutgoingTransfer); ok { - t.SetSize(int64(len(bs))) - } - _, err := rf.ReadFrom(bytes.NewBuffer(bs)) - if err != nil { - fmt.Fprintf(os.Stderr, "Can't send %s: %v\n", filename, err) - return err - } - return nil -} - -type randReader struct { - src rand.Source - next int64 - i int8 -} - -func newRandReader(src rand.Source) io.Reader { - r := &randReader{ - src: src, - next: src.Int63(), - } - return r -} - -func (r *randReader) Read(p []byte) (n int, err error) { - next, i := r.next, r.i - for n = 0; n < len(p); n++ { - if i == 7 { - next, i = r.src.Int63(), 0 - } - p[n] = byte(next) - next >>= 8 - i++ - } - r.next, r.i = next, i - return -} - -func serverTimeoutSendTest(s *Server, c *Client, t *testing.T) { - s.SetTimeout(time.Second) - s.SetRetries(2) - sec := make(chan error, 1) - s.mu.Lock() - s.readHandler = func(filename string, rf io.ReaderFrom) error { - r := io.LimitReader(newRandReader(rand.NewSource(42)), 80000) - _, err := rf.ReadFrom(r) - sec <- err - return err - } - s.mu.Unlock() - defer s.Shutdown() - filename := "test-server-send-timeout" - mode := "octet" - readTransfer, err := c.Receive(filename, mode) - if err != nil { - t.Fatalf("requesting read %s: %v", filename, err) - } - w := &slowWriter{ - n: 3, - delay: 8 * time.Second, - } - _, _ = readTransfer.WriteTo(w) - servErr := <-sec - netErr, ok := servErr.(net.Error) - if !ok { - t.Fatalf("network error expected: %T", servErr) - } - if !netErr.Timeout() { - t.Fatalf("timout is expected: %v", servErr) - } - -} - -func TestServerSendTimeout(t *testing.T) { - s, c := makeTestServer(false) - serverTimeoutSendTest(s, c, t) -} - -func serverReceiveTimeoutTest(s *Server, c *Client, t *testing.T) { - s.SetTimeout(time.Second) - s.SetRetries(2) - sec := make(chan error, 1) - s.mu.Lock() - s.writeHandler = func(filename string, wt io.WriterTo) error { - buf := &bytes.Buffer{} - _, err := wt.WriteTo(buf) - sec <- err - return err - } - s.mu.Unlock() - defer s.Shutdown() - filename := "test-server-receive-timeout" - mode := "octet" - writeTransfer, err := c.Send(filename, mode) - if err != nil { - t.Fatalf("requesting write %s: %v", filename, err) - } - r := &slowReader{ - r: io.LimitReader(newRandReader(rand.NewSource(42)), 80000), - n: 3, - delay: 8 * time.Second, - } - _, _ = writeTransfer.ReadFrom(r) - servErr := <-sec - netErr, ok := servErr.(net.Error) - if !ok { - t.Fatalf("network error expected: %T", servErr) - } - if !netErr.Timeout() { - t.Fatalf("timout is expected: %v", servErr) - } -} - -func TestServerReceiveTimeout(t *testing.T) { - s, c := makeTestServer(false) - serverReceiveTimeoutTest(s, c, t) -} - -func TestClientReceiveTimeout(t *testing.T) { - s, c := makeTestServer(false) - c.SetTimeout(time.Second) - c.SetRetries(2) - s.mu.Lock() - s.readHandler = func(filename string, rf io.ReaderFrom) error { - r := &slowReader{ - r: io.LimitReader(newRandReader(rand.NewSource(42)), 80000), - n: 3, - delay: 8 * time.Second, - } - _, err := rf.ReadFrom(r) - return err - } - s.mu.Unlock() - defer s.Shutdown() - filename := "test-client-receive-timeout" - mode := "octet" - readTransfer, err := c.Receive(filename, mode) - if err != nil { - t.Fatalf("requesting read %s: %v", filename, err) - } - buf := &bytes.Buffer{} - _, err = readTransfer.WriteTo(buf) - netErr, ok := err.(net.Error) - if !ok { - t.Fatalf("network error expected: %T", err) - } - if !netErr.Timeout() { - t.Fatalf("timout is expected: %v", err) - } -} - -func TestClientSendTimeout(t *testing.T) { - s, c := makeTestServer(false) - c.SetTimeout(time.Second) - c.SetRetries(2) - s.mu.Lock() - s.writeHandler = func(filename string, wt io.WriterTo) error { - w := &slowWriter{ - n: 3, - delay: 8 * time.Second, - } - _, err := wt.WriteTo(w) - return err - } - s.mu.Unlock() - defer s.Shutdown() - filename := "test-client-send-timeout" - mode := "octet" - writeTransfer, err := c.Send(filename, mode) - if err != nil { - t.Fatalf("requesting write %s: %v", filename, err) - } - r := io.LimitReader(newRandReader(rand.NewSource(42)), 80000) - _, err = writeTransfer.ReadFrom(r) - netErr, ok := err.(net.Error) - if !ok { - t.Fatalf("network error expected: %T", err) - } - if !netErr.Timeout() { - t.Fatalf("timout is expected: %v", err) - } -} - -type slowReader struct { - r io.Reader - n int64 - delay time.Duration -} - -func (r *slowReader) Read(p []byte) (n int, err error) { - if r.n > 0 { - r.n-- - return r.r.Read(p) - } - time.Sleep(r.delay) - return r.r.Read(p) -} - -type slowWriter struct { - n int64 - delay time.Duration -} - -func (r *slowWriter) Write(p []byte) (n int, err error) { - if r.n > 0 { - r.n-- - return len(p), nil - } - time.Sleep(r.delay) - return len(p), nil -} - -// TestRequestPacketInfo checks that request packet destination address -// obtained by server using out-of-band socket info is sane. -// It creates server and tries to do transfers using different local interfaces. -// NB: Test ignores transfer errors and validates RequestPacketInfo only -// if transfer is completed successfully. So it checks that LocalIP returns -// correct result if any result is returned, but does not check if result was -// returned at all when it should. -func TestRequestPacketInfo(t *testing.T) { - // localIP keeps value received from RequestPacketInfo.LocalIP - // call inside handler. - // If RequestPacketInfo is not supported, value is set to unspecified - // IP address. - var localIP net.IP - var localIPMu sync.Mutex - - s := NewServer( - func(_ string, rf io.ReaderFrom) error { - localIPMu.Lock() - if rpi, ok := rf.(RequestPacketInfo); ok { - localIP = rpi.LocalIP() - } else { - localIP = net.IP{} - } - localIPMu.Unlock() - _, err := rf.ReadFrom(io.LimitReader( - newRandReader(rand.NewSource(42)), 42)) - if err != nil { - t.Logf("sending to client: %v", err) - } - return nil - }, - func(_ string, wt io.WriterTo) error { - localIPMu.Lock() - if rpi, ok := wt.(RequestPacketInfo); ok { - localIP = rpi.LocalIP() - } else { - localIP = net.IP{} - } - localIPMu.Unlock() - _, err := wt.WriteTo(io.Discard) - if err != nil { - t.Logf("receiving from client: %v", err) - } - return nil - }, - ) - - conn, err := net.ListenUDP("udp", &net.UDPAddr{}) - if err != nil { - t.Fatalf("listen UDP: %v", err) - } - - _, port, err := net.SplitHostPort(conn.LocalAddr().String()) - if err != nil { - t.Fatalf("parsing server port: %v", err) - } - - // Start server - errChan := make(chan error, 1) - go func() { - err := s.Serve(conn) - if err != nil { - errChan <- fmt.Errorf("serve: %w", err) - } - }() - defer func() { - s.Shutdown() - select { - case err := <-errChan: - t.Errorf("server error: %v", err) - default: - } - }() - - addrs, err := net.InterfaceAddrs() - if err != nil { - t.Fatalf("listing interface addresses: %v", err) - } - - for _, addr := range addrs { - ip := networkIP(addr.(*net.IPNet)) - if ip == nil { - continue - } - - c, err := NewClient(net.JoinHostPort(ip.String(), port)) - if err != nil { - t.Fatalf("new client: %v", err) - } - - // Skip re-tries to skip non-routable interfaces faster - c.SetRetries(0) - - ot, err := c.Send("a", "octet") - if err != nil { - t.Logf("start sending to %v: %v", ip, err) - continue - } - _, err = ot.ReadFrom(io.LimitReader( - newRandReader(rand.NewSource(42)), 42)) - if err != nil { - t.Logf("sending to %v: %v", ip, err) - continue - } - - // Check that read handler received IP that was used - // to create the client. - localIPMu.Lock() - if localIP != nil && !localIP.IsUnspecified() { // Skip check if no packet info - if !localIP.Equal(ip) { - t.Errorf("sent to: %v, request packet: %v", ip, localIP) - } - } else { - fmt.Printf("Skip %v\n", ip) - } - localIPMu.Unlock() - - it, err := c.Receive("a", "octet") - if err != nil { - t.Logf("start receiving from %v: %v", ip, err) - continue - } - _, err = it.WriteTo(io.Discard) - if err != nil { - t.Logf("receiving from %v: %v", ip, err) - continue - } - - // Check that write handler received IP that was used - // to create the client. - localIPMu.Lock() - if localIP != nil && !localIP.IsUnspecified() { // Skip check if no packet info - if !localIP.Equal(ip) { - t.Errorf("sent to: %v, request packet: %v", ip, localIP) - } - } else { - fmt.Printf("Skip %v\n", ip) - } - localIPMu.Unlock() - - fmt.Printf("Done %v\n", ip) - } -} - -func networkIP(n *net.IPNet) net.IP { - if ip := n.IP.To4(); ip != nil { - return ip - } - if len(n.IP) == net.IPv6len { - return n.IP - } - return nil -} - -// TestFileIOExceptions checks that errors returned by io.Reader or io.Writer used by -// the handler are handled correctly. -func TestReadWriteErrors(t *testing.T) { - s := NewServer( - func(_ string, rf io.ReaderFrom) error { - _, err := rf.ReadFrom(&failingReader{}) // Read operation fails immediately. - if err != errRead { - t.Errorf("want: %v, got: %v", errRead, err) - } - // return no error from handler, client still should receive error - return nil - }, - func(_ string, wt io.WriterTo) error { - _, err := wt.WriteTo(&failingWriter{}) // Write operation fails immediately. - if err != errWrite { - t.Errorf("want: %v, got: %v", errWrite, err) - } - // return no error from handler, client still should receive error - return nil - }, - ) - - conn, err := net.ListenUDP("udp", &net.UDPAddr{}) - if err != nil { - t.Fatalf("listen UDP: %v", err) - } - - _, port, err := net.SplitHostPort(conn.LocalAddr().String()) - if err != nil { - t.Fatalf("parsing server port: %v", err) - } - - // Start server - errChan := make(chan error, 1) - go func() { - err := s.Serve(conn) - if err != nil { - errChan <- fmt.Errorf("running serve: %w", err) - } - }() - defer func() { - s.Shutdown() - select { - case err := <-errChan: - t.Errorf("server error: %v", err) - default: - } - }() - - // Create client - c, err := NewClient(net.JoinHostPort(localhost, port)) - if err != nil { - t.Fatalf("creating new client: %v", err) - } - - ot, err := c.Send("a", "octet") - if err != nil { - t.Errorf("start sending: %v", err) - } - - _, err = ot.ReadFrom(io.LimitReader( - newRandReader(rand.NewSource(42)), 42)) - if err == nil { - t.Errorf("missing write error") - } - - _, err = c.Receive("a", "octet") - if err == nil { - t.Errorf("missing read error") - } -} - -type failingReader struct{} - -var errRead = errors.New("read error") - -func (r *failingReader) Read(_ []byte) (int, error) { - return 0, errRead -} - -type failingWriter struct{} - -var errWrite = errors.New("write error") - -func (r *failingWriter) Write(_ []byte) (int, error) { - return 0, errWrite -} - -// countingWriter signals through a channel when a certain number of bytes have been written -type countingWriter struct { - w io.Writer - total int64 - threshold int64 - signal chan struct{} - signaled bool -} - -func (w *countingWriter) Write(p []byte) (n int, err error) { - n, err = w.w.Write(p) - w.total += int64(n) - if !w.signaled && w.total >= w.threshold { - w.signal <- struct{}{} - w.signaled = true - } - return n, err -} - -// TestShutdownDuringTransfer starts a transfer, then shuts down the server mid-transfer. -// Checks that neither server nor client hang and server shuts down cleanly. -func TestShutdownDuringTransfer(t *testing.T) { - for _, singlePort := range []bool{false, true} { - name := "regular" - if singlePort { - name = "single_port" - } - t.Run(name, func(t *testing.T) { - testShutdownDuringTransfer(t, singlePort) - }) - } -} - -func testShutdownDuringTransfer(t *testing.T, singlePort bool) { - s := NewServer(func(_ string, rf io.ReaderFrom) error { - // Simulate a slow reader: send 1MB, but slowly - _, err := rf.ReadFrom(&slowReader{r: bytes.NewReader(make([]byte, 1<<23)), n: 1 << 20, delay: 10 * time.Millisecond}) - return err - }, nil) - - if singlePort { - s.EnableSinglePort() - } - - conn, err := net.ListenUDP("udp", &net.UDPAddr{}) - if err != nil { - t.Fatal(err) - } - - // Start a goroutine to monitor server errors - errChan := make(chan error, 1) - go func() { - errChan <- s.Serve(conn) - }() - - c, err := NewClient(localSystem(conn)) - if err != nil { - t.Fatal(err) - } - - dl := make(chan error, 1) - received := make(chan struct{}, 1) - go func() { - wt, err := c.Receive("file", "octet") - if err != nil { - dl <- err - return - } - // Use custom writer to signal when 100KB is received - counter := &countingWriter{ - w: io.Discard, - threshold: 100 * 1024, // 100KB - signal: received, - } - _, err = wt.WriteTo(counter) - dl <- err - }() - - // Wait for either 100KB to be received or timeout - select { - case <-received: - // Received enough data, proceed with shutdown - case <-time.After(5 * time.Second): - t.Fatal("timeout waiting for data transfer to start") - } - s.Shutdown() - - // Server should shut down cleanly - select { - case err := <-errChan: - if err != nil { - t.Errorf("server error: %v", err) - } - case <-time.After(5 * time.Second): - t.Error("server did not shut down in time") - } - - // Client should shutdown cleanly too because server waits for transfers to finish - select { - case err := <-dl: - if err != nil { - t.Errorf("client transfer error: %v", err) - } - case <-time.After(5 * time.Second): - t.Error("client did not finish in time") - } -} - -func TestSetLocalAddr(t *testing.T) { - interfaces, err := net.Interfaces() - if err != nil { - t.Fatalf("failed to get network interfaces: %v", err) - } - - var addrs []net.IP - for _, i := range interfaces { - interfaceAddrs, err := i.Addrs() - if err != nil { - continue - } - for _, addr := range interfaceAddrs { - var ip net.IP - switch v := addr.(type) { - case *net.IPNet: - ip = v.IP - case *net.IPAddr: - ip = v.IP - } - if ip == nil || ip.IsLoopback() || ip.IsLinkLocalUnicast() { - continue - } - addrs = append(addrs, ip) - } - } - - for _, addrA := range addrs { - for _, addrB := range addrs { - t.Run(fmt.Sprintf("%s receives from %s", addrA.String(), addrB.String()), func(t *testing.T) { - var mu sync.Mutex - var remoteAddr net.UDPAddr - s := NewServer(nil, func(filename string, wt io.WriterTo) error { - _, err := wt.WriteTo(io.Discard) - if err != nil { - return err - } - mu.Lock() - defer mu.Unlock() - remoteAddr = wt.(IncomingTransfer).RemoteAddr() - return nil - }) - s.SetTimeout(2 * time.Second) - - conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: addrA, Port: 0}) - if err != nil { - panic(err) - } - - go s.Serve(conn) - defer s.Shutdown() - - c, err := NewClient(conn.LocalAddr().String()) - if err != nil { - t.Fatalf("creating client: %v", err) - } - c.SetLocalAddr(addrB.String()) - - wt, err := c.Send("testfile", "octet") - if err != nil { - return - } - _, _ = wt.ReadFrom(bytes.NewReader([]byte("test data for cross-interface"))) - - // Compare addrB to remoteAddr in thread-safe manner - mu.Lock() - defer mu.Unlock() - if remoteAddr.IP == nil { - t.Error("remote address was not captured from server") - } else if !remoteAddr.IP.Equal(addrB) { - t.Errorf("remote address mismatch: expected %s, got %s", addrB.String(), remoteAddr.IP.String()) - } - }) - } - } -} diff --git a/transfer_test.go b/transfer_test.go new file mode 100644 index 0000000..15915e7 --- /dev/null +++ b/transfer_test.go @@ -0,0 +1,132 @@ +package tftp + +import ( + "bytes" + "fmt" + "io" + "math/rand" + "testing" + "testing/iotest" + "time" +) + +func TestZeroLength(t *testing.T) { + forModes(t, func(t *testing.T, mode transferMode) { + s, c := newFixture(t, mode) + defer s.Shutdown() + testSendReceive(t, c, 0) + }) +} + +func TestSendReceiveRange(t *testing.T) { + forModes(t, func(t *testing.T, mode transferMode) { + s, c := newFixture(t, mode) + defer s.Shutdown() + for i := 600; i < 1000; i++ { + testSendReceive(t, c, 5000+int64(i)) + } + }) +} + +func TestSendReceiveWithBlockSizeRange(t *testing.T) { + forModes(t, func(t *testing.T, mode transferMode) { + s, c := newFixture(t, mode) + defer s.Shutdown() + for i := 600; i < 1000; i++ { + c.SetBlockSize(i) + testSendReceive(t, c, 5000+int64(i)) + } + }) +} + +func Test1000(t *testing.T) { + forModes(t, func(t *testing.T, mode transferMode) { + s, c := newFixture(t, mode) + defer s.Shutdown() + for i := int64(0); i < 5000; i++ { + filename := fmt.Sprintf("length-%d-bytes-%d", i, time.Now().UnixNano()) + rf, err := c.Send(filename, "octet") + if err != nil { + t.Fatalf("requesting %s write: %v", filename, err) + } + r := io.LimitReader(newRandReader(rand.NewSource(i)), i) + n, err := rf.ReadFrom(r) + if err != nil { + t.Fatalf("sending %s: %v", filename, err) + } + if n != i { + t.Errorf("%s length mismatch: %d != %d", filename, n, i) + } + } + }) +} + +func TestBlockWrapsAround(t *testing.T) { + forModes(t, func(t *testing.T, mode transferMode) { + s, c := newFixture(t, mode) + defer s.Shutdown() + n := 65535 * 512 + for i := n - 2; i < n+2; i++ { + testSendReceive(t, c, int64(i)) + } + }) +} + +func TestRandomLength(t *testing.T) { + forModes(t, func(t *testing.T, mode transferMode) { + s, c := newFixture(t, mode) + defer s.Shutdown() + r := rand.New(rand.NewSource(42)) + for i := 0; i < 100; i++ { + testSendReceive(t, c, r.Int63n(100000)) + } + }) +} + +func TestBigFile(t *testing.T) { + forModes(t, func(t *testing.T, mode transferMode) { + s, c := newFixture(t, mode) + defer s.Shutdown() + testSendReceive(t, c, 3*1000*1000) + }) +} + +func TestByOneByte(t *testing.T) { + forModes(t, func(t *testing.T, mode transferMode) { + s, c := newFixture(t, mode) + defer s.Shutdown() + filename := "test-by-one-byte" + xferMode := "octet" + const length = 80000 + sender, err := c.Send(filename, xferMode) + if err != nil { + t.Fatalf("requesting write: %v", err) + } + r := iotest.OneByteReader(io.LimitReader( + newRandReader(rand.NewSource(42)), length)) + n, err := sender.ReadFrom(r) + if err != nil { + t.Fatalf("send error: %v", err) + } + if n != length { + t.Errorf("%s read length mismatch: %d != %d", filename, n, length) + } + readTransfer, err := c.Receive(filename, xferMode) + if err != nil { + t.Fatalf("requesting read %s: %v", filename, err) + } + buf := &bytes.Buffer{} + n, err = readTransfer.WriteTo(buf) + if err != nil { + t.Fatalf("%s read error: %v", filename, err) + } + if n != length { + t.Errorf("%s read length mismatch: %d != %d", filename, n, length) + } + bs, _ := io.ReadAll(io.LimitReader( + newRandReader(rand.NewSource(42)), length)) + if !bytes.Equal(bs, buf.Bytes()) { + t.Errorf("\nsent: %x\nrcvd: %x", bs, buf) + } + }) +} diff --git a/tsize_test.go b/tsize_test.go new file mode 100644 index 0000000..cb29f50 --- /dev/null +++ b/tsize_test.go @@ -0,0 +1,86 @@ +package tftp + +import ( + "bytes" + "fmt" + "io" + "math/rand" + "net" + "testing" +) + +func TestTSize(t *testing.T) { + forModes(t, func(t *testing.T, mode transferMode) { + s, c := newFixture(t, mode) + defer s.Shutdown() + c.RequestTSize(true) + testSendReceive(t, c, 640) + }) +} + +func TestSendTsizeFromSeek(t *testing.T) { + forModes(t, func(t *testing.T, mode transferMode) { + // create read-only server + s := NewServer(func(filename string, rf io.ReaderFrom) error { + b := make([]byte, 100) + rr := newRandReader(rand.NewSource(42)) + rr.Read(b) + // bytes.Reader implements io.Seek + r := bytes.NewReader(b) + _, err := rf.ReadFrom(r) + if err != nil { + t.Errorf("sending bytes: %v", err) + } + return nil + }, nil) + if mode == modeSinglePort { + s.SetBlockSize(2000) + s.EnableSinglePort() + } + + conn, err := net.ListenUDP("udp", &net.UDPAddr{}) + if err != nil { + t.Fatalf("listening: %v", err) + } + + go s.Serve(conn) + defer s.Shutdown() + + c, err := NewClient(localSystem(t, conn)) + if err != nil { + t.Fatalf("new client: %v", err) + } + c.RequestTSize(true) + r, err := c.Receive("f", "octet") + if err != nil { + t.Fatalf("receive: %v", err) + } + var size int64 + if it, ok := r.(IncomingTransfer); ok { + if n, ok := it.Size(); ok { + size = n + fmt.Printf("Transfer size: %d\n", n) + } + } + + if size != 100 { + t.Errorf("size expected: 100, got %d", size) + } + + r.WriteTo(io.Discard) + + c.RequestTSize(false) + r, err = c.Receive("f", "octet") + if err != nil { + t.Fatalf("receive without tsize: %v", err) + } + if it, ok := r.(IncomingTransfer); ok { + _, ok := it.Size() + if ok { + t.Errorf("unexpected size received") + } + } + + r.WriteTo(io.Discard) + }) +} diff --git a/util_test.go b/util_test.go new file mode 100644 index 0000000..7940421 --- /dev/null +++ b/util_test.go @@ -0,0 +1,379 @@ +package tftp + +import ( + "bytes" + "fmt" + "io" + "math/rand" + "net" + "os" + "sync" + "testing" + "time" +) + +// Shared test fixtures/helpers only. Do not add test cases here. + +type transferMode string + +const ( + modeRegular transferMode = "regular" + modeSinglePort transferMode = "single_port" +) + +func forModes(t *testing.T, fn func(t *testing.T, mode transferMode)) { + t.Helper() + for _, mode := range []transferMode{modeRegular, modeSinglePort} { + mode := mode + t.Run(string(mode), func(t *testing.T) { + fn(t, mode) + }) + } +} + +func newFixture(t *testing.T, mode transferMode) (*Server, *Client) { + t.Helper() + switch mode { + case modeRegular: + return makeTestServer(t, false) + case modeSinglePort: + return makeTestServer(t, true) + default: + t.Fatalf("unknown transfer mode: %q", mode) + return nil, nil + } +} + +var ( + localhostOnce sync.Once + localhostAddr string + localhostErr error +) + +func determineLocalhost() (string, error) { + l, err := net.ListenTCP("tcp", nil) + if err != nil { + return "", fmt.Errorf("listen tcp: %w", err) + } + _, lport, err := net.SplitHostPort(l.Addr().String()) + if err != nil { + l.Close() + return "", fmt.Errorf("split host port: %w", err) + } + defer l.Close() + + go func() { + for { + conn, err := l.Accept() + if err != nil { + break + } + conn.Close() + } + }() + + for _, af := range []string{"tcp6", "tcp4"} { + conn, err := net.Dial(af, net.JoinHostPort("", lport)) + if err != nil { + continue + } + host, _, splitErr := net.SplitHostPort(conn.LocalAddr().String()) + conn.Close() + if splitErr == nil { + return host, nil + } + } + + return "", fmt.Errorf("could not determine localhost address family") +} + +func resolveLocalhost() (string, error) { + localhostOnce.Do(func() { + localhostAddr, localhostErr = determineLocalhost() + }) + return localhostAddr, localhostErr +} + +func localSystem(tb testing.TB, c *net.UDPConn) string { + tb.Helper() + host, err := resolveLocalhost() + if err != nil { + tb.Fatalf("resolve localhost: %v", err) + } + _, port, err := net.SplitHostPort(c.LocalAddr().String()) + if err != nil { + tb.Fatalf("split listener address: %v", err) + } + return net.JoinHostPort(host, port) +} + +type testHook struct { + *sync.Mutex + transfersCompleted int + transfersFailed int +} + +func newTestHook() *testHook { + return &testHook{ + Mutex: &sync.Mutex{}, + } +} + +func (h *testHook) OnSuccess(result TransferStats) { + h.Lock() + defer h.Unlock() + h.transfersCompleted++ +} + +func (h *testHook) OnFailure(result TransferStats, err error) { + h.Lock() + defer h.Unlock() + h.transfersFailed++ +} + +func testSendReceive(t *testing.T, client *Client, length int64) { + t.Helper() + filename := fmt.Sprintf("length-%d-bytes", length) + mode := "octet" + writeTransfer, err := client.Send(filename, mode) + if err != nil { + t.Fatalf("requesting write %s: %v", filename, err) + } + r := io.LimitReader(newRandReader(rand.NewSource(42)), length) + n, err := writeTransfer.ReadFrom(r) + if err != nil { + t.Fatalf("%s write error: %v", filename, err) + } + if n != length { + t.Errorf("%s write length mismatch: %d != %d", filename, n, length) + } + readTransfer, err := client.Receive(filename, mode) + if err != nil { + t.Fatalf("requesting read %s: %v", filename, err) + } + if it, ok := readTransfer.(IncomingTransfer); ok { + if n, ok := it.Size(); ok { + fmt.Printf("Transfer size: %d\n", n) + if n != length { + t.Errorf("tsize mismatch: %d vs %d", n, length) + } + } + } + buf := &bytes.Buffer{} + n, err = readTransfer.WriteTo(buf) + if err != nil { + t.Fatalf("%s read error: %v", filename, err) + } + if n != length { + t.Errorf("%s read length mismatch: %d != %d", filename, n, length) + } + bs, _ := io.ReadAll(io.LimitReader( + newRandReader(rand.NewSource(42)), length)) + if !bytes.Equal(bs, buf.Bytes()) { + t.Errorf("\nsent: %x\nrcvd: %x", bs, buf) + } +} + +type testBackend struct { + m map[string][]byte + mu sync.Mutex +} + +func makeTestServer(t *testing.T, singlePort bool) (*Server, *Client) { + t.Helper() + b := &testBackend{} + b.m = make(map[string][]byte) + + // Create server + s := NewServer(b.handleRead, b.handleWrite) + + if singlePort { + s.SetBlockSize(2000) + s.EnableSinglePort() + } + + conn, err := net.ListenUDP("udp", &net.UDPAddr{}) + if err != nil { + t.Fatalf("listen udp: %v", err) + } + + go s.Serve(conn) + + // Create client for that server + c, err := NewClient(localSystem(t, conn)) + if err != nil { + t.Fatalf("new client: %v", err) + } + + return s, c +} + +func (b *testBackend) handleWrite(filename string, wt io.WriterTo) error { + b.mu.Lock() + defer b.mu.Unlock() + _, ok := b.m[filename] + if ok { + fmt.Fprintf(os.Stderr, "File %s already exists\n", filename) + return fmt.Errorf("file already exists") + } + if t, ok := wt.(IncomingTransfer); ok { + if n, ok := t.Size(); ok { + fmt.Printf("Transfer size: %d\n", n) + } + } + buf := &bytes.Buffer{} + _, err := wt.WriteTo(buf) + if err != nil { + fmt.Fprintf(os.Stderr, "Can't receive %s: %v\n", filename, err) + return err + } + b.m[filename] = buf.Bytes() + return nil +} + +func (b *testBackend) handleRead(filename string, rf io.ReaderFrom) error { + b.mu.Lock() + defer b.mu.Unlock() + bs, ok := b.m[filename] + if !ok { + fmt.Fprintf(os.Stderr, "File %s not found\n", filename) + return fmt.Errorf("file not found") + } + if t, ok := rf.(OutgoingTransfer); ok { + t.SetSize(int64(len(bs))) + } + _, err := rf.ReadFrom(bytes.NewBuffer(bs)) + if err != nil { + fmt.Fprintf(os.Stderr, "Can't send %s: %v\n", filename, err) + return err + } + return nil +} + +type randReader struct { + src rand.Source + next int64 + i int8 +} + +func newRandReader(src rand.Source) io.Reader { + r := &randReader{ + src: src, + next: src.Int63(), + } + return r +} + +func (r *randReader) Read(p []byte) (n int, err error) { + next, i := r.next, r.i + for n = 0; n < len(p); n++ { + if i == 7 { + next, i = r.src.Int63(), 0 + } + p[n] = byte(next) + next >>= 8 + i++ + } + r.next, r.i = next, i + return +} + +func serverTimeoutSendTest(s *Server, c *Client, t *testing.T) { + t.Helper() + s.SetTimeout(time.Second) + s.SetRetries(2) + sec := make(chan error, 1) + s.mu.Lock() + s.readHandler = func(filename string, rf io.ReaderFrom) error { + r := io.LimitReader(newRandReader(rand.NewSource(42)), 80000) + _, err := rf.ReadFrom(r) + sec <- err + return err + } + s.mu.Unlock() + defer s.Shutdown() + filename := "test-server-send-timeout" + mode := "octet" + readTransfer, err := c.Receive(filename, mode) + if err != nil { + t.Fatalf("requesting read %s: %v", filename, err) + } + w := &slowWriter{ + n: 3, + delay: 8 * time.Second, + } + _, _ = readTransfer.WriteTo(w) + servErr := <-sec + netErr, ok := servErr.(net.Error) + if !ok { + t.Fatalf("network error expected: %T", servErr) + } + if !netErr.Timeout() { + t.Fatalf("timout is expected: %v", servErr) + } +} + +func serverReceiveTimeoutTest(s *Server, c *Client, t *testing.T) { + t.Helper() + s.SetTimeout(time.Second) + s.SetRetries(2) + sec := make(chan error, 1) + s.mu.Lock() + s.writeHandler = func(filename string, wt io.WriterTo) error { + buf := &bytes.Buffer{} + _, err := wt.WriteTo(buf) + sec <- err + return err + } + s.mu.Unlock() + defer s.Shutdown() + filename := "test-server-receive-timeout" + mode := "octet" + writeTransfer, err := c.Send(filename, mode) + if err != nil { + t.Fatalf("requesting write %s: %v", filename, err) + } + r := &slowReader{ + r: io.LimitReader(newRandReader(rand.NewSource(42)), 80000), + n: 3, + delay: 8 * time.Second, + } + _, _ = writeTransfer.ReadFrom(r) + servErr := <-sec + netErr, ok := servErr.(net.Error) + if !ok { + t.Fatalf("network error expected: %T", servErr) + } + if !netErr.Timeout() { + t.Fatalf("timout is expected: %v", servErr) + } +} + +type slowReader struct { + r io.Reader + n int64 + delay time.Duration +} + +func (r *slowReader) Read(p []byte) (n int, err error) { + if r.n > 0 { + r.n-- + return r.r.Read(p) + } + time.Sleep(r.delay) + return r.r.Read(p) +} + +type slowWriter struct { + n int64 + delay time.Duration +} + +func (r *slowWriter) Write(p []byte) (n int, err error) { + if r.n > 0 { + r.n-- + return len(p), nil + } + time.Sleep(r.delay) + return len(p), nil +}