forked from pytorch/ao
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils_gmem.cuh
92 lines (85 loc) · 4.06 KB
/
utils_gmem.cuh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
// Copyright 2024 FP6-LLM authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// This file is modified from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/utils_gmem.cuh
#ifndef UTILS_GMEM_CUH
#define UTILS_GMEM_CUH
#include <assert.h>
#include "configs.h"
#include "ptx_cp.async.cuh"
/*
* Copying A1/A2 from global memory to shared memory.
* Usually 1024 or 2048 Bytes
*/
template<int SMEM_SIZE_IN_BYTES_PER_WARP>
__device__ __forceinline__ void CopyFromGlobalToShared_A(uint32_t* SPTR,
const uint4* GPTR,
bool pred_guard = true) {
#ifdef DEBUG_MODE
static_assert(SMEM_SIZE_IN_BYTES_PER_WARP/WARP_SIZE % 16 == 0);
#endif
int lane_id = threadIdx.x % WARP_SIZE;
half* SPTR_HALF = reinterpret_cast<half*>(SPTR);
const half* GPTR_HALF = reinterpret_cast<const half*>(GPTR);
SPTR_HALF += lane_id*8;
GPTR_HALF += lane_id*8;
#pragma unroll
for(int i=0; i<SMEM_SIZE_IN_BYTES_PER_WARP/WARP_SIZE/16; i++) {
cp_async<16>( SPTR_HALF, GPTR_HALF, pred_guard);
SPTR_HALF += 256; // Forward 512 Bytes
GPTR_HALF += 256; // Forward 512 Bytes
}
}
/*
* Copying 64 Quant Scales (FP16) from global memory to shared memory.
*/
__device__ __forceinline__ void CopyFromGlobalToShared_Scales(half* SPTR_QuantScales,
const half* GPTR_A_Scales) {
int lane_id = threadIdx.x % WARP_SIZE;
int Offset_Shared = lane_id*2;
int Offset_Global = lane_id/4 + (lane_id%4)*16;
for(int i=0; i<2; i++) SPTR_QuantScales[Offset_Shared+i] = GPTR_A_Scales[Offset_Global+i*8];
}
// MODIFICATION NOTE: to support MSVC, half __restrict__ (*SharedPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] is changed to below.
/*
* (1) Copying X rows * 64 columns of FP16 values, originally in row major
* (2) Copying 64 rows * X columns of FP16 values, originally in column major
* 16 Bytes per thread -> 512 Bytes per WARP = 4 line per WARP = 1 line per 8 Threads
*/
template<int MaxNumOfLinesToCopy, int BLOCK_WARPS>
__device__ __forceinline__ void CopyFromGlobalToShared(half (* __restrict__ SharedPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8],
const half* GlobalPTR,
const int GlobalStride,
const int NumOfLinesLeft, // To support arbitrary N dimensions.
bool Pred = true) {
// static parameters: 1 Group (8 Threads) can copy 1 line (64 FP16) each time
const int NumOfThreads = BLOCK_WARPS * WARP_SIZE;
const int NumOfGroups = NumOfThreads / 8;
const int MaxIteration = (MaxNumOfLinesToCopy-1) / NumOfGroups + 1;
// runtime variables
const int line_id = threadIdx.x / 8;
const int line_offset = (threadIdx.x%8) * 8;
// PTR for source global memory and target shared memory
GlobalPTR += line_id * GlobalStride + line_offset;
SharedPTR += line_id;
#pragma unroll
for (int i = 0; i < MaxIteration; i++) {
bool AsyncCopyPred = (line_id+i*NumOfGroups) < NumOfLinesLeft && Pred;
cp_async<16>( &(*SharedPTR)[line_offset], GlobalPTR, AsyncCopyPred);
//
GlobalPTR += NumOfGroups * GlobalStride;
SharedPTR += NumOfGroups;
}
}
#endif