-
Notifications
You must be signed in to change notification settings - Fork 10.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ggml-cpu : add chunking support to mul_mat_id #11666
base: master
Are you sure you want to change the base?
Conversation
Looks good but seems to eat a few T/S on X-Elite for me. before
after
will recheck again tomorrow |
The chunk size could probably be tuned for different processors, it has a significant effect on performance. From what I could find, the X-Elite does not actually have heterogeneous cores, so I wouldn't expect an improvement from these changes, and increasing the chunk size should reduce the performance loss from the additional synchronization that this change requires. |
b6bd497
to
1f060ab
Compare
I have changed it to use a different atomic chunk counter for each matrix, which should remove the need of most of the synchronization. With 13900k:
|
1f060ab
to
6bcb537
Compare
Yep. X-Elite has uniform cores, so I wasn't expecting the gain just wanted to make sure we don't see regressions. |
6bcb537
to
45cede2
Compare
@max-krasnyansky I realized that Q4_0 does not use the same mat mul function due to the aarch64 conversion. So if you test this again, do not use a Q4_0 model, it is using the same code as on master and any differences should be just noise. |
I'll try Q2_K or Q4_K shortly |
Here is another set of runs with Q4_K on X-Elite. I'll play with it some more tomorrow.
build: 55ac8c7 (4675) chunk-size-16
build: 45cede2 (4629) chunk-size-64
build: 45cede2 (4629) |
Thanks for testing. It's surprising that increasing the chunk size decreases performance, I would expect that a larger chunk size should make the behavior closer to the previous behavior. I am not sure what could cause this, maybe it is some cache effect that I am missing? I don't have the hardware to test this, so all I can do is disable this for ARM. On M3 Max the change is mostly neutral, with better performance above 12 threads (but still worse than 12).
|
parallelize src1 quantization by column to allows parallelization even when there is only one row
45cede2
to
1b90527
Compare
#if defined(__aarch64__) | ||
// disable for ARM | ||
const bool disable_chunking = true; | ||
#else | ||
// disable for NUMA | ||
const bool disable_chunking = ggml_is_numa(); | ||
#endif // defined(__aarch64__) | ||
|
||
// attempt to reduce false-sharing (does not seem to make a difference) | ||
float tmp[16]; | ||
int64_t nchunk0 = (nr0 + chunk_size - 1) / chunk_size; | ||
int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size; | ||
|
||
for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) { | ||
for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) { | ||
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) { | ||
const int64_t _i12 = ir1; // logical row index for this expert | ||
if (nchunk0 * nchunk1 < nth * 4 || disable_chunking) { | ||
nchunk0 = nr0 > nr1 ? nth : 1; | ||
nchunk1 = nr0 > nr1 ? 1 : nth; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The same change could be applied to the regular mul_mat
if it causes a regression on some ARM devices.
I agree it's odd. |
Adds the same dynamic scheduling to mul_mat_id as is currently implemented for mul_mat. May improve performance slightly of MoE models on CPUs with heterogeneous cores.