Skip to content

Commit 25ee421

Browse files
authored
Support all RuntimeEnvironment parameters in DataflowTemplatedJobStartOperator (#8531)
1 parent 520aeed commit 25ee421

File tree

4 files changed

+112
-120
lines changed

4 files changed

+112
-120
lines changed

airflow/providers/google/cloud/hooks/dataflow.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,13 @@ def start_template_dataflow(
532532
533533
:param job_name: The name of the job.
534534
:type job_name: str
535-
:param variables: Variables passed to the job.
535+
:param variables: Map of job runtime environment options.
536+
537+
.. seealso::
538+
For more information on possible configurations, look at the API documentation
539+
`https://cloud.google.com/dataflow/pipelines/specifying-exec-params
540+
<https://cloud.google.com/dataflow/docs/reference/rest/v1b3/RuntimeEnvironment>`__
541+
536542
:type variables: dict
537543
:param parameters: Parameters fot the template
538544
:type parameters: dict
@@ -548,23 +554,17 @@ def start_template_dataflow(
548554
:type location: str
549555
"""
550556
name = self._build_dataflow_job_name(job_name, append_job_name)
551-
# Builds RuntimeEnvironment from variables dictionary
552-
# https://cloud.google.com/dataflow/docs/reference/rest/v1b3/RuntimeEnvironment
553-
environment = {}
554-
for key in ['numWorkers', 'maxWorkers', 'zone', 'serviceAccountEmail',
555-
'tempLocation', 'bypassTempDirValidation', 'machineType',
556-
'additionalExperiments', 'network', 'subnetwork', 'additionalUserLabels']:
557-
if key in variables:
558-
environment.update({key: variables[key]})
559-
body = {"jobName": name,
560-
"parameters": parameters,
561-
"environment": environment}
557+
562558
service = self.get_conn()
563559
request = service.projects().locations().templates().launch( # pylint: disable=no-member
564560
projectId=project_id,
565561
location=location,
566562
gcsPath=dataflow_template,
567-
body=body
563+
body={
564+
"jobName": name,
565+
"parameters": parameters,
566+
"environment": variables
567+
}
568568
)
569569
response = request.execute(num_retries=self.num_retries)
570570

airflow/providers/google/cloud/operators/dataflow.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import re
2424
from contextlib import ExitStack
2525
from enum import Enum
26-
from typing import List, Optional
26+
from typing import Any, Dict, List, Optional
2727

2828
from airflow.models import BaseOperator
2929
from airflow.providers.google.cloud.hooks.dataflow import DEFAULT_DATAFLOW_LOCATION, DataflowHook
@@ -277,6 +277,14 @@ class DataflowTemplatedJobStartOperator(BaseOperator):
277277
:type template: str
278278
:param job_name: The 'jobName' to use when executing the DataFlow template
279279
(templated).
280+
:param options: Map of job runtime environment options.
281+
282+
.. seealso::
283+
For more information on possible configurations, look at the API documentation
284+
`https://cloud.google.com/dataflow/pipelines/specifying-exec-params
285+
<https://cloud.google.com/dataflow/docs/reference/rest/v1b3/RuntimeEnvironment>`__
286+
287+
:type options: dict
280288
:param dataflow_default_options: Map of default job environment options.
281289
:type dataflow_default_options: dict
282290
:param parameters: Map of job specific parameters for the template.
@@ -344,16 +352,25 @@ class DataflowTemplatedJobStartOperator(BaseOperator):
344352
For more detail on job template execution have a look at the reference:
345353
https://cloud.google.com/dataflow/docs/templates/executing-templates
346354
"""
347-
template_fields = ['parameters', 'dataflow_default_options', 'template', 'job_name']
355+
template_fields = [
356+
'template',
357+
'job_name',
358+
'options',
359+
'parameters',
360+
'project_id',
361+
'location',
362+
'gcp_conn_id'
363+
]
348364
ui_color = '#0273d4'
349365

350366
@apply_defaults
351-
def __init__(
367+
def __init__( # pylint: disable=too-many-arguments
352368
self,
353369
template: str,
354370
job_name: str = '{{task.task_id}}',
355-
dataflow_default_options: Optional[dict] = None,
356-
parameters: Optional[dict] = None,
371+
options: Optional[Dict[str, Any]] = None,
372+
dataflow_default_options: Optional[Dict[str, Any]] = None,
373+
parameters: Optional[Dict[str, str]] = None,
357374
project_id: Optional[str] = None,
358375
location: str = DEFAULT_DATAFLOW_LOCATION,
359376
gcp_conn_id: str = 'google_cloud_default',
@@ -362,14 +379,11 @@ def __init__(
362379
*args,
363380
**kwargs) -> None:
364381
super().__init__(*args, **kwargs)
365-
366-
dataflow_default_options = dataflow_default_options or {}
367-
parameters = parameters or {}
368-
369382
self.template = template
370383
self.job_name = job_name
371-
self.dataflow_default_options = dataflow_default_options
372-
self.parameters = parameters
384+
self.options = options or {}
385+
self.dataflow_default_options = dataflow_default_options or {}
386+
self.parameters = parameters or {}
373387
self.project_id = project_id
374388
self.location = location
375389
self.gcp_conn_id = gcp_conn_id
@@ -387,10 +401,12 @@ def execute(self, context):
387401

388402
def set_current_job_id(job_id):
389403
self.job_id = job_id
404+
options = self.dataflow_default_options
405+
options.update(self.options)
390406

391407
job = self.hook.start_template_dataflow(
392408
job_name=self.job_name,
393-
variables=self.dataflow_default_options,
409+
variables=options,
394410
parameters=self.parameters,
395411
dataflow_template=self.template,
396412
on_new_job_id_callback=set_current_job_id,

0 commit comments

Comments
 (0)