Skip to content

Commit 7aa0f47

Browse files
astro-sql-decoratorashb
authored andcommitted
[AIRFLOW-5413] Allow K8S worker pod to be configured from JSON/YAML file (#6230)
* [AIRFLOW-5413] enable pod config from file * Update airflow/kubernetes/pod_generator.py Co-Authored-By: Ash Berlin-Taylor <[email protected]> * Update airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py Co-Authored-By: Ash Berlin-Taylor <[email protected]> Co-authored-by: Ash Berlin-Taylor <[email protected]> (cherry picked from commit 967930c)
1 parent bde12c8 commit 7aa0f47

File tree

14 files changed

+304
-85
lines changed

14 files changed

+304
-85
lines changed

airflow/config_templates/config.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1736,6 +1736,13 @@
17361736
type: string
17371737
example: ~
17381738
default: ""
1739+
- name: pod_template_file
1740+
description: |
1741+
Path to the YAML pod file. If set, all other kubernetes-related fields are ignored.
1742+
version_added: ~
1743+
type: string
1744+
example: ~
1745+
default: ""
17391746
- name: worker_container_tag
17401747
description: ~
17411748
version_added: ~

airflow/config_templates/default_airflow.cfg

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -801,6 +801,9 @@ verify_certs = True
801801
[kubernetes]
802802
# The repository, tag and imagePullPolicy of the Kubernetes Image for the Worker to Run
803803
worker_container_repository =
804+
805+
# Path to the YAML pod file. If set, all other kubernetes-related fields are ignored.
806+
pod_template_file =
804807
worker_container_tag =
805808
worker_container_image_pull_policy = IfNotPresent
806809

airflow/contrib/operators/kubernetes_pod_operator.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,18 @@ class KubernetesPodOperator(BaseOperator): # pylint: disable=too-many-instance-
130130
:type schedulername: str
131131
:param full_pod_spec: The complete podSpec
132132
:type full_pod_spec: kubernetes.client.models.V1Pod
133+
:param init_containers: init container for the launched Pod
134+
:type init_containers: list[kubernetes.client.models.V1Container]
135+
:param log_events_on_failure: Log the pod's events if a failure occurs
136+
:type log_events_on_failure: bool
137+
:param do_xcom_push: If True, the content of the file
138+
/airflow/xcom/return.json in the container will also be pushed to an
139+
XCom when the container completes.
140+
:type do_xcom_push: bool
141+
:param pod_template_file: path to pod template file
142+
:type pod_template_file: str
133143
"""
134-
template_fields = ('cmds', 'arguments', 'env_vars', 'config_file')
144+
template_fields = ('cmds', 'arguments', 'env_vars', 'config_file', 'pod_template_file')
135145

136146
@apply_defaults
137147
def __init__(self, # pylint: disable=too-many-arguments,too-many-locals
@@ -215,8 +225,8 @@ def __init__(self, # pylint: disable=too-many-arguments,too-many-locals
215225
self.full_pod_spec = full_pod_spec
216226
self.init_containers = init_containers or []
217227
self.log_events_on_failure = log_events_on_failure
218-
self.priority_class_name = priority_class_name
219228
self.pod_template_file = pod_template_file
229+
self.priority_class_name = priority_class_name
220230
self.name = self._set_name(name)
221231

222232
@staticmethod
@@ -348,6 +358,7 @@ def create_new_pod_for_operator(self, labels, launcher):
348358
init_containers=self.init_containers,
349359
restart_policy='Never',
350360
schedulername=self.schedulername,
361+
pod_template_file=self.pod_template_file,
351362
priority_class_name=self.priority_class_name,
352363
pod=self.full_pod_spec,
353364
).gen_pod()

airflow/example_dags/example_kubernetes_executor_config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,14 +83,14 @@ def test_volume_mount():
8383
}
8484
)
8585

86-
# Test that we can run tasks as a normal user
86+
# Test that we can add labels to pods
8787
third_task = PythonOperator(
8888
task_id="non_root_task",
8989
python_callable=print_stuff,
9090
executor_config={
9191
"KubernetesExecutor": {
92-
"securityContext": {
93-
"runAsUser": 1000
92+
"labels": {
93+
"release": "stable"
9494
}
9595
}
9696
}

airflow/executors/kubernetes_executor.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import multiprocessing
2121
import time
2222
from queue import Empty
23-
from uuid import uuid4
2423

2524
import kubernetes
2625
from dateutil import parser
@@ -72,6 +71,9 @@ def __init__(self):
7271
)
7372
self.kube_node_selectors = configuration_dict.get('kubernetes_node_selectors', {})
7473
self.kube_annotations = configuration_dict.get('kubernetes_annotations', {}) or None
74+
self.pod_template_file = conf.get(self.kubernetes_section, 'pod_template_file',
75+
fallback=None)
76+
7577
self.kube_labels = configuration_dict.get('kubernetes_labels', {})
7678
self.delete_worker_pods = conf.getboolean(
7779
self.kubernetes_section, 'delete_worker_pods')
@@ -220,6 +222,8 @@ def _get_security_context_val(self, scontext):
220222
return int(val)
221223

222224
def _validate(self):
225+
if self.pod_template_file:
226+
return
223227
# TODO: use XOR for dags_volume_claim and git_dags_folder_mount_point
224228
if not self.dags_volume_claim \
225229
and not self.dags_volume_host \
@@ -498,10 +502,7 @@ def _create_pod_id(dag_id, task_id):
498502
dag_id)
499503
safe_task_id = AirflowKubernetesScheduler._strip_unsafe_kubernetes_special_chars(
500504
task_id)
501-
safe_uuid = AirflowKubernetesScheduler._strip_unsafe_kubernetes_special_chars(
502-
uuid4().hex)
503-
return AirflowKubernetesScheduler._make_safe_pod_id(safe_dag_id, safe_task_id,
504-
safe_uuid)
505+
return safe_dag_id + safe_task_id
505506

506507
@staticmethod
507508
def _label_safe_datestring_to_datetime(string):

airflow/kubernetes/pod_generator.py

Lines changed: 99 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -24,21 +24,35 @@
2424
import copy
2525
import hashlib
2626
import re
27+
try:
28+
from inspect import signature
29+
except ImportError:
30+
# Python 2.7
31+
from funcsigs import signature # type: ignore
32+
import os
2733
import uuid
34+
from functools import reduce
2835

2936
import kubernetes.client.models as k8s
37+
import yaml
38+
from kubernetes.client.api_client import ApiClient
3039

40+
from airflow.exceptions import AirflowConfigException
3141
from airflow.version import version as airflow_version
3242

3343
MAX_LABEL_LEN = 63
3444

3545
MAX_POD_ID_LEN = 253
3646

3747

38-
class PodDefaults:
48+
class PodDefaults(object):
3949
"""
4050
Static defaults for the PodGenerator
4151
"""
52+
53+
def __init__(self):
54+
pass
55+
4256
XCOM_MOUNT_PATH = '/airflow/xcom'
4357
SIDECAR_CONTAINER_NAME = 'airflow-xcom-sidecar'
4458
XCOM_CMD = 'trap "exit 0" INT; while true; do sleep 30; done;'
@@ -82,7 +96,7 @@ def make_safe_label_value(string):
8296
return safe_label
8397

8498

85-
class PodGenerator:
99+
class PodGenerator(object):
86100
"""
87101
Contains Kubernetes Airflow Worker configuration logic
88102
@@ -147,9 +161,11 @@ class PodGenerator:
147161
:param dnspolicy: Specify a dnspolicy for the pod
148162
:type dnspolicy: str
149163
:param schedulername: Specify a schedulername for the pod
150-
:type schedulername: str
151-
:param pod: The fully specified pod.
152-
:type pod: kubernetes.client.models.V1Pod
164+
:type schedulername: Optional[str]
165+
:param pod: The fully specified pod. Mutually exclusive with `path_or_string`
166+
:type pod: Optional[kubernetes.client.models.V1Pod]
167+
:param pod_template_file: Path to YAML file. Mutually exclusive with `pod`
168+
:type pod_template_file: Optional[str]
153169
:param extract_xcom: Whether to bring up a container for xcom
154170
:type extract_xcom: bool
155171
"""
@@ -167,8 +183,8 @@ def __init__(
167183
node_selectors=None,
168184
ports=None,
169185
volumes=None,
170-
image_pull_policy='IfNotPresent',
171-
restart_policy='Never',
186+
image_pull_policy=None,
187+
restart_policy=None,
172188
image_pull_secrets=None,
173189
init_containers=None,
174190
service_account_name=None,
@@ -183,9 +199,16 @@ def __init__(
183199
schedulername=None,
184200
priority_class_name=None,
185201
pod=None,
202+
pod_template_file=None,
186203
extract_xcom=False,
187204
):
188-
self.ud_pod = pod
205+
self.validate_pod_generator_args(locals())
206+
207+
if pod_template_file:
208+
self.ud_pod = self.deserialize_model_file(pod_template_file)
209+
else:
210+
self.ud_pod = pod
211+
189212
self.pod = k8s.V1Pod()
190213
self.pod.api_version = 'v1'
191214
self.pod.kind = 'Pod'
@@ -348,37 +371,7 @@ def from_obj(obj):
348371
'iam.cloud.google.com/service-account': gcp_service_account_key
349372
})
350373

351-
pod_spec_generator = PodGenerator(
352-
image=namespaced.get('image'),
353-
envs=namespaced.get('env'),
354-
cmds=namespaced.get('cmds'),
355-
args=namespaced.get('args'),
356-
labels=namespaced.get('labels'),
357-
node_selectors=namespaced.get('node_selectors'),
358-
name=namespaced.get('name'),
359-
ports=namespaced.get('ports'),
360-
volumes=namespaced.get('volumes'),
361-
volume_mounts=namespaced.get('volume_mounts'),
362-
namespace=namespaced.get('namespace'),
363-
image_pull_policy=namespaced.get('image_pull_policy'),
364-
restart_policy=namespaced.get('restart_policy'),
365-
image_pull_secrets=namespaced.get('image_pull_secrets'),
366-
init_containers=namespaced.get('init_containers'),
367-
service_account_name=namespaced.get('service_account_name'),
368-
resources=resources,
369-
annotations=namespaced.get('annotations'),
370-
affinity=namespaced.get('affinity'),
371-
hostnetwork=namespaced.get('hostnetwork'),
372-
tolerations=namespaced.get('tolerations'),
373-
security_context=namespaced.get('security_context'),
374-
configmaps=namespaced.get('configmaps'),
375-
dnspolicy=namespaced.get('dnspolicy'),
376-
schedulername=namespaced.get('schedulername'),
377-
pod=namespaced.get('pod'),
378-
extract_xcom=namespaced.get('extract_xcom'),
379-
)
380-
381-
return pod_spec_generator.gen_pod()
374+
return PodGenerator(**namespaced).gen_pod()
382375

383376
@staticmethod
384377
def reconcile_pods(base_pod, client_pod):
@@ -495,12 +488,73 @@ def construct_pod(
495488
name=pod_id
496489
).gen_pod()
497490

498-
# Reconcile the pod generated by the Operator and the Pod
499-
# generated by the .cfg file
500-
pod_with_executor_config = PodGenerator.reconcile_pods(worker_config,
501-
kube_executor_config)
502-
# Reconcile that pod with the dynamic fields.
503-
return PodGenerator.reconcile_pods(pod_with_executor_config, dynamic_pod)
491+
# Reconcile the pods starting with the first chronologically,
492+
# Pod from the airflow.cfg -> Pod from executor_config arg -> Pod from the K8s executor
493+
pod_list = [worker_config, kube_executor_config, dynamic_pod]
494+
495+
return reduce(PodGenerator.reconcile_pods, pod_list)
496+
497+
@staticmethod
498+
def deserialize_model_file(path):
499+
"""
500+
:param path: Path to the file
501+
:return: a kubernetes.client.models.V1Pod
502+
503+
Unfortunately we need access to the private method
504+
``_ApiClient__deserialize_model`` from the kubernetes client.
505+
This issue is tracked here; https://github.com/kubernetes-client/python/issues/977.
506+
"""
507+
api_client = ApiClient()
508+
if os.path.exists(path):
509+
with open(path) as stream:
510+
pod = yaml.safe_load(stream)
511+
else:
512+
pod = yaml.safe_load(path)
513+
514+
# pylint: disable=protected-access
515+
return api_client._ApiClient__deserialize_model(pod, k8s.V1Pod)
516+
517+
@staticmethod
518+
def validate_pod_generator_args(given_args):
519+
"""
520+
:param given_args: The arguments passed to the PodGenerator constructor.
521+
:type given_args: dict
522+
:return: None
523+
524+
Validate that if `pod` or `pod_template_file` are set that the user is not attempting
525+
to configure the pod with the other arguments.
526+
"""
527+
pod_args = list(signature(PodGenerator).parameters.items())
528+
529+
def predicate(k, v):
530+
"""
531+
:param k: an arg to PodGenerator
532+
:type k: string
533+
:param v: the parameter of the given arg
534+
:type v: inspect.Parameter
535+
:return: bool
536+
537+
returns True if the PodGenerator argument has no default arguments
538+
or the default argument is None, and it is not one of the listed field
539+
in `non_empty_fields`.
540+
"""
541+
non_empty_fields = {
542+
'pod', 'pod_template_file', 'extract_xcom', 'service_account_name', 'image_pull_policy',
543+
'restart_policy'
544+
}
545+
546+
return (v.default is None or v.default is v.empty) and k not in non_empty_fields
547+
548+
args_without_defaults = {k: given_args[k] for k, v in pod_args if predicate(k, v) and given_args[k]}
549+
550+
if given_args['pod'] and given_args['pod_template_file']:
551+
raise AirflowConfigException("Cannot pass both `pod` and `pod_template_file` arguments")
552+
if args_without_defaults and (given_args['pod'] or given_args['pod_template_file']):
553+
raise AirflowConfigException(
554+
"Cannot configure pod and pass either `pod` or `pod_template_file`. Fields {} passed.".format(
555+
list(args_without_defaults.keys())
556+
)
557+
)
504558

505559

506560
def merge_objects(base_obj, client_obj):

airflow/kubernetes/worker_configuration.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,12 @@
2828

2929

3030
class WorkerConfiguration(LoggingMixin):
31-
"""Contains Kubernetes Airflow Worker configuration logic"""
31+
"""
32+
Contains Kubernetes Airflow Worker configuration logic
33+
34+
:param kube_config: the kubernetes configuration from airflow.cfg
35+
:type kube_config: airflow.executors.kubernetes_executor.KubeConfig
36+
"""
3237

3338
dags_volume_name = 'airflow-dags'
3439
logs_volume_name = 'airflow-logs'
@@ -424,9 +429,12 @@ def generate_dag_volume_mount_path(self):
424429

425430
def as_pod(self):
426431
"""Creates POD."""
427-
pod_generator = PodGenerator(
432+
if self.kube_config.pod_template_file:
433+
return PodGenerator(pod_template_file=self.kube_config.pod_template_file).gen_pod()
434+
435+
pod = PodGenerator(
428436
image=self.kube_config.kube_image,
429-
image_pull_policy=self.kube_config.kube_image_pull_policy,
437+
image_pull_policy=self.kube_config.kube_image_pull_policy or 'IfNotPresent',
430438
image_pull_secrets=self.kube_config.image_pull_secrets,
431439
volumes=self._get_volumes(),
432440
volume_mounts=self._get_volume_mounts(),
@@ -436,10 +444,10 @@ def as_pod(self):
436444
tolerations=self.kube_config.kube_tolerations,
437445
envs=self._get_environment(),
438446
node_selectors=self.kube_config.kube_node_selectors,
439-
service_account_name=self.kube_config.worker_service_account_name,
440-
)
447+
service_account_name=self.kube_config.worker_service_account_name or 'default',
448+
restart_policy='Never'
449+
).gen_pod()
441450

442-
pod = pod_generator.gen_pod()
443451
pod.spec.containers[0].env_from = pod.spec.containers[0].env_from or []
444452
pod.spec.containers[0].env_from.extend(self._get_env_from())
445453
pod.spec.security_context = self._get_security_context()

0 commit comments

Comments
 (0)