Divide and conquer with branching

The largest chunk of work in our project is our model fitting and the associated grid search. When scaling up these kinds of processes to larger data and larger grids you typically hit some stumbling blocks.

For example:

  • You ‘add more cores’, or utilise more parallel threads, but then you unexpectedly run out of memory.
  • You get iterations that have model convergence problems due a bad combinations of hyper parameters.
  • A colleague suggests you need to expand you hyper-parameter search.

These types of things are frustrating because the problem might not appear until hours into a very long running process, and the intermediate result of all of those hours is immediately dumped.

With {targets} we can use a technique called ‘dynamic branching’ to promote every iteration in a set to its own target.

  • Each result is individually cached, which means large iterative processes are now resumable.
  • We can also add or remove iterations by changing input data, while reusing the results from previous iterations.
  • We can take advantage of {targets} parallelism features to run the iterations in parallel.

We’ll refactor the model grid search in our project to use this approach. After we see how it works it’s going to be a little easier to explain why it is called ‘dynamic branching’.

Dynamic branching refactor steps:

In the process of this refactor we’re going to remove our training grid. If you recall it was a dataframe with one row per combination of training fold and model hyper parameters:

  occurrence_cv_splits <-
    vfold_cv(training_data, v = 5, repeats = 1)

  training_grid <-
    expand.grid(
      fold_id = occurrence_cv_splits$id,
      mtry = mtry_candidates,
      num_trees = num_trees_candidates
    ) |>
    as_tibble() |>
    left_join(
      occurrence_cv_splits,
      by = c(fold_id = "id")
    )

  training_grid

This refactor will give {targets} the job of materialising the grid, with one target per combination of parameters and data (row).

  1. Create a new target that summarises the grid training results:
  • This code remains unchanged.

i.e put this code:

  summarised_training_results <-
    training_results |>
    summarise(
      mean_auc = mean(auc),
      sd_auc = sd(auc),
      mean_accuracy = mean(accuracy),
      sd_accuracy = sd(accuracy),
      .by = c(mtry, num_trees)
    ) |>
    arrange(-mean_auc)

  summarised_training_results

inside a new target:

  tar_target(
    species_classification_model_training_summary,
    summarise_species_model_training_results(
      species_classification_model_training_results
    )
  ),
  1. Promote mtry_candidates and num_trees_candidates to plan targets:
 tar_target(
    mtry_candidates,
    c(1, 2, 3)
  ),
  tar_target(
    num_trees_candidates,
    c(200, 500, 100)
  ),
  1. Make the cross validation fold dataset into a plan target:
  tar_target(
    training_cross_validation_folds,
    vfold_cv(training(test_train_split), v = 5, repeats = 1)
  ),
  1. What’s left is to refactor the middle bit, actually fitting the models, ie. this code:
  training_results <-
    training_grid |>
    mutate(
      pmap(
        .l = list(
          training_grid$splits,
          training_grid$num_trees,
          training_grid$mtry
        ),
        .f = fit_fold_calc_results
      ) |>
        bind_rows()
      # by returning a dataframe inside mutate, the resulting columns are appended to training_grid
    )

We change that into a new target that looks like this:

tar_target(
    species_classification_model_training_results,
    fit_fold_calc_results(
      training_cross_validation_folds$splits[[1]],
      num_trees_candidates,
      mtry_candidates
    ),
    pattern = cross(training_cross_validation_folds, mtry_candidates, num_trees_candidates)
  )

We’re introducing a bit of magic here: pattern = cross(training_cross_validation_folds, mtry_candidates, num_trees_candidates).

This says to {targets}: We’re declaring a group of targets here, that you’re going to create for us. That group is defined by evaluating this target’s expression on a set of inputs, this case a cross product of input targets. If we call tar_make() at this point we get:

▶ dispatched branch species_classification_model_training_results_d19c364718f81c21
Setting levels: control = 0, case = 1
Setting direction: controls < cases
● completed branch species_classification_model_training_results_d19c364718f81c21 [0.051 seconds]
▶ dispatched branch species_classification_model_training_results_bde0689544bd8cc7
Setting levels: control = 0, case = 1
Setting direction: controls < cases
● completed branch species_classification_model_training_results_bde0689544bd8cc7 [0.058 seconds]
● completed pattern species_classification_model_training_results
▶ dispatched target species_classification_model_training_summary
✖ errored target species_classification_model_training_summary

