Skip to content

Commit

Permalink
ci: add python codegen test (#96)
Browse files Browse the repository at this point in the history
  • Loading branch information
kevmo314 authored Jan 26, 2025
1 parent cd58f7c commit 03dc72f
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 65 deletions.
25 changes: 25 additions & 0 deletions .github/workflows/codegen.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
name: Codegen Test

on:
pull_request:
branches: [main]

jobs:
codegen:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
python-version: ['3.9', '3.10', '3.11', '3.12', '3.13']

steps:
- name: Checkout code
uses: actions/checkout@v4

- uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: 'pip' # caching pip dependencies
- working-directory: codegen
run: pip install -r requirements.txt && python codegen.py

2 changes: 1 addition & 1 deletion .github/workflows/pr-test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:

steps:
- name: Checkout code
uses: actions/checkout@v3
uses: actions/checkout@v4

- name: Set up Tailscale
uses: tailscale/github-action@v2
Expand Down
150 changes: 86 additions & 64 deletions codegen/codegen.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from cxxheaderparser.simple import parse_file, ParsedData, ParserOptions
from cxxheaderparser.preprocessor import make_gcc_preprocessor
from cxxheaderparser.types import Type, Pointer, Parameter, Function, Array
from typing import Optional
from typing import Optional, Union
from dataclasses import dataclass
import os
import glob
Expand Down Expand Up @@ -112,7 +112,11 @@ def client_rpc_write(self, f):
)

def client_unified_copy(self, f, direction, error):
f.write(" if (maybe_copy_unified_arg(0, (void*){name}, cudaMemcpyDeviceToHost) < 0)\n".format(name=self.parameter.name))
f.write(
" if (maybe_copy_unified_arg(0, (void*){name}, cudaMemcpyDeviceToHost) < 0)\n".format(
name=self.parameter.name
)
)
f.write(" return {error};\n".format(error=error))

@property
Expand Down Expand Up @@ -195,9 +199,8 @@ class ArrayOperation:
recv: bool
parameter: Parameter
ptr: Pointer
length: (
int | Parameter
) # if int, it's a constant length, if Parameter, it's a variable length.
# if int, it's a constant length, if Parameter, it's a variable length.
length: Union[int, Parameter]

def client_rpc_write(self, f):
if not self.send:
Expand Down Expand Up @@ -305,7 +308,7 @@ def server_declaration(self) -> str:
s = f" {self.ptr.format()} {self.parameter.name};\n"
self.ptr.ptr_to.const = c
return s

def server_rpc_read(self, f, index) -> Optional[str]:
if not self.send:
# if this parameter is recv only and it's a type pointer, it needs to be malloc'd.
Expand All @@ -317,7 +320,9 @@ def server_rpc_read(self, f, index) -> Optional[str]:
param_name=self.parameter.name,
param_type=self.ptr.ptr_to.format(),
server_type=self.ptr.format(),
length=self.length if isinstance(self.length, int) else self.length.name,
length=self.length
if isinstance(self.length, int)
else self.length.name,
)
)
f.write(" if(")
Expand All @@ -326,7 +331,7 @@ def server_rpc_read(self, f, index) -> Optional[str]:
elif isinstance(self.length, int):
f.write(
" rpc_read(conn, &{param_name}, {size}) < 0 ||\n".format(
param_name=self.parameter.name,
param_name=self.parameter.name,
size=self.length,
)
)
Expand Down Expand Up @@ -502,7 +507,7 @@ class OpaqueTypeOperation:
send: bool
recv: bool
parameter: Parameter
type_: Type | Pointer
type_: Union[Type, Pointer]

def client_rpc_write(self, f):
if not self.send:
Expand All @@ -523,8 +528,8 @@ def server_declaration(self) -> str:
# but "const cudnnTensorDescriptor_t *xDesc" IS valid. This subtle change carries reprecussions.
elif (
"const " in self.type_.format()
and not "void" in self.type_.format()
and not "*" in self.type_.format()
and "void" not in self.type_.format()
and "*" not in self.type_.format()
):
return f" {self.type_.format().replace('const', '')} {self.parameter.name};\n"
else:
Expand Down Expand Up @@ -653,13 +658,13 @@ def client_rpc_read(self, f):
)


Operation = (
NullableOperation
| ArrayOperation
| NullTerminatedOperation
| OpaqueTypeOperation
| DereferenceOperation
)
Operation = Union[
NullableOperation,
ArrayOperation,
NullTerminatedOperation,
OpaqueTypeOperation,
DereferenceOperation,
]


# parses a function annotation. if disabled is encountered, returns True for short circuiting.
Expand Down Expand Up @@ -811,72 +816,87 @@ def parse_annotation(
parameter=param,
ptr=param.type,
length=length_param,
))
)
)
elif size_arg:
# if it has a size, it's an array operation with constant length
operations.append(ArrayOperation(
send=send,
recv=recv,
parameter=param,
ptr=param.type,
length=int(size_arg.split(":")[1]),
))
operations.append(
ArrayOperation(
send=send,
recv=recv,
parameter=param,
ptr=param.type,
length=int(size_arg.split(":")[1]),
)
)
elif null_terminated:
# if it's null terminated, it's a null terminated operation
operations.append(NullTerminatedOperation(
send=send,
recv=recv,
parameter=param,
ptr=param.type,
))
operations.append(
NullTerminatedOperation(
send=send,
recv=recv,
parameter=param,
ptr=param.type,
)
)
elif nullable:
# if it's nullable, it's a nullable operation
operations.append(NullableOperation(
send=send,
recv=recv,
parameter=param,
ptr=param.type,
))
operations.append(
NullableOperation(
send=send,
recv=recv,
parameter=param,
ptr=param.type,
)
)
else:
# otherwise, it's a pointer to a single value or another pointer
if recv:
if param.type.ptr_to.format() == "void":
raise NotImplementedError("Cannot dereference a void pointer")
# this is an out parameter so use the base type as the server declaration
operations.append(DereferenceOperation(
send=send,
recv=recv,
parameter=param,
type_=param.type,
))
operations.append(
DereferenceOperation(
send=send,
recv=recv,
parameter=param,
type_=param.type,
)
)
else:
# otherwise, treat it as an opaque type
operations.append(OpaqueTypeOperation(
send=send,
recv=recv,
parameter=param,
type_=param.type,
))
operations.append(
OpaqueTypeOperation(
send=send,
recv=recv,
parameter=param,
type_=param.type,
)
)
elif isinstance(param.type, Type):
if param.type.const:
recv = False
operations.append(OpaqueTypeOperation(
send=send,
recv=recv,
parameter=param,
type_=param.type,
))
operations.append(
OpaqueTypeOperation(
send=send,
recv=recv,
parameter=param,
type_=param.type,
)
)
elif isinstance(param.type, Array):
length_param = next(p for p in params if p.name == length_arg.split(":")[1])
if param.type.array_of.const:
recv = False
operations.append(ArrayOperation(
send=send,
recv=recv,
parameter=param,
ptr=param.type,
length=length_param,
))
operations.append(
ArrayOperation(
send=send,
recv=recv,
parameter=param,
ptr=param.type,
length=length_param,
)
)
else:
raise NotImplementedError("Unknown type")
return operations, False
Expand Down Expand Up @@ -1249,7 +1269,9 @@ def main():

f.write(" if (\n")
for operation in operations:
if isinstance(operation, NullTerminatedOperation) or isinstance(operation, ArrayOperation):
if isinstance(operation, NullTerminatedOperation) or isinstance(
operation, ArrayOperation
):
if error := operation.server_rpc_read(f, len(defers)):
defers.append(error)
else:
Expand Down

0 comments on commit 03dc72f

Please sign in to comment.