RNNs, Huggingface Trainer, and PackedSequence’s

code
huggingface
research
transformers relies on a bunch of internal heuristics to maintain the façade of intuitiveness. One such heuristic breaks recurrent models, but (fortunately) there’s a simple fix
Author

Enyan Zhang

Published

January 31, 2025

TL;DR

The Issue

When training with Huggingface Trainer, If your data collator (data_collator in Trainer or collate_fn for PyTorch DataLoader) outputs a PackedSequence for training an recurrent model (rnn/lstm/gru/who knows), there will be an assertion error assert isinstance(data, (list, tuple)) and len(data) == 2 triggered by line 254 of torch/nn/utils/rnn.py

The Solution

Huggingface trainer is sending the PackedSequence to the correct device (e.g. GPU) incorrectly, you need to override a method, see this section for the code.

An Even Better Way

See afterwords. This issue is completely avoidable if you define your model class differently.

Full Story

Background

Recently I was training toy RNNs for a project. Writing a train function with a for epoch in range(epochs) in 2024 felt very wrong (and unnecessary), so I thought about making everything work with Trainer of Huggingface Transformers. There are many good reasons for doing so (and it was a huge quality of life improvement!), I’ll list a few I’ve already used (and worked pretty much out of the box):

  • saving/loading models with a one-liner
  • adding/changing learning rate schedules
  • generating with .generate()
  • doing simple hyperparameter sweeps (see Hyperparameter Search)

But things don’t always work, and when they don’t work, debugging Trainer is frustrating — it does too many things and many such things rely on heuristics, below is an incomplete list of issues I already came across (and still remember debugging):

  • it assumes the training target is a dict entry called label or labels, and will skip evaluate() otherwise — but it won’t skip the eval loop entirely, instead it will only return eval metainfo such as runtime. The solution is to specify label-names in TrainingArguments.
  • it sends tensors to the model’s device by recursively iterating all inputs until it reaches the basic data elements (which should normally be some Tensor), but the heuristic for stopping this recursion is hasattr(data, "to") (see source) — so if you define a class that contains your custom data, it absolutely cannot have a to mathod that does something else.

And unfortunately one such heuristic breaks Pytorch RNNs. Here’s the premise:

  1. Transformers deal with variable-length sequences by padding inputs
    1. This is usually done by a DataCollator, which gets a list of dict and returns a dict of collated tensors (the action of “creating a batch” from samples)
    2. Additionally, attention_mask helps model zero-out attention on padding tokens, so effectively the model does not “see” the padded tokens
  2. RNNs also need to deal with variable-length input sequences
    1. It’s best if we also delegate this task to a data-collating function
    2. But RNN’s can’t deal with padding! There’s no trivial parallel for something like attention_mask, especially because Pytorch RNNs have are called with the entire sequence at once, as opposed to manually “unrolling” the model.

The solution of the above problem is to use a PackedSequence. The underlying idea is quite simple: instead of viewing the input as a batch of sequences, view it as a sequence of batches, where each batch can have a different batch size. The figure below illustrates it quite well1:

A Visual Illustration of PackedSequence from @sgrvinod

So the solution seems simple enough: we just need to define a data collating function that creates a PackedSequence from a list of samples, like the one below, and life’s good, right?

from torch.nn.utils.rnn import pack_padded_sequence

def collate_fn(examples):
  # first collate, e.g. using torch.stack
  examples = {k: torch.stack([e[k] for e in examples]) for k in examples[0]}

  # assume you previously tokenized the input with a transformers tokenizer
  input_lengths = torch.sum(tokenized_input["attention_mask"], dim=1)
  examples["input_ids"] = pack_padded_sequence(
    examples["input_ids"], 
    input_lengths, 
    batch_first=True, 
    encforce_sorted=False
    )

  return examples

Why Trainer cannot process PackedSequence’s

If only life is so easy — I invite you to re-read the title of this post and realize that we’ve only just gotten to the issue. If you tried trainer.train() with a collate_fn like above, you will get the following cryptic error message:

  0%|                                                    | 0/12520 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "/<project-dir>/src/train.py", line 243, in <module>
    main()
  File "/<project-dir>/src/train.py", line 173, in main
    trainer.train()
  File "/<project-dir>/.venv/lib64/python3.11/site-packages/transformers/trainer.py", line 2123, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "/<project-dir>/.venv/lib64/python3.11/site-packages/transformers/trainer.py", line 2481, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/<project-dir>/.venv/lib64/python3.11/site-packages/transformers/trainer.py", line 3573, in training_step
    inputs = self._prepare_inputs(inputs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/<project-dir>/.venv/lib64/python3.11/site-packages/transformers/trainer.py", line 3520, in _prepare_inputs
    inputs = self._prepare_input(inputs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/<project-dir>/.venv/lib64/python3.11/site-packages/transformers/trainer.py", line 3502, in _prepare_input
    return type(data)({k: self._prepare_input(v) for k, v in data.items()})
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/<project-dir>/.venv/lib64/python3.11/site-packages/transformers/trainer.py", line 3502, in <dictcomp>
    return type(data)({k: self._prepare_input(v) for k, v in data.items()})
                          ^^^^^^^^^^^^^^^^^^^^^^
  File "/<project-dir>/.venv/lib64/python3.11/site-packages/transformers/trainer.py", line 3504, in _prepare_input
    return type(data)(self._prepare_input(v) for v in data)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/<project-dir>/.venv/lib64/python3.11/site-packages/torch/nn/utils/rnn.py", line 93, in __new__
    *_packed_sequence_init_args(
     ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/<project-dir>/.venv/lib64/python3.11/site-packages/torch/nn/utils/rnn.py", line 254, in _packed_sequence_init_args
    assert isinstance(data, (list, tuple)) and len(data) == 2
                                               ^^^^^^^^^^^^^^
AssertionError: 
  In call to configurable 'main' (<function main at 0x148164687240>)

What happend?? If you look at the call stack at this point, it’s roughly the following:

  1. Trainer dispatches a batch (list of examples) to our collator
  2. Collator does its job, returning a dict where the value corresponding to input_ids is a PackedSequence
  3. The collated batch (now one dict) gets sent to _prepare_inputs, which then sends the batch to _prepare_input to map the inputs on the right devices
  4. Since the collated bunch can have arbitrary nesting (think a dict of list of tensors), _parepare_input recursively calls itself until it reaches the bottom level — tensors — and puts them to the right device. See below:
def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor, Any]:
    """
    Prepares one `data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors.
    """
    if isinstance(data, Mapping):
        return type(data)({k: self._prepare_input(v) for k, v in data.items()})
    elif isinstance(data, (tuple, list)):
        return type(data)(self._prepare_input(v) for v in data)
    elif isinstance(data, torch.Tensor):
        kwargs = {"device": self.args.device}
        if self.is_deepspeed_enabled and (torch.is_floating_point(data) or torch.is_complex(data)):
            # NLP models inputs are int/uint and those get adjusted to the right dtype of the
            # embedding. Other models such as wav2vec2's inputs are already float and thus
            # may need special handling to match the dtypes of the model
            kwargs.update({"dtype": self.accelerator.state.deepspeed_plugin.hf_ds_config.dtype()})
        return data.to(**kwargs)
    return data

If you look at the error message, PackedSequence’s constructor here is complaining that it didn’t get enough arguments: there needs to be at least 2, the padded tensor and lengths of each example. If you use a debugger you’ll also find that the data getting passed here is only one tensor. Why?

It turns out, PackedSequence inherits from NamedTuple, which in turn is a tuple!

$ python
Python 3.12.7 | packaged by conda-forge | (main, Oct  4 2024, 15:57:01) [Clang 17.0.6 ] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> from torch.nn.utils.rnn import PackedSequence
>>> a = PackedSequence(torch.tensor([[1, 2], [1, 1]]), torch.tensor([1, 2]))
>>> isinstance(a, tuple)
True

So in the second elif of _prepare_input, Huggingface trainer incorrectly iterates over it, thinking it’s a list of some sort, and then proceeds to attempt to instantiate a new PackedSequence. All the fuss because a slightly wrong heursitic.

Fixing the issue

Fixing the problem once we know what happened is fairly easy: a specific problem calls for a specific solution. Just define a mixin for Trainer classes that overrides default behavior if the data is a PackedSequence, and subsequenctly define new Trainer’s that inherits from the mixin.

If you have the exact issue, adding the codeblock below should be a simple fix (notice that it replaces Trainer and Seq2SeqTrainer by subclassing them).

from transformers import (
    Seq2SeqTrainer as HFSeq2SeqTrainer,  
    Trainer as HFTrainer,
)

class PrepareInputMixin:
    def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor, Any]:
        if isinstance(data, PackedSequence):
            return PackedSequence(self._prepare_input(data.data), data.batch_sizes, data.sorted_indices, data.unsorted_indices)
        else:
            return super()._prepare_input(data)

class Seq2SeqTrainer(PrepareInputMixin, HFSeq2SeqTrainer):
    pass

class Trainer(PrepareInputMixin, HFTrainer):
    pass

The code should now run! (or, at least, you should now see a different bug!)

Afterword

Only after I fixed this bug, I realized that this is totally preventable: an even better way to train RNNs is to do the packing (and unpacking) of tensors within the model’s forward method. This has a few advantages: it’s more compatible with huggingface’s api (you can, for example, sum attention_mask’s to infer the sequence length, or add an input_lengths argument), and it also makes embedding and encoder-decoder structures more intuitive. So something like the following

class RecurrentEncoder(PreTrainedModel):
    config_class = RecurrentEncoderConfig

    ... other methods ...

    def forward(
        self,
        input_ids: torch.LongTensor,
        input_lengths: Optional[torch.LongTensor],
        return_hidden_states: bool = False,
    ) -> BaseModelOutputWithNoAttention:

        embedded = self.embedding(input_ids)

        packed_embedded = torch.nn.utils.rnn.pack_padded_sequence(
            embedded, input_lengths, batch_first=True, enforce_sorted=True
        )

        packed_output, hidden = self.recurrent_unit(packed_embedded)

        hidden_states, _ = torch.nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True, padding_value=0.0)

        return BaseModelOutputWithNoAttention(
            hidden_states=hidden_states,
        )

I should probably tidy up and make a release for the reccurent models I wrote at some point.

Credits

Thumbnail image: Stanford CS 230
PackedSequence’s: This Github demo, and this Stackoverflow answer

Footnotes

  1. In addition, I liked this StackOverflow answer explaining how it works.↩︎