If we look tar_load(species_classification_model_training_results) and interactively run:

summarise_species_model_training_results(
      species_classification_model_training_results
)

we can see the problem more clearly:

Error in `summarise()` at R/summarise_species_model_training_results.R:12:3:
! Can't select columns that don't exist.
✖ Column `mtry` doesn't exist.
Run `rlang::last_trace()` to see where the error occurred.

This is happening because species_classification_model_training_results no longer has columns mtry and num_trees. These were being included in the data because of the way we were calling mutate in our earlier grid search code.

To fix this we can modify the object returned by fit_fold_calc_results():

# use auc and accuracy as our summary statistics
  data.frame(
    auc = auc(roc_object) |> as.numeric(),
    accuracy = sum(test_set$is_moluccus > 0.5 & test_set$is_moluccus == 1) / nrow(test_set),
    mtry = mtry,
    num_trees = num_trees
  )

And now tar_make() should succeed!

  1. Put fit_fold_calc_results() in R/fit_fold_calc_results.R
  • Since the file name where it is no longer reflects what’s in there.
  • Can delete old file.

We successfully refactored to ‘dynamic branching’. We shall see in a moment all that we have bought with that. But first…

Why it’s called ‘Dynamic Branching’…

‘Branching’ is a reference to the tree-like appearance of pipeline graphs.

{targets} has a number of ways to add targets to the graph programatically. In our case we instructed {targets} to add a training target to our graph for each combination of model parameters and training data. After being computed those targets are immediately consolidated into into the dataset species_classification_model_training_results which we then summarised.

We do not need to immediately consolidate the dynamically generated targets though. We could for example create a new target from each of our dynamically generated targets which if you recall are 1 row dataframes. See fit_fold_calc_results().

This would look like:

tar_target(
  new_dynamic_target,
  a_function(species_classification_model_training_results),
  pattern = map(species_classification_model_training_results)
)

We again use pattern to express this, but this time with map we’re expressing a 1:1 transformation of the input targets, instead of a cross product.

Our pipeline graph is having ‘branches’ extended, so that you can imagine looks like:

 species_classification_model_training_results_a - new_dynamic_target_a  \
/
- species_classification_model_training_results_b - new_dynamic_target_b -  final_summary
\
 species_classification_model_training_results_c - new_dynamic_target_c  /

These branches can be chains of targets that continue on, perhaps even splitting into even more branches themselves before being finally consolidated into an collection like a list, dataframe, or vector.

The analogy of ‘branches’ fits this idea of splitting and potentially growing and splitting further.

‘Dynamic’ comes from the fact that there are actually two ways to do branching in {targets}.

  • ‘Dynamic branching’ where {targets} generates branches for you at run time, when all the target’s dependencies are computed. The reason this is important is that it may not be known how many items a list/vector/dataframe target contains, and thus how many branches {targets} would need to create.
  • ‘Static branching’ where the the exact number of branches that need to be created is known based on fixed inputs in the plan. For example we might have been able to use this to create a branch for each statically known combination of parameters in our grid search.

Static Branching was developed first, and is largely superseded by Dynamic Branching. There is little reason to prefer the static mode.

The Proof in the pudding

Reusing existing grid search points

No here’s where if you do a bit of modeling, {targets} should get really exciting.

First let’s explore what happens if we expand the grid search, e.g. by trying a model version with 1000 trees:

  tar_target(
    num_trees_candidates,
    c(200, 500, 100, 1000)
  ),

running tar_make() gives:

...
✔ skipped branch species_classification_model_training_results_d15e5fdbcf21af3b
✔ skipped branch species_classification_model_training_results_d19c364718f81c21
✔ skipped branch species_classification_model_training_results_bde0689544bd8cc7
▶ dispatched branch species_classification_model_training_results_e8827b8265f7d539
Setting levels: control = 0, case = 1
Setting direction: controls < cases
● completed branch species_classification_model_training_results_e8827b8265f7d539 [0.043 seconds]
● completed pattern species_classification_model_training_results
▶ dispatched target species_classification_model_training_summary
● completed target species_classification_model_training_summary [0.009 seconds]
▶ dispatched target species_classification_model
● completed target species_classification_model [0.049 seconds]
▶ dispatched target species_model_validation_data
● completed target species_model_validation_data [0.014 seconds]
✔ skipped target base_plot_model_roc_object
✔ skipped target gg_species_class_accuracy_hexes
▶ dispatched target report
● completed target report [3.984 seconds]
▶ ended pipeline [5.811 seconds]

We can see that:

  • We skipped a lot of branches in calculating species_classification_model_training_results
    • We only calculated new combinations in our training grid with num_trees = 1000
      • 3 x 15 x 1 of these
  • species_classification_model_training_summary changed and since it is an input to species_classification_model the model was refit.
  • BUT it turned out that the best model remained the same. So we did not rebuild:
    • base_plot_model_roc_object
    • gg_species_class_accuracy_hexes
  • Question: How could we refactor the plan if we wanted to make it so that if the best model didn’t change we would not refit final model?
  • Refactoring ideas
  • We could make a separate target which is just the first row of species_classification_model_training_summary, which represents the best model.
  • The final model would only be refit if this changes.

Remember workspaces?

Initially when I made this refactor I made this mistake:

tar_target(
    species_classification_model_training_results,
    fit_fold_calc_results(
      training_cross_validation_folds$splits,
      mtry_candidates,
      num_trees_candidates
    ),
    pattern = cross(training_cross_validation_folds, mtry_candidates, num_trees_candidates)
  )

Forgetting that splits was a list column, and so the dataset I want will have an extra layer of list wrapping that needs to be stripped off.

The error this generated is hard to debug:

▶ dispatched target training_cross_validation_folds
● completed target training_cross_validation_folds [0.007 seconds]
▶ dispatched branch species_classification_model_training_results_2f9ab41c1360f0ce
✖ errored branch species_classification_model_training_results_2f9ab41c1360f0ce
✖ errored pipeline [9.093 seconds]
Error:
! Error running targets::tar_make()
Error messages: targets::tar_meta(fields = error, complete_only = TRUE)
Debugging guide: https://books.ropensci.org/targets/debugging.html
How to ask for help: https://books.ropensci.org/targets/help.html
Last error message:
    No method for objects of class: list
Last error trace back:
    fit_fold_calc_results(training_cross_validation_folds$splits,      mtry_...

Partly because the inputs to species_classification_model_training_results_2f9ab41c1360f0ce are not known exactly. They could be any combination of elements from training_cross_validation_folds, mtry_candidates, and num_trees_candidates. So what would we tar_load() to test the problem interactively?

This is the situation we discussed earlier in the context of workspaces. To debug we set:

tar_option_set(
  seed = 2048,
  workspace_on_error = TRUE
)
  • It’s actually not a bad idea to turn this on defensively when working with dynamic branches.

An run tar_make():

▶ dispatched branch species_classification_model_training_results_2f9ab41c1360f0ce
▶ recorded workspace species_classification_model_training_results_2f9ab41c1360f0ce
✖ errored branch species_classification_model_training_results_2f9ab41c1360f0ce
✖ errored pipeline [0.349 seconds]

and then tar_workspace(species_classification_model_training_results_2f9ab41c1360f0ce).

We can now observe that the data object we pass to the fitting function for this branch is:

> training_cross_validation_folds$splits
[[1]]
<Analysis/Assess/Total>
<5910/1478/7388>

Inside a length 1 list. So when we try to run training() on it inside fit_fold_calc_results() we get this error:

> training(training_cross_validation_folds$splits)
Error in `training()`:
! No method for objects of class: list
Run `rlang::last_trace()` to see where the error occurred.

So the quick fix is the [[1]] I added.

Converting to parallel

BUT WAIT THERE’S MORE:

Things run pretty fast now. But what if we wanted to speed things up by making more cores available to run model fits in parallel?

We add library(crew) to our packages and then set our options like:

  tar_option_set(
    seed = 2048,
    controller = crew_controller_local(workers = 2)
  )

Let’s blow away our targets store with tar_destroy(), and then run tar_make() to see:

An error!

✖ errored target occurrences
✖ errored pipeline [3.665 seconds]
Error:
! Error running targets::tar_make()
Error messages: targets::tar_meta(fields = error, complete_only = TRUE)
Debugging guide: https://books.ropensci.org/targets/debugging.html
How to ask for help: https://books.ropensci.org/targets/help.html
Last error message:
    [conflicted] filter found in 2 packages.
Either pick the one you want with `::`:
• dplyr::filter
• stats::filter
Or declare a preference with `conflicts_prefer()`:
• `conflicts_prefer(dplyr::filter)`
• `conflicts_prefer(stats::filter)`

What gives! We called conflicts_prefer at the start of our _targets.R.

So in many cases of your {targets} plan can be made parallel with just that one config change. Unfortunately in our case there is a small issue:

  • The environment we create in _targets.R is copied to the worker threads that will run targets in parallel
  • {targets} Doesn’t reach into package namespaces and copy their internal state. That’s a can of worms!
  • So any package that uses internal state its own namespace for its functionality could have problems when that state is not replicated to workers.
  • This also a problem of calling impure functions!

We have two packages that utilise internal state in their namespaces:

  • {conflicted} for the conflict resolution data
  • {galah} for its credentials

Hooks to the rescue

Luckily this gives us a really good motivating case for something {targets} calls ‘hooks’. Hooks are ways to modify the target definitions in our plan after we have defined them. We have at our disposal:

  • tar_hook_before(): code to evaluate before a target is built
  • tar_hook_inner(): code to wrap around a target any time it appears as a dependency for another target (i.e. in input position)
  • tar_hook_outer(): code to wrap around a target after it is built, but before it is saved to the store.

In our case we can append tar_hook_before() to the end of our list() of target definitions:

  list(
  # inside list of targets
  ) |>
  tar_hook_before(
    hook = {
      conflicts_prefer(
        dplyr::filter,
      )
      galah_config(
        atlas = "ALA",
        email = Sys.getenv("ALA_EMAIL") # You can replace this with your email to run. But you might not want to commit it to a public repository!
      )
    }
  )

Hooks can be targeted toward specific targets using the names. A classic use is to strip back {ggplot2} objects before they are saved to the store, since they can hold onto a reference to a large dataset. E.g.

list(
  # inside list of targets
  ) |>
  tar_hook_outer(
    hook = lighten_ggplot(.x), # .x is a placeholder we use for the target output
    names = starts_with("gg")
  )

This will postprocess any targets that start their name with “gg”.

Parallel, finally

With the hook in place, we can now build our pipeline in parallel. Exactly how that is done we leave to {targets} and the parallel backend in {crew}. Targets that are not dependent on each other are fair game to run in parallel.

Targets supports other parallel backends from packages {clustermq} and {future}.

One thing that commonly crops up running in parallel is that the increased memory pressure can cause out of memory errors. Luckily with {targets}, we don’t lose work. If you were working on AWS you could change your instance config for more RAM and resume processing.

There are helpful options for dealing with resource usage, for example:

  • tar_option_set(memory = "transient") can force targets to be dropped from memory after they’re stored. Under defaults targets can stick around in workers memory. It’s slower to use this but it lowers peak memory usage.
  • Also see storage = "worker option which can control if the result needs to be copied back to main thread or not.
  • We can also use the deployment option to specify only certain targets go to workers, and the rest get processed in the main thread that runs our plan.

You don’t need to remember these.

  • You can find them in the help for tar_option_set().
  • They can all be set globally or at an individual target level
  • These and other options give you valuable control of ‘shape’ of your pipeline’s process.
  • Just remember they exist if you run into resource usage problems.

Dynamic Branching refactor

The completed refactor for this section is available on this branch of the example project.

Review

  • Dynamic branching lets us dynamically create targets
    • These targets can represent iterations over other targets
    • Makes iterations resumable
    • Makes iterations parallelisable
    • Can add iterations as inputs are updated, but keep prior work
    • There are two important arguments for this.
      • We discussed pattern (with map and cross). There are other patterns available.
      • the iteration argument is also useful, but not covered here.
  • Hooks are useful for changing selections of targets after we have defined them
    • We can apply pre / post processing steps or do setup work.
  • {targets} has options for controlling resource usage.