Skip to content

Commit

Permalink
Type hinting (#127)
Browse files Browse the repository at this point in the history
- implemented the basic infrastructure for type-hints

- added a couple of type hints. 

- defined a custom `numba.njit` wrapper, that preserves the type information and sets `cache=True` by default. (related to and relevant for numba/numba#7424)
  • Loading branch information
mcocdawc authored Jan 21, 2025
1 parent 651ddc0 commit cd29f44
Show file tree
Hide file tree
Showing 21 changed files with 221 additions and 106 deletions.
4 changes: 0 additions & 4 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,6 @@ run_tests: &run_tests
- run:
name: Running tests
command: python3 -m pytest -Werror --cov=./src/chemcoord tests/
- run:
name: Upload coverage reports to Codecov
command: |
bash <(curl -s https://codecov.io/bash)
- run:
name: Prepare documentation
command: pip3 install -r docs/requirements.txt
Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,6 @@ cealign-0.8-RBS
#pytest
*.chache/*
.cache/


.mypy_cache/
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"sympy",
"six",
"pymatgen",
"typing_extensions",
]
KEYWORDS = [
"chemcoord",
Expand Down
14 changes: 7 additions & 7 deletions src/chemcoord/_cartesian_coordinates/_cart_transformation.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import numba as nb
import numpy as np
from numba import njit
from numba.extending import overload
from numpy import arccos, arctan2, sqrt

import chemcoord.constants as constants
from chemcoord._cartesian_coordinates.xyz_functions import (
_jit_normalize,
)
from chemcoord._utilities._decorators import njit
from chemcoord.exceptions import ERR_CODE_OK, ERR_CODE_InvalidReference


Expand Down Expand Up @@ -44,12 +44,12 @@ def f(X, indices):
raise AssertionError("Should not be here")


@njit(cache=True)
@njit
def get_ref_pos(X, indices):
return _stub_get_ref_pos(X, indices)


@njit(cache=True)
@njit
def get_B(X, c_table, j):
B = np.empty((3, 3))
ref_pos = get_ref_pos(X, c_table[:, j])
Expand All @@ -66,7 +66,7 @@ def get_B(X, c_table, j):
return (ERR_CODE_OK, B)


@njit(cache=True)
@njit
def get_grad_B(X, c_table, j):
grad_B = np.empty((3, 3, 3, 3))
ref_pos = get_ref_pos(X, c_table[:, j])
Expand Down Expand Up @@ -1130,7 +1130,7 @@ def get_grad_S_inv(v):
return grad_S_inv


@njit(cache=True)
@njit
def get_T(X, c_table, j):
err, B = get_B(X, c_table, j)
if err == ERR_CODE_OK:
Expand All @@ -1141,7 +1141,7 @@ def get_T(X, c_table, j):
return err, result


@njit(cache=True)
@njit
def get_C(X, c_table):
C = np.empty((3, c_table.shape[1]))

Expand All @@ -1154,7 +1154,7 @@ def get_C(X, c_table):
return (ERR_CODE_OK, C)


@njit(cache=True)
@njit
def get_grad_C(X, c_table):
n_atoms = X.shape[1]
grad_C = np.zeros((3, n_atoms, n_atoms, 3))
Expand Down
Loading

0 comments on commit cd29f44

Please sign in to comment.