Processing math: 100%

torchchat中简单的对称量化实现

这里简单记录 dynamically_quantize_per_channel 中对称量化实现。更详细的量化知识请阅读 A Visual Guide to Quantization

import torch
from typing import Optional
# https://github.com/pytorch/torchchat/blob/main/quantization/quantize.py
def dynamically_quantize_per_channel(
x,
quant_min,
quant_max,
target_dtype,
groupsize: Optional[int] = None,
*,
scales_dtype=torch.float16,
enable_non_multiple_groups=True,
):
"""
Dynamically quantize per channel. This function is used for quantizing weights,
for linear and embedding layers.
Arguments:
x: input tensor,
quant_min: minimum value after quantization,
quant_max: maximum value after quantization,
target_dtype: target data type for weights after quantization,
groupsize: number of elements of the channel to quantize together
Keyword arguments:
scales_dtype: data type of scale,
enable_non_multiple_groups: if True, allow the rowsize to not be a multiple of group size,
with a final group of a size less than group size.
Assumptions:
This function assumes symmetric quantization, axis ==0 and a dense memory format.
"""
# assumes symmetric quantization
# assumes axis == 0
# assumes dense memory format
# TODO(future): relax ^ as needed
x_shape_1 = x.shape[1]
if groupsize is None or groupsize == 0:
items = x_shape_1
elif ((x_shape_1 % groupsize) == 0) or not enable_non_multiple_groups:
assert groupsize > 0, "group size must be positive"
assert (x_shape_1 % groupsize) == 0, f"weights dimension 1 = {x_shape_1} must be a multiple of group size {groupsize}"
items = groupsize
else:
assert groupsize > 0, "group size must be positive"
print(f"row-size of weight matrix {x_shape_1} is not divisible by group size {groupsize}, using nearest neighbor rounding"
)
assert (x_shape_1 % groupsize != 0), f"expected x.shape[1] to not be a multiple of group size {groupsize}, but got {x_shape_1}"
padding = groupsize - (x_shape_1 % groupsize)
x = F.pad(x, (0, padding))
items = groupsize
# default setup for affine quantization of activations
eps = torch.finfo(torch.float32).eps
x = x.view(x.shape[0], x.shape[1] // items, items)
# get min and max
min_val, max_val = torch.aminmax(x, dim=2)
# print(f"min_val {min_val}")
# print(f"max_val {max_val}")
# calculate scales and zero_points based on min and max
# reference: https://fburl.com/code/srbiybme
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
device = min_val_neg.device
# reference: https://fburl.com/code/4wll53rk
max_val_pos = torch.max(-min_val_neg, max_val_pos)
scales = max_val_pos / (float(quant_max - quant_min) / 2)
# ensure scales is the same dtype as the original tensor
scales = torch.clamp(scales, min=eps).to(x.dtype)
zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
# quantize based on qmin/qmax/scales/zp
# reference: https://www.internalfb.com/code/fbsource/[8edc275012b1]/fbcode/caffe2/torch/ao/quantization/fx/_decomposed.py?lines=63
x_div = x / scales.unsqueeze(-1)
x_round = torch.round(x_div)
x_zp = x_round + zero_points.unsqueeze(-1)
quant = (torch.clamp(x_zp, quant_min, quant_max).to(target_dtype).view(x.shape[0], -1)
)
scales = scales.to(dtype=scales_dtype)
quant = quant[:, :x_shape_1]
return quant, scales, zero_points
if __name__ == '__main__':
import torch
import torch.nn.functional as F
x = torch.tensor([
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
[8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0]
], dtype=torch.float32)
# 量化参数
quant_min = -128
quant_max = 127
target_dtype = torch.int8
groupsize = 4
# 量化函数调用
quant, scales, zero_points = dynamically_quantize_per_channel(
x,
quant_min,
quant_max,
target_dtype,
groupsize
)
print("量化后的张量:")
print(quant)
print("尺度因子:")
print(scales)
print("零点:")
print(zero_points)
######################
量化后的张量:tensor([[ 32, 64, 96, 127, 80, 96, 112, 127],
[127, 112, 96, 80, 127, 96, 64, 32]], dtype=torch.int8)
尺度因子:tensor([[0.0314, 0.0627],
[0.0627, 0.0314]], dtype=torch.float16)
零点:tensor([[0, 0],
[0, 0]])

下面是这个量化过程的公式和步骤:

  1. 输入张量重塑
  • 将输入张量 x 变形为 (N, C', G),其中 N 是批次大小,C'x 的第二维度(通常是通道数)除以 items(如果需要的话填充以确保这个维度是 items 的倍数),G 是每组的大小(等于 items)。
  1. 计算最小值和最大值
  • 对每一组 G,计算每个通道的最小值 min_val 和最大值 max_val
    min_val=min(x,dim=2) max_val=max(x,dim=2)
  1. 计算尺度因子和零点
  • 对于每个通道的负最小值 min_val_neg 和正最大值 max_val_pos,计算 max_val_pos
    max_val_pos=max(min_val_neg,max_val_pos)
  • 计算尺度因子 scales
    scales=max_val_posquant_maxquant_min2
    其中 quant_minquant_max 是量化后的最小值和最大值。确保 scales 的最小值为 eps(一个很小的常数,防止除以零)。
  1. 量化计算
  • 对输入张量 x 进行量化:
    xdiv=xscales.unsqueeze(1) xround=round(xdiv) xzp=xround+zero_points.unsqueeze(1) quant=clamp(xzp,quant_min,quant_max)
  • 将量化后的结果转换为目标数据类型 target_dtype

总结来说,量化的公式可以表示为:

quant=clamp(xscales+zero_points,quant_min,quant_max)
其中:

  • (scales=max_val_posquant_maxquant_min2)
  • zero_points=0

这段代码实现了对每个通道进行量化,能够处理不同的组大小,并允许处理最后一组大小不等于组大小的情况。

正文完
 
admin
版权声明:本站原创文章,由 admin 2024-08-03发表,共计4559字。
转载说明:除特殊说明外本站文章皆由CC-4.0协议发布,转载请联系tensortimes@gmail.com。