Skip to content
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1368,6 +1368,9 @@ public DataflowPipelineJob run(Pipeline pipeline) {
options.getStager().stageToFile(serializedProtoPipeline, PIPELINE_FILE_NAME);
dataflowOptions.setPipelineUrl(stagedPipeline.getLocation());

String pipelineProtoHash = Hashing.sha256().hashBytes(serializedProtoPipeline).toString();
dataflowOptions.setPipelineProtoHash(pipelineProtoHash);

if (useUnifiedWorker(options)) {
LOG.info("Skipping v1 transform replacements since job will run on v2.");
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,12 @@ public interface DataflowPipelineOptions

void setPipelineUrl(String urlString);

/** The hex-encoded SHA256 hash of the staged portable pipeline proto. */
@Description("The hex-encoded SHA256 hash of the staged portable pipeline proto")
String getPipelineProtoHash();

void setPipelineProtoHash(String hash);

@Description("The customized dataflow worker jar")
String getDataflowWorkerJar();

Expand Down
8 changes: 7 additions & 1 deletion sdks/go/pkg/beam/runners/dataflow/dataflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ package dataflow

import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"flag"
"fmt"
Expand All @@ -40,6 +42,7 @@ import (
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/graphx"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/pipelinex"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/hooks"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/protox"
"github.com/apache/beam/sdks/v2/go/pkg/beam/internal/errors"
"github.com/apache/beam/sdks/v2/go/pkg/beam/log"
"github.com/apache/beam/sdks/v2/go/pkg/beam/options/gcpopts"
Expand Down Expand Up @@ -239,7 +242,10 @@ func Execute(ctx context.Context, p *beam.Pipeline) (beam.PipelineResult, error)
log.Info(ctx, "Dry-run: not submitting job!")

log.Info(ctx, model.String())
job, err := dataflowlib.Translate(ctx, model, opts, workerURL, modelURL)
modelBytes := protox.MustEncode(model)
hash := sha256.Sum256(modelBytes)
pipelineProtoHash := hex.EncodeToString(hash[:])
job, err := dataflowlib.Translate(ctx, model, opts, workerURL, modelURL, pipelineProtoHash)
if err != nil {
return nil, err
}
Expand Down
9 changes: 7 additions & 2 deletions sdks/go/pkg/beam/runners/dataflow/dataflowlib/execute.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package dataflowlib

import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"os"
"strings"
Expand Down Expand Up @@ -83,14 +85,17 @@ func Execute(ctx context.Context, raw *pipepb.Pipeline, opts *JobOptions, worker
// (2) Upload model to GCS
log.Info(ctx, raw.String())

if err := StageModel(ctx, opts.Project, modelURL, protox.MustEncode(raw)); err != nil {
modelBytes := protox.MustEncode(raw)
modelHash := sha256.Sum256(modelBytes)
pipelineProtoHash := hex.EncodeToString(modelHash[:])
if err := StageModel(ctx, opts.Project, modelURL, modelBytes); err != nil {
return presult, err
}
log.Infof(ctx, "Staged model pipeline: %v", modelURL)

// (3) Translate to v1b3 and submit

job, err := Translate(ctx, raw, opts, workerURL, modelURL)
job, err := Translate(ctx, raw, opts, workerURL, modelURL, pipelineProtoHash)
if err != nil {
return presult, err
}
Expand Down
12 changes: 7 additions & 5 deletions sdks/go/pkg/beam/runners/dataflow/dataflowlib/job.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ func containerImages(p *pipepb.Pipeline) ([]*df.SdkHarnessContainerImage, []stri
}

// Translate translates a pipeline to a Dataflow job.
func Translate(ctx context.Context, p *pipepb.Pipeline, opts *JobOptions, workerURL, modelURL string) (*df.Job, error) {
func Translate(ctx context.Context, p *pipepb.Pipeline, opts *JobOptions, workerURL, modelURL string, pipelineProtoHash string) (*df.Job, error) {
// (1) Translate pipeline to v1b3 speak.

jobType := "JOB_TYPE_BATCH"
Expand Down Expand Up @@ -181,10 +181,11 @@ func Translate(ctx context.Context, p *pipepb.Pipeline, opts *JobOptions, worker
SdkPipelineOptions: newMsg(pipelineOptions{
DisplayData: printOptions(opts, images),
Options: dataflowOptions{
PipelineURL: modelURL,
Region: opts.Region,
Experiments: opts.Experiments,
TempLocation: opts.TempLocation,
PipelineURL: modelURL,
PipelineProtoHash: pipelineProtoHash,
Region: opts.Region,
Experiments: opts.Experiments,
TempLocation: opts.TempLocation,
},
GoOptions: opts.Options,
}),
Expand Down Expand Up @@ -359,6 +360,7 @@ func GetMetrics(ctx context.Context, client *df.Service, project, region, jobID
type dataflowOptions struct {
Experiments []string `json:"experiments,omitempty"`
PipelineURL string `json:"pipelineUrl"`
PipelineProtoHash string `json:"pipelineProtoHash,omitempty"`
Region string `json:"region"`
TempLocation string `json:"tempLocation"`
DiskProvisionedIops int64 `json:"diskProvisionedIops"`
Expand Down
51 changes: 50 additions & 1 deletion sdks/go/pkg/beam/runners/dataflow/dataflowlib/job_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ import (
"reflect"
"testing"

"encoding/json"

"github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/graphx"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/protox"
pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1"
Expand Down Expand Up @@ -293,7 +296,7 @@ func TestTranslate(t *testing.T) {
workerURL := "gs://any-location/temp"
modelURL := "gs://any-location/temp"

job, err := Translate(ctx, p, opts, workerURL, modelURL)
job, err := Translate(ctx, p, opts, workerURL, modelURL, "dummy-hash-12345")
if err != nil {
t.Fatalf("Translate(...) error = %v, want nil", err)
}
Expand All @@ -310,3 +313,49 @@ func TestTranslate(t *testing.T) {
t.Errorf("DiskProvisionedThroughputMibps = %v, want 200", wp.DiskProvisionedThroughputMibps)
}
}

func TestTranslateWithPipelineHash(t *testing.T) {
p := &pipepb.Pipeline{
Components: &pipepb.Components{
Environments: map[string]*pipepb.Environment{
"env1": {
Payload: protox.MustEncode(&pipepb.DockerPayload{
ContainerImage: "dummy_image",
}),
},
},
},
}
opts := &JobOptions{
Name: "test-job",
Project: "test-project",
Region: "test-region",
Options: runtime.RawOptions{
Options: make(map[string]string),
},
}

expectedHashStr := "dummy-hash-12345"

job, err := Translate(context.Background(), p, opts, "worker-url", "model-url", expectedHashStr)
if err != nil {
t.Fatalf("Translate failed: %v", err)
}

// Verify PipelineProtoHash
var recoveredOptions struct {
Options struct {
PipelineURL string `json:"pipelineUrl"`
PipelineProtoHash string `json:"pipelineProtoHash"`
} `json:"options"`
}

rawOpts := job.Environment.SdkPipelineOptions
if err := json.Unmarshal(rawOpts, &recoveredOptions); err != nil {
t.Fatalf("Failed to unmarshal SdkPipelineOptions: %v", err)
}

if recoveredOptions.Options.PipelineProtoHash != expectedHashStr {
t.Errorf("Expected PipelineProtoHash %v, got %v", expectedHashStr, recoveredOptions.Options.PipelineProtoHash)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,10 @@ def __init__(
for k, v in sdk_pipeline_options.items() if v is not None
}
options_dict["pipelineUrl"] = proto_pipeline_staged_url
if self._proto_pipeline:
serialized_pipeline = self._proto_pipeline.SerializeToString()
options_dict["pipelineProtoHash"] = hashlib.sha256(
serialized_pipeline).hexdigest()
Comment thread
tarun-google marked this conversation as resolved.
Outdated
# Don't pass impersonate_service_account through to the harness.
# Though impersonation should start a job, the workers should
# not try to modify their credentials.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,39 @@ def test_pipeline_url(self):

self.assertEqual(pipeline_url.string_value, FAKE_PIPELINE_URL)

def test_pipeline_proto_hash(self):
pipeline_options = PipelineOptions(
['--temp_location', 'gs://any-location/temp'])
proto_pipeline = beam_runner_api_pb2.Pipeline()
proto_pipeline.components.transforms['dummy'].unique_name = 'dummy'

env = apiclient.Environment([],
pipeline_options,
'2.0.0',
FAKE_PIPELINE_URL,
proto_pipeline)

recovered_options = None
for additionalProperty in env.proto.sdkPipelineOptions.additionalProperties:
if additionalProperty.key == 'options':
recovered_options = additionalProperty.value
break
else:
self.fail('No pipeline options found')

pipeline_proto_hash = None
for property in recovered_options.object_value.properties:
if property.key == 'pipelineProtoHash':
pipeline_proto_hash = property.value
break
else:
self.fail('No pipelineProtoHash found')

import hashlib
expected_hash = hashlib.sha256(
proto_pipeline.SerializeToString()).hexdigest()
self.assertEqual(pipeline_proto_hash.string_value, expected_hash)

def test_set_network(self):
pipeline_options = PipelineOptions([
'--network',
Expand Down
Loading