Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Doc] Remove the limit on kernel argument size #8408

Merged
merged 7 commits into from
Nov 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 24 additions & 10 deletions docs/lang/articles/kernels/kernel_function.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,16 @@ You can also use argument packs if you want to pass many arguments to a kernel.

When defining the arguments of a kernel in Taichi, please make sure that each of the arguments has type hint.

:::caution WARNING

We have removed the limit on the size of the argument in Taichi v1.7.0.
However, please keep in mind that the size of arguments in a kernel should be small.
When you pass a large argument to a kernel, the compile time will increase significantly.
If you find yourself passing a large argument to a kernel, you may want to consider using a `ti.field()` or a `ti.types.ndarray()` instead.

We have not tested arguments with a very large size (>4KB), and we do not guarantee that it will work properly.
:::

### Return value

In Taichi, a kernel can have multiple return values, and the return values can either be a scalar, `ti.types.matrix()`, or `ti.types.vector()`.
Expand Down Expand Up @@ -188,6 +198,16 @@ When defining the return value of a kernel in Taichi, it is important to follow
- Use type hint to specify the return value of a kernel.
- Make sure that you have at most one return statement in a kernel.

:::caution WARNING

We have removed the limit on the size of the return values in Taichi v1.7.0.
However, please keep in mind that the size of return values in a kernel should be small.
When the return value of the kernel is very large, the compile time will increase significantly.
If you find your return value is very large, you may want to consider using a `ti.field()` or a `ti.types.ndarray()` instead.

We have not tested return values with a very large size (>4KB), and we do not guarantee that it will work properly.
:::

#### Automatic type cast

In the following code snippet, the return value is automatically cast into the hinted type:
Expand Down Expand Up @@ -264,11 +284,8 @@ All Taichi inline functions are force-inlined. This means that if you call a Tai

### Arguments

A Taichi inline function can accept multiple arguments, which may include scalar, `ti.types.matrix()`, `ti.types.vector()`, `ti.types.struct()`, `ti.types.ndarray()`, `ti.field()`, and `ti.template()` types. Note that some of the restrictions on kernel arguments do not apply to Taichi functions:

- It is not strictly required to type hint the function arguments (but it is still recommended).
- You can pass an unlimited number of elements in the function arguments.

A Taichi inline function can accept multiple arguments, which may include scalar, `ti.types.matrix()`, `ti.types.vector()`, `ti.types.struct()`, `ti.types.ndarray()`, `ti.field()`, and `ti.template()` types.
Note that unlike Taichi kernels, it is not strictly required to type hint the function arguments (but it is still recommended).

### Return values

Expand All @@ -281,10 +298,8 @@ Return values of a Taichi inline function can be scalars, `ti.types.matrix()`, `

### Arguments

A Taichi real function can accept multiple arguments, which may include scalar, `ti.types.matrix()`, `ti.types.vector()`, `ti.types.struct()`, `ti.types.ndarray()`, `ti.field()`, and `ti.template()` types. Note the following:

- You must type hint the function arguments.
- You can pass an unlimited number of elements in the function arguments.
A Taichi real function can accept multiple arguments, which may include scalar, `ti.types.matrix()`, `ti.types.vector()`, `ti.types.struct()`, `ti.types.ndarray()`, `ti.field()`, and `ti.template()` types.
Note that you must type hint the function arguments.

### Return values

Expand All @@ -302,7 +317,6 @@ Return values of a Taichi inline function can be scalars, `ti.types.matrix()`, `
| Type hint arguments | Mandatory | Recommended | Mandatory |
| Type hint return values | Mandatory | Recommended | Mandatory |
| Return type | <ul><li>Scalar</li><li>`ti.types.matrix()`</li><li>`ti.types.vector()`</li><li>`ti.types.struct()`(Only on LLVM-based backends)</li></ul> | <ul><li>Scalar</li><li>`ti.types.matrix()`</li><li>`ti.types.vector()`</li><li>`ti.types.struct()`</li><li>...</li></ul> | <ul><li>Scalar</li><li>`ti.types.matrix()`</li><li>`ti.types.vector()`</li><li>`ti.types.struct()`</li><li>...</li></ul> |
| Maximum number of elements in arguments | <ul><li>Unlimited (CPU and CUDA)</li><li>32 (OpenGL)</li><li>64 (otherwise)</li></ul> | Unlimited | Unlimited |
| Maximum number of return statements | 1 | 1 | Unlimited |


Expand Down
16 changes: 16 additions & 0 deletions tests/python/test_argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,3 +274,19 @@ def foo(a: ti.f32) -> ti.f32:
return bar(a)

assert foo(1.5) == 1.0


@test_utils.test(exclude=[ti.amdgpu])
def test_arg_4k():
vec1024 = ti.types.vector(1024, ti.i32)

@ti.kernel
def bar(a: vec1024) -> ti.i32:
ret = 0
for i in range(1024):
ret += a[i]

return ret

a = vec1024([i for i in range(1024)])
assert bar(a) == 523776
16 changes: 16 additions & 0 deletions tests/python/test_return.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,3 +343,19 @@ def foo() -> tp:
return bar()

assert foo().a == 0


@test_utils.test(exclude=[ti.amdgpu])
def test_ret_4k():
vec1024 = ti.types.vector(1024, ti.i32)

@ti.kernel
def foo() -> vec1024:
ret = vec1024(0)
for i in range(1024):
ret[i] = i
return ret

ret = foo()
for i in range(1024):
assert ret[i] == i
Loading