You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
importpaddleimportnumpyasnp# Function to save tensor as numpy arraydefsave_tensor_to_numpy(tensor_data, file_path):
# Detach the tensor from the computation graph, move it to CPU, and convert to numpytensor_data_cpu=tensor_data.astype(paddle.float32).detach().cpu().numpy()
# Saving the numpy array to a filenp.save(file_path, tensor_data_cpu)
np.random.seed(12333)
# Generate random matrices using numpyA_np=np.random.normal(0.49, 0.0833, (1024, 6144))
A_tensor_1=paddle.to_tensor(A_np).cuda()
A_tensor_1=A_tensor_1.astype(paddle.bfloat16)
A_tensor_2=paddle.to_tensor(A_np, dtype=paddle.bfloat16).cuda()
save_tensor_to_numpy(A_tensor_1, 'A_matrix_origin_1.npy')
save_tensor_to_numpy(A_tensor_2, 'A_matrix_origin_2.npy')
Pytorch 代码
importtorchimportnumpyasnp# Function to save tensor as numpy arraydefsave_tensor_to_numpy(tensor_data, file_path):
# Detach the tensor from the computation graph, move it to CPU, and convert to numpytensor_data_cpu=tensor_data.to(torch.float32).detach().cpu().numpy()
# Saving the numpy array to a filenp.save(file_path, tensor_data_cpu)
np.random.seed(12333)
# Generate random matrices using numpyA_np=np.random.normal(0.49, 0.0833, (1024, 6144))
A_tensor_1=torch.tensor(A_np).cuda()
A_tensor_1=A_tensor_1.to(torch.bfloat16)
A_tensor_2=torch.tensor(A_np, dtype=torch.bfloat16).cuda()
save_tensor_to_numpy(A_tensor_1, 'A_matrix_origin_1.npy')
save_tensor_to_numpy(A_tensor_2, 'A_matrix_origin_2.npy')
对比脚本
defcompare(torch_tensor: np.ndarray, paddle_tensor: np.ndarray) ->dict:
# Ensure both tensors have the same dtype and shapeasserttorch_tensor.dtype==paddle_tensor.dtype, \
f"Data type mismatch: torch_tensor dtype={torch_tensor.dtype}, paddle_tensor dtype= {paddle_tensor.dtype}"asserttorch_tensor.shape==paddle_tensor.shape, \
f"Shape mismatch: torch_tensor shape={torch_tensor.shape}, paddle_tensor shape= {paddle_tensor.shape}"# Calculate mean and variance for both tensorstorch_mean=np.mean(torch_tensor)
paddle_mean=np.mean(paddle_tensor)
torch_variance=np.var(torch_tensor)
paddle_variance=np.var(paddle_tensor)
# Calculate Mean Absolute Error (MAE)mae=np.mean(np.abs(torch_tensor.squeeze() -paddle_tensor.squeeze()))
# Create a mask for values where both tensors have values greater than 1e-6mask= (np.abs(torch_tensor) >1e-6) & (np.abs(paddle_tensor) >1e-6)
# Calculate MAE percentage for positions where both values are greater than 1e-6# Avoid division by zero with the mask and 1e-8 for safetymae_percentage=np.mean(np.abs(torch_tensor[mask] -paddle_tensor[mask]) / (np.abs(torch_tensor[mask]) +1e-8)) *100return {
"torch_mean": torch_mean,
"paddle_mean": paddle_mean,
"torch_variance": torch_variance,
"paddle_variance": paddle_variance,
"mae": mae,
"mae%": mae_percentage,
}
之前怀疑是由于 Paddle 这头的张量的 Place 不同导致的表现不一致,但是排查之后 Place 应该一直是一致的,复现场景如下:
Paddle精度分析
importpaddleimportnumpyasnpdefcompare(tensor_1: paddle.Tensor, tensor_2: paddle.Tensor) ->dict:
# Ensure both tensors have the same dtype and shapeasserttensor_1.dtype==tensor_2.dtype, \
f"Data type mismatch: torch_tensor dtype={tensor_1.dtype}, paddle_tensor dtype= {tensor_2.dtype}"asserttensor_1.shape==tensor_2.shape, \
f"Shape mismatch: torch_tensor shape={tensor_1.shape}, paddle_tensor shape= {tensor_2.shape}"asserttensor_1.place._equals(tensor_2.place), \
f"Device mismatch: torch_tensor device={tensor_1.place}, paddle_tensor device= {tensor_2.place}"# Calculate mean and variance for both tensorstensor_1_mean=paddle.mean(tensor_1)
tensor_2_mean=paddle.mean(tensor_2)
tensor_1_variance=paddle.var(tensor_1)
tensor_2_variance=paddle.var(tensor_2)
# Calculate Mean Absolute Error (MAE)mae=paddle.mean(paddle.abs(tensor_1.squeeze() -tensor_2.squeeze()))
# Create a mask for values where both tensors have values greater than 1e-6mask= (paddle.abs(tensor_1) >1e-6) & (paddle.abs(tensor_2) >1e-6)
# Calculate MAE percentage for positions where both values are greater than 1e-6# Avoid division by zero with the mask and 1e-8 for safetymae_percentage=paddle.mean(paddle.abs(tensor_1[mask] -tensor_2[mask]) / (paddle.abs(tensor_1[mask]) +1e-8)) *100return {
"tensor_1_mean": tensor_1_mean,
"tensor_2_mean": tensor_2_mean,
"tensor_1_variance": tensor_1_variance,
"tensor_2_variance": tensor_2_variance,
"mae": mae,
"mae%": mae_percentage,
}
np.random.seed(12333)
# Generate random matrices using numpyA_np=np.random.normal(0.49, 0.0833, (1024, 6144))
A_tensor_1=paddle.to_tensor(A_np).cuda()
A_tensor_1=A_tensor_1.astype(paddle.bfloat16)
A_tensor_2=paddle.to_tensor(A_np, dtype=paddle.bfloat16).cuda()
print(compare(A_tensor_1, A_tensor_2))
importtorchimportnumpyasnpdefcompare(tensor_1: torch.Tensor, tensor_2: torch.Tensor) ->dict:
# Ensure both tensors have the same dtype and shapeasserttensor_1.dtype==tensor_2.dtype, \
f"Data type mismatch: torch_tensor dtype={tensor_1.dtype}, paddle_tensor dtype= {tensor_2.dtype}"asserttensor_1.shape==tensor_2.shape, \
f"Shape mismatch: torch_tensor shape={tensor_1.shape}, paddle_tensor shape= {tensor_2.shape}"asserttensor_1.device==tensor_2.device, \
f"Device mismatch: torch_tensor device={tensor_1.device}, paddle_tensor device= {tensor_2.device}"# Calculate mean and variance for both tensorstorch_1_mean=torch.mean(tensor_1)
torch_2_mean=torch.mean(tensor_2)
torch_1_variance=torch.var(tensor_1)
torch_2_variance=torch.var(tensor_2)
# Calculate Mean Absolute Error (MAE)mae=torch.mean(torch.abs(tensor_1.squeeze() -tensor_2.squeeze()))
# Create a mask for values where both tensors have values greater than 1e-6mask= (torch.abs(tensor_1) >1e-6) & (torch.abs(tensor_2) >1e-6)
# Calculate MAE percentage for positions where both values are greater than 1e-6# Avoid division by zero with the mask and 1e-8 for safetymae_percentage=torch.mean(torch.abs(tensor_1[mask] -tensor_2[mask]) / (torch.abs(tensor_1[mask]) +1e-8)) *100return {
"torch_1_mean": torch_1_mean,
"torch_2_mean": torch_2_mean,
"torch_1_variance": torch_1_variance,
"torch_2_variance": torch_2_variance,
"mae": mae,
"mae%": mae_percentage,
}
np.random.seed(12333)
# Generate random matrices using numpyA_np=np.random.normal(0.49, 0.0833, (1024, 6144))
A_tensor_1=torch.tensor(A_np).cuda()
A_tensor_1=A_tensor_1.to(torch.bfloat16)
A_tensor_2=torch.tensor(A_np, dtype=torch.bfloat16).cuda()
print(compare(A_tensor_1, A_tensor_2))
问题背景
在进行 Paddle 和 PyTorch 的精度对齐的过程中,为了模拟某些计算过程,需要使用随机生成的张量进行实验,为了消除随机性,采用 numpy 充当中间件,指定 seed 后生成随机 array ,再将其分别转换为两个框架的 tensor ,在此过程中发现了一个 Paddle 和 PyTorch 行为不一致的现象,现场如下:
Paddle 代码
Pytorch 代码
对比脚本
实验结果
结论
上述代码中的
A_np
为float64
类型的numpy.arrary
,使用两种不同的方式将其转为tensor
,第一种方式为先转为float64
类型的tensor
,再将其转为目标数据类型bfloat16
,第二种方式是直接转为目标数据类型bfloat16
的tensor
,PyTorch 的两种方式结果保持了一致,但是 Paddle 的两种方式结果不一致,且对比分析应该是第二种转换方式存在精度误差。原因分析
之前怀疑是由于 Paddle 这头的张量的 Place 不同导致的表现不一致,但是排查之后 Place 应该一直是一致的,复现场景如下:
Paddle精度分析
输出为:
PyTorch 精度分析
输出为:
感觉自己对于 Paddle 的 API 还是不够熟悉,会继续看看相关的源码分析一下,如果有哪位大佬知道问题的原因务请指出。
The text was updated successfully, but these errors were encountered: