21
21
import logging
22
22
import re
23
23
import warnings
24
- from typing import List , Optional
24
+ from typing import Dict , List , Optional
25
25
26
26
from airflow .exceptions import AirflowException
27
27
from airflow .models import BaseOperator , BaseOperatorLink
@@ -151,6 +151,8 @@ class MLEngineStartBatchPredictionJobOperator(BaseOperator):
151
151
For this to work, the service account making the request must
152
152
have domain-wide delegation enabled.
153
153
:type delegate_to: str
154
+ :param labels: a dictionary containing labels for the job; passed to BigQuery
155
+ :type labels: Dict[str, str]
154
156
:raises: ``ValueError``: if a unique model/version origin cannot be
155
157
determined.
156
158
"""
@@ -183,6 +185,7 @@ def __init__(self, # pylint: disable=too-many-arguments
183
185
project_id : Optional [str ] = None ,
184
186
gcp_conn_id : str = 'google_cloud_default' ,
185
187
delegate_to : Optional [str ] = None ,
188
+ labels : Optional [Dict [str , str ]] = None ,
186
189
** kwargs ) -> None :
187
190
super ().__init__ (** kwargs )
188
191
@@ -200,6 +203,7 @@ def __init__(self, # pylint: disable=too-many-arguments
200
203
self ._signature_name = signature_name
201
204
self ._gcp_conn_id = gcp_conn_id
202
205
self ._delegate_to = delegate_to
206
+ self ._labels = labels
203
207
204
208
if not self ._project_id :
205
209
raise AirflowException ('Google Cloud project id is required.' )
@@ -234,6 +238,8 @@ def execute(self, context):
234
238
'region' : self ._region
235
239
}
236
240
}
241
+ if self ._labels :
242
+ prediction_request ['labels' ] = self ._labels
237
243
238
244
if self ._uri :
239
245
prediction_request ['predictionInput' ]['uri' ] = self ._uri
@@ -953,6 +959,8 @@ class MLEngineStartTrainingJobOperator(BaseOperator):
953
959
will be printed out. In 'CLOUD' mode, a real MLEngine training job
954
960
creation request will be issued.
955
961
:type mode: str
962
+ :param labels: a dictionary containing labels for the job; passed to BigQuery
963
+ :type labels: Dict[str, str]
956
964
"""
957
965
958
966
template_fields = [
@@ -990,6 +998,7 @@ def __init__(self, # pylint: disable=too-many-arguments
990
998
gcp_conn_id : str = 'google_cloud_default' ,
991
999
delegate_to : Optional [str ] = None ,
992
1000
mode : str = 'PRODUCTION' ,
1001
+ labels : Optional [Dict [str , str ]] = None ,
993
1002
** kwargs ) -> None :
994
1003
super ().__init__ (** kwargs )
995
1004
self ._project_id = project_id
@@ -1006,6 +1015,7 @@ def __init__(self, # pylint: disable=too-many-arguments
1006
1015
self ._gcp_conn_id = gcp_conn_id
1007
1016
self ._delegate_to = delegate_to
1008
1017
self ._mode = mode
1018
+ self ._labels = labels
1009
1019
1010
1020
if not self ._project_id :
1011
1021
raise AirflowException ('Google Cloud project id is required.' )
@@ -1039,6 +1049,8 @@ def execute(self, context):
1039
1049
'args' : self ._training_args ,
1040
1050
}
1041
1051
}
1052
+ if self ._labels :
1053
+ training_request ['labels' ] = self ._labels
1042
1054
1043
1055
if self ._runtime_version :
1044
1056
training_request ['trainingInput' ]['runtimeVersion' ] = self ._runtime_version
0 commit comments