RNNs, Huggingface Trainer, and PackedSequence
’s
transformers
relies on a bunch of internal heuristics to maintain the façade of intuitiveness. One such heuristic breaks recurrent models, but (fortunately) there’s a simple fix
TL;DR
The Issue
When training with Huggingface Trainer
, If your data collator (data_collator
in Trainer
or collate_fn
for PyTorch DataLoader
) outputs a PackedSequence
for training an recurrent model (rnn/lstm/gru/who knows), there will be an assertion error assert isinstance(data, (list, tuple)) and len(data) == 2
triggered by line 254 of torch/nn/utils/rnn.py
The Solution
Huggingface trainer is sending the PackedSequence
to the correct device (e.g. GPU) incorrectly, you need to override a method, see this section for the code.
An Even Better Way
See afterwords. This issue is completely avoidable if you define your model class differently.
Full Story
Background
Recently I was training toy RNNs for a project. Writing a train
function with a for epoch in range(epochs)
in 2024 felt very wrong (and unnecessary), so I thought about making everything work with Trainer
of Huggingface Transformers. There are many good reasons for doing so (and it was a huge quality of life improvement!), I’ll list a few I’ve already used (and worked pretty much out of the box):
- saving/loading models with a one-liner
- adding/changing learning rate schedules
- generating with
.generate()
- doing simple hyperparameter sweeps (see Hyperparameter Search)
But things don’t always work, and when they don’t work, debugging Trainer
is frustrating — it does too many things and many such things rely on heuristics, below is an incomplete list of issues I already came across (and still remember debugging):
- it assumes the training target is a dict entry called
label
orlabels
, and will skipevaluate()
otherwise — but it won’t skip the eval loop entirely, instead it will only return eval metainfo such as runtime. The solution is to specifylabel-names
inTrainingArguments
. - it sends tensors to the model’s device by recursively iterating all inputs until it reaches the basic data elements (which should normally be some
Tensor
), but the heuristic for stopping this recursion ishasattr(data, "to")
(see source) — so if you define a class that contains your custom data, it absolutely cannot have ato
mathod that does something else.
And unfortunately one such heuristic breaks Pytorch RNNs. Here’s the premise:
- Transformers deal with variable-length sequences by padding inputs
- This is usually done by a
DataCollator
, which gets a list of dict and returns a dict of collated tensors (the action of “creating a batch” from samples) - Additionally,
attention_mask
helps model zero-out attention on padding tokens, so effectively the model does not “see” the padded tokens
- This is usually done by a
- RNNs also need to deal with variable-length input sequences
- It’s best if we also delegate this task to a data-collating function
- But RNN’s can’t deal with padding! There’s no trivial parallel for something like
attention_mask
, especially because Pytorch RNNs have are called with the entire sequence at once, as opposed to manually “unrolling” the model.
The solution of the above problem is to use a PackedSequence
. The underlying idea is quite simple: instead of viewing the input as a batch of sequences, view it as a sequence of batches, where each batch can have a different batch size. The figure below illustrates it quite well1:
PackedSequence
from @sgrvinodSo the solution seems simple enough: we just need to define a data collating function that creates a PackedSequence
from a list of samples, like the one below, and life’s good, right?
from torch.nn.utils.rnn import pack_padded_sequence
def collate_fn(examples):
# first collate, e.g. using torch.stack
= {k: torch.stack([e[k] for e in examples]) for k in examples[0]}
examples
# assume you previously tokenized the input with a transformers tokenizer
= torch.sum(tokenized_input["attention_mask"], dim=1)
input_lengths "input_ids"] = pack_padded_sequence(
examples["input_ids"],
examples[
input_lengths, =True,
batch_first=False
encforce_sorted
)
return examples
Why Trainer
cannot process PackedSequence
’s
If only life is so easy — I invite you to re-read the title of this post and realize that we’ve only just gotten to the issue. If you tried trainer.train()
with a collate_fn
like above, you will get the following cryptic error message:
0%| | 0/12520 [00:00<?, ?it/s]
Traceback (most recent call last):
File "/<project-dir>/src/train.py", line 243, in <module>
main()
File "/<project-dir>/src/train.py", line 173, in main
trainer.train()
File "/<project-dir>/.venv/lib64/python3.11/site-packages/transformers/trainer.py", line 2123, in train
return inner_training_loop(
^^^^^^^^^^^^^^^^^^^^
File "/<project-dir>/.venv/lib64/python3.11/site-packages/transformers/trainer.py", line 2481, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/<project-dir>/.venv/lib64/python3.11/site-packages/transformers/trainer.py", line 3573, in training_step
inputs = self._prepare_inputs(inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/<project-dir>/.venv/lib64/python3.11/site-packages/transformers/trainer.py", line 3520, in _prepare_inputs
inputs = self._prepare_input(inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/<project-dir>/.venv/lib64/python3.11/site-packages/transformers/trainer.py", line 3502, in _prepare_input
return type(data)({k: self._prepare_input(v) for k, v in data.items()})
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/<project-dir>/.venv/lib64/python3.11/site-packages/transformers/trainer.py", line 3502, in <dictcomp>
return type(data)({k: self._prepare_input(v) for k, v in data.items()})
^^^^^^^^^^^^^^^^^^^^^^
File "/<project-dir>/.venv/lib64/python3.11/site-packages/transformers/trainer.py", line 3504, in _prepare_input
return type(data)(self._prepare_input(v) for v in data)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/<project-dir>/.venv/lib64/python3.11/site-packages/torch/nn/utils/rnn.py", line 93, in __new__
*_packed_sequence_init_args(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/<project-dir>/.venv/lib64/python3.11/site-packages/torch/nn/utils/rnn.py", line 254, in _packed_sequence_init_args
assert isinstance(data, (list, tuple)) and len(data) == 2
^^^^^^^^^^^^^^
AssertionError:
In call to configurable 'main' (<function main at 0x148164687240>)
What happend?? If you look at the call stack at this point, it’s roughly the following:
Trainer
dispatches a batch (list of examples) to our collator- Collator does its job, returning a
dict
where the value corresponding toinput_ids
is aPackedSequence
- The collated batch (now one
dict
) gets sent to_prepare_inputs
, which then sends the batch to_prepare_input
to map the inputs on the right devices - Since the collated bunch can have arbitrary nesting (think a dict of list of tensors),
_parepare_input
recursively calls itself until it reaches the bottom level — tensors — and puts them to the right device. See below:
def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor, Any]:
"""
Prepares one `data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors.
"""
if isinstance(data, Mapping):
return type(data)({k: self._prepare_input(v) for k, v in data.items()})
elif isinstance(data, (tuple, list)):
return type(data)(self._prepare_input(v) for v in data)
elif isinstance(data, torch.Tensor):
= {"device": self.args.device}
kwargs if self.is_deepspeed_enabled and (torch.is_floating_point(data) or torch.is_complex(data)):
# NLP models inputs are int/uint and those get adjusted to the right dtype of the
# embedding. Other models such as wav2vec2's inputs are already float and thus
# may need special handling to match the dtypes of the model
"dtype": self.accelerator.state.deepspeed_plugin.hf_ds_config.dtype()})
kwargs.update({return data.to(**kwargs)
return data
If you look at the error message, PackedSequence
’s constructor here is complaining that it didn’t get enough arguments: there needs to be at least 2, the padded tensor and lengths of each example. If you use a debugger you’ll also find that the data
getting passed here is only one tensor. Why?
It turns out, PackedSequence
inherits from NamedTuple
, which in turn is a tuple
!
$ python
Python 3.12.7 | packaged by conda-forge | (main, Oct 4 2024, 15:57:01) [Clang 17.0.6 ] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> from torch.nn.utils.rnn import PackedSequence
>>> a = PackedSequence(torch.tensor([[1, 2], [1, 1]]), torch.tensor([1, 2]))
>>> isinstance(a, tuple)
True
So in the second elif
of _prepare_input
, Huggingface trainer incorrectly iterates over it, thinking it’s a list of some sort, and then proceeds to attempt to instantiate a new PackedSequence
. All the fuss because a slightly wrong heursitic.
Fixing the issue
Fixing the problem once we know what happened is fairly easy: a specific problem calls for a specific solution. Just define a mixin for Trainer
classes that overrides default behavior if the data is a PackedSequence
, and subsequenctly define new Trainer
’s that inherits from the mixin.
If you have the exact issue, adding the codeblock below should be a simple fix (notice that it replaces Trainer
and Seq2SeqTrainer
by subclassing them).
from transformers import (
as HFSeq2SeqTrainer,
Seq2SeqTrainer as HFTrainer,
Trainer
)
class PrepareInputMixin:
def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor, Any]:
if isinstance(data, PackedSequence):
return PackedSequence(self._prepare_input(data.data), data.batch_sizes, data.sorted_indices, data.unsorted_indices)
else:
return super()._prepare_input(data)
class Seq2SeqTrainer(PrepareInputMixin, HFSeq2SeqTrainer):
pass
class Trainer(PrepareInputMixin, HFTrainer):
pass
The code should now run! (or, at least, you should now see a different bug!)
Afterword
Only after I fixed this bug, I realized that this is totally preventable: an even better way to train RNNs is to do the packing (and unpacking) of tensors within the model’s forward
method. This has a few advantages: it’s more compatible with huggingface’s api (you can, for example, sum attention_mask
’s to infer the sequence length, or add an input_lengths
argument), and it also makes embedding and encoder-decoder structures more intuitive. So something like the following
class RecurrentEncoder(PreTrainedModel):
= RecurrentEncoderConfig
config_class
... other methods ...
def forward(
self,
input_ids: torch.LongTensor,
input_lengths: Optional[torch.LongTensor],bool = False,
return_hidden_states: -> BaseModelOutputWithNoAttention:
)
= self.embedding(input_ids)
embedded
= torch.nn.utils.rnn.pack_padded_sequence(
packed_embedded =True, enforce_sorted=True
embedded, input_lengths, batch_first
)
= self.recurrent_unit(packed_embedded)
packed_output, hidden
= torch.nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True, padding_value=0.0)
hidden_states, _
return BaseModelOutputWithNoAttention(
=hidden_states,
hidden_states )
I should probably tidy up and make a release for the reccurent models I wrote at some point.
Credits
Thumbnail image: Stanford CS 230
PackedSequence
’s: This Github demo, and this Stackoverflow answer
Footnotes
In addition, I liked this StackOverflow answer explaining how it works.↩︎