diff --git a/pool/marshal.go b/pool/marshal.go index 6d26ea5..f5d81c4 100644 --- a/pool/marshal.go +++ b/pool/marshal.go @@ -30,6 +30,9 @@ func marshalJob(job *Job) []byte { if err := binary.Write(&buf, binary.LittleEndian, job.CreatedAt.UnixNano()); err != nil { panic(err) } + if err := binary.Write(&buf, binary.LittleEndian, job.Requeued); err != nil { + panic(err) + } return buf.Bytes() } @@ -73,6 +76,7 @@ func unmarshalJob(data []byte) *Job { Payload: payload, CreatedAt: time.Unix(0, createdAtTimestamp).UTC(), NodeID: nodeID, + Requeued: unmarshalBool(reader), } } @@ -122,6 +126,14 @@ func unmarshalJobKeyAndNodeID(data []byte) (string, string) { return string(keyBytes), string(nodeIDBytes) } +func unmarshalBool(reader *bytes.Reader) bool { + var value bool + if err := binary.Read(reader, binary.LittleEndian, &value); err != nil { + panic(err) + } + return value +} + func marshalNotification(key string, payload []byte) []byte { var buf bytes.Buffer if err := binary.Write(&buf, binary.LittleEndian, int32(len(key))); err != nil { diff --git a/pool/marshal_test.go b/pool/marshal_test.go index 92d6c35..e017042 100644 --- a/pool/marshal_test.go +++ b/pool/marshal_test.go @@ -28,6 +28,15 @@ func TestMarshalJob(t *testing.T) { CreatedAt: time.Date(2022, 1, 1, 0, 0, 0, 0, time.UTC), }, }, + { + name: "requeued job", + job: Job{ + Key: "test-key", + Payload: []byte("test-payload"), + CreatedAt: time.Date(2022, 1, 1, 0, 0, 0, 0, time.UTC), + Requeued: true, + }, + }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { @@ -38,6 +47,7 @@ func TestMarshalJob(t *testing.T) { assert.Equal(t, tc.job.Key, job.Key) assert.Equal(t, tc.job.Payload, job.Payload) assert.Equal(t, tc.job.CreatedAt, job.CreatedAt) + assert.Equal(t, tc.job.Requeued, job.Requeued) // Compare original and unmarshaled byte slices marshaled2 := marshalJob(job) diff --git a/pool/node.go b/pool/node.go index 4e0ed75..8a7ed70 100644 --- a/pool/node.go +++ b/pool/node.go @@ -96,8 +96,13 @@ const ( // pendingEventTTL is the TTL for pending events. var pendingEventTTL = 2 * time.Minute -// ErrJobExists is returned when attempting to dispatch a job with a key that already exists. -var ErrJobExists = errors.New("job already exists") +var ( + // ErrJobExists is returned when attempting to dispatch a job with a key that already exists. + ErrJobExists = errors.New("job already exists") + + errJobAwaitingOwner = errors.New("job awaiting active owner") + errJobNotFound = errors.New("job not found") +) // AddNode adds a new node to the pool with the given name and returns it. The // node can be used to dispatch jobs and add new workers. A node also routes @@ -696,11 +701,20 @@ func (node *Node) routeWorkerEvent(ev *streaming.Event) error { // Compute the worker ID that will handle the job. key := unmarshalJobKey(ev.Payload) - activeWorkers := node.activeWorkers() - if len(activeWorkers) == 0 { - return fmt.Errorf("routeWorkerEvent: no active worker in pool %q", node.PoolName) + wid, err := node.workerForEvent(ev.EventName, key) + if err != nil { + if errors.Is(err, errJobAwaitingOwner) { + node.logger.Debug("routeWorkerEvent: job has no active owner yet", "event", ev.EventName, "id", ev.ID, "key", key) + return nil + } + if errors.Is(err, errJobNotFound) { + if ackErr := node.poolSink.Ack(context.Background(), ev); ackErr != nil { + node.logger.Error(fmt.Errorf("routeWorkerEvent: failed to ack event for missing job: %w", ackErr), "event", ev.EventName, "id", ev.ID) + } + return nil + } + return err } - wid := activeWorkers[node.h.Hash(key, int64(len(activeWorkers)))] // Stream the event to the worker corresponding to the key hash. stream, err := node.getWorkerStream(wid) @@ -824,6 +838,82 @@ func (node *Node) returnDispatchStatus(ev *streaming.Event) { val.(chan error) <- err } +// workerForEvent returns the worker that should receive a pool event. Start +// events are routed by the current consistent hash ring; stop and notification +// events target the worker that currently owns the job. +func (node *Node) workerForEvent(eventName, key string) (string, error) { + if eventName == evStartJob { + activeWorkers := node.activeWorkers() + if len(activeWorkers) == 0 { + return "", fmt.Errorf("routeWorkerEvent: no active worker in pool %q", node.PoolName) + } + return activeWorkers[node.h.Hash(key, int64(len(activeWorkers)))], nil + } + if eventName == evStopJob || eventName == evNotify { + owner, ok, err := node.activeJobOwner(key) + if err != nil { + return "", err + } + if !ok { + exists, err := node.jobPayloadExists(context.Background(), key) + if err != nil { + return "", err + } + if exists { + return "", fmt.Errorf("%w: %q", errJobAwaitingOwner, key) + } + return "", fmt.Errorf("%w: %q", errJobNotFound, key) + } + return owner, nil + } + return "", fmt.Errorf("routeWorkerEvent: unknown worker event %q", eventName) +} + +// jobPayloadExists reads the durable job record from Redis, which is the source +// of truth when the local ownership map has no active owner during handoff. +func (node *Node) jobPayloadExists(ctx context.Context, key string) (bool, error) { + exists, err := node.rdb.HExists(ctx, rmapContentKey(jobPayloadMapName(node.PoolName)), key).Result() + if err != nil { + return false, fmt.Errorf("routeWorkerEvent: failed to check job payload %q: %w", key, err) + } + return exists, nil +} + +// activeJobOwner returns the single active worker that owns a job key according +// to the replicated ownership map. +func (node *Node) activeJobOwner(key string) (string, bool, error) { + activeWorkers := node.activeWorkers() + active := make(map[string]struct{}, len(activeWorkers)) + for _, workerID := range activeWorkers { + active[workerID] = struct{}{} + } + + var owner string + for workerID := range node.jobMap.Map() { + if _, ok := active[workerID]; !ok { + continue + } + keys, ok := node.jobMap.GetValues(workerID) + if !ok { + continue + } + for _, ownedKey := range keys { + if ownedKey != key { + continue + } + if owner != "" { + return "", false, fmt.Errorf("routeWorkerEvent: job %q has multiple active owners", key) + } + owner = workerID + break + } + } + if owner == "" { + return "", false, nil + } + return owner, true, nil +} + // watches monitors the workers replicated map and triggers job rebalancing // when workers are added or removed from the pool. func (node *Node) watchWorkers(ctx context.Context) { @@ -1067,7 +1157,7 @@ func (node *Node) requeueOrphanedPayloads(ctx context.Context) { node.orphanedPayloads.Delete(key) continue } - job := &Job{Key: key, Payload: payload, CreatedAt: now, NodeID: node.ID} + job := &Job{Key: key, Payload: payload, CreatedAt: now, NodeID: node.ID, Requeued: true} if _, err := node.poolStream.Add(ctx, evStartJob, marshalJob(job)); err != nil { node.logger.Error(fmt.Errorf("requeueOrphanedPayloads: failed to requeue orphaned job: %w", err), "key", key) continue @@ -1117,7 +1207,7 @@ func (node *Node) cleanupWorker(ctx context.Context, workerID string) { processed++ continue } - job := &Job{Key: key, Payload: payload, CreatedAt: time.Now(), NodeID: node.ID} + job := &Job{Key: key, Payload: payload, CreatedAt: time.Now(), NodeID: node.ID, Requeued: true} // Requeue by adding an event back to the pool stream. // We intentionally do not wait for the job to start (which can time out // under heavy churn) - the pool sink will retry routing until it is acked. diff --git a/pool/node_test.go b/pool/node_test.go index 3677ae2..f74236d 100644 --- a/pool/node_test.go +++ b/pool/node_test.go @@ -250,6 +250,81 @@ func TestDispatchJobTwoWorkers(t *testing.T) { assert.NoError(t, node.Shutdown(ctx), "Failed to shutdown node") } +func TestStopJobRoutesToCurrentOwner(t *testing.T) { + testName := strings.Replace(t.Name(), "/", "_", -1) + ctx := ptesting.NewTestContext(t) + rdb := ptesting.NewRedisClient(t) + node := newTestNode(t, ctx, rdb, testName) + defer ptesting.CleanupRedis(t, rdb, true, testName) + + node.h = &ptesting.Hasher{Index: 0} + worker1 := newTestWorker(t, ctx, node) + worker2 := newTestWorker(t, ctx, node) + + stopped := make(chan struct{}) + var stoppedOnce sync.Once + worker1.handler.(*mockHandler).stopFunc = func(key string) error { + stoppedOnce.Do(func() { close(stopped) }) + return nil + } + worker2.handler.(*mockHandler).stopFunc = func(key string) error { + t.Errorf("stop routed to worker without ownership: %s", key) + return nil + } + + payload := []byte("payload") + require.NoError(t, node.DispatchJob(ctx, testName, payload)) + require.Eventually(t, func() bool { + return len(worker1.Jobs()) == 1 + }, max, delay) + require.Eventually(t, func() bool { + return sameStrings(jobOwners(node, testName), []string{worker1.ID}) + }, max, delay) + + node.h = &ptesting.Hasher{Index: 1} + require.NoError(t, node.StopJob(ctx, testName)) + select { + case <-stopped: + case <-time.After(max): + t.Fatal("job stop was not routed to current owner") + } + require.Eventually(t, func() bool { + return len(worker1.Jobs()) == 0 && len(worker2.Jobs()) == 0 + }, max, delay) + require.Eventually(t, func() bool { + _, ok := node.JobPayload(testName) + return !ok && len(jobOwners(node, testName)) == 0 + }, max, delay) + + assert.NoError(t, node.Shutdown(ctx), "Failed to shutdown node") +} + +func TestControlEventRoutingDuringOwnershipGaps(t *testing.T) { + testName := strings.Replace(t.Name(), "/", "_", -1) + ctx := ptesting.NewTestContext(t) + rdb := ptesting.NewRedisClient(t) + node := newTestNode(t, ctx, rdb, testName) + defer ptesting.CleanupRedis(t, rdb, true, testName) + + const jobKey = "handoff-job" + _, err := node.jobPayloadMap.SetAndWait(ctx, jobKey, "payload") + require.NoError(t, err) + + _, err = node.workerForEvent(evStopJob, jobKey) + assert.ErrorIs(t, err, errJobAwaitingOwner) + + _, err = node.workerForEvent(evNotify, jobKey) + assert.ErrorIs(t, err, errJobAwaitingOwner) + + _, err = node.jobPayloadMap.Delete(ctx, jobKey) + require.NoError(t, err) + + _, err = node.workerForEvent(evStopJob, jobKey) + assert.ErrorIs(t, err, errJobNotFound) + + assert.NoError(t, node.Shutdown(ctx), "Failed to shutdown node") +} + func TestDispatchJobRaceCondition(t *testing.T) { testName := strings.Replace(t.Name(), "/", "_", -1) ctx := ptesting.NewTestContext(t) diff --git a/pool/worker.go b/pool/worker.go index 85a2378..4371cba 100644 --- a/pool/worker.go +++ b/pool/worker.go @@ -56,6 +56,9 @@ type ( Payload []byte // CreatedAt is the time the job was created. CreatedAt time.Time + // Requeued indicates that this start event moves or recovers an existing + // durable job payload rather than admitting a new dispatched job. + Requeued bool // Worker is the worker that handles the job. Worker *Worker // NodeID is the ID of the node that created the job. @@ -85,6 +88,8 @@ type ( } ) +var errJobNotOwned = errors.New("job not owned by worker") + // newWorker creates a new worker. func newWorker(ctx context.Context, node *Node, h JobHandler) (*Worker, error) { wid := ulid.Make().String() @@ -247,6 +252,9 @@ func (w *Worker) startJob(ctx context.Context, job *Job) error { } if _, err := w.jobPayloadsMap.Set(ctx, job.Key, string(job.Payload)); err != nil { w.logger.Error(fmt.Errorf("failed to add job payload %q to job payloads map: %w, requeueing", job.Key, err)) + if _, _, removeErr := w.jobsMap.RemoveValues(ctx, w.ID, job.Key); removeErr != nil { + w.logger.Error(fmt.Errorf("start failure handling: failed to remove job %q from jobs map: %w", job.Key, removeErr)) + } return ErrRequeue } job.Worker = w @@ -255,8 +263,10 @@ func (w *Worker) startJob(ctx context.Context, job *Job) error { if _, _, err := w.jobsMap.RemoveValues(ctx, w.ID, job.Key); err != nil { w.logger.Error(fmt.Errorf("start failure handling: failed to remove job %q from jobs map: %w", job.Key, err)) } - if _, err := w.jobPayloadsMap.Delete(ctx, job.Key); err != nil { - w.logger.Error(fmt.Errorf("start failure handling: failed to remove job payload %q from job payloads map: %w", job.Key, err)) + if !job.Requeued { + if _, err := w.jobPayloadsMap.Delete(ctx, job.Key); err != nil { + w.logger.Error(fmt.Errorf("start failure handling: failed to remove job payload %q from job payloads map: %w", job.Key, err)) + } } return err } @@ -267,8 +277,24 @@ func (w *Worker) startJob(ctx context.Context, job *Job) error { // stopJob stops a job. func (w *Worker) stopJob(ctx context.Context, key string) error { + if err := w.releaseJob(ctx, key); err != nil { + if errors.Is(err, errJobNotOwned) { + return ErrRequeue + } + return err + } + if _, err := w.jobPayloadsMap.Delete(ctx, key); err != nil { + w.logger.Error(fmt.Errorf("stop job: failed to remove job payload %q from job payloads map: %w", key, err)) + } + w.logger.Info("stopped job", "job", key) + return nil +} + +// releaseJob stops local execution and removes this worker's ownership while +// preserving the shared payload for another worker to claim. +func (w *Worker) releaseJob(ctx context.Context, key string) error { if _, ok := w.jobs.Load(key); !ok { - return fmt.Errorf("job %s not found in local worker", key) + return fmt.Errorf("%w: %s", errJobNotOwned, key) } if err := w.handler.Stop(key); err != nil { return fmt.Errorf("failed to stop job %q: %w", key, err) @@ -276,12 +302,8 @@ func (w *Worker) stopJob(ctx context.Context, key string) error { w.logger.Debug("stopped job", "job", key) w.jobs.Delete(key) if _, _, err := w.jobsMap.RemoveValues(ctx, w.ID, key); err != nil { - w.logger.Error(fmt.Errorf("stop job: failed to remove job %q from jobs map: %w", key, err)) - } - if _, err := w.jobPayloadsMap.Delete(ctx, key); err != nil { - w.logger.Error(fmt.Errorf("stop job: failed to remove job payload %q from job payloads map: %w", key, err)) + return fmt.Errorf("failed to release job %q from jobs map: %w", key, err) } - w.logger.Info("stopped job", "job", key) return nil } @@ -291,6 +313,9 @@ func (w *Worker) notify(_ context.Context, key string, payload []byte) error { w.logger.Debug("worker stopped, ignoring notification") return nil } + if _, ok := w.jobs.Load(key); !ok { + return ErrRequeue + } nh, ok := w.handler.(NotificationHandler) if !ok { w.logger.Error(fmt.Errorf("worker does not implement NotificationHandler, ignoring notification"), "worker", w.ID) @@ -360,21 +385,22 @@ func (w *Worker) rebalance(ctx context.Context, activeWorkers []string) { return } for key, job := range rebalanced { - if err := w.handler.Stop(key); err != nil { - w.logger.Error(fmt.Errorf("rebalance: failed to stop job: %w", err), "job", key) + job.Requeued = true + if err := w.releaseJob(ctx, key); err != nil { + w.logger.Error(fmt.Errorf("rebalance: failed to release job: %w", err), "job", key) + if _, ok := w.jobs.Load(key); !ok { + if err := w.startJob(ctx, job); err != nil { + w.logger.Error(fmt.Errorf("rebalance: failed to restart job: %w", err), "job", key) + } + } continue } - w.logger.Debug("stopped job", "job", key) - w.jobs.Delete(key) if _, err := w.node.poolStream.Add(ctx, evStartJob, marshalJob(job)); err != nil { w.logger.Error(fmt.Errorf("rebalance: failed to requeue job: %w", err), "job", key) - if err := w.handler.Start(job); err != nil { + if err := w.startJob(ctx, job); err != nil { w.logger.Error(fmt.Errorf("rebalance: failed to restart job: %w", err), "job", key) continue } - // Requeue failed but we restarted the job locally; restore local - // tracking so future close/shutdown can still requeue it. - w.jobs.Store(key, job) continue } delete(rebalanced, key) @@ -486,6 +512,7 @@ func (w *Worker) attemptRequeue(ctx context.Context, jobsToRequeue map[string]*J // requeueJob requeues a job. func (w *Worker) requeueJob(ctx context.Context, job *Job) error { + job.Requeued = true eventID, err := w.node.poolStream.Add(ctx, evStartJob, marshalJob(job)) if err != nil { return fmt.Errorf("requeueJob: failed to add job to pool stream: %w", err) diff --git a/pool/worker_test.go b/pool/worker_test.go index d8c000e..f3ea2c8 100644 --- a/pool/worker_test.go +++ b/pool/worker_test.go @@ -2,6 +2,8 @@ package pool import ( "context" + "errors" + "sort" "strconv" "strings" "testing" @@ -10,6 +12,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "goa.design/pulse/rmap" ptesting "goa.design/pulse/testing" ) @@ -54,6 +57,116 @@ func TestWorkerRequeueJobs(t *testing.T) { assert.NoError(t, node.Shutdown(ctx)) } +func TestWorkerRebalanceReleasesPreviousJobOwner(t *testing.T) { + var ( + ctx = ptesting.NewTestContext(t) + testName = strings.Replace(t.Name(), "/", "_", -1) + rdb = ptesting.NewRedisClient(t) + node = newTestNode(t, ctx, rdb, testName) + ) + defer ptesting.CleanupRedis(t, rdb, true, testName) + node.h = &ptesting.Hasher{IndexFunc: func(_ string, numBuckets int64) int64 { + return numBuckets - 1 + }} + + const jobKey = "rebalance-job" + payload := []byte("payload") + worker1 := newTestWorker(t, ctx, node) + require.NoError(t, node.DispatchJob(ctx, jobKey, payload)) + require.Eventually(t, func() bool { + return len(worker1.Jobs()) == 1 + }, max, delay) + require.Eventually(t, func() bool { + return sameStrings(jobOwners(node, jobKey), []string{worker1.ID}) + }, max, delay) + + worker2 := newTestWorker(t, ctx, node) + worker1.rebalance(ctx, []string{worker1.ID, worker2.ID}) + require.Eventually(t, func() bool { + return len(worker1.Jobs()) == 0 && len(worker2.Jobs()) == 1 + }, max, delay) + require.Eventually(t, func() bool { + return sameStrings(jobOwners(node, jobKey), []string{worker2.ID}) + }, max, delay, "job must have exactly one replicated owner after rebalance") + gotPayload, ok := node.JobPayload(jobKey) + require.True(t, ok) + assert.Equal(t, payload, gotPayload) + + assert.NoError(t, node.Shutdown(ctx)) +} + +func TestWorkerStartFailurePayloadOwnership(t *testing.T) { + var ( + ctx = ptesting.NewTestContext(t) + testName = strings.Replace(t.Name(), "/", "_", -1) + rdb = ptesting.NewRedisClient(t) + node = newTestNode(t, ctx, rdb, testName) + ) + defer ptesting.CleanupRedis(t, rdb, true, testName) + + errStart := errors.New("start failed") + worker := newTestWorker(t, ctx, node) + worker.handler.(*mockHandler).startFunc = func(job *Job) error { + return errStart + } + + err := worker.startJob(ctx, &Job{ + Key: "new-job", + Payload: []byte("new payload"), + CreatedAt: time.Now(), + NodeID: node.ID, + }) + assert.ErrorIs(t, err, errStart) + _, ok := snapshotValue(t, ctx, node, node.jobPayloadMap, "new-job") + assert.False(t, ok) + assert.Empty(t, snapshotJobOwners(t, ctx, node, "new-job")) + + err = worker.startJob(ctx, &Job{ + Key: "requeued-job", + Payload: []byte("requeued payload"), + CreatedAt: time.Now(), + NodeID: node.ID, + Requeued: true, + }) + assert.ErrorIs(t, err, errStart) + assert.Empty(t, snapshotJobOwners(t, ctx, node, "requeued-job")) + gotPayload, ok := snapshotValue(t, ctx, node, node.jobPayloadMap, "requeued-job") + require.True(t, ok) + assert.Equal(t, "requeued payload", gotPayload) + + assert.NoError(t, node.Shutdown(ctx)) +} + +func TestWorkerControlEventsRequireLocalOwnership(t *testing.T) { + var ( + ctx = ptesting.NewTestContext(t) + testName = strings.Replace(t.Name(), "/", "_", -1) + rdb = ptesting.NewRedisClient(t) + node = newTestNode(t, ctx, rdb, testName) + ) + defer ptesting.CleanupRedis(t, rdb, true, testName) + + worker := newTestWorker(t, ctx, node) + stopCalled := false + notifyCalled := false + worker.handler.(*mockHandler).stopFunc = func(key string) error { + stopCalled = true + return nil + } + worker.handler.(*mockHandler).notifyFunc = func(key string, payload []byte) error { + notifyCalled = true + return nil + } + + assert.ErrorIs(t, worker.stopJob(ctx, "missing-job"), ErrRequeue) + assert.False(t, stopCalled) + + assert.ErrorIs(t, worker.notify(ctx, "missing-job", []byte("payload")), ErrRequeue) + assert.False(t, notifyCalled) + + assert.NoError(t, node.Shutdown(ctx)) +} + func TestStaleWorkerCleanupInNode(t *testing.T) { var ( ctx = ptesting.NewTestContext(t) @@ -134,3 +247,63 @@ func TestStaleWorkerCleanupAcrossNodes(t *testing.T) { assert.NoError(t, node1.Shutdown(ctx)) assert.NoError(t, node2.Shutdown(ctx)) } + +func jobOwners(node *Node, key string) []string { + var owners []string + for workerID := range node.jobMap.Map() { + values, ok := node.jobMap.GetValues(workerID) + if !ok { + continue + } + for _, value := range values { + if value == key { + owners = append(owners, workerID) + break + } + } + } + sort.Strings(owners) + return owners +} + +func snapshotJobOwners(t *testing.T, ctx context.Context, node *Node, key string) []string { + t.Helper() + snapshot, err := rmap.Join(ctx, node.jobMap.Name, node.rdb) + require.NoError(t, err) + defer snapshot.Close() + var owners []string + for workerID := range snapshot.Map() { + values, ok := snapshot.GetValues(workerID) + if !ok { + continue + } + for _, value := range values { + if value == key { + owners = append(owners, workerID) + break + } + } + } + sort.Strings(owners) + return owners +} + +func snapshotValue(t *testing.T, ctx context.Context, node *Node, source *rmap.Map, key string) (string, bool) { + t.Helper() + snapshot, err := rmap.Join(ctx, source.Name, node.rdb) + require.NoError(t, err) + defer snapshot.Close() + return snapshot.Get(key) +} + +func sameStrings(got, want []string) bool { + if len(got) != len(want) { + return false + } + for i := range got { + if got[i] != want[i] { + return false + } + } + return true +}