Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ggml-cpu : add chunking support to mul_mat_id #11666

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open

Conversation

slaren
Copy link
Collaborator

@slaren slaren commented Feb 5, 2025

Adds the same dynamic scheduling to mul_mat_id as is currently implemented for mul_mat. May improve performance slightly of MoE models on CPUs with heterogeneous cores.

@github-actions github-actions bot added testing Everything test related ggml changes relating to the ggml tensor library for machine learning labels Feb 5, 2025
@max-krasnyansky
Copy link
Collaborator

Looks good but seems to eat a few T/S on X-Elite for me.

before

model size params backend threads test t/s
olmoe A1.7B Q4_0 3.66 GiB 6.92 B CPU 10 pp128 544.64 ± 4.73
olmoe A1.7B Q4_0 3.66 GiB 6.92 B CPU 10 tg64 112.80 ± 0.85

after

model size params backend threads test t/s
olmoe A1.7B Q4_0 3.66 GiB 6.92 B CPU 10 pp128 534.85 ± 4.62
olmoe A1.7B Q4_0 3.66 GiB 6.92 B CPU 10 tg64 112.15 ± 1.33

will recheck again tomorrow

@slaren
Copy link
Collaborator Author

slaren commented Feb 5, 2025

The chunk size could probably be tuned for different processors, it has a significant effect on performance. From what I could find, the X-Elite does not actually have heterogeneous cores, so I wouldn't expect an improvement from these changes, and increasing the chunk size should reduce the performance loss from the additional synchronization that this change requires.

@slaren
Copy link
Collaborator Author

slaren commented Feb 5, 2025

I have changed it to use a different atomic chunk counter for each matrix, which should remove the need of most of the synchronization.

With 13900k:

Model Threads Test t/s master t/s sl/mmid-cpu-perf Speedup
llama 8x7B Q3_K_S 8 pp64 23.36 27.11 1.16
llama 8x7B Q3_K_S 8 tg32 12.30 14.13 1.15
llama 8x7B Q3_K_S 16 pp64 18.42 35.99 1.95
llama 8x7B Q3_K_S 16 tg32 10.42 15.04 1.44
llama 8x7B Q3_K_S 24 pp64 26.21 41.35 1.58
llama 8x7B Q3_K_S 24 tg32 11.23 13.13 1.17
llama 8x7B Q3_K_S 32 pp64 32.68 40.74 1.25
llama 8x7B Q3_K_S 32 tg32 10.99 11.46 1.04

@slaren slaren marked this pull request as draft February 5, 2025 15:44
@max-krasnyansky
Copy link
Collaborator

The chunk size could probably be tuned for different processors, it has a significant effect on performance. From what I could find, the X-Elite does not actually have heterogeneous cores, so I wouldn't expect an improvement from these changes, and increasing the chunk size should reduce the performance loss from the additional synchronization that this change requires.

Yep. X-Elite has uniform cores, so I wasn't expecting the gain just wanted to make sure we don't see regressions.
I will check your latest update a bit later and also test on S24 and S25 with het.cores.

@slaren
Copy link
Collaborator Author

slaren commented Feb 7, 2025

@max-krasnyansky I realized that Q4_0 does not use the same mat mul function due to the aarch64 conversion. So if you test this again, do not use a Q4_0 model, it is using the same code as on master and any differences should be just noise.

@max-krasnyansky
Copy link
Collaborator

@max-krasnyansky I realized that Q4_0 does not use the same mat mul function due to the aarch64 conversion. So if you test this again, do not use a Q4_0 model, it is using the same code as on master and any differences should be just noise.

I'll try Q2_K or Q4_K shortly

@max-krasnyansky
Copy link
Collaborator

Here is another set of runs with Q4_K on X-Elite.
Looks like some further tunning will be needed.
master is definitely better, chunk size 64 helps with fewer threads but hurts with more threads.

I'll play with it some more tomorrow.
btw I realized the regular matmul has the same issue. I vaguely remember that original chunking support caused regression
but with Q4_0_4_8 we got a big bump and I forgot about it.

master\bin\llama-bench.exe -m C:\Users\maxk\src\gguf\OLMoE-1B-7B-0924.Q4_K.gguf -p 128 -n 0 -t 2,4,6,8,10 -ngl 0 --delay 15

model size params backend threads test t/s
olmoe A1.7B Q4_K - Medium 3.92 GiB 6.92 B CPU 2 pp128 56.82 ± 1.04
olmoe A1.7B Q4_K - Medium 3.92 GiB 6.92 B CPU 4 pp128 110.64 ± 0.43
olmoe A1.7B Q4_K - Medium 3.92 GiB 6.92 B CPU 6 pp128 163.05 ± 0.81
olmoe A1.7B Q4_K - Medium 3.92 GiB 6.92 B CPU 8 pp128 214.60 ± 0.66
olmoe A1.7B Q4_K - Medium 3.92 GiB 6.92 B CPU 10 pp128 263.76 ± 1.97

build: 55ac8c7 (4675)

