这里简单记录 dynamically_quantize_per_channel 中对称量化实现。更详细的量化知识请阅读 A Visual Guide to Quantization。
import torch | |
from typing import Optional | |
# https://github.com/pytorch/torchchat/blob/main/quantization/quantize.py | |
def dynamically_quantize_per_channel( | |
x, | |
quant_min, | |
quant_max, | |
target_dtype, | |
groupsize: Optional[int] = None, | |
*, | |
scales_dtype=torch.float16, | |
enable_non_multiple_groups=True, | |
): | |
""" | |
Dynamically quantize per channel. This function is used for quantizing weights, | |
for linear and embedding layers. | |
Arguments: | |
x: input tensor, | |
quant_min: minimum value after quantization, | |
quant_max: maximum value after quantization, | |
target_dtype: target data type for weights after quantization, | |
groupsize: number of elements of the channel to quantize together | |
Keyword arguments: | |
scales_dtype: data type of scale, | |
enable_non_multiple_groups: if True, allow the rowsize to not be a multiple of group size, | |
with a final group of a size less than group size. | |
Assumptions: | |
This function assumes symmetric quantization, axis ==0 and a dense memory format. | |
""" | |
# assumes symmetric quantization | |
# assumes axis == 0 | |
# assumes dense memory format | |
# TODO(future): relax ^ as needed | |
x_shape_1 = x.shape[1] | |
if groupsize is None or groupsize == 0: | |
items = x_shape_1 | |
elif ((x_shape_1 % groupsize) == 0) or not enable_non_multiple_groups: | |
assert groupsize > 0, "group size must be positive" | |
assert (x_shape_1 % groupsize) == 0, f"weights dimension 1 = {x_shape_1} must be a multiple of group size {groupsize}" | |
items = groupsize | |
else: | |
assert groupsize > 0, "group size must be positive" | |
print(f"row-size of weight matrix {x_shape_1} is not divisible by group size {groupsize}, using nearest neighbor rounding" | |
) | |
assert (x_shape_1 % groupsize != 0), f"expected x.shape[1] to not be a multiple of group size {groupsize}, but got {x_shape_1}" | |
padding = groupsize - (x_shape_1 % groupsize) | |
x = F.pad(x, (0, padding)) | |
items = groupsize | |
# default setup for affine quantization of activations | |
eps = torch.finfo(torch.float32).eps | |
x = x.view(x.shape[0], x.shape[1] // items, items) | |
# get min and max | |
min_val, max_val = torch.aminmax(x, dim=2) | |
# print(f"min_val {min_val}") | |
# print(f"max_val {max_val}") | |
# calculate scales and zero_points based on min and max | |
# reference: https://fburl.com/code/srbiybme | |
min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) | |
max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) | |
device = min_val_neg.device | |
# reference: https://fburl.com/code/4wll53rk | |
max_val_pos = torch.max(-min_val_neg, max_val_pos) | |
scales = max_val_pos / (float(quant_max - quant_min) / 2) | |
# ensure scales is the same dtype as the original tensor | |
scales = torch.clamp(scales, min=eps).to(x.dtype) | |
zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device) | |
# quantize based on qmin/qmax/scales/zp | |
# reference: https://www.internalfb.com/code/fbsource/[8edc275012b1]/fbcode/caffe2/torch/ao/quantization/fx/_decomposed.py?lines=63 | |
x_div = x / scales.unsqueeze(-1) | |
x_round = torch.round(x_div) | |
x_zp = x_round + zero_points.unsqueeze(-1) | |
quant = (torch.clamp(x_zp, quant_min, quant_max).to(target_dtype).view(x.shape[0], -1) | |
) | |
scales = scales.to(dtype=scales_dtype) | |
quant = quant[:, :x_shape_1] | |
return quant, scales, zero_points | |
if __name__ == '__main__': | |
import torch | |
import torch.nn.functional as F | |
x = torch.tensor([ | |
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], | |
[8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0] | |
], dtype=torch.float32) | |
# 量化参数 | |
quant_min = -128 | |
quant_max = 127 | |
target_dtype = torch.int8 | |
groupsize = 4 | |
# 量化函数调用 | |
quant, scales, zero_points = dynamically_quantize_per_channel( | |
x, | |
quant_min, | |
quant_max, | |
target_dtype, | |
groupsize | |
) | |
print("量化后的张量:") | |
print(quant) | |
print("尺度因子:") | |
print(scales) | |
print("零点:") | |
print(zero_points) | |
###################### | |
量化后的张量:tensor([[ 32, 64, 96, 127, 80, 96, 112, 127], | |
[127, 112, 96, 80, 127, 96, 64, 32]], dtype=torch.int8) | |
尺度因子:tensor([[0.0314, 0.0627], | |
[0.0627, 0.0314]], dtype=torch.float16) | |
零点:tensor([[0, 0], | |
[0, 0]]) |
下面是这个量化过程的公式和步骤:
- 输入张量重塑 :
- 将输入张量
x
变形为(N, C', G)
,其中N
是批次大小,C'
是x
的第二维度(通常是通道数)除以items
(如果需要的话填充以确保这个维度是items
的倍数),G
是每组的大小(等于items
)。
- 计算最小值和最大值 :
- 对每一组
G
,计算每个通道的最小值min_val
和最大值max_val
:
min_val=min(x,dim=2) max_val=max(x,dim=2)
- 计算尺度因子和零点 :
- 对于每个通道的负最小值
min_val_neg
和正最大值max_val_pos
,计算max_val_pos
:
max_val_pos=max(−min_val_neg,max_val_pos) - 计算尺度因子
scales
:
scales=max_val_posquant_max–quant_min2
其中quant_min
和quant_max
是量化后的最小值和最大值。确保scales
的最小值为eps
(一个很小的常数,防止除以零)。
- 量化计算 :
- 对输入张量
x
进行量化:
xdiv=xscales.unsqueeze(−1) xround=round(xdiv) xzp=xround+zero_points.unsqueeze(−1) quant=clamp(xzp,quant_min,quant_max) - 将量化后的结果转换为目标数据类型
target_dtype
。
总结来说,量化的公式可以表示为:
quant=clamp(xscales+zero_points,quant_min,quant_max)
其中:
- (scales=max_val_posquant_max–quant_min2)
- zero_points=0
这段代码实现了对每个通道进行量化,能够处理不同的组大小,并允许处理最后一组大小不等于组大小的情况。
正文完