Skip to content

Commit

Permalink
Feat (GPxQ): new init mechanism
Browse files Browse the repository at this point in the history
  • Loading branch information
fabianandresgrob authored Nov 13, 2023
1 parent 7750ea8 commit 7e5d5ac
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 20 deletions.
29 changes: 22 additions & 7 deletions src/brevitas/graph/gpfq.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ def __init__(

self.orig_forward = self.model.forward
self.model.forward = self.catch_stopfwd
self.class_implementation = GPFQ
GPFQ.p = p
self.p = p

def catch_stopfwd(self, *args, **kwargs):
# Collect quant input
Expand Down Expand Up @@ -95,23 +94,39 @@ def catch_stopfwd(self, *args, **kwargs):
gpxq_class.disable_pre_forward_hook = False
return out

def initialize_module_optimizer(
self, layer, name, act_order, len_parallel_layers, create_weight_orig):
return GPFQ(
layer=layer,
name=name,
act_order=act_order,
len_parallel_layers=len_parallel_layers,
create_weight_orig=create_weight_orig,
p=self.p)


class GPFQ(GPxQ):
"""
Based on https://github.com/YixuanSeanZhou/Quantized_Neural_Nets/tree/main
"""
p = 0.25

def __init__(self, layer, name, act_order, parallel_layers=1, create_weight_orig=True) -> None:
def __init__(
self,
layer,
name,
act_order,
len_parallel_layers=1,
create_weight_orig=True,
p=0.25) -> None:

if act_order:
raise ValueError("Act_order is not supported in GPFQ")

super().__init__(layer, name, act_order, parallel_layers, create_weight_orig)
super().__init__(layer, name, act_order, len_parallel_layers, create_weight_orig)
self.float_input = None
self.quantized_input = None
self.index_computed = False
self.p = GPFQ.p
self.p = p

def update_batch(self, module, input, current_layer):
if self.disable_pre_forward_hook:
Expand Down Expand Up @@ -188,7 +203,7 @@ def update_batch(self, module, input, current_layer):
# we executed. Once we executed as many as the number of parallel_layers, we raise
# StopFwdException
current_layer.forward_count += 1
if current_layer.forward_count == len(self.parallel_layers):
if current_layer.forward_count == self.len_parallel_layers:
current_layer.forward_count = 0
raise StopFwdException

Expand Down
28 changes: 21 additions & 7 deletions src/brevitas/graph/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,6 @@ def __init__(
self.model.forward = self.catch_stopfwd
# How many subblock to use during GPTQ for each layer
self.num_blocks = num_blocks
self.class_implementation = GPTQ
GPTQ.num_blocks = num_blocks

def catch_stopfwd(self, *args, **kwargs):
try:
Expand All @@ -85,6 +83,16 @@ def catch_stopfwd(self, *args, **kwargs):
gpxq_class.disable_pre_forward_hook = False
return out

def initialize_module_optimizer(
self, layer, name, act_order, len_parallel_layers, create_weight_orig):
return GPTQ(
layer=layer,
name=name,
act_order=act_order,
len_parallel_layers=len_parallel_layers,
create_weight_orig=create_weight_orig,
num_blocks=self.num_blocks)


class GPTQ(GPxQ):
"""
Expand All @@ -104,15 +112,21 @@ class GPTQ(GPxQ):
See the License for the specific language governing permissions and
limitations under the License.
"""
num_blocks = 100

def __init__(self, layer, name, act_order, parallel_layers=1, create_weight_orig=True) -> None:
super().__init__(layer, name, act_order, parallel_layers, create_weight_orig)
def __init__(
self,
layer,
name,
act_order,
len_parallel_layers=1,
create_weight_orig=True,
num_blocks=100) -> None:
super().__init__(layer, name, act_order, len_parallel_layers, create_weight_orig)

dev = self.layer.weight.device

# Define how many columns to update in each mini-block
self.blocksize = math.ceil(self.columns / GPTQ.num_blocks)
self.blocksize = math.ceil(self.columns / num_blocks)

# Initialize Hessian matrix and counter. We need it in float32 to compute the inverse
self.H = torch.zeros((self.groups, self.columns, self.columns),
Expand Down Expand Up @@ -170,7 +184,7 @@ def update_batch(self, module, input, current_layer):
# we executed. Once we executed as many as the number of parallel_layers, we raise
# StopFwdException
current_layer.forward_count += 1
if current_layer.forward_count == len(self.parallel_layers):
if current_layer.forward_count == self.len_parallel_layers:
current_layer.forward_count = 0
raise StopFwdException

Expand Down
14 changes: 8 additions & 6 deletions src/brevitas/graph/gpxq.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,15 +98,16 @@ def __enter__(self):

# Attach hooks for GPTQ
if self._is_module_supported(module):
gpxq = self.class_implementation(
gpxq_module_optimizer = self.initialize_module_optimizer(
module,
name,
act_order=self.act_order,
parallel_layers=parallel_layers,
len_parallel_layers=len(parallel_layers),
create_weight_orig=self.create_weight_orig)
hook_fn = partial(gpxq.update_batch, current_layer=self.current_layer)
hook_fn = partial(
gpxq_module_optimizer.update_batch, current_layer=self.current_layer)
self.hook_dict[name] = module.register_forward_pre_hook(hook_fn)
self.gpxq_layers[name] = gpxq
self.gpxq_layers[name] = gpxq_module_optimizer
if not self.use_quant_activations:
self.disable_quant_inference.disable_act_quantization(
self.model, is_training=self.model.training)
Expand Down Expand Up @@ -137,7 +138,8 @@ def catch_stopfwd(self, *args, **kwargs):

class GPxQ(ABC):

def __init__(self, layer, name, act_order, parallel_layers=1, create_weight_orig=True) -> None:
def __init__(
self, layer, name, act_order, len_parallel_layers=1, create_weight_orig=True) -> None:
self.layer = layer
self.name = name
self.act_order = act_order
Expand All @@ -159,7 +161,7 @@ def __init__(self, layer, name, act_order, parallel_layers=1, create_weight_orig
self.rows = weight.shape[0]
# Number of columns is equal to the input channels (IC)
self.columns = weight.shape[1]
self.parallel_layers = parallel_layers
self.len_parallel_layers = len_parallel_layers

self.disable_pre_forward_hook = False
# Some layers require knowledge from quant inputs to compute quant weights
Expand Down

0 comments on commit 7e5d5ac

Please sign in to comment.