在 Sentence Transformer 微调期间发生内存不足 (OOM) 错误通常是因为您的 GPU 运行内存不足,无法存储模型、数据和中间计算结果。 这通常是由于以下三个因素之一造成的:批量大小过大、模型复杂度过高或数据处理效率低下。 例如,较大的批量大小需要更多的内存来存储激活和梯度,而具有许多层或高维嵌入的模型可能会超出 GPU 限制。 同样,预处理不当的数据(例如,过长的文本序列)会增加内存使用量。 解决这些问题需要在资源约束和训练效率之间取得平衡。
为了缓解 OOM 错误,首先要减少批量大小。 例如,如果您使用 32 的批量大小,请尝试将其降低到 16 或 8。 这直接减少了正向和反向传递所需的内存。 如果较小的批量大小会损害训练稳定性,请使用梯度累积(例如,在更新权重之前累积 4 个批次的梯度)。 另一种方法是混合精度训练,它对某些操作使用 16 位浮点数,从而将内存使用量减少近一半。 像 PyTorch 的 torch.cuda.amp
这样的库可以自动执行此操作。 此外,冻结模型的部分(例如,transformer 的较低层)以避免更新其参数,从而减少内存开销。
优化数据处理和模型配置。 使用 max_seq_length
将输入文本截断或填充到固定长度(例如,128 个 token),而不是动态调整到最长序列。 确保数据管道(通过 DataLoader
)使用高效的批处理,并避免内存中的冗余副本 - 设置 pin_memory=True
并调整 num_workers
。 对于模型调整,请考虑切换到较小的预训练架构(例如,all-MiniLM-L6-v2
而不是 all-mpnet-base-v2
)。 最后,使用诸如 nvidia-smi
或 PyTorch 的 torch.cuda.memory_summary()
之类的工具监视 GPU 使用情况,以识别瓶颈。 如果所有其他方法均失败,请使用具有更多内存的基于云的 GPU(例如,A100 而不是 T4)或跨多个 GPU 的分布式训练。