diff --git a/pkg/gateway/services/stub.go b/pkg/gateway/services/stub.go index dbbcc6261..58f9461d8 100644 --- a/pkg/gateway/services/stub.go +++ b/pkg/gateway/services/stub.go @@ -7,9 +7,12 @@ import ( "fmt" "math" "slices" + "strconv" "strings" "time" + "github.com/rs/zerolog/log" + "github.com/beam-cloud/beta9/pkg/abstractions/endpoint" "github.com/beam-cloud/beta9/pkg/abstractions/function" "github.com/beam-cloud/beta9/pkg/abstractions/taskqueue" @@ -82,6 +85,15 @@ func (gws *GatewayService) GetOrCreateStub(ctx context.Context, in *pb.GetOrCrea }, nil } + for _, gpu := range gpus { + if !gws.multiGPUAvailable(gpu, in.GpuCount) { + return &pb.GetOrCreateStubResponse{ + Ok: false, + ErrMsg: fmt.Sprintf("Multi-GPU is not currently available for %s.", gpu.String()), + }, nil + } + } + stubConfig := types.StubConfigV1{ Runtime: types.Runtime{ Cpu: in.Cpu, @@ -442,3 +454,19 @@ func (gws *GatewayService) getLowCapacityGpus(gpus []types.GpuType) ([]string, e } return lowGpus, nil } + +func (gws *GatewayService) multiGPUAvailable(gpu types.GpuType, reqGpuCount uint32) bool { + for _, poolConfig := range gws.appConfig.Worker.Pools { + if poolConfig.GPUType == gpu.String() { + count, err := strconv.Atoi(poolConfig.PoolSizing.DefaultWorkerGpuCount) + if err != nil { + log.Warn().Msgf("Failed to parse default worker GPU count for %s: %v", gpu.String(), err) + continue + } + if count >= int(reqGpuCount) { + return true + } + } + } + return false +}