Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1368,6 +1368,9 @@ public DataflowPipelineJob run(Pipeline pipeline) {
options.getStager().stageToFile(serializedProtoPipeline, PIPELINE_FILE_NAME);
dataflowOptions.setPipelineUrl(stagedPipeline.getLocation());

String pipelineProtoHash = Hashing.sha256().hashBytes(serializedProtoPipeline).toString();
dataflowOptions.setPipelineProtoHash(pipelineProtoHash);

if (useUnifiedWorker(options)) {
LOG.info("Skipping v1 transform replacements since job will run on v2.");
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,12 @@ public interface DataflowPipelineOptions

void setPipelineUrl(String urlString);

/** The hex-encoded SHA256 hash of the staged portable pipeline proto. */
@Description("The hex-encoded SHA256 hash of the staged portable pipeline proto")
String getPipelineProtoHash();

void setPipelineProtoHash(String hash);

@Description("The customized dataflow worker jar")
String getDataflowWorkerJar();

Expand Down
8 changes: 7 additions & 1 deletion sdks/go/pkg/beam/runners/dataflow/dataflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ package dataflow

import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"flag"
"fmt"
Expand All @@ -40,6 +42,7 @@ import (
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/graphx"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/pipelinex"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/hooks"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/protox"
"github.com/apache/beam/sdks/v2/go/pkg/beam/internal/errors"
"github.com/apache/beam/sdks/v2/go/pkg/beam/log"
"github.com/apache/beam/sdks/v2/go/pkg/beam/options/gcpopts"
Expand Down Expand Up @@ -239,7 +242,10 @@ func Execute(ctx context.Context, p *beam.Pipeline) (beam.PipelineResult, error)
log.Info(ctx, "Dry-run: not submitting job!")

log.Info(ctx, model.String())
job, err := dataflowlib.Translate(ctx, model, opts, workerURL, modelURL)
modelBytes := protox.MustEncode(model)
hash := sha256.Sum256(modelBytes)
pipelineProtoHash := hex.EncodeToString(hash[:])
job, err := dataflowlib.Translate(ctx, model, opts, workerURL, modelURL, pipelineProtoHash)
if err != nil {
return nil, err
}
Expand Down
9 changes: 7 additions & 2 deletions sdks/go/pkg/beam/runners/dataflow/dataflowlib/execute.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package dataflowlib

import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"os"
"strings"
Expand Down Expand Up @@ -83,14 +85,17 @@ func Execute(ctx context.Context, raw *pipepb.Pipeline, opts *JobOptions, worker
// (2) Upload model to GCS
log.Info(ctx, raw.String())

if err := StageModel(ctx, opts.Project, modelURL, protox.MustEncode(raw)); err != nil {
modelBytes := protox.MustEncode(raw)
modelHash := sha256.Sum256(modelBytes)
pipelineProtoHash := hex.EncodeToString(modelHash[:])
if err := StageModel(ctx, opts.Project, modelURL, modelBytes); err != nil {
return presult, err
}
log.Infof(ctx, "Staged model pipeline: %v", modelURL)

// (3) Translate to v1b3 and submit

job, err := Translate(ctx, raw, opts, workerURL, modelURL)
job, err := Translate(ctx, raw, opts, workerURL, modelURL, pipelineProtoHash)
if err != nil {
return presult, err
}
Expand Down
12 changes: 7 additions & 5 deletions sdks/go/pkg/beam/runners/dataflow/dataflowlib/job.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ func containerImages(p *pipepb.Pipeline) ([]*df.SdkHarnessContainerImage, []stri
}

// Translate translates a pipeline to a Dataflow job.
func Translate(ctx context.Context, p *pipepb.Pipeline, opts *JobOptions, workerURL, modelURL string) (*df.Job, error) {
func Translate(ctx context.Context, p *pipepb.Pipeline, opts *JobOptions, workerURL, modelURL string, pipelineProtoHash string) (*df.Job, error) {
// (1) Translate pipeline to v1b3 speak.

jobType := "JOB_TYPE_BATCH"
Expand Down Expand Up @@ -181,10 +181,11 @@ func Translate(ctx context.Context, p *pipepb.Pipeline, opts *JobOptions, worker
SdkPipelineOptions: newMsg(pipelineOptions{
DisplayData: printOptions(opts, images),
Options: dataflowOptions{
PipelineURL: modelURL,
Region: opts.Region,
Experiments: opts.Experiments,
TempLocation: opts.TempLocation,
PipelineURL: modelURL,
PipelineProtoHash: pipelineProtoHash,
Region: opts.Region,
Experiments: opts.Experiments,
TempLocation: opts.TempLocation,
},
GoOptions: opts.Options,
}),
Expand Down Expand Up @@ -359,6 +360,7 @@ func GetMetrics(ctx context.Context, client *df.Service, project, region, jobID
type dataflowOptions struct {
Experiments []string `json:"experiments,omitempty"`
PipelineURL string `json:"pipelineUrl"`
PipelineProtoHash string `json:"pipelineProtoHash,omitempty"`
Region string `json:"region"`
TempLocation string `json:"tempLocation"`
DiskProvisionedIops int64 `json:"diskProvisionedIops"`
Expand Down
51 changes: 50 additions & 1 deletion sdks/go/pkg/beam/runners/dataflow/dataflowlib/job_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ import (
"reflect"
"testing"

"encoding/json"

"github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/graphx"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/protox"
pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1"
Expand Down Expand Up @@ -293,7 +296,7 @@ func TestTranslate(t *testing.T) {
workerURL := "gs://any-location/temp"
modelURL := "gs://any-location/temp"

job, err := Translate(ctx, p, opts, workerURL, modelURL)
job, err := Translate(ctx, p, opts, workerURL, modelURL, "dummy-hash-12345")
if err != nil {
t.Fatalf("Translate(...) error = %v, want nil", err)
}
Expand All @@ -310,3 +313,49 @@ func TestTranslate(t *testing.T) {
t.Errorf("DiskProvisionedThroughputMibps = %v, want 200", wp.DiskProvisionedThroughputMibps)
}
}

func TestTranslateWithPipelineHash(t *testing.T) {
p := &pipepb.Pipeline{
Components: &pipepb.Components{
Environments: map[string]*pipepb.Environment{
"env1": {
Payload: protox.MustEncode(&pipepb.DockerPayload{
ContainerImage: "dummy_image",
}),
},
},
},
}
opts := &JobOptions{
Name: "test-job",
Project: "test-project",
Region: "test-region",
Options: runtime.RawOptions{
Options: make(map[string]string),
},
}

expectedHashStr := "dummy-hash-12345"

job, err := Translate(context.Background(), p, opts, "worker-url", "model-url", expectedHashStr)
if err != nil {
t.Fatalf("Translate failed: %v", err)
}

// Verify PipelineProtoHash
var recoveredOptions struct {
Options struct {
PipelineURL string `json:"pipelineUrl"`
PipelineProtoHash string `json:"pipelineProtoHash"`
} `json:"options"`
}

rawOpts := job.Environment.SdkPipelineOptions
if err := json.Unmarshal(rawOpts, &recoveredOptions); err != nil {
t.Fatalf("Failed to unmarshal SdkPipelineOptions: %v", err)
}

if recoveredOptions.Options.PipelineProtoHash != expectedHashStr {
t.Errorf("Expected PipelineProtoHash %v, got %v", expectedHashStr, recoveredOptions.Options.PipelineProtoHash)
}
}
13 changes: 10 additions & 3 deletions sdks/python/apache_beam/runners/dataflow/internal/apiclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ def __init__(
options,
environment_version,
proto_pipeline_staged_url,
proto_pipeline=None):
proto_pipeline=None,
pipeline_proto_hash=None):
self.standard_options = options.view_as(StandardOptions)
self.google_cloud_options = options.view_as(GoogleCloudOptions)
self.worker_options = options.view_as(WorkerOptions)
Expand Down Expand Up @@ -279,6 +280,8 @@ def __init__(
for k, v in sdk_pipeline_options.items() if v is not None
}
options_dict["pipelineUrl"] = proto_pipeline_staged_url
if pipeline_proto_hash:
options_dict["pipelineProtoHash"] = pipeline_proto_hash
# Don't pass impersonate_service_account through to the harness.
# Though impersonation should start a job, the workers should
# not try to modify their credentials.
Expand Down Expand Up @@ -831,10 +834,13 @@ def create_job_description(self, job):
resources = self._stage_resources(job.proto_pipeline, job.options)

# Stage proto pipeline.
serialized_pipeline = job.proto_pipeline.SerializeToString()
pipeline_proto_hash = hashlib.sha256(serialized_pipeline).hexdigest()

self.stage_file_with_retry(
job.google_cloud_options.staging_location,
shared_names.STAGED_PIPELINE_FILENAME,
io.BytesIO(job.proto_pipeline.SerializeToString()))
io.BytesIO(serialized_pipeline))

job.proto.environment = Environment(
proto_pipeline_staged_url=FileSystems.join(
Expand All @@ -843,7 +849,8 @@ def create_job_description(self, job):
packages=resources,
options=job.options,
environment_version=self.environment_version,
proto_pipeline=job.proto_pipeline).proto
proto_pipeline=job.proto_pipeline,
pipeline_proto_hash=pipeline_proto_hash).proto
_LOGGER.debug('JOB: %s', job)

@retry.with_exponential_backoff(num_retries=3, initial_delay_secs=3)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

# pytype: skip-file

import hashlib
import io
import itertools
import json
Expand Down Expand Up @@ -97,6 +98,40 @@ def test_pipeline_url(self):

self.assertEqual(pipeline_url.string_value, FAKE_PIPELINE_URL)

def test_pipeline_proto_hash(self):
pipeline_options = PipelineOptions(
['--temp_location', 'gs://any-location/temp'])
proto_pipeline = beam_runner_api_pb2.Pipeline()
proto_pipeline.components.transforms['dummy'].unique_name = 'dummy'

expected_hash = hashlib.sha256(
proto_pipeline.SerializeToString()).hexdigest()

env = apiclient.Environment([],
pipeline_options,
'2.0.0',
FAKE_PIPELINE_URL,
proto_pipeline,
pipeline_proto_hash=expected_hash)

recovered_options = None
for additionalProperty in env.proto.sdkPipelineOptions.additionalProperties:
if additionalProperty.key == 'options':
recovered_options = additionalProperty.value
break
else:
self.fail('No pipeline options found')

pipeline_proto_hash = None
for property in recovered_options.object_value.properties:
if property.key == 'pipelineProtoHash':
pipeline_proto_hash = property.value
break
else:
self.fail('No pipelineProtoHash found')

self.assertEqual(pipeline_proto_hash.string_value, expected_hash)

def test_set_network(self):
pipeline_options = PipelineOptions([
'--network',
Expand Down
Loading