chunk-size-16
mmid-cs16\bin\llama-bench.exe -m C:\Users\maxk\src\gguf\OLMoE-1B-7B-0924.Q4_K.gguf -p 128 -n 0 -t 2,4,6,8,10 -ngl 0 --delay 15

model size params backend threads test t/s
olmoe A1.7B Q4_K - Medium 3.92 GiB 6.92 B CPU 2 pp128 49.48 ± 0.44
olmoe A1.7B Q4_K - Medium 3.92 GiB 6.92 B CPU 4 pp128 100.48 ± 0.30
olmoe A1.7B Q4_K - Medium 3.92 GiB 6.92 B CPU 6 pp128 152.68 ± 0.25
olmoe A1.7B Q4_K - Medium 3.92 GiB 6.92 B CPU 8 pp128 203.39 ± 0.95
olmoe A1.7B Q4_K - Medium 3.92 GiB 6.92 B CPU 10 pp128 250.23 ± 2.08

build: 45cede2 (4629)

chunk-size-64
mmid-cs64\bin\llama-bench.exe -m C:\Users\maxk\src\gguf\OLMoE-1B-7B-0924.Q4_K.gguf -p 128 -n 0 -t 2,4,6,8,10 -ngl 0 --delay 15

model size params backend threads test t/s
olmoe A1.7B Q4_K - Medium 3.92 GiB 6.92 B CPU 2 pp128 57.30 ± 1.37
olmoe A1.7B Q4_K - Medium 3.92 GiB 6.92 B CPU 4 pp128 111.17 ± 0.30
olmoe A1.7B Q4_K - Medium 3.92 GiB 6.92 B CPU 6 pp128 164.98 ± 0.56
olmoe A1.7B Q4_K - Medium 3.92 GiB 6.92 B CPU 8 pp128 199.64 ± 1.03
olmoe A1.7B Q4_K - Medium 3.92 GiB 6.92 B CPU 10 pp128 219.92 ± 1.55

build: 45cede2 (4629)

@slaren
Copy link
Collaborator Author

slaren commented Feb 9, 2025

Thanks for testing. It's surprising that increasing the chunk size decreases performance, I would expect that a larger chunk size should make the behavior closer to the previous behavior. I am not sure what could cause this, maybe it is some cache effect that I am missing?

I don't have the hardware to test this, so all I can do is disable this for ARM. On M3 Max the change is mostly neutral, with better performance above 12 threads (but still worse than 12).

CPU Model Threads Test t/s master t/s sl/mmid-cpu-perf Speedup
Apple M3 Max olmoe A1.7B Q4_K_M 2 pp128 97.00 92.91 0.96
Apple M3 Max olmoe A1.7B Q4_K_M 4 pp128 181.95 178.70 0.98
Apple M3 Max olmoe A1.7B Q4_K_M 8 pp128 320.09 317.74 0.99
Apple M3 Max olmoe A1.7B Q4_K_M 12 pp128 386.02 381.51 0.99
Apple M3 Max olmoe A1.7B Q4_K_M 13 pp128 244.26 355.65 1.46
Apple M3 Max olmoe A1.7B Q4_K_M 14 pp128 241.09 355.24 1.47
Apple M3 Max olmoe A1.7B Q4_K_M 15 pp128 225.00 327.30 1.45
Apple M3 Max olmoe A1.7B Q4_K_M 16 pp128 184.14 260.59 1.42

parallelize src1 quantization by column to allows parallelization even when there is only one row
@slaren slaren marked this pull request as ready for review February 9, 2025 15:28
Comment on lines +7832 to +7846
#if defined(__aarch64__)
// disable for ARM
const bool disable_chunking = true;
#else
// disable for NUMA
const bool disable_chunking = ggml_is_numa();
#endif // defined(__aarch64__)

// attempt to reduce false-sharing (does not seem to make a difference)
float tmp[16];
int64_t nchunk0 = (nr0 + chunk_size - 1) / chunk_size;
int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size;

for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
const int64_t _i12 = ir1; // logical row index for this expert
if (nchunk0 * nchunk1 < nth * 4 || disable_chunking) {
nchunk0 = nr0 > nr1 ? nth : 1;
nchunk1 = nr0 > nr1 ? 1 : nth;
}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same change could be applied to the regular mul_mat if it causes a regression on some ARM devices.

@max-krasnyansky
Copy link
Collaborator

Thanks for testing. It's surprising that increasing the chunk size decreases performance, I would expect that a larger chunk size should make the behavior closer to the previous behavior. I am not sure what could cause this, maybe it is some cache effect that I am missing?

I don't have the hardware to test this, so all I can do is disable this for ARM. On M3 Max the change is mostly neutral, with better performance above 12 threads (but still worse than 12).

I agree it's odd.
I'm going to try it on M4 Pro, and S24 and S25 with Snapdragon Gen 3 and 8 Elite.
Will try to do some profiling with linux perf to see if something jumps at me.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning testing Everything test related
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants