Skip to content

Commit c295338

Browse files
authored
Add labels param to Google MLEngine Operators (#10222)
1 parent 8a655cf commit c295338

File tree

3 files changed

+21
-1
lines changed

3 files changed

+21
-1
lines changed

airflow/providers/google/cloud/example_dags/example_mlengine.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
package_uris=[TRAINER_URI],
7474
training_python_module=TRAINER_PY_MODULE,
7575
training_args=[],
76+
labels={"job_type": "training"},
7677
)
7778
# [END howto_operator_gcp_mlengine_training]
7879

@@ -169,6 +170,7 @@
169170
data_format="TEXT",
170171
input_paths=[PREDICTION_INPUT],
171172
output_path=PREDICTION_OUTPUT,
173+
labels={"job_type": "prediction"},
172174
)
173175
# [END howto_operator_gcp_mlengine_get_prediction]
174176

airflow/providers/google/cloud/operators/mlengine.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import logging
2222
import re
2323
import warnings
24-
from typing import List, Optional
24+
from typing import Dict, List, Optional
2525

2626
from airflow.exceptions import AirflowException
2727
from airflow.models import BaseOperator, BaseOperatorLink
@@ -151,6 +151,8 @@ class MLEngineStartBatchPredictionJobOperator(BaseOperator):
151151
For this to work, the service account making the request must
152152
have domain-wide delegation enabled.
153153
:type delegate_to: str
154+
:param labels: a dictionary containing labels for the job; passed to BigQuery
155+
:type labels: Dict[str, str]
154156
:raises: ``ValueError``: if a unique model/version origin cannot be
155157
determined.
156158
"""
@@ -183,6 +185,7 @@ def __init__(self, # pylint: disable=too-many-arguments
183185
project_id: Optional[str] = None,
184186
gcp_conn_id: str = 'google_cloud_default',
185187
delegate_to: Optional[str] = None,
188+
labels: Optional[Dict[str, str]] = None,
186189
**kwargs) -> None:
187190
super().__init__(**kwargs)
188191

@@ -200,6 +203,7 @@ def __init__(self, # pylint: disable=too-many-arguments
200203
self._signature_name = signature_name
201204
self._gcp_conn_id = gcp_conn_id
202205
self._delegate_to = delegate_to
206+
self._labels = labels
203207

204208
if not self._project_id:
205209
raise AirflowException('Google Cloud project id is required.')
@@ -234,6 +238,8 @@ def execute(self, context):
234238
'region': self._region
235239
}
236240
}
241+
if self._labels:
242+
prediction_request['labels'] = self._labels
237243

238244
if self._uri:
239245
prediction_request['predictionInput']['uri'] = self._uri
@@ -953,6 +959,8 @@ class MLEngineStartTrainingJobOperator(BaseOperator):
953959
will be printed out. In 'CLOUD' mode, a real MLEngine training job
954960
creation request will be issued.
955961
:type mode: str
962+
:param labels: a dictionary containing labels for the job; passed to BigQuery
963+
:type labels: Dict[str, str]
956964
"""
957965

958966
template_fields = [
@@ -990,6 +998,7 @@ def __init__(self, # pylint: disable=too-many-arguments
990998
gcp_conn_id: str = 'google_cloud_default',
991999
delegate_to: Optional[str] = None,
9921000
mode: str = 'PRODUCTION',
1001+
labels: Optional[Dict[str, str]] = None,
9931002
**kwargs) -> None:
9941003
super().__init__(**kwargs)
9951004
self._project_id = project_id
@@ -1006,6 +1015,7 @@ def __init__(self, # pylint: disable=too-many-arguments
10061015
self._gcp_conn_id = gcp_conn_id
10071016
self._delegate_to = delegate_to
10081017
self._mode = mode
1018+
self._labels = labels
10091019

10101020
if not self._project_id:
10111021
raise AirflowException('Google Cloud project id is required.')
@@ -1039,6 +1049,8 @@ def execute(self, context):
10391049
'args': self._training_args,
10401050
}
10411051
}
1052+
if self._labels:
1053+
training_request['labels'] = self._labels
10421054

10431055
if self._runtime_version:
10441056
training_request['trainingInput']['runtimeVersion'] = self._runtime_version

tests/providers/google/cloud/operators/test_mlengine.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ class TestMLEngineBatchPredictionOperator(unittest.TestCase):
6363
}
6464
SUCCESS_MESSAGE_MISSING_INPUT = {
6565
'jobId': 'test_prediction',
66+
'labels': {'some': 'labels'},
6667
'predictionOutput': {
6768
'outputPath': 'gs://fake-output-path',
6869
'predictionCount': 5000,
@@ -74,6 +75,7 @@ class TestMLEngineBatchPredictionOperator(unittest.TestCase):
7475
BATCH_PREDICTION_DEFAULT_ARGS = {
7576
'project_id': 'test-project',
7677
'job_id': 'test_prediction',
78+
'labels': {'some': 'labels'},
7779
'region': 'us-east1',
7880
'data_format': 'TEXT',
7981
'input_paths': ['gs://legal-bucket-dash-Capital/legal-input-path/*'],
@@ -116,6 +118,7 @@ def test_success_with_model(self, mock_hook):
116118
input_paths=input_with_model['inputPaths'],
117119
output_path=input_with_model['outputPath'],
118120
model_name=input_with_model['modelName'].split('/')[-1],
121+
labels={'some': 'labels'},
119122
dag=self.dag,
120123
task_id='test-prediction')
121124
prediction_output = prediction_task.execute(None)
@@ -125,6 +128,7 @@ def test_success_with_model(self, mock_hook):
125128
project_id='test-project',
126129
job={
127130
'jobId': 'test_prediction',
131+
'labels': {'some': 'labels'},
128132
'predictionInput': input_with_model
129133
},
130134
use_existing_job_fn=ANY
@@ -308,11 +312,13 @@ class TestMLEngineTrainingOperator(unittest.TestCase):
308312
'training_args': '--some_arg=\'aaa\'',
309313
'region': 'us-east1',
310314
'scale_tier': 'STANDARD_1',
315+
'labels': {'some': 'labels'},
311316
'task_id': 'test-training',
312317
'start_date': days_ago(1)
313318
}
314319
TRAINING_INPUT = {
315320
'jobId': 'test_training',
321+
'labels': {'some': 'labels'},
316322
'trainingInput': {
317323
'scaleTier': 'STANDARD_1',
318324
'packageUris': ['gs://some-bucket/package1'],

0 commit comments

Comments
 (0)