torchchat中简单的对称量化实现

这里简单记录 dynamically_quantize_per_channel 中对称量化实现。更详细的量化知识请阅读 A Visual Guide to Quantization

import torch
from typing import Optional
# https://github.com/pytorch/torchchat/blob/main/quantization/quantize.py


def dynamically_quantize_per_channel(
    x,
    quant_min,
    quant_max,
    target_dtype,
    groupsize: Optional[int] = None,
    *,
    scales_dtype=torch.float16,
    enable_non_multiple_groups=True,
):
    """
    Dynamically quantize per channel.  This function is used for quantizing weights,
    for linear and embedding layers.

    Arguments:
        x: input tensor,
        quant_min: minimum value after quantization,
        quant_max: maximum value after quantization,
        target_dtype: target data type for weights after quantization,
        groupsize: number of elements of the channel to quantize together

    Keyword arguments:
        scales_dtype: data type of scale,
        enable_non_multiple_groups: if True, allow the rowsize to not be a multiple of group size,
                        with a final group of a size less than group size.

    Assumptions:
        This function assumes symmetric quantization, axis ==0 and a dense memory format.
    """

    # assumes symmetric quantization
    # assumes axis == 0
    # assumes dense memory format
    # TODO(future): relax ^ as needed

    x_shape_1 = x.shape[1]

    if groupsize is None or groupsize == 0:
        items = x_shape_1
    elif ((x_shape_1 % groupsize) == 0) or not enable_non_multiple_groups:
        assert groupsize > 0, "group size must be positive"
        assert (x_shape_1 % groupsize) == 0, f"weights dimension 1 = {x_shape_1} must be a multiple of group size {groupsize}"
        items = groupsize
    else:
        assert groupsize > 0, "group size must be positive"
        print(f"row-size of weight matrix {x_shape_1} is not divisible by group size {groupsize}, using nearest neighbor rounding"
        )
        assert (x_shape_1 % groupsize != 0), f"expected x.shape[1] to not be a multiple of group size {groupsize}, but got {x_shape_1}"
        padding = groupsize - (x_shape_1 % groupsize)
        x = F.pad(x, (0, padding))
        items = groupsize

    # default setup for affine quantization of activations
    eps = torch.finfo(torch.float32).eps

    x = x.view(x.shape[0], x.shape[1] // items, items)
    # get min and max
    min_val, max_val = torch.aminmax(x, dim=2)
    # print(f"min_val {min_val}")
    # print(f"max_val {max_val}")

    # calculate scales and zero_points based on min and max
    # reference: https://fburl.com/code/srbiybme
    min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
    max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
    device = min_val_neg.device

    # reference: https://fburl.com/code/4wll53rk
    max_val_pos = torch.max(-min_val_neg, max_val_pos)
    scales = max_val_pos / (float(quant_max - quant_min) / 2)
    # ensure scales is the same dtype as the original tensor
    scales = torch.clamp(scales, min=eps).to(x.dtype)
    zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)

    # quantize based on qmin/qmax/scales/zp
    # reference: https://www.internalfb.com/code/fbsource/[8edc275012b1]/fbcode/caffe2/torch/ao/quantization/fx/_decomposed.py?lines=63
    x_div = x / scales.unsqueeze(-1)
    x_round = torch.round(x_div)
    x_zp = x_round + zero_points.unsqueeze(-1)
    quant = (torch.clamp(x_zp, quant_min, quant_max).to(target_dtype).view(x.shape[0], -1)
    )

    scales = scales.to(dtype=scales_dtype)
    quant = quant[:, :x_shape_1]

    return quant, scales, zero_points


if __name__ == '__main__':
    import torch
    import torch.nn.functional as F

    x = torch.tensor([
        [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
        [8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0]
    ], dtype=torch.float32)

    # 量化参数
    quant_min = -128
    quant_max = 127
    target_dtype = torch.int8
    groupsize = 4

    # 量化函数调用
    quant, scales, zero_points = dynamically_quantize_per_channel(
        x,
        quant_min,
        quant_max,
        target_dtype,
        groupsize
    )

    print("量化后的张量:")
    print(quant)
    print("尺度因子:")
    print(scales)
    print("零点:")
    print(zero_points)


######################
量化后的张量:tensor([[ 32,  64,  96, 127,  80,  96, 112, 127],
        [127, 112,  96,  80, 127,  96,  64,  32]], dtype=torch.int8)
尺度因子:tensor([[0.0314, 0.0627],
        [0.0627, 0.0314]], dtype=torch.float16)
零点:tensor([[0, 0],
        [0, 0]])

下面是这个量化过程的公式和步骤:

  1. 输入张量重塑
  • 将输入张量 x 变形为 (N, C', G),其中 N 是批次大小,C'x 的第二维度(通常是通道数)除以 items(如果需要的话填充以确保这个维度是 items 的倍数),G 是每组的大小(等于 items)。
  1. 计算最小值和最大值
  • 对每一组 G,计算每个通道的最小值 min_val 和最大值 max_val
    $$
    \text{min_val} = \min(x, \text{dim}=2) \
    \text{max_val} = \max(x, \text{dim}=2) \tag{1}
    $$
  1. 计算尺度因子和零点
  • 对于每个通道的负最小值 min_val_neg 和正最大值 max_val_pos,计算 max_val_pos
    $$
    \text{max_val_pos} = \max(-\text{min_val_neg}, \text{max_val_pos}) \tag{2}
    $$
  • 计算尺度因子 scales
    $$
    \text{scales} = \frac{\text{max_val_pos}}{\frac{\text{quant_max} – \text{quant_min}}{2}} \tag{3}
    $$
    其中 quant_minquant_max 是量化后的最小值和最大值。确保 scales 的最小值为 eps(一个很小的常数,防止除以零)。
  1. 量化计算
  • 对输入张量 x 进行量化:
    $$
    x_div = \frac{x}{\text{scales}.unsqueeze(-1)} \
    x_round = \text{round}(x_div) \
    x_zp = x_round + \text{zero_points}.unsqueeze(-1)\
    \text{quant} = \text{clamp}(x_zp, \text{quant_min}, \text{quant_max}) \tag{4}
    $$
  • 将量化后的结果转换为目标数据类型 target_dtype

总结来说,量化的公式可以表示为:

$$
\text{quant} = \text{clamp}\left(\frac{x}{\text{scales}} + \text{zero_points}, \text{quant_min}, \text{quant_max}\right) \tag{5}
$$
其中:

  • $(\text{scales} = \frac{\text{max_val_pos}}{\frac{\text{quant_max} – \text{quant_min}}{2}})$
  • $\text{zero_points} = 0$

这段代码实现了对每个通道进行量化,能够处理不同的组大小,并允许处理最后一组大小不等于组大小的情况。

正文完
 
admin
版权声明:本站原创文章,由 admin 2024-08-03发表,共计4559字。
转载说明:除特殊说明外本站文章皆由CC-4.0协议发布,转载请联系tensortimes@gmail.com。