-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathhlu!.jl
87 lines (79 loc) · 2.51 KB
/
hlu!.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
"""
hlu!(A::AbstractMatrix{T}) where {T}
Return LU factorization of A
C. T. Kelley, 2023
This function is a hack of generic_lufact! which is part of
https://github.com/JuliaLang/julia/blob/master/stdlib/LinearAlgebra/src/lu.jl
I "fixed" the code to be Float16 only and fixed pivoting
to only MaxRow.
All I did in the factorization
was thread the critical loop with OhMyThreads:tforeach and
put @simd in the inner loop. For larger problems (n > 128)
these changes got me a 2-10x speedup
on my Mac M2 Pro with 8 performance cores. I'm happy.
"""
function hlu!(A::AbstractMatrix{T}) where {T}
pivot = RowMaximum()
T == Float16 || @warn("Use hlu for half precision only!")
LAPACK.chkfinite(A)
# Extract values and make sure the problem is square
m, n = size(A)
# Small n? Revert to normal lu
(n < 128) && (AF = lu!(A); return AF)
minmn = min(m, n)
# Initialize variables
info = 0
BlasInt = LinearAlgebra.BLAS.BlasInt
ipiv = Vector{BlasInt}(undef, minmn)
@inbounds begin
for k = 1:minmn
# find index max
kp = k
if k < m
amax = abs(A[k, k])
for i = k+1:m
absi = abs(A[i, k])
if absi > amax
kp = i
amax = absi
end
end
end
ipiv[k] = kp
if !iszero(A[kp, k])
if k != kp
# Interchange
for i = 1:n
tmp = A[k, i]
A[k, i] = A[kp, i]
A[kp, i] = tmp
end
end
# Scale first column
Akkinv = inv(A[k, k])
for i = k+1:m
A[i, k] *= Akkinv
end
elseif info == 0
info = k
end
# Update the rest
ntasks = min(nthreads(), 1 + floor(Int, (n - k) / 8))
tforeach(k+1:n; ntasks = ntasks) do j
Akj = -A[k, j]
@inbounds @simd ivdep for i = k+1:m
A[i, j] += A[i, k] * Akj
end # i loop
end #j loop
end
end
checknonsingular(info, pivot)
return LU{T,typeof(A),typeof(ipiv)}(A, ipiv, convert(BlasInt, info))
end
function hlu(A)
C = copy(A)
AF = hlu!(C)
return AF
end
# More stuff I got from Base
checknonsingular(info, ::RowMaximum) = info == 0 || throw(SingularException(info))