Control Flow

Create pipelines with control flow

Although a KFP pipeline decorated with the @dsl.pipeline decorator looks like a normal Python function, it is actually an expression of pipeline topology and control flow semantics, constructed using the KFP domain-specific language (DSL). Pipeline Basics covered how data passing expresses pipeline topology through task dependencies. This section describes how to use control flow in your pipelines using the KFP DSL. The DSL features three types of control flow, each implemented by a Python context manager:

  1. Conditions
  2. Looping
  3. Exit handling

Conditions (dsl.If, dsl.Elif, dsl.Else)

The dsl.If context manager enables conditional execution of tasks within its scope based on the output of an upstream task or pipeline input parameter. The context manager takes two arguments: a required condition and an optional name. The condition is a comparative expression where at least one of the two operands is an output from an upstream task or a pipeline input parameter.

In the following pipeline, conditional_task only executes if coin_flip_task has the output 'heads'.

from kfp import dsl

@dsl.pipeline
def my_pipeline():
    coin_flip_task = flip_coin()
    with dsl.If(coin_flip_task.output == 'heads'):
        conditional_task = my_comp()

You may also use dsl.Elif and dsl.Else context managers immediately downstream of dsl.If for additional conditional control flow functionality:

from kfp import dsl

@dsl.pipeline
def my_pipeline():
    coin_flip_task = flip_three_sided_coin()
    with dsl.If(coin_flip_task.output == 'heads'):
        print_comp(text='Got heads!')
    with dsl.Elif(coin_flip_task.output == 'tails'):
        print_comp(text='Got tails!')
    with dsl.Else():
        print_comp(text='Draw!')

dsl.OneOf

dsl.OneOf can be used to gather outputs from mutually exclusive branches into a single task output which can be consumed by a downstream task or outputted from a pipeline. Branches are mutually exclusive if exactly one will be executed. To enforce this, the KFP SDK compiler requires dsl.OneOf consume from taksks within a logically associated group of conditional branches and that one of the branches is a dsl.Else branch.

from kfp import dsl

@dsl.pipeline
def my_pipeline() -> str:
    coin_flip_task = flip_three_sided_coin()
    with dsl.If(coin_flip_task.output == 'heads'):
        t1 = print_and_return(text='Got heads!')
    with dsl.Elif(coin_flip_task.output == 'tails'):
        t2 = print_and_return(text='Got tails!')
    with dsl.Else():
        t3 = print_and_return(text='Draw!')
    
    oneof = dsl.OneOf(t1.output, t2.output, t3.output)
    announce_result(oneof)
    return oneof

You should provide task outputs to the dsl.OneOf using .output or .outputs[<key>], just as you would pass an output to a downstream task. The outputs provided to dsl.OneOf must be of the same type and cannot be other instances of dsl.OneOf or dsl.Collected.

Parallel looping (dsl.ParallelFor)

The dsl.ParallelFor context manager allows parallel execution of tasks over a static set of items. The context manager takes three arguments: a required items, an optional parallelism, and an optional name. items is the static set of items to loop over and parallelism is the maximum number of concurrent iterations permitted while executing the dsl.ParallelFor group. parallelism=0 indicates unconstrained parallelism.

In the following pipeline, train_model will train a model for 1, 5, 10, and 25 epochs, with no more than two training tasks running at one time:

from kfp import dsl

@dsl.pipeline
def my_pipeline():
    with dsl.ParallelFor(
        items=[1, 5, 10, 25],
        parallelism=2
    ) as epochs:
        train_model(epochs=epochs)

dsl.Collected

Use dsl.Collected with dsl.ParallelFor to gather outputs from a parallel loop of tasks:

from kfp import dsl

@dsl.pipeline
def my_pipeline():
    with dsl.ParallelFor(
        items=[1, 5, 10, 25],
    ) as epochs:
        train_model_task = train_model(epochs=epochs)
    max_accuracy(models=dsl.Collected(train_model_task.outputs['model']))

Downstream tasks might consume dsl.Collected outputs via an input annotated with a List of parameters or a List of artifacts. For example, select_best in the preceding example has the input models with type Input[List[Model]], as shown by the following component definition:

from kfp import dsl
from kfp.dsl import Model, Input

@dsl.component
def select_best(models: Input[List[Model]]) -> float:
    return max(score_model(model) for model in models)

You can use dsl.Collected to collect outputs from nested loops in a nested list of parameters. For example, output parameters from two nested dsl.ParallelFor groups are collected in a multilevel nested list of parameters, where each nested list contains the output parameters from one of the dsl.ParallelFor groups. The number of nested levels is based on the number of nested dsl.ParallelFor contexts.

By comparison, artifacts created in nested loops are collected in a flat list.

You can also return a dsl.Collected from a pipeline. Use a List of parameters or a List of artifacts in the return annotation, as shown in the following example:

from kfp import dsl
from kfp.dsl import Model

@dsl.pipeline
def my_pipeline() -> List[Model]:
    with dsl.ParallelFor(
        items=[1, 5, 10, 25],
    ) as epochs:
        train_model_task = train_model(epochs=epochs)
    return dsl.Collected(train_model_task.outputs['model'])

Exit handling (dsl.ExitHandler)

The dsl.ExitHandler context manager allows pipeline authors to specify an exit task which will run after the tasks within the context manager’s scope finish execution, even if one of those tasks fails. This is analogous to using a try: block followed by a finally: block in normal Python, where the exit task is in the finally: block. The context manager takes two arguments: a required exit_task and an optional name. exit_task accepts an instantiated PipelineTask.

In the following pipeline, clean_up_task will execute after both create_dataset and train_and_save_models finish or either of them fail:

from kfp import dsl

@dsl.pipeline
def my_pipeline():
    clean_up_task = clean_up_resources()
    with dsl.ExitHandler(exit_task=clean_up_task):
        dataset_task = create_datasets()
        train_task = train_and_save_models(dataset=dataset_task.output)

The task you use as an exit task may use a special input that provides access to pipeline and task status metadata, including pipeline failure or success status. You can use this special input by annotating your exit task with the dsl.PipelineTaskFinalStatus annotation. The argument for this parameter will be provided by the backend automatically at runtime. You should not provide any input to this annotation when you instantiate your exit task.

The following pipeline uses dsl.PipelineTaskFinalStatus to obtain information about the pipeline and task failure, even after fail_op fails:

from kfp import dsl
from kfp.dsl import PipelineTaskFinalStatus


@dsl.component
def exit_op(user_input: str, status: PipelineTaskFinalStatus):
    """Prints pipeline run status."""
    print(user_input)
    print('Pipeline status: ', status.state)
    print('Job resource name: ', status.pipeline_job_resource_name)
    print('Pipeline task name: ', status.pipeline_task_name)
    print('Error code: ', status.error_code)
    print('Error message: ', status.error_message)

@dsl.component
def fail_op():
    import sys
    sys.exit(1)

@dsl.pipeline
def my_pipeline():
    print_op()
    print_status_task = exit_op(user_input='Task execution status:')
    with dsl.ExitHandler(exit_task=print_status_task):
        fail_op()

Ignore upstream failure

The .ignore_upstream_failure() task method on PipelineTask enables another approach to author pipelines with exit handling behavior. Calling this method on a task causes the task to ignore failures of any specified upstream tasks (as established by data exchange or by use of .after()). If the task has no upstream tasks, this method has no effect.

In the following pipeline definition, clean_up_task is executed after fail_op, regardless of whether fail_op succeeds:

from kfp import dsl

@dsl.pipeline()
def my_pipeline(text: str = 'message'):
    task = fail_op(message=text)
    clean_up_task = print_op(
        message=task.output).ignore_upstream_failure()

Note that the component used for the caller task (print_op in the example above) requires a default value for all inputs it consumes from an upstream task. The default value is applied if the upstream task fails to produce the outputs that are passed to the caller task. Specifying default values ensures that the caller task always succeeds, regardless of the status of the upstream task.

Feedback

Was this page helpful?


Last modified October 27, 2023: sdk: add dsl.OneOf docs (#3605) (f697081)