Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions pool/marshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ func marshalJob(job *Job) []byte {
if err := binary.Write(&buf, binary.LittleEndian, job.CreatedAt.UnixNano()); err != nil {
panic(err)
}
if err := binary.Write(&buf, binary.LittleEndian, job.Requeued); err != nil {
panic(err)
}
return buf.Bytes()
}

Expand Down Expand Up @@ -73,6 +76,7 @@ func unmarshalJob(data []byte) *Job {
Payload: payload,
CreatedAt: time.Unix(0, createdAtTimestamp).UTC(),
NodeID: nodeID,
Requeued: unmarshalBool(reader),
}
}

Expand Down Expand Up @@ -122,6 +126,14 @@ func unmarshalJobKeyAndNodeID(data []byte) (string, string) {
return string(keyBytes), string(nodeIDBytes)
}

func unmarshalBool(reader *bytes.Reader) bool {
var value bool
if err := binary.Read(reader, binary.LittleEndian, &value); err != nil {
panic(err)
}
return value
}

func marshalNotification(key string, payload []byte) []byte {
var buf bytes.Buffer
if err := binary.Write(&buf, binary.LittleEndian, int32(len(key))); err != nil {
Expand Down
10 changes: 10 additions & 0 deletions pool/marshal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@ func TestMarshalJob(t *testing.T) {
CreatedAt: time.Date(2022, 1, 1, 0, 0, 0, 0, time.UTC),
},
},
{
name: "requeued job",
job: Job{
Key: "test-key",
Payload: []byte("test-payload"),
CreatedAt: time.Date(2022, 1, 1, 0, 0, 0, 0, time.UTC),
Requeued: true,
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
Expand All @@ -38,6 +47,7 @@ func TestMarshalJob(t *testing.T) {
assert.Equal(t, tc.job.Key, job.Key)
assert.Equal(t, tc.job.Payload, job.Payload)
assert.Equal(t, tc.job.CreatedAt, job.CreatedAt)
assert.Equal(t, tc.job.Requeued, job.Requeued)

// Compare original and unmarshaled byte slices
marshaled2 := marshalJob(job)
Expand Down
106 changes: 98 additions & 8 deletions pool/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,13 @@ const (
// pendingEventTTL is the TTL for pending events.
var pendingEventTTL = 2 * time.Minute

// ErrJobExists is returned when attempting to dispatch a job with a key that already exists.
var ErrJobExists = errors.New("job already exists")
var (
// ErrJobExists is returned when attempting to dispatch a job with a key that already exists.
ErrJobExists = errors.New("job already exists")

errJobAwaitingOwner = errors.New("job awaiting active owner")
errJobNotFound = errors.New("job not found")
)

// AddNode adds a new node to the pool with the given name and returns it. The
// node can be used to dispatch jobs and add new workers. A node also routes
Expand Down Expand Up @@ -696,11 +701,20 @@ func (node *Node) routeWorkerEvent(ev *streaming.Event) error {

// Compute the worker ID that will handle the job.
key := unmarshalJobKey(ev.Payload)
activeWorkers := node.activeWorkers()
if len(activeWorkers) == 0 {
return fmt.Errorf("routeWorkerEvent: no active worker in pool %q", node.PoolName)
wid, err := node.workerForEvent(ev.EventName, key)
if err != nil {
if errors.Is(err, errJobAwaitingOwner) {
node.logger.Debug("routeWorkerEvent: job has no active owner yet", "event", ev.EventName, "id", ev.ID, "key", key)
return nil
}
if errors.Is(err, errJobNotFound) {
if ackErr := node.poolSink.Ack(context.Background(), ev); ackErr != nil {
node.logger.Error(fmt.Errorf("routeWorkerEvent: failed to ack event for missing job: %w", ackErr), "event", ev.EventName, "id", ev.ID)
}
return nil
}
return err
}
wid := activeWorkers[node.h.Hash(key, int64(len(activeWorkers)))]

// Stream the event to the worker corresponding to the key hash.
stream, err := node.getWorkerStream(wid)
Expand Down Expand Up @@ -824,6 +838,82 @@ func (node *Node) returnDispatchStatus(ev *streaming.Event) {
val.(chan error) <- err
}

// workerForEvent returns the worker that should receive a pool event. Start
// events are routed by the current consistent hash ring; stop and notification
// events target the worker that currently owns the job.
func (node *Node) workerForEvent(eventName, key string) (string, error) {
if eventName == evStartJob {
activeWorkers := node.activeWorkers()
if len(activeWorkers) == 0 {
return "", fmt.Errorf("routeWorkerEvent: no active worker in pool %q", node.PoolName)
}
return activeWorkers[node.h.Hash(key, int64(len(activeWorkers)))], nil
}
if eventName == evStopJob || eventName == evNotify {
owner, ok, err := node.activeJobOwner(key)
if err != nil {
return "", err
}
if !ok {
exists, err := node.jobPayloadExists(context.Background(), key)
if err != nil {
return "", err
}
if exists {
return "", fmt.Errorf("%w: %q", errJobAwaitingOwner, key)
}
return "", fmt.Errorf("%w: %q", errJobNotFound, key)
}
return owner, nil
}
return "", fmt.Errorf("routeWorkerEvent: unknown worker event %q", eventName)
}

// jobPayloadExists reads the durable job record from Redis, which is the source
// of truth when the local ownership map has no active owner during handoff.
func (node *Node) jobPayloadExists(ctx context.Context, key string) (bool, error) {
exists, err := node.rdb.HExists(ctx, rmapContentKey(jobPayloadMapName(node.PoolName)), key).Result()
if err != nil {
return false, fmt.Errorf("routeWorkerEvent: failed to check job payload %q: %w", key, err)
}
return exists, nil
}

// activeJobOwner returns the single active worker that owns a job key according
// to the replicated ownership map.
func (node *Node) activeJobOwner(key string) (string, bool, error) {
activeWorkers := node.activeWorkers()
active := make(map[string]struct{}, len(activeWorkers))
for _, workerID := range activeWorkers {
active[workerID] = struct{}{}
}

var owner string
for workerID := range node.jobMap.Map() {
if _, ok := active[workerID]; !ok {
continue
}
keys, ok := node.jobMap.GetValues(workerID)
if !ok {
continue
}
for _, ownedKey := range keys {
if ownedKey != key {
continue
}
if owner != "" {
return "", false, fmt.Errorf("routeWorkerEvent: job %q has multiple active owners", key)
}
owner = workerID
break
}
}
if owner == "" {
return "", false, nil
}
return owner, true, nil
}

// watches monitors the workers replicated map and triggers job rebalancing
// when workers are added or removed from the pool.
func (node *Node) watchWorkers(ctx context.Context) {
Expand Down Expand Up @@ -1067,7 +1157,7 @@ func (node *Node) requeueOrphanedPayloads(ctx context.Context) {
node.orphanedPayloads.Delete(key)
continue
}
job := &Job{Key: key, Payload: payload, CreatedAt: now, NodeID: node.ID}
job := &Job{Key: key, Payload: payload, CreatedAt: now, NodeID: node.ID, Requeued: true}
if _, err := node.poolStream.Add(ctx, evStartJob, marshalJob(job)); err != nil {
node.logger.Error(fmt.Errorf("requeueOrphanedPayloads: failed to requeue orphaned job: %w", err), "key", key)
continue
Expand Down Expand Up @@ -1117,7 +1207,7 @@ func (node *Node) cleanupWorker(ctx context.Context, workerID string) {
processed++
continue
}
job := &Job{Key: key, Payload: payload, CreatedAt: time.Now(), NodeID: node.ID}
job := &Job{Key: key, Payload: payload, CreatedAt: time.Now(), NodeID: node.ID, Requeued: true}
// Requeue by adding an event back to the pool stream.
// We intentionally do not wait for the job to start (which can time out
// under heavy churn) - the pool sink will retry routing until it is acked.
Expand Down
75 changes: 75 additions & 0 deletions pool/node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,81 @@ func TestDispatchJobTwoWorkers(t *testing.T) {
assert.NoError(t, node.Shutdown(ctx), "Failed to shutdown node")
}

func TestStopJobRoutesToCurrentOwner(t *testing.T) {
testName := strings.Replace(t.Name(), "/", "_", -1)
ctx := ptesting.NewTestContext(t)
rdb := ptesting.NewRedisClient(t)
node := newTestNode(t, ctx, rdb, testName)
defer ptesting.CleanupRedis(t, rdb, true, testName)

node.h = &ptesting.Hasher{Index: 0}
worker1 := newTestWorker(t, ctx, node)
worker2 := newTestWorker(t, ctx, node)

stopped := make(chan struct{})
var stoppedOnce sync.Once
worker1.handler.(*mockHandler).stopFunc = func(key string) error {
stoppedOnce.Do(func() { close(stopped) })
return nil
}
worker2.handler.(*mockHandler).stopFunc = func(key string) error {
t.Errorf("stop routed to worker without ownership: %s", key)
return nil
}

payload := []byte("payload")
require.NoError(t, node.DispatchJob(ctx, testName, payload))
require.Eventually(t, func() bool {
return len(worker1.Jobs()) == 1
}, max, delay)
require.Eventually(t, func() bool {
return sameStrings(jobOwners(node, testName), []string{worker1.ID})
}, max, delay)

node.h = &ptesting.Hasher{Index: 1}
require.NoError(t, node.StopJob(ctx, testName))
select {
case <-stopped:
case <-time.After(max):
t.Fatal("job stop was not routed to current owner")
}
require.Eventually(t, func() bool {
return len(worker1.Jobs()) == 0 && len(worker2.Jobs()) == 0
}, max, delay)
require.Eventually(t, func() bool {
_, ok := node.JobPayload(testName)
return !ok && len(jobOwners(node, testName)) == 0
}, max, delay)

assert.NoError(t, node.Shutdown(ctx), "Failed to shutdown node")
}

func TestControlEventRoutingDuringOwnershipGaps(t *testing.T) {
testName := strings.Replace(t.Name(), "/", "_", -1)
ctx := ptesting.NewTestContext(t)
rdb := ptesting.NewRedisClient(t)
node := newTestNode(t, ctx, rdb, testName)
defer ptesting.CleanupRedis(t, rdb, true, testName)

const jobKey = "handoff-job"
_, err := node.jobPayloadMap.SetAndWait(ctx, jobKey, "payload")
require.NoError(t, err)

_, err = node.workerForEvent(evStopJob, jobKey)
assert.ErrorIs(t, err, errJobAwaitingOwner)

_, err = node.workerForEvent(evNotify, jobKey)
assert.ErrorIs(t, err, errJobAwaitingOwner)

_, err = node.jobPayloadMap.Delete(ctx, jobKey)
require.NoError(t, err)

_, err = node.workerForEvent(evStopJob, jobKey)
assert.ErrorIs(t, err, errJobNotFound)

assert.NoError(t, node.Shutdown(ctx), "Failed to shutdown node")
}

func TestDispatchJobRaceCondition(t *testing.T) {
testName := strings.Replace(t.Name(), "/", "_", -1)
ctx := ptesting.NewTestContext(t)
Expand Down
Loading
Loading