这里简单记录 dynamically_quantize_per_channel 中对称量化实现。更详细的量化知识请阅读 A Visual Guide to Quantization。
import torch
from typing import Optional
# https://github.com/pytorch/torchchat/blob/main/quantization/quantize.py
def dynamically_quantize_per_channel(
x,
quant_min,
quant_max,
target_dtype,
groupsize: Optional[int] = None,
*,
scales_dtype=torch.float16,
enable_non_multiple_groups=True,
):
"""
Dynamically quantize per channel. This function is used for quantizing weights,
for linear and embedding layers.
Arguments:
x: input tensor,
quant_min: minimum value after quantization,
quant_max: maximum value after quantization,
target_dtype: target data type for weights after quantization,
groupsize: number of elements of the channel to quantize together
Keyword arguments:
scales_dtype: data type of scale,
enable_non_multiple_groups: if True, allow the rowsize to not be a multiple of group size,
with a final group of a size less than group size.
Assumptions:
This function assumes symmetric quantization, axis ==0 and a dense memory format.
"""
# assumes symmetric quantization
# assumes axis == 0
# assumes dense memory format
# TODO(future): relax ^ as needed
x_shape_1 = x.shape[1]
if groupsize is None or groupsize == 0:
items = x_shape_1
elif ((x_shape_1 % groupsize) == 0) or not enable_non_multiple_groups:
assert groupsize > 0, "group size must be positive"
assert (x_shape_1 % groupsize) == 0, f"weights dimension 1 = {x_shape_1} must be a multiple of group size {groupsize}"
items = groupsize
else:
assert groupsize > 0, "group size must be positive"
print(f"row-size of weight matrix {x_shape_1} is not divisible by group size {groupsize}, using nearest neighbor rounding"
)
assert (x_shape_1 % groupsize != 0), f"expected x.shape[1] to not be a multiple of group size {groupsize}, but got {x_shape_1}"
padding = groupsize - (x_shape_1 % groupsize)
x = F.pad(x, (0, padding))
items = groupsize
# default setup for affine quantization of activations
eps = torch.finfo(torch.float32).eps
x = x.view(x.shape[0], x.shape[1] // items, items)
# get min and max
min_val, max_val = torch.aminmax(x, dim=2)
# print(f"min_val {min_val}")
# print(f"max_val {max_val}")
# calculate scales and zero_points based on min and max
# reference: https://fburl.com/code/srbiybme
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
device = min_val_neg.device
# reference: https://fburl.com/code/4wll53rk
max_val_pos = torch.max(-min_val_neg, max_val_pos)
scales = max_val_pos / (float(quant_max - quant_min) / 2)
# ensure scales is the same dtype as the original tensor
scales = torch.clamp(scales, min=eps).to(x.dtype)
zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
# quantize based on qmin/qmax/scales/zp
# reference: https://www.internalfb.com/code/fbsource/[8edc275012b1]/fbcode/caffe2/torch/ao/quantization/fx/_decomposed.py?lines=63
x_div = x / scales.unsqueeze(-1)
x_round = torch.round(x_div)
x_zp = x_round + zero_points.unsqueeze(-1)
quant = (torch.clamp(x_zp, quant_min, quant_max).to(target_dtype).view(x.shape[0], -1)
)
scales = scales.to(dtype=scales_dtype)
quant = quant[:, :x_shape_1]
return quant, scales, zero_points
if __name__ == '__main__':
import torch
import torch.nn.functional as F
x = torch.tensor([
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
[8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0]
], dtype=torch.float32)
# 量化参数
quant_min = -128
quant_max = 127
target_dtype = torch.int8
groupsize = 4
# 量化函数调用
quant, scales, zero_points = dynamically_quantize_per_channel(
x,
quant_min,
quant_max,
target_dtype,
groupsize
)
print("量化后的张量:")
print(quant)
print("尺度因子:")
print(scales)
print("零点:")
print(zero_points)
######################
量化后的张量:tensor([[ 32, 64, 96, 127, 80, 96, 112, 127],
[127, 112, 96, 80, 127, 96, 64, 32]], dtype=torch.int8)
尺度因子:tensor([[0.0314, 0.0627],
[0.0627, 0.0314]], dtype=torch.float16)
零点:tensor([[0, 0],
[0, 0]])
下面是这个量化过程的公式和步骤:
- 输入张量重塑 :
- 将输入张量
x
变形为(N, C', G)
,其中N
是批次大小,C'
是x
的第二维度(通常是通道数)除以items
(如果需要的话填充以确保这个维度是items
的倍数),G
是每组的大小(等于items
)。
- 计算最小值和最大值 :
- 对每一组
G
,计算每个通道的最小值min_val
和最大值max_val
:
$$
\text{min_val} = \min(x, \text{dim}=2) \
\text{max_val} = \max(x, \text{dim}=2) \tag{1}
$$
- 计算尺度因子和零点 :
- 对于每个通道的负最小值
min_val_neg
和正最大值max_val_pos
,计算max_val_pos
:
$$
\text{max_val_pos} = \max(-\text{min_val_neg}, \text{max_val_pos}) \tag{2}
$$ - 计算尺度因子
scales
:
$$
\text{scales} = \frac{\text{max_val_pos}}{\frac{\text{quant_max} – \text{quant_min}}{2}} \tag{3}
$$
其中quant_min
和quant_max
是量化后的最小值和最大值。确保scales
的最小值为eps
(一个很小的常数,防止除以零)。
- 量化计算 :
- 对输入张量
x
进行量化:
$$
x_div = \frac{x}{\text{scales}.unsqueeze(-1)} \
x_round = \text{round}(x_div) \
x_zp = x_round + \text{zero_points}.unsqueeze(-1)\
\text{quant} = \text{clamp}(x_zp, \text{quant_min}, \text{quant_max}) \tag{4}
$$ - 将量化后的结果转换为目标数据类型
target_dtype
。
总结来说,量化的公式可以表示为:
$$
\text{quant} = \text{clamp}\left(\frac{x}{\text{scales}} + \text{zero_points}, \text{quant_min}, \text{quant_max}\right) \tag{5}
$$
其中:
- $(\text{scales} = \frac{\text{max_val_pos}}{\frac{\text{quant_max} – \text{quant_min}}{2}})$
- $\text{zero_points} = 0$
这段代码实现了对每个通道进行量化,能够处理不同的组大小,并允许处理最后一组大小不等于组大小的情况。
正文完