diff --git a/operator/charts/templates/clusterrole.yaml b/operator/charts/templates/clusterrole.yaml index 8ed60a40d..2c7348f4d 100644 --- a/operator/charts/templates/clusterrole.yaml +++ b/operator/charts/templates/clusterrole.yaml @@ -142,6 +142,18 @@ rules: - patch - update - delete +- apiGroups: + - scheduling.run.ai + resources: + - podgroups + verbs: + - create + - get + - list + - watch + - patch + - update + - delete {{- if .Values.config.network.autoMNNVLEnabled }} # MNNVL (Multi-Node NVLink) support requires permissions for ComputeDomain and ResourceClaimTemplate resources. # Note: Kubernetes allows RBAC rules for resources that don't exist yet. If the ComputeDomain CRD is not installed, diff --git a/operator/internal/client/scheme.go b/operator/internal/client/scheme.go index 8b43467d8..78e6fcd38 100644 --- a/operator/internal/client/scheme.go +++ b/operator/internal/client/scheme.go @@ -21,6 +21,7 @@ import ( grovecorev1alpha1 "github.com/ai-dynamo/grove/operator/api/core/v1alpha1" kaitopologyv1alpha1 "github.com/NVIDIA/KAI-scheduler/pkg/apis/kai/v1alpha1" + kaischedulingv2alpha2 "github.com/NVIDIA/KAI-scheduler/pkg/apis/scheduling/v2alpha2" schedv1alpha1 "github.com/ai-dynamo/grove/scheduler/api/core/v1alpha1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" @@ -37,6 +38,7 @@ func init() { grovecorev1alpha1.AddToScheme, schedv1alpha1.AddToScheme, kaitopologyv1alpha1.AddToScheme, + kaischedulingv2alpha2.AddToScheme, k8sscheme.AddToScheme, ) utilruntime.Must(metav1.AddMetaToScheme(Scheme)) diff --git a/operator/internal/scheduler/kai/backend.go b/operator/internal/scheduler/kai/backend.go index da8c83bb0..d28aefa3a 100644 --- a/operator/internal/scheduler/kai/backend.go +++ b/operator/internal/scheduler/kai/backend.go @@ -18,20 +18,27 @@ package kai import ( "context" + "fmt" + "reflect" + apicommonconstants "github.com/ai-dynamo/grove/operator/api/common/constants" configv1alpha1 "github.com/ai-dynamo/grove/operator/api/config/v1alpha1" grovecorev1alpha1 "github.com/ai-dynamo/grove/operator/api/core/v1alpha1" "github.com/ai-dynamo/grove/operator/internal/scheduler" + kaischedulingv2alpha2 "github.com/NVIDIA/KAI-scheduler/pkg/apis/scheduling/v2alpha2" groveschedulerv1alpha1 "github.com/ai-dynamo/grove/scheduler/api/core/v1alpha1" corev1 "k8s.io/api/core/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" "k8s.io/client-go/tools/record" + "k8s.io/utils/ptr" "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" ) // schedulerBackend implements the scheduler Backend interface (Backend in scheduler package) for KAI scheduler. -// TODO: Converts PodGang → PodGroup type schedulerBackend struct { client client.Client scheme *runtime.Scheme @@ -42,6 +49,13 @@ type schedulerBackend struct { var _ scheduler.Backend = (*schedulerBackend)(nil) +const ( + labelKeyQueueName = "kai.scheduler/queue" + labelKeyNodePoolName = "kai.scheduler/node-pool" + annotationKeyIgnoreGrove = "grove.io/ignore" + annotationValIgnoreGrove = "true" +) + // New creates a new KAI backend instance. profile is the scheduler profile for kai-scheduler; // schedulerBackend uses profile.Name and may unmarshal profile.Config for kai-specific options. func New(cl client.Client, scheme *runtime.Scheme, eventRecorder record.EventRecorder, profile configv1alpha1.SchedulerProfile) scheduler.Backend { @@ -65,13 +79,47 @@ func (b *schedulerBackend) Init() error { } // SyncPodGang converts PodGang to KAI PodGroup and synchronizes it -func (b *schedulerBackend) SyncPodGang(_ context.Context, _ *groveschedulerv1alpha1.PodGang) error { - return nil +func (b *schedulerBackend) SyncPodGang(ctx context.Context, podGang *groveschedulerv1alpha1.PodGang) error { + if podGang == nil { + return fmt.Errorf("podGang is nil") + } + if err := b.ensurePodGangIgnoredByGrovePlugin(ctx, podGang); err != nil { + return err + } + + newPodGroup, err := b.buildPodGroupForPodGang(podGang) + if err != nil { + return err + } + + oldPodGroup := &kaischedulingv2alpha2.PodGroup{} + key := client.ObjectKeyFromObject(newPodGroup) + if err = b.client.Get(ctx, key, oldPodGroup); err != nil { + if apierrors.IsNotFound(err) { + return b.client.Create(ctx, newPodGroup) + } + return err + } + + newPodGroup = b.inheritRuntimeManagedFields(oldPodGroup, newPodGroup) + if podGroupsEqual(oldPodGroup, newPodGroup) { + return nil + } + updatePodGroup(oldPodGroup, newPodGroup) + return b.client.Update(ctx, oldPodGroup) } // OnPodGangDelete removes the PodGroup owned by this PodGang -func (b *schedulerBackend) OnPodGangDelete(_ context.Context, _ *groveschedulerv1alpha1.PodGang) error { - return nil +func (b *schedulerBackend) OnPodGangDelete(ctx context.Context, podGang *groveschedulerv1alpha1.PodGang) error { + if podGang == nil { + return nil + } + return client.IgnoreNotFound(b.client.Delete(ctx, &kaischedulingv2alpha2.PodGroup{ + ObjectMeta: metav1.ObjectMeta{ + Name: podGang.Name, + Namespace: podGang.Namespace, + }, + })) } // PreparePod adds KAI scheduler-specific configuration to the Pod. @@ -84,3 +132,198 @@ func (b *schedulerBackend) PreparePod(pod *corev1.Pod) { func (b *schedulerBackend) ValidatePodCliqueSet(_ context.Context, _ *grovecorev1alpha1.PodCliqueSet) error { return nil } + +// ensurePodGangIgnoredByGrovePlugin marks PodGang so legacy Grove podgrouper ignores it. +func (b *schedulerBackend) ensurePodGangIgnoredByGrovePlugin(ctx context.Context, podGang *groveschedulerv1alpha1.PodGang) error { + if podGang.Annotations != nil && podGang.Annotations[annotationKeyIgnoreGrove] == annotationValIgnoreGrove { + return nil + } + patchBase := podGang.DeepCopy() + if podGang.Annotations == nil { + podGang.Annotations = map[string]string{} + } + podGang.Annotations[annotationKeyIgnoreGrove] = annotationValIgnoreGrove + return b.client.Patch(ctx, podGang, client.MergeFrom(patchBase)) +} + +// buildPodGroupForPodGang translates a Grove PodGang into a KAI PodGroup object. +func (b *schedulerBackend) buildPodGroupForPodGang(podGang *groveschedulerv1alpha1.PodGang) (*kaischedulingv2alpha2.PodGroup, error) { + topologyName := getTopologyName(podGang) + topologyConstraint, err := toKAITopologyConstraint(podGang.Spec.TopologyConstraint, topologyName) + if err != nil { + return nil, err + } + + parentBySubGroupName := map[string]string{} + subGroups := make([]kaischedulingv2alpha2.SubGroup, 0, len(podGang.Spec.TopologyConstraintGroupConfigs)+len(podGang.Spec.PodGroups)) + + for _, groupConfig := range podGang.Spec.TopologyConstraintGroupConfigs { + groupTopologyConstraint, groupErr := toKAITopologyConstraint(groupConfig.TopologyConstraint, topologyName) + if groupErr != nil { + return nil, groupErr + } + subGroups = append(subGroups, kaischedulingv2alpha2.SubGroup{ + Name: groupConfig.Name, + MinMember: 0, + TopologyConstraint: groupTopologyConstraint, + }) + for _, podGroupName := range groupConfig.PodGroupNames { + parentBySubGroupName[podGroupName] = groupConfig.Name + } + } + + var minMember int32 + for _, podGroup := range podGang.Spec.PodGroups { + subGroupTopologyConstraint, groupErr := toKAITopologyConstraint(podGroup.TopologyConstraint, topologyName) + if groupErr != nil { + return nil, groupErr + } + subGroup := kaischedulingv2alpha2.SubGroup{ + Name: podGroup.Name, + MinMember: podGroup.MinReplicas, + TopologyConstraint: subGroupTopologyConstraint, + } + if parentName, found := parentBySubGroupName[podGroup.Name]; found { + subGroup.Parent = ptr.To(parentName) + } + subGroups = append(subGroups, subGroup) + minMember += podGroup.MinReplicas + } + + result := &kaischedulingv2alpha2.PodGroup{ + ObjectMeta: metav1.ObjectMeta{ + Name: podGang.Name, + Namespace: podGang.Namespace, + Labels: cloneStringMap(podGang.Labels), + Annotations: cloneStringMap(podGang.Annotations), + }, + Spec: kaischedulingv2alpha2.PodGroupSpec{ + MinMember: minMember, + Queue: resolveQueueName(podGang), + PriorityClassName: podGang.Spec.PriorityClassName, + SubGroups: subGroups, + }, + } + if topologyConstraint != nil { + result.Spec.TopologyConstraint = *topologyConstraint + } + if err := controllerutil.SetControllerReference(podGang, result, b.scheme); err != nil { + return nil, err + } + return result, nil +} + +// getTopologyName resolves topology name from PodGang annotations with fallback keys. +func getTopologyName(podGang *groveschedulerv1alpha1.PodGang) string { + if podGang.Annotations == nil { + return "" + } + if topologyName := podGang.Annotations[apicommonconstants.AnnotationTopologyName]; topologyName != "" { + return topologyName + } + // Backward compatibility with KAI annotation key. + return podGang.Annotations["kai.scheduler/topology"] +} + +// toKAITopologyConstraint converts Grove topology constraint to KAI topology constraint. +func toKAITopologyConstraint(topologyConstraint *groveschedulerv1alpha1.TopologyConstraint, topologyName string) (*kaischedulingv2alpha2.TopologyConstraint, error) { + if topologyConstraint == nil || topologyConstraint.PackConstraint == nil { + return nil, nil + } + if topologyName == "" { + return nil, fmt.Errorf("topology name cannot be empty when topology constraints are defined") + } + result := &kaischedulingv2alpha2.TopologyConstraint{ + Topology: topologyName, + } + if topologyConstraint.PackConstraint.Preferred != nil { + result.PreferredTopologyLevel = *topologyConstraint.PackConstraint.Preferred + } + if topologyConstraint.PackConstraint.Required != nil { + result.RequiredTopologyLevel = *topologyConstraint.PackConstraint.Required + } + return result, nil +} + +// resolveQueueName returns queue from labels first, then falls back to annotations. +func resolveQueueName(podGang *groveschedulerv1alpha1.PodGang) string { + if podGang.Labels != nil && podGang.Labels[labelKeyQueueName] != "" { + return podGang.Labels[labelKeyQueueName] + } + if podGang.Annotations != nil { + return podGang.Annotations[labelKeyQueueName] + } + return "" +} + +// inheritRuntimeManagedFields preserves fields that are managed by KAI runtime components. +func (b *schedulerBackend) inheritRuntimeManagedFields(oldPodGroup, newPodGroup *kaischedulingv2alpha2.PodGroup) *kaischedulingv2alpha2.PodGroup { + newPodGroupCopy := newPodGroup.DeepCopy() + // These fields are managed by KAI components after initial creation. + newPodGroupCopy.Spec.MarkUnschedulable = oldPodGroup.Spec.MarkUnschedulable + newPodGroupCopy.Spec.SchedulingBackoff = oldPodGroup.Spec.SchedulingBackoff + newPodGroupCopy.Spec.Queue = oldPodGroup.Spec.Queue + + if newPodGroupCopy.Labels == nil { + newPodGroupCopy.Labels = map[string]string{} + } + if nodePoolName := oldPodGroup.Labels[labelKeyNodePoolName]; nodePoolName != "" { + newPodGroupCopy.Labels[labelKeyNodePoolName] = nodePoolName + } + if queueName := oldPodGroup.Labels[labelKeyQueueName]; queueName != "" { + newPodGroupCopy.Labels[labelKeyQueueName] = queueName + } + return newPodGroupCopy +} + +// podGroupsEqual compares spec plus source-owned metadata fields for update decisions. +func podGroupsEqual(oldPodGroup, newPodGroup *kaischedulingv2alpha2.PodGroup) bool { + return reflect.DeepEqual(oldPodGroup.Spec, newPodGroup.Spec) && + reflect.DeepEqual(oldPodGroup.OwnerReferences, newPodGroup.OwnerReferences) && + mapsEqualBySourceKeys(newPodGroup.Labels, oldPodGroup.Labels) && + mapsEqualBySourceKeys(newPodGroup.Annotations, oldPodGroup.Annotations) +} + +// mapsEqualBySourceKeys checks whether target contains all key-values from source. +func mapsEqualBySourceKeys(source, target map[string]string) bool { + if source != nil && target == nil { + return false + } + for key, sourceValue := range source { + if targetValue, exists := target[key]; !exists || targetValue != sourceValue { + return false + } + } + return true +} + +// updatePodGroup copies desired fields from newPodGroup into existing object. +func updatePodGroup(oldPodGroup, newPodGroup *kaischedulingv2alpha2.PodGroup) { + oldPodGroup.Annotations = copyStringMap(newPodGroup.Annotations, oldPodGroup.Annotations) + oldPodGroup.Labels = copyStringMap(newPodGroup.Labels, oldPodGroup.Labels) + oldPodGroup.Spec = newPodGroup.Spec + oldPodGroup.OwnerReferences = newPodGroup.OwnerReferences +} + +// copyStringMap copies all key-values from source into target map. +func copyStringMap(source, target map[string]string) map[string]string { + if source != nil && target == nil { + target = map[string]string{} + } + for k, v := range source { + target[k] = v + } + return target +} + +// cloneStringMap returns a shallow copy of the input string map. +func cloneStringMap(input map[string]string) map[string]string { + if input == nil { + return nil + } + cloned := make(map[string]string, len(input)) + for k, v := range input { + cloned[k] = v + } + return cloned +} diff --git a/operator/internal/scheduler/kai/backend_test.go b/operator/internal/scheduler/kai/backend_test.go index 9d05da489..77225e68d 100644 --- a/operator/internal/scheduler/kai/backend_test.go +++ b/operator/internal/scheduler/kai/backend_test.go @@ -17,13 +17,20 @@ package kai import ( + "context" "testing" configv1alpha1 "github.com/ai-dynamo/grove/operator/api/config/v1alpha1" testutils "github.com/ai-dynamo/grove/operator/test/utils" + kaischedulingv2alpha2 "github.com/NVIDIA/KAI-scheduler/pkg/apis/scheduling/v2alpha2" + groveschedulerv1alpha1 "github.com/ai-dynamo/grove/scheduler/api/core/v1alpha1" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/client-go/tools/record" + "k8s.io/utils/ptr" + "sigs.k8s.io/controller-runtime/pkg/client" ) func TestBackend_PreparePod(t *testing.T) { @@ -40,3 +47,117 @@ func TestBackend_PreparePod(t *testing.T) { assert.Equal(t, "kai-scheduler", pod.Spec.SchedulerName) } + +func TestBackend_SyncPodGang_CreateAndUpdate(t *testing.T) { + podGang := testutils.NewPodGangBuilder("test-podgang", "default"). + WithSchedulerName(string(configv1alpha1.SchedulerNameKai)). + Build() + podGang.Labels["kai.scheduler/queue"] = "team-a" + podGang.Annotations = map[string]string{"grove.io/topology-name": "cluster-topology"} + podGang.Spec.PriorityClassName = "high-priority" + podGang.Spec.TopologyConstraint = &groveschedulerv1alpha1.TopologyConstraint{ + PackConstraint: &groveschedulerv1alpha1.TopologyPackConstraint{ + Required: ptr.To("zone"), + }, + } + podGang.Spec.TopologyConstraintGroupConfigs = []groveschedulerv1alpha1.TopologyConstraintGroupConfig{ + { + Name: "decoder-group", + PodGroupNames: []string{"decoder"}, + TopologyConstraint: &groveschedulerv1alpha1.TopologyConstraint{ + PackConstraint: &groveschedulerv1alpha1.TopologyPackConstraint{ + Preferred: ptr.To("rack"), + }, + }, + }, + } + podGang.Spec.PodGroups = []groveschedulerv1alpha1.PodGroup{ + { + Name: "encoder", + MinReplicas: 2, + TopologyConstraint: &groveschedulerv1alpha1.TopologyConstraint{ + PackConstraint: &groveschedulerv1alpha1.TopologyPackConstraint{ + Required: ptr.To("host"), + }, + }, + }, + { + Name: "decoder", + MinReplicas: 3, + }, + } + + cl := testutils.NewTestClientBuilder(). + WithObjects(podGang). + Build() + recorder := record.NewFakeRecorder(10) + profile := configv1alpha1.SchedulerProfile{Name: configv1alpha1.SchedulerNameKai} + b := New(cl, cl.Scheme(), recorder, profile) + + ctx := context.Background() + require.NoError(t, b.SyncPodGang(ctx, podGang)) + + syncedPodGang := &groveschedulerv1alpha1.PodGang{} + require.NoError(t, cl.Get(ctx, client.ObjectKeyFromObject(podGang), syncedPodGang)) + assert.Equal(t, "true", syncedPodGang.Annotations["grove.io/ignore"]) + + gotPodGroup := &kaischedulingv2alpha2.PodGroup{} + require.NoError(t, cl.Get(ctx, client.ObjectKey{Name: podGang.Name, Namespace: podGang.Namespace}, gotPodGroup)) + + assert.Equal(t, int32(5), gotPodGroup.Spec.MinMember) + assert.Equal(t, "team-a", gotPodGroup.Spec.Queue) + assert.Equal(t, "high-priority", gotPodGroup.Spec.PriorityClassName) + assert.Equal(t, "zone", gotPodGroup.Spec.TopologyConstraint.RequiredTopologyLevel) + + require.Len(t, gotPodGroup.Spec.SubGroups, 3) + assert.Equal(t, "decoder-group", gotPodGroup.Spec.SubGroups[0].Name) + assert.Equal(t, int32(0), gotPodGroup.Spec.SubGroups[0].MinMember) + + assert.Equal(t, "encoder", gotPodGroup.Spec.SubGroups[1].Name) + assert.Equal(t, int32(2), gotPodGroup.Spec.SubGroups[1].MinMember) + assert.Nil(t, gotPodGroup.Spec.SubGroups[1].Parent) + assert.Equal(t, "host", gotPodGroup.Spec.SubGroups[1].TopologyConstraint.RequiredTopologyLevel) + + assert.Equal(t, "decoder", gotPodGroup.Spec.SubGroups[2].Name) + require.NotNil(t, gotPodGroup.Spec.SubGroups[2].Parent) + assert.Equal(t, "decoder-group", *gotPodGroup.Spec.SubGroups[2].Parent) + assert.Equal(t, int32(3), gotPodGroup.Spec.SubGroups[2].MinMember) + + // Update PodGang: remove queue label and change min replicas. + updatedPodGang := syncedPodGang.DeepCopy() + delete(updatedPodGang.Labels, "kai.scheduler/queue") + updatedPodGang.Spec.PodGroups[0].MinReplicas = 4 + require.NoError(t, cl.Update(ctx, updatedPodGang)) + + require.NoError(t, b.SyncPodGang(ctx, updatedPodGang)) + gotAfterUpdate := &kaischedulingv2alpha2.PodGroup{} + require.NoError(t, cl.Get(ctx, client.ObjectKey{Name: podGang.Name, Namespace: podGang.Namespace}, gotAfterUpdate)) + + // Existing queue should be preserved even when source label is removed. + assert.Equal(t, "team-a", gotAfterUpdate.Spec.Queue) + assert.Equal(t, int32(7), gotAfterUpdate.Spec.MinMember) + assert.Equal(t, int32(4), gotAfterUpdate.Spec.SubGroups[1].MinMember) +} + +func TestBackend_OnPodGangDelete(t *testing.T) { + podGang := testutils.NewPodGangBuilder("to-delete", "default"). + WithSchedulerName(string(configv1alpha1.SchedulerNameKai)). + Build() + podGroup := &kaischedulingv2alpha2.PodGroup{ + ObjectMeta: metav1.ObjectMeta{ + Name: "to-delete", + Namespace: "default", + }, + } + + cl := testutils.NewTestClientBuilder().WithObjects(podGang, podGroup).Build() + recorder := record.NewFakeRecorder(10) + profile := configv1alpha1.SchedulerProfile{Name: configv1alpha1.SchedulerNameKai} + b := New(cl, cl.Scheme(), recorder, profile) + + ctx := context.Background() + require.NoError(t, b.OnPodGangDelete(ctx, podGang)) + + err := cl.Get(ctx, client.ObjectKey{Name: podGroup.Name, Namespace: podGroup.Namespace}, &kaischedulingv2alpha2.PodGroup{}) + assert.Error(t, err) +}