Skip to content

fused_rms_norm 在h卡上训练导致 nullptr 问题 #72452

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
westfish opened this issue Apr 24, 2025 · 1 comment
Open

fused_rms_norm 在h卡上训练导致 nullptr 问题 #72452

westfish opened this issue Apr 24, 2025 · 1 comment
Assignees

Comments

@westfish
Copy link
Contributor

bug描述 Describe the Bug

模型中包含fused_rms_norm实现的RMSNorm,在h卡上训练时loss.backward()会出现nullptr问题,如果裸写的话则正常

class RMSNorm(nn.Layer):
    def __init__(self, dim, epsilon: float, elementwise_affine: bool = True, bias: bool = False):
        super().__init__()
        self.epsilon = epsilon
        self.elementwise_affine = elementwise_affine

        self.dim = dim

        if elementwise_affine:
            self.weight = self.create_parameter(
                shape=[dim], default_initializer=nn.initializer.Constant(1.0)
            )
            if bias:
                self.bias = self.create_parameter(
                    shape=[dim], default_initializer=nn.initializer.Constant(0.0)
                )
            else:
                self.bias = None
        else:
            self.weight, self.bias = None, None

    def forward(self, hidden_states: paddle.Tensor, begin_norm_axis=None) -> paddle.Tensor:
        if str2bool(os.getenv("FLAGS_use_fused_rmsnorm", "no")):
            x_dtype = hidden_states.dtype
            variance = paddle.mean(paddle.pow(hidden_states.astype("float32"), 2), axis=-1, keepdim=True)
            hidden_states = hidden_states * paddle.rsqrt(variance + self.epsilon)

            if self.weight is not None:
                if self.weight.dtype in [paddle.float16, paddle.bfloat16]:
                    hidden_states = paddle.cast(hidden_states, self.weight.dtype)
                hidden_states = hidden_states * self.weight
            if self.bias is not None:
                hidden_states = hidden_states + self.bias

            if not self.elementwise_affine and x_dtype in [paddle.float16, paddle.bfloat16]:
                hidden_states = paddle.cast(hidden_states, x_dtype)

            return hidden_states
        else:
            return paddle.incubate.nn.functional.fused_rms_norm(
                x=hidden_states,
                norm_weight=self.weight,
                norm_bias=None,
                epsilon=self.epsilon,
                begin_norm_axis=len(hidden_states.shape)-1 if begin_norm_axis is None else begin_norm_axis,
            )[0]

Steps:   0%|          | 0/500 [00:00<?, ?it/s]W0423 20:00:01.601589  2777 backward.cc:437] While running Node (RmsNormGradNode) raises an EnforceNotMet exception
DEBUG loss: Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=False,
       0.57071793) <class 'paddle.Tensor'>
Has trainable params: True
Traceback (most recent call last):
  File "/root/paddlejob/workspace/env_run/zhangxu/flux/PaddleMIX-westfish_flux_lora/ppdiffusers/examples/dreambooth/train_dreambooth_lora_flux.py", line 1790, in <module>
    main(args)
  File "/root/paddlejob/workspace/env_run/zhangxu/flux/PaddleMIX-westfish_flux_lora/ppdiffusers/examples/dreambooth/train_dreambooth_lora_flux.py", line 1641, in main
    loss.backward()
  File "/root/miniconda3/envs/zx2/lib/python3.10/site-packages/decorator.py", line 235, in fun
    return caller(func, *(extras + args), **kw)
  File "/root/miniconda3/envs/zx2/lib/python3.10/site-packages/paddle/base/wrapped_decorator.py", line 40, in __impl__
    return wrapped_func(*args, **kwargs)
  File "/root/miniconda3/envs/zx2/lib/python3.10/site-packages/paddle/base/framework.py", line 704, in __impl__
    return func(*args, **kwargs)
  File "/root/miniconda3/envs/zx2/lib/python3.10/site-packages/paddle/base/dygraph/tensor_patch_methods.py", line 357, in backward
    core.eager.run_backward([self], grad_tensor, retain_graph)
ValueError: (InvalidArgument) Required tensor shall not be nullptr, but received nullptr.
  [Hint: tensor should not be null.] (at ../paddle/phi/core/device_context.cc:141)

Traceback (most recent call last):
  File "/root/paddlejob/workspace/env_run/zhangxu/flux/PaddleMIX-westfish_flux_lora/ppdiffusers/examples/dreambooth/train_dreambooth_lora_flux.py", line 1790, in <module>
    main(args)
  File "/root/paddlejob/workspace/env_run/zhangxu/flux/PaddleMIX-westfish_flux_lora/ppdiffusers/examples/dreambooth/train_dreambooth_lora_flux.py", line 1641, in main
    loss.backward()
  File "/root/miniconda3/envs/zx2/lib/python3.10/site-packages/decorator.py", line 235, in fun
    return caller(func, *(extras + args), **kw)
  File "/root/miniconda3/envs/zx2/lib/python3.10/site-packages/paddle/base/wrapped_decorator.py", line 40, in __impl__
    return wrapped_func(*args, **kwargs)
  File "/root/miniconda3/envs/zx2/lib/python3.10/site-packages/paddle/base/framework.py", line 704, in __impl__
    return func(*args, **kwargs)
  File "/root/miniconda3/envs/zx2/lib/python3.10/site-packages/paddle/base/dygraph/tensor_patch_methods.py", line 357, in backward
    core.eager.run_backward([self], grad_tensor, retain_graph)
ValueError: (InvalidArgument) Required tensor shall not be nullptr, but received nullptr.
  [Hint: tensor should not be null.] (at ../paddle/phi/core/device_context.cc:141)

其他补充信息 Additional Supplementary Information

No response

@liuruyan
Copy link
Contributor

您好,我们会安排相关同学跟进解决这个问题:可以提供下复现信息吗?

  1. 本机环境(windows/linux/max)
  2. 是否docker进行部署安装
  3. python环境及版本、cuda版本
  4. 报错模型名称及复现方法

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants