Skip to content

Commit

Permalink
Merge branch 'main' into sasha/uv/activated-venv
Browse files Browse the repository at this point in the history
  • Loading branch information
superlopuh authored Dec 16, 2024
2 parents d590723 + 7c4cb63 commit f6ddebe
Show file tree
Hide file tree
Showing 17 changed files with 283 additions and 113 deletions.
38 changes: 38 additions & 0 deletions .github/workflows/ci-lockfile.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# This workflow will install Python dependencies, run tests and lint with a single version of Python
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions

name: CI - Lockfile

on:
# Trigger the workflow on push or pull request,
# but only for the master branch
push:
branches:
- main
pull_request:

jobs:
build:

runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.12']

steps:
- uses: actions/checkout@v4

- name: Install uv
uses: astral-sh/setup-uv@v3
with:
enable-cache: true
cache-dependency-glob: "uv.lock"

- name: Set up Python ${{ matrix.python-version }}
run: uv python install ${{ matrix.python-version }}

- name: Install the package locally and check for lockfile mismatch
run: |
# Install all default extras.
# Fail if the lockfile dependencies are out of date with pyproject.toml.
XDSL_VERSION_OVERRIDE="0+dynamic" uv sync --extra gui --extra dev --extra jax --extra riscv --locked
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ uv-installed:
# set up the venv with all dependencies for development
.PHONY: ${VENV_DIR}/
${VENV_DIR}/: uv-installed
uv sync ${VENV_EXTRAS}
XDSL_VERSION_OVERRIDE="0+dynamic" uv sync ${VENV_EXTRAS}

# make sure `make venv` also works correctly
.PHONY: venv
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ dev = [
"lit<19.0.0",
"marimo==0.9.34",
"pre-commit==4.0.1",
"ruff==0.8.2",
"ruff==0.8.3",
"asv<0.7",
"nbconvert>=7.7.2,<8.0.0",
"textual-dev==1.7.0",
"pytest-asyncio==0.24.0",
"pyright==1.1.390",
]
gui = ["textual==0.89.1", "pyclip==0.7"]
gui = ["textual==1.0.0", "pyclip==0.7"]
jax = ["jax==0.4.37", "numpy==2.2.0"]
onnx = ["onnx==1.17.0", "numpy==2.2.0"]
riscv = ["riscemu==2.2.7"]
Expand Down
9 changes: 8 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
import os
from collections.abc import Mapping
from typing import cast

from setuptools import Command, find_packages, setup

import versioneer

if "XDSL_VERSION_OVERRIDE" in os.environ:
version = os.environ["XDSL_VERSION_OVERRIDE"]
else:
version = versioneer.get_version()


setup(
version=versioneer.get_version(),
version=version,
cmdclass=cast(Mapping[str, type[Command]], versioneer.get_cmdclass()),
packages=find_packages(),
)
16 changes: 16 additions & 0 deletions tests/dialects/test_arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
AnyVectorType,
FloatAttr,
IndexType,
IntegerAttr,
IntegerType,
TensorType,
VectorType,
Expand Down Expand Up @@ -124,6 +125,21 @@ def test_Cmpi_from_mnemonic(self, input: str):
_ = CmpiOp(self.a, self.b, input)


@pytest.mark.parametrize(
"value, truncated",
[
(-1, -1),
(1, 1),
(255, -1),
(256, 0),
],
)
def test_constant_truncation(value: int, truncated: int):
constant = ConstantOp.from_int_and_width(value, 8, truncate_bits=True)
assert isinstance(v := constant.value, IntegerAttr)
assert v.value.data == truncated


@pytest.mark.parametrize(
"lhs_type, rhs_type, sum_type, is_correct",
[
Expand Down
55 changes: 52 additions & 3 deletions tests/dialects/test_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from xdsl.dialects.arith import ConstantOp
from xdsl.dialects.builtin import (
AnyTensorType,
AnyVectorType,
ArrayAttr,
BFloat16Type,
BytesAttr,
Expand Down Expand Up @@ -119,6 +120,26 @@ def test_IntegerType_normalized():
assert ui8.normalized_value(255) == 255


def test_IntegerType_truncated():
si8 = IntegerType(8, Signedness.SIGNED)
ui8 = IntegerType(8, Signedness.UNSIGNED)

assert i8.normalized_value(-1, truncate_bits=True) == -1
assert i8.normalized_value(1, truncate_bits=True) == 1
assert i8.normalized_value(255, truncate_bits=True) == -1
assert i8.normalized_value(256, truncate_bits=True) == 0

assert si8.normalized_value(-1, truncate_bits=True) == -1
assert si8.normalized_value(1, truncate_bits=True) == 1
assert si8.normalized_value(255, truncate_bits=True) == -1
assert si8.normalized_value(256, truncate_bits=True) == 0

assert ui8.normalized_value(-1, truncate_bits=True) == 255
assert ui8.normalized_value(1, truncate_bits=True) == 1
assert ui8.normalized_value(255, truncate_bits=True) == 255
assert ui8.normalized_value(256, truncate_bits=True) == 0


def test_IntegerAttr_normalize():
"""
Test that the value within the accepted signless range is normalized to signed
Expand Down Expand Up @@ -217,7 +238,7 @@ def test_IntegerType_packing():


def test_DenseIntOrFPElementsAttr_fp_type_conversion():
check1 = DenseIntOrFPElementsAttr.tensor_from_list([4, 5], f32, [])
check1 = DenseIntOrFPElementsAttr.tensor_from_list([4, 5], f32, [2])

value1 = check1.get_attrs()[0].value.data
value2 = check1.get_attrs()[1].value.data
Expand All @@ -231,7 +252,7 @@ def test_DenseIntOrFPElementsAttr_fp_type_conversion():
t1 = FloatAttr(4.0, f32)
t2 = FloatAttr(5.0, f32)

check2 = DenseIntOrFPElementsAttr.tensor_from_list([t1, t2], f32, [])
check2 = DenseIntOrFPElementsAttr.tensor_from_list([t1, t2], f32, [2])

value3 = check2.get_attrs()[0].value.data
value4 = check2.get_attrs()[1].value.data
Expand All @@ -244,9 +265,37 @@ def test_DenseIntOrFPElementsAttr_fp_type_conversion():


def test_DenseIntOrFPElementsAttr_from_list():
# legal zero-rank tensor
attr = DenseIntOrFPElementsAttr.tensor_from_list([5.5], f32, [])

assert attr.type == AnyTensorType(f32, [])
assert len(attr) == 1

# illegal zero-rank tensor
with pytest.raises(
ValueError, match="A zero-rank tensor can only hold 1 value but 2 were given."
):
DenseIntOrFPElementsAttr.tensor_from_list([5.5, 5.6], f32, [])

# legal 1 element tensor
attr = DenseIntOrFPElementsAttr.tensor_from_list([5.5], f32, [1])
assert attr.type == AnyTensorType(f32, [1])
assert len(attr) == 1

# legal normal tensor
attr = DenseIntOrFPElementsAttr.tensor_from_list([5.5, 5.6], f32, [2])
assert attr.type == AnyTensorType(f32, [2])
assert len(attr) == 2

# splat initialization
attr = DenseIntOrFPElementsAttr.tensor_from_list([4], f32, [4])
assert attr.type == AnyTensorType(f32, [4])
assert tuple(attr.get_values()) == (4, 4, 4, 4)
assert len(attr) == 4

# vector with inferred shape
attr = DenseIntOrFPElementsAttr.vector_from_list([1, 2, 3, 4], f32)
assert attr.type == AnyVectorType(f32, [4])
assert len(attr) == 4


@pytest.mark.parametrize(
Expand Down
18 changes: 18 additions & 0 deletions tests/filecheck/transforms/individual_rewrite/add-same.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// RUN:xdsl-opt %s --split-input-file -p 'apply-individual-rewrite{matched_operation_index=2 operation_name="arith.addi" pattern_name="AdditionOfSameVariablesToMultiplyByTwo"}'| filecheck %s


// CHECK: %v = "test.op"() : () -> i32
// CHECK-NEXT: %[[#two:]] = arith.constant 2 : i32
// CHECK-NEXT: %{{.*}} = arith.muli %v, %[[#two]] : i32

%v = "test.op"() : () -> (i32)
%1 = arith.addi %v, %v : i32

// -----

// CHECK: %v = "test.op"() : () -> i1
// CHECK-NEXT: %[[#zero:]] = arith.constant false
// CHECK-NEXT: %{{.*}} = arith.muli %v, %[[#zero]] : i1

%v = "test.op"() : () -> (i1)
%1 = arith.addi %v, %v : i1
2 changes: 1 addition & 1 deletion tests/interactive/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def callback(x: str):
assert app.condense_mode is True
rewrites = get_all_possible_rewrites(
expected_module,
individual_rewrite.REWRITE_BY_NAMES,
individual_rewrite.INDIVIDUAL_REWRITE_PATTERNS_BY_NAME,
)
assert app.available_pass_list == get_condensed_pass_list(
expected_module, app.all_passes
Expand Down
17 changes: 9 additions & 8 deletions tests/interactive/test_rewrites.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from xdsl.context import MLContext
from xdsl.dialects import get_all_dialects
from xdsl.dialects.builtin import (
Builtin,
StringAttr,
)
from xdsl.dialects.test import TestOp
from xdsl.dialects.test import Test, TestOp
from xdsl.interactive.passes import AvailablePass
from xdsl.interactive.rewrites import (
IndexedIndividualRewrite,
Expand Down Expand Up @@ -39,9 +39,10 @@ def test_get_all_possible_rewrite():
}
"""

ctx = MLContext(True)
for dialect_name, dialect_factory in get_all_dialects().items():
ctx.register_dialect(dialect_name, dialect_factory)
ctx = MLContext()
ctx.load_dialect(Builtin)
ctx.load_dialect(Test)

parser = Parser(ctx, prog)
module = parser.parse_module()

Expand Down Expand Up @@ -73,9 +74,9 @@ def test_convert_indexed_individual_rewrites_to_available_pass():
}
"""

ctx = MLContext(True)
for dialect_name, dialect_factory in get_all_dialects().items():
ctx.register_dialect(dialect_name, dialect_factory)
ctx = MLContext()
ctx.load_dialect(Builtin)
ctx.load_dialect(Test)
parser = Parser(ctx, prog)
module = parser.parse_module()

Expand Down
2 changes: 1 addition & 1 deletion tests/interpreters/test_ml_program_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def test_ml_program_global_load_constant():
interpreter.register_implementations(MLProgramFunctions())

(result,) = interpreter.run_op(fetch, ())
assert result == ShapedArray(TypedPtr.new_int32([4]), [4])
assert result == ShapedArray(TypedPtr.new_int32([4] * 4), [4])


def test_ml_program_global_load_constant_ex2():
Expand Down
Loading

0 comments on commit f6ddebe

Please sign in to comment.