Skip to content

Paddle 的 to_tensor()方法和 torch 的 tensor()方法的行为不一致 #72484

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
YoctoHan opened this issue Apr 25, 2025 · 2 comments
Open
Assignees

Comments

@YoctoHan
Copy link

YoctoHan commented Apr 25, 2025

问题背景

在进行 Paddle 和 PyTorch 的精度对齐的过程中,为了模拟某些计算过程,需要使用随机生成的张量进行实验,为了消除随机性,采用 numpy 充当中间件,指定 seed 后生成随机 array ,再将其分别转换为两个框架的 tensor ,在此过程中发现了一个 Paddle 和 PyTorch 行为不一致的现象,现场如下:

Paddle 代码

import paddle
import numpy as np

# Function to save tensor as numpy array
def save_tensor_to_numpy(tensor_data, file_path):
    # Detach the tensor from the computation graph, move it to CPU, and convert to numpy
    tensor_data_cpu = tensor_data.astype(paddle.float32).detach().cpu().numpy()

    # Saving the numpy array to a file
    np.save(file_path, tensor_data_cpu)


np.random.seed(12333)
# Generate random matrices using numpy
A_np = np.random.normal(0.49, 0.0833, (1024, 6144))

A_tensor_1 = paddle.to_tensor(A_np).cuda()
A_tensor_1 = A_tensor_1.astype(paddle.bfloat16)

A_tensor_2 = paddle.to_tensor(A_np, dtype=paddle.bfloat16).cuda()

save_tensor_to_numpy(A_tensor_1, 'A_matrix_origin_1.npy')
save_tensor_to_numpy(A_tensor_2, 'A_matrix_origin_2.npy')

Pytorch 代码

import torch
import numpy as np

# Function to save tensor as numpy array
def save_tensor_to_numpy(tensor_data, file_path):
    # Detach the tensor from the computation graph, move it to CPU, and convert to numpy
    tensor_data_cpu = tensor_data.to(torch.float32).detach().cpu().numpy()

    # Saving the numpy array to a file
    np.save(file_path, tensor_data_cpu)


np.random.seed(12333)
# Generate random matrices using numpy
A_np = np.random.normal(0.49, 0.0833, (1024, 6144))

A_tensor_1 = torch.tensor(A_np).cuda()
A_tensor_1 = A_tensor_1.to(torch.bfloat16)

A_tensor_2 = torch.tensor(A_np, dtype=torch.bfloat16).cuda()

save_tensor_to_numpy(A_tensor_1, 'A_matrix_origin_1.npy')
save_tensor_to_numpy(A_tensor_2, 'A_matrix_origin_2.npy')

对比脚本

def compare(torch_tensor: np.ndarray, paddle_tensor: np.ndarray) -> dict:
    # Ensure both tensors have the same dtype and shape
    assert torch_tensor.dtype == paddle_tensor.dtype, \
        f"Data type mismatch: torch_tensor dtype={torch_tensor.dtype}, paddle_tensor dtype= {paddle_tensor.dtype}"
    assert torch_tensor.shape == paddle_tensor.shape, \
        f"Shape mismatch: torch_tensor shape={torch_tensor.shape}, paddle_tensor shape= {paddle_tensor.shape}"
    
    # Calculate mean and variance for both tensors
    torch_mean = np.mean(torch_tensor)
    paddle_mean = np.mean(paddle_tensor)
    torch_variance = np.var(torch_tensor)
    paddle_variance = np.var(paddle_tensor)
    
    # Calculate Mean Absolute Error (MAE)
    mae = np.mean(np.abs(torch_tensor.squeeze() - paddle_tensor.squeeze()))
    # Create a mask for values where both tensors have values greater than 1e-6
    mask = (np.abs(torch_tensor) > 1e-6) & (np.abs(paddle_tensor) > 1e-6)
    # Calculate MAE percentage for positions where both values are greater than 1e-6
    # Avoid division by zero with the mask and 1e-8 for safety
    mae_percentage = np.mean(np.abs(torch_tensor[mask] - paddle_tensor[mask]) / (np.abs(torch_tensor[mask]) + 1e-8)) * 100

    return {
        "torch_mean": torch_mean,
        "paddle_mean": paddle_mean,
        "torch_variance": torch_variance,
        "paddle_variance": paddle_variance,
        "mae": mae,
        "mae%": mae_percentage,
    }

