From 03dc72f4c179da90df62be627c8f3096d89ac7f4 Mon Sep 17 00:00:00 2001 From: Kevin Wang Date: Sun, 26 Jan 2025 18:41:58 -0500 Subject: [PATCH] ci: add python codegen test (#96) --- .github/workflows/codegen.yaml | 25 ++++++ .github/workflows/pr-test.yaml | 2 +- codegen/codegen.py | 150 +++++++++++++++++++-------------- 3 files changed, 112 insertions(+), 65 deletions(-) create mode 100644 .github/workflows/codegen.yaml diff --git a/.github/workflows/codegen.yaml b/.github/workflows/codegen.yaml new file mode 100644 index 0000000..ad11208 --- /dev/null +++ b/.github/workflows/codegen.yaml @@ -0,0 +1,25 @@ +name: Codegen Test + +on: + pull_request: + branches: [main] + +jobs: + codegen: + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] + python-version: ['3.9', '3.10', '3.11', '3.12', '3.13'] + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' # caching pip dependencies + - working-directory: codegen + run: pip install -r requirements.txt && python codegen.py + diff --git a/.github/workflows/pr-test.yaml b/.github/workflows/pr-test.yaml index f8544c7..2b80933 100644 --- a/.github/workflows/pr-test.yaml +++ b/.github/workflows/pr-test.yaml @@ -13,7 +13,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Set up Tailscale uses: tailscale/github-action@v2 diff --git a/codegen/codegen.py b/codegen/codegen.py index 51c95ce..a338b04 100644 --- a/codegen/codegen.py +++ b/codegen/codegen.py @@ -1,7 +1,7 @@ from cxxheaderparser.simple import parse_file, ParsedData, ParserOptions from cxxheaderparser.preprocessor import make_gcc_preprocessor from cxxheaderparser.types import Type, Pointer, Parameter, Function, Array -from typing import Optional +from typing import Optional, Union from dataclasses import dataclass import os import glob @@ -112,7 +112,11 @@ def client_rpc_write(self, f): ) def client_unified_copy(self, f, direction, error): - f.write(" if (maybe_copy_unified_arg(0, (void*){name}, cudaMemcpyDeviceToHost) < 0)\n".format(name=self.parameter.name)) + f.write( + " if (maybe_copy_unified_arg(0, (void*){name}, cudaMemcpyDeviceToHost) < 0)\n".format( + name=self.parameter.name + ) + ) f.write(" return {error};\n".format(error=error)) @property @@ -195,9 +199,8 @@ class ArrayOperation: recv: bool parameter: Parameter ptr: Pointer - length: ( - int | Parameter - ) # if int, it's a constant length, if Parameter, it's a variable length. + # if int, it's a constant length, if Parameter, it's a variable length. + length: Union[int, Parameter] def client_rpc_write(self, f): if not self.send: @@ -305,7 +308,7 @@ def server_declaration(self) -> str: s = f" {self.ptr.format()} {self.parameter.name};\n" self.ptr.ptr_to.const = c return s - + def server_rpc_read(self, f, index) -> Optional[str]: if not self.send: # if this parameter is recv only and it's a type pointer, it needs to be malloc'd. @@ -317,7 +320,9 @@ def server_rpc_read(self, f, index) -> Optional[str]: param_name=self.parameter.name, param_type=self.ptr.ptr_to.format(), server_type=self.ptr.format(), - length=self.length if isinstance(self.length, int) else self.length.name, + length=self.length + if isinstance(self.length, int) + else self.length.name, ) ) f.write(" if(") @@ -326,7 +331,7 @@ def server_rpc_read(self, f, index) -> Optional[str]: elif isinstance(self.length, int): f.write( " rpc_read(conn, &{param_name}, {size}) < 0 ||\n".format( - param_name=self.parameter.name, + param_name=self.parameter.name, size=self.length, ) ) @@ -502,7 +507,7 @@ class OpaqueTypeOperation: send: bool recv: bool parameter: Parameter - type_: Type | Pointer + type_: Union[Type, Pointer] def client_rpc_write(self, f): if not self.send: @@ -523,8 +528,8 @@ def server_declaration(self) -> str: # but "const cudnnTensorDescriptor_t *xDesc" IS valid. This subtle change carries reprecussions. elif ( "const " in self.type_.format() - and not "void" in self.type_.format() - and not "*" in self.type_.format() + and "void" not in self.type_.format() + and "*" not in self.type_.format() ): return f" {self.type_.format().replace('const', '')} {self.parameter.name};\n" else: @@ -653,13 +658,13 @@ def client_rpc_read(self, f): ) -Operation = ( - NullableOperation - | ArrayOperation - | NullTerminatedOperation - | OpaqueTypeOperation - | DereferenceOperation -) +Operation = Union[ + NullableOperation, + ArrayOperation, + NullTerminatedOperation, + OpaqueTypeOperation, + DereferenceOperation, +] # parses a function annotation. if disabled is encountered, returns True for short circuiting. @@ -811,72 +816,87 @@ def parse_annotation( parameter=param, ptr=param.type, length=length_param, - )) + ) + ) elif size_arg: # if it has a size, it's an array operation with constant length - operations.append(ArrayOperation( - send=send, - recv=recv, - parameter=param, - ptr=param.type, - length=int(size_arg.split(":")[1]), - )) + operations.append( + ArrayOperation( + send=send, + recv=recv, + parameter=param, + ptr=param.type, + length=int(size_arg.split(":")[1]), + ) + ) elif null_terminated: # if it's null terminated, it's a null terminated operation - operations.append(NullTerminatedOperation( - send=send, - recv=recv, - parameter=param, - ptr=param.type, - )) + operations.append( + NullTerminatedOperation( + send=send, + recv=recv, + parameter=param, + ptr=param.type, + ) + ) elif nullable: # if it's nullable, it's a nullable operation - operations.append(NullableOperation( - send=send, - recv=recv, - parameter=param, - ptr=param.type, - )) + operations.append( + NullableOperation( + send=send, + recv=recv, + parameter=param, + ptr=param.type, + ) + ) else: # otherwise, it's a pointer to a single value or another pointer if recv: if param.type.ptr_to.format() == "void": raise NotImplementedError("Cannot dereference a void pointer") # this is an out parameter so use the base type as the server declaration - operations.append(DereferenceOperation( - send=send, - recv=recv, - parameter=param, - type_=param.type, - )) + operations.append( + DereferenceOperation( + send=send, + recv=recv, + parameter=param, + type_=param.type, + ) + ) else: # otherwise, treat it as an opaque type - operations.append(OpaqueTypeOperation( - send=send, - recv=recv, - parameter=param, - type_=param.type, - )) + operations.append( + OpaqueTypeOperation( + send=send, + recv=recv, + parameter=param, + type_=param.type, + ) + ) elif isinstance(param.type, Type): if param.type.const: recv = False - operations.append(OpaqueTypeOperation( - send=send, - recv=recv, - parameter=param, - type_=param.type, - )) + operations.append( + OpaqueTypeOperation( + send=send, + recv=recv, + parameter=param, + type_=param.type, + ) + ) elif isinstance(param.type, Array): length_param = next(p for p in params if p.name == length_arg.split(":")[1]) if param.type.array_of.const: recv = False - operations.append(ArrayOperation( - send=send, - recv=recv, - parameter=param, - ptr=param.type, - length=length_param, - )) + operations.append( + ArrayOperation( + send=send, + recv=recv, + parameter=param, + ptr=param.type, + length=length_param, + ) + ) else: raise NotImplementedError("Unknown type") return operations, False @@ -1249,7 +1269,9 @@ def main(): f.write(" if (\n") for operation in operations: - if isinstance(operation, NullTerminatedOperation) or isinstance(operation, ArrayOperation): + if isinstance(operation, NullTerminatedOperation) or isinstance( + operation, ArrayOperation + ): if error := operation.server_rpc_read(f, len(defers)): defers.append(error) else: