joennlae / halutmatmul

Hashed Lookup Table based Matrix Multiplication (halutmatmul) - Stella Nera accelerator

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Train and test on model level

CanYing0913 opened this issue · comments

Hi, thanks for your work! I want to know the standard way to train halutmatmul on model level, as I am trying various model architectures with different datasets to examine accuracy loss from applying PQ in inference. That is, train it with same training dataset used to train the model and perform inference to collect accuracy. Currently with example.py, it only shows matrix multiplication level of training and inference. If you can point me to the right file(s) to look for that will be great, thanks!

Hey :-)

Thank you for reaching out.

For example, the Resnet9 model can be found here:

class ResNet9(nn.Module):

So, for the default pre-training, you could run the standard training script:

python training/train.py --device cuda:0 --opt adam --model resnet9 --cifar10 --lr 0.001 --lr-scheduler cosineannealinglr --epochs 200 --amp --output-dir /path/to/model-checkpoint-output-base.pth

You can adapt the hyperparameters around here (for retraining + fine-tuning):

LR = 0.001 # 0.001/0.002 layer-per-payer, 0.0005 fine-tuning

Then, for the layer-per-layer retraining run:

python retraining.py 0 -single -testname resnet9-lpl-0.001 -checkpoint /path/to/model-checkpoint-output-base.pth

This will generate checkpoints for each layer:

args.checkpoint = f"{args_checkpoint.output_dir}/retrained_checkpoint_{i}_trained.pth" # type: ignore

Then, to run fine-tuning, update EPOCHS to 300 here:

TRAIN_EPOCHS = 25 # 25 layer-per-layer, 300 fine-tuning

and run with the checkpoint of the last layer:

python retraining.py 0 -single -testname resnet9-lpl-int8-lut-fine-tune -checkpoint /path/to/output-retraining/resnet9-lpl-0.001/checkpoints/retrained_checkpoint_7_trained.pth

I hope this helps :-) I know it could be cleaned up.

Thank you for your thorough answer! I took a look at HalutConv2d here:

class HalutConv2d(_ConvNd):
It seems you replaced every Conv2d and Linear to halut version with good accuracy. I can't wait to see accuracies with larger models using your conv2d and linear :-)
In your halut_matmul_forward function:
def halut_matmul_forward(
the prototype is defined as self.P, however, I cannot find any place that assigns the prototypes, except:
self.P = Parameter(
state_dict[prefix + "P"]
.clone()
.to(str(self.weight.device))
.to(self.weight.dtype),
requires_grad=True,
)
but that is simply loading tensors from dictionary. how (and where) are the prototype and lut trained?

I see how this can be confusing. The prototype is used when we run it like in the LUT-NN paper (https://arxiv.org/abs/2302.03213). Back in the day, there was no implementation of that public :-) Now there is. I would highly recommend also checking out their implementation here: https://github.com/lutnn/blink-mm

So P is basically set to null when running halut_matmul. The initial LUT is trained according to the madness paper: https://arxiv.org/abs/2106.10860 and here: https://github.com/dblalock/bolt

I heavily refactored their code and it is here: https://github.com/joennlae/halutmatmul/tree/master/src/python/halutmatmul

The learning algorithm is defined here:

and here:

def learn_halut(
l: str,
C: int,
data_path: str,
store_path: str,
K: int = 16,
loop_order: Literal["im2col", "kn2col"] = "im2col",
kernel_size: tuple[int, int] = (1, 1), # only needed for kn2col
stride: tuple[int, int] = (1, 1), # only needed for kn2col
padding: tuple[int, int] = (0, 0), # only needed for kn2col
niter=2,
nredo=1,
min_points_per_centroid=100,
max_points_per_centroid=1000,
codebook: int = -1,
) -> None:

and it is ultimately during retraining called from here:

def run_halut_offline_training(self, codebook: int = -1) -> None:

Thanks again! I noted that the implementations for plain data vectors and convolution models are separated. (i.e. halutmatmul and HalutConv2d/HalutLinear have totally different training procedures) Where are the prototypes and lut trained for HalutConv2d? It seems the update_lut function only performs a rounding operation on lut itself, but I could not find anywhere self.P and self.lut are trained.
I am also quite confused: is HalutConv2d trained along with original model training? If so, where are the functions that learns prototypes, and construct lut at the end of iterations? My current approach is to train my PQConv2d offline (i.e. trained after entire model is trained, basically I need to iterate through dataset batches again) but they should be similar.

I'm not entirely certain where the misunderstanding lies, but I'll do my best to clarify :-)

The lut and P are initialized using the methods described in the previous response. During training, updates are propagated back in FP16, but for the forward pass, we utilize the update_lut function to quantize the lut and then process the activation through it. This method is fairly typical.