实验结果

Matrix Torch Mean Paddle Mean Torch Variance Paddle Variance MAE MAE%
A_matrix_origin_1 0.49004754 0.49004754 0.0069367127 0.0069367127 0.0 0.0
A_matrix_origin_2 0.49004754 0.4886367 0.0069367127 0.006872112 0.001410692 0.282440148293972

结论

上述代码中的 A_npfloat64 类型的 numpy.arrary,使用两种不同的方式将其转为tensor,第一种方式为先转为float64类型的tensor,再将其转为目标数据类型bfloat16,第二种方式是直接转为目标数据类型bfloat16tensor,PyTorch 的两种方式结果保持了一致,但是 Paddle 的两种方式结果不一致,且对比分析应该是第二种转换方式存在精度误差。

原因分析

之前怀疑是由于 Paddle 这头的张量的 Place 不同导致的表现不一致,但是排查之后 Place 应该一直是一致的,复现场景如下:

Paddle精度分析

import paddle
import numpy as np

def compare(tensor_1: paddle.Tensor, tensor_2: paddle.Tensor) -> dict:
    # Ensure both tensors have the same dtype and shape
    assert tensor_1.dtype == tensor_2.dtype, \
        f"Data type mismatch: torch_tensor dtype={tensor_1.dtype}, paddle_tensor dtype= {tensor_2.dtype}"
    assert tensor_1.shape == tensor_2.shape, \
        f"Shape mismatch: torch_tensor shape={tensor_1.shape}, paddle_tensor shape= {tensor_2.shape}"
    assert tensor_1.place._equals(tensor_2.place), \
        f"Device mismatch: torch_tensor device={tensor_1.place}, paddle_tensor device= {tensor_2.place}"
    
    # Calculate mean and variance for both tensors
    tensor_1_mean = paddle.mean(tensor_1)
    tensor_2_mean = paddle.mean(tensor_2)
    tensor_1_variance = paddle.var(tensor_1)
    tensor_2_variance = paddle.var(tensor_2)
    
    # Calculate Mean Absolute Error (MAE)
    mae = paddle.mean(paddle.abs(tensor_1.squeeze() - tensor_2.squeeze()))
    # Create a mask for values where both tensors have values greater than 1e-6
    mask = (paddle.abs(tensor_1) > 1e-6) & (paddle.abs(tensor_2) > 1e-6)
    # Calculate MAE percentage for positions where both values are greater than 1e-6
    # Avoid division by zero with the mask and 1e-8 for safety
    mae_percentage = paddle.mean(paddle.abs(tensor_1[mask] - tensor_2[mask]) / (paddle.abs(tensor_1[mask]) + 1e-8)) * 100

    return {
        "tensor_1_mean": tensor_1_mean,
        "tensor_2_mean": tensor_2_mean,
        "tensor_1_variance": tensor_1_variance,
        "tensor_2_variance": tensor_2_variance,
        "mae": mae,
        "mae%": mae_percentage,
    }

np.random.seed(12333)
# Generate random matrices using numpy
A_np = np.random.normal(0.49, 0.0833, (1024, 6144))

A_tensor_1 = paddle.to_tensor(A_np).cuda()
A_tensor_1 = A_tensor_1.astype(paddle.bfloat16)

A_tensor_2 = paddle.to_tensor(A_np, dtype=paddle.bfloat16).cuda()

print(compare(A_tensor_1, A_tensor_2))

输出为:

{
    'tensor_1_mean': Tensor(shape=[], dtype=bfloat16, place=Place(gpu:0), stop_gradient=True, 0.49023438), 
    'tensor_2_mean': Tensor(shape=[], dtype=bfloat16, place=Place(gpu:0), stop_gradient=True, 0.48828125), 
    'tensor_1_variance': Tensor(shape=[], dtype=bfloat16, place=Place(gpu:0), stop_gradient=True, 0.00692749), 
    'tensor_2_variance': Tensor(shape=[], dtype=bfloat16, place=Place(gpu:0), stop_gradient=True, 0.00686646), 
    'mae': Tensor(shape=[], dtype=bfloat16, place=Place(gpu:0), stop_gradient=True, 0.00141144), 
    'mae%': Tensor(shape=[], dtype=bfloat16, place=Place(gpu:0), stop_gradient=True, 0.28320312)
}

PyTorch 精度分析

import torch
import numpy as np

def compare(tensor_1: torch.Tensor, tensor_2: torch.Tensor) -> dict:
    # Ensure both tensors have the same dtype and shape
    assert tensor_1.dtype == tensor_2.dtype, \
        f"Data type mismatch: torch_tensor dtype={tensor_1.dtype}, paddle_tensor dtype= {tensor_2.dtype}"
    assert tensor_1.shape == tensor_2.shape, \
        f"Shape mismatch: torch_tensor shape={tensor_1.shape}, paddle_tensor shape= {tensor_2.shape}"
    assert tensor_1.device == tensor_2.device, \
        f"Device mismatch: torch_tensor device={tensor_1.device}, paddle_tensor device= {tensor_2.device}"
    
    # Calculate mean and variance for both tensors
    torch_1_mean = torch.mean(tensor_1)
    torch_2_mean = torch.mean(tensor_2)
    torch_1_variance = torch.var(tensor_1)
    torch_2_variance = torch.var(tensor_2)
    
    # Calculate Mean Absolute Error (MAE)
    mae = torch.mean(torch.abs(tensor_1.squeeze() - tensor_2.squeeze()))
    # Create a mask for values where both tensors have values greater than 1e-6
    mask = (torch.abs(tensor_1) > 1e-6) & (torch.abs(tensor_2) > 1e-6)
    # Calculate MAE percentage for positions where both values are greater than 1e-6
    # Avoid division by zero with the mask and 1e-8 for safety
    mae_percentage = torch.mean(torch.abs(tensor_1[mask] - tensor_2[mask]) / (torch.abs(tensor_1[mask]) + 1e-8)) * 100

    return {
        "torch_1_mean": torch_1_mean,
        "torch_2_mean": torch_2_mean,
        "torch_1_variance": torch_1_variance,
        "torch_2_variance": torch_2_variance,
        "mae": mae,
        "mae%": mae_percentage,
    }

np.random.seed(12333)
# Generate random matrices using numpy
A_np = np.random.normal(0.49, 0.0833, (1024, 6144))

A_tensor_1 = torch.tensor(A_np).cuda()
A_tensor_1 = A_tensor_1.to(torch.bfloat16)

A_tensor_2 = torch.tensor(A_np, dtype=torch.bfloat16).cuda()

print(compare(A_tensor_1, A_tensor_2))

输出为:

{
    'torch_1_mean': tensor(0.4902, device='cuda:0', dtype=torch.bfloat16), 
    'torch_2_mean': tensor(0.4902, device='cuda:0', dtype=torch.bfloat16), 
    'torch_1_variance': tensor(0.0069, device='cuda:0', dtype=torch.bfloat16), 
    'torch_2_variance': tensor(0.0069, device='cuda:0', dtype=torch.bfloat16), 
    'mae': tensor(0., device='cuda:0', dtype=torch.bfloat16), 
    'mae%': tensor(0., device='cuda:0', dtype=torch.bfloat16)
}

感觉自己对于 Paddle 的 API 还是不够熟悉,会继续看看相关的源码分析一下,如果有哪位大佬知道问题的原因务请指出。

@ZhangX-21
Copy link
Contributor

@zhwesky2010 可以帮忙看看吗?

@zhwesky2010
Copy link
Contributor

zhwesky2010 commented Apr 25, 2025

@YoctoHan paddle.to_tensor(dtype='bfloat16') 实现有Bug,其他三种情况结果是正确的。
我们后续会修复这个问题,已记录。

@paddle-bot paddle-bot bot removed the type/question 用户提问 label Apr 25, 2025
@ZhangX-21 ZhangX-21 removed the Bug label Apr 25, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants