基于 PyTorch Snapshot 数据分析内存问题案例

案例背景

在训练场景中,模型在较小 batch size 下可以正常运行,但当 batch size 增大到某个阈值后却突然出现 OOM,这类问题通常不只是“显存不够”这么简单。实际分析时,除了关注整体显存峰值,还需要结合 profiling 数据判断是否存在异常的算子 workspace 申请、shape 放大、或者某个阶段的临时内存被成倍拉高。

本案例以一个 batch size 从 2 增加到 3 后触发 OOM 的问题为例,介绍如何先通过 profiling 数据观察显存占用趋势,再结合异常算子和调用栈信息定位到具体原因。

分析方法论

针对这类“batch size 增大后 OOM”的问题,建议按照以下顺序分析:

  1. 先看整体显存峰值:确认问题是单次峰值过高,还是某种持续累积导致的 OOM。

  2. 再看静态显存占用:区分基础占用、PTA 统计占用和算子临时申请,判断是否存在明显异常。

  3. 重点检查 Reserved 与 Allocated 差值:如果两者之间存在较大 gap,需要继续排查 workspace、碎片化或未统计到的组件占用。

  4. 回到异常算子和 timeline:通过 profiler 中的大内存申请事件,定位具体算子及其对应的 shape。

  5. 结合源码确认根因:最后检查算子前后的 shape 变化和类型转换顺序,判断是否存在不必要的大内存放大。

这类问题通常不是“模型整体太大”,而是某个中间张量在特定顺序下被放大,导致 workspace 成倍增加。

问题现象

某训练场景在 bs=2 时可以正常训练,但当 batch size 提升到 bs=3 后出现 OOM。按照模型规模和历史经验判断,该配置理论上应当可以承载 bs=3,因此需要进一步分析是否存在异常显存申请。

OOM 日志如下:

File "<path-to-py-file>/transformer.py", line xx, in forward
    attn_mask = attn_mask.bool()
RuntimeError: NPU out of memory. Tried to allocate 48.78 GiB (NPU 0; 60.96 GiB total capacity; 13.22 GiB already allocated; 13.22 GiB current active; 40.54 GiB free; 54.86 GiB allowed; 13.86 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.

从报错信息看,OOM 发生在 attn_mask.bool() 这样的类型转换操作上,而且申请量达到了 48GiB 级别,明显异常。

信息收集

在继续分析前,先补充两个关键信息:

  1. 训练任务开启了 task_queue_enable=2

  2. 需要对比 bs=2bs=3 两种场景下的 profiling 数据。

这里之所以要特别关注 task_queue_enable=2,是因为该模式下 profiling 中的 PTA reserved / allocated 统计不一定包含 workspace,需要结合其它视图继续确认是否存在独立的 workspace 占用。

Profiling 数据分析

先看 bs=1 和 bs=2 的基线

先从较小 batch size 的 profiling 数据入手,可以帮助确认问题不是由模型本身的基础占用持续增长引起的。

bs=1 时,profiling 中 observed 的 allocated 峰值约为 21G,其中静态显存占用大约 11G。

bs=2 时,静态显存占用仍然约为 11G,但 allocated 峰值上升到约 26G。这个结果说明模型的基础占用没有明显异常,但随着 batch size 增大,峰值显存开始快速上升。

此时在 profiling 图中还能看到一个值得注意的现象:APP ReservedOperator Reserved 之间存在较大的 gap,峰值接近 20G+。这通常意味着除了已知的算子申请之外,还存在其它异常占用来源,需要继续向下追查。

再看 bs=3 的 OOM 对应特征

当 batch size 增加到 bs=3 后,训练直接 OOM,且报错点出现在 attn_mask.bool() 这一类 cast 操作上。由于该位置本身不是大模型参数加载或长期缓存路径,而是一个中间张量类型转换点,因此更像是一次临时 workspace 申请失败。

结合 bs=2 的 profiling 结果和 bs=3 的 OOM 日志,可以初步判断:

  • 问题不是模型权重或静态显存本身爆掉;

  • 问题也不像是 CANN 组件独占了大量显存;

  • 更可能是某个算子在特定 shape 下申请了极大的 workspace。

排除非 PTA 占用

为了确认异常占用是否来自 PTA 之外的组件,继续查看 npu_module_mem。结果显示,CANN 组件占用不到 4G,整体并不异常。

接着查看算子申请表,也没有发现明显“未统计进 PTA”的大额占用。因此可以进一步缩小范围:问题更可能出在某个算子的临时申请,而不是外部组件长期占用。

关注 Workspace 统计是否缺失

这里还有一个关键线索:当 task_queue_enable=2 时,profiling 中的 PTA reserved / allocated 不一定覆盖所有 workspace 行为,而应该有单独的 workspace 统计项。若图中看不到明显的 Workspace 占用,就需要怀疑某个算子内部发生了非常大的临时申请,但没有在常规视图中直观体现出来。

这一步的意义在于:不要只盯着整体 reserved/allocated 数字,还要结合具体算子和调用栈去找“谁在申请这块内存”。

通过算子排序和 timeline 定位异常申请

为了进一步确认问题,先将 task_queue_enable 调整为 1,并重新采集 bs=2 的 profiling 数据。这样可以更清楚地观察单个算子申请和峰值尖刺。

在缩放后的 profiling 图中,按算子申请大小从大到小排序后,可以发现一个异常的 cast 算子,其总内存申请达到 32.5G。随后跳转到 timeline 查看对应 shape,可以发现该申请量与输入 shape 高度相关。

根据计算公式:

2 * 24 * 9536 * 9536 * 8 / 1024 / 1024 / 1024 = 32.5G

当 batch size 提升到 3 时,这个值会进一步上升到约 48G,与 OOM 日志中的申请量基本一致。

这说明问题并不是随机波动,而是某个中间张量的 shape 在进入 cast 时已经被放大到了极大规模。

分析结论

结合 profiling 数据和 OOM 日志可以得到结论:前向过程中,attention mask 先执行了 expand,将 shape 扩大到一个非常大的维度,然后再执行 bool() 转换,导致 cast 阶段申请了巨大的 workspace。随着 batch size 增大,这个 workspace 按 shape 线性膨胀,最终在 bs=3 时触发 OOM。

换句话说,问题的本质不是模型参数过大,而是中间张量在错误的顺序下被放大,使得一个本应较轻量的类型转换操作变成了超大内存申请。

修复建议

修复思路很直接:调整操作顺序,先执行 bool(),再执行 expand()

原始逻辑:

attn_mask = attn_mask.expand(batch_size, self.num_attention_heads, total_seq_len, -1)
attn_mask = attn_mask.bool()

建议调整为:

attn_mask = attn_mask.bool()
attn_mask = attn_mask.expand(batch_size, self.num_attention_heads, total_seq_len, -1)

这样可以避免在超大 shape 上执行 cast,从而显著降低 workspace 占用。

小结

这类问题的定位关键在于:先通过 profiling 识别峰值异常,再结合算子排序和 timeline 找到大内存申请发生的位置,最后回到源码检查 shape 放大顺序。对于 batch size 增大后突然 OOM 的场景,优先排查中间张量是否在类型转换、reshape、expand、broadcast 等操作前被放大,往往能更快定位根因。