Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix fp16 dtype checking for argmin op #1

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from

Conversation

hua-zi
Copy link
Owner

@hua-zi hua-zi commented Feb 23, 2023

PR types

Bug fixes

PR changes

APIs

Describe

问题描述:在静态图模式下,输入为FP16类型时,argmin会报TypeError。

import paddle
import numpy as np

paddle.enable_static()

x_np = np.random.random((10, 16)).astype('float16')
x = paddle.static.data(shape=[10, 16], name='x', dtype='float16')
out = paddle.argmin(x)

exe = paddle.static.Executor()
exe.run(paddle.static.default_startup_program())
out = exe.run(feed={'x': x_np},
            fetch_list=[out])

报错:

Traceback (most recent call last):
  File ".\test.py", line 11, in <module>
    out = paddle.argmin(x)
  File "c:\app\anaconda3\envs\pytorch\lib\site-packages\paddle\tensor\search.py", line 271, in argmin
    check_variable_and_dtype(
  File "c:\app\anaconda3\envs\pytorch\lib\site-packages\paddle\fluid\data_feeder.py", line 86, in check_variable_and_dtype  
    check_dtype(input.dtype, input_name, expected_dtype, op_name, extra_message)
  File "c:\app\anaconda3\envs\pytorch\lib\site-packages\paddle\fluid\data_feeder.py", line 147, in check_dtype
    raise TypeError(
TypeError: The data type of 'x' in paddle.argmin must be ['float32', 'float64', 'int16', 'int32', 'int64', 'uint8'], but received float16.

修复方案:在argmin API静态图模式下的类型检查中增加fp16支持

# PR types
Bug fixes

# PR changes
APIs

# Describe
# 问题描述:在静态图模式下,输入为FP16类型时,argmin会报TypeError。
```
import paddle
import numpy as np

paddle.enable_static()

x_np = np.random.random((10, 16)).astype('float16')
x = paddle.static.data(shape=[10, 16], name='x', dtype='float16')
out = paddle.argmin(x)

exe = paddle.static.Executor()
exe.run(paddle.static.default_startup_program())
out = exe.run(feed={'x': x_np},
            fetch_list=[out])
```
报错:
```
Traceback (most recent call last):
  File ".\test.py", line 11, in <module>
    out = paddle.argmin(x)
  File "c:\app\anaconda3\envs\pytorch\lib\site-packages\paddle\tensor\search.py", line 271, in argmin
    check_variable_and_dtype(
  File "c:\app\anaconda3\envs\pytorch\lib\site-packages\paddle\fluid\data_feeder.py", line 86, in check_variable_and_dtype  
    check_dtype(input.dtype, input_name, expected_dtype, op_name, extra_message)
  File "c:\app\anaconda3\envs\pytorch\lib\site-packages\paddle\fluid\data_feeder.py", line 147, in check_dtype
    raise TypeError(
TypeError: The data type of 'x' in paddle.argmin must be ['float32', 'float64', 'int16', 'int32', 'int64', 'uint8'], but received float16.
```
修复方案:在argmin API静态图模式下的类型检查中增加fp16支持
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant