From 88776da450e83f4393e26e5f41a69c2f222f1042 Mon Sep 17 00:00:00 2001 From: ulises-jeremias Date: Sun, 28 Apr 2024 02:35:53 -0300 Subject: [PATCH] Refactor dgetrf function to use blocked algorithm --- ...ck_lapacke copy.v => cflags_d_vsl_lapack_lapacke.v} | 0 lapack/lapack64/dgetrf.v | 10 +++++++++- ml/knn.v | 2 +- 3 files changed, 10 insertions(+), 2 deletions(-) rename lapack/{cflags_d_vsl_lapack_lapacke copy.v => cflags_d_vsl_lapack_lapacke.v} (100%) diff --git a/lapack/cflags_d_vsl_lapack_lapacke copy.v b/lapack/cflags_d_vsl_lapack_lapacke.v similarity index 100% rename from lapack/cflags_d_vsl_lapack_lapacke copy.v rename to lapack/cflags_d_vsl_lapack_lapacke.v diff --git a/lapack/lapack64/dgetrf.v b/lapack/lapack64/dgetrf.v index 9a1a10a25..89aee7a63 100644 --- a/lapack/lapack64/dgetrf.v +++ b/lapack/lapack64/dgetrf.v @@ -73,7 +73,15 @@ pub fn dgetrf(m int, n int, mut a []f64, lda int, ipiv []int) { // apply interchanges to columns 1..j-1. mut slice := unsafe { a[j + jb..] } dlaswp(j, mut slice, lda, j, j + jb, ipiv[..j + jb], 1) - // + + blas.dtstrf(.left, .lower, .notrans, .unit, jb, n - j - jb, 1, a[j * lda + j..], + lda, a[j * lda + j + jb..], lda) + + if j + jb < m { + blas.dgemm(.notrans, .notrans, m - j - jb, n - j - jb, jb, -1, a[(j + jb) * lda + j..], + lda, a[j * lda + j + jb..], lda, 1, a[(j + jb) * lda + j + jb..], + lda) + } } } } diff --git a/ml/knn.v b/ml/knn.v index 407cbfec0..c84db9397 100644 --- a/ml/knn.v +++ b/ml/knn.v @@ -112,7 +112,7 @@ pub struct PredictConfig { pub: max_iter int k int - to_pred []f64 + to_pred []f64 } // predict will find the `k` points nearest to the specified `to_pred`.