THUDM / SwissArmyTransformer

SwissArmyTransformer is a flexible and powerful library to develop your own Transformer variants.

Home Page:https://THUDM.github.io/SwissArmyTransformer

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

测试源码中给的qlora.py报错

shituo123456 opened this issue · comments

直接跑源码的qlora.py,报错
image
给model.child = LoraLinear(100, 200, 10)改为model.child = LoraLinear(100, 200, 10,10,2)后,又报错
image

这是旧版本的__main__函数了,需要你自己改一下。

这是qlora.py的执行代码,怎么改呢,一直做cv,才基础多模态大模型

if __name__ == '__main__':
    class Model(nn.Module):
        def __init__(self):
            super().__init__()
            self.child = nn.Linear(100, 200)
        
        def forward(self, x):
            return self.child(x)

    model = Model()
    torch.save(model.state_dict(), "linear.pt")
    x = torch.randn(2, 100)
    out1 = model(x)
    model.child = LoraLinear(100, 200, 10)
    model.load_state_dict(torch.load("linear.pt"), strict=False)
    out2 = model(x)
    torch.save(model.state_dict(), "lora.pt")
    ckpt = torch.load("lora.pt")
    breakpoint()
    model.load_state_dict(ckpt, strict=False)
    out3 = model(x)
    breakpoint() 

我也忘记了,时间太久了,你自己读一下源码吧,也不长

我也忘记了,时间太久了,你自己读一下源码吧,也不长

好的,那我先试试

这样改还会报quant_state不能是None,这个quant_state该怎么添加

if __name__ == '__main__':
    class Model(nn.Module):
        def __init__(self):
            super().__init__()
            self.child = nn.Linear(100, 200)
        
        def forward(self, x):
            return self.child(x)

    model = Model()
    torch.save(model.state_dict(), "linear.pt")
    x = torch.randn(2, 100)
    out1 = model(x)
    model.child = LoraLinear(nn.Linear, 5, 100, 200, 10, qlora=True)
    model.load_state_dict(torch.load("linear.pt"), strict=False)
    out2 = model(x)
    torch.save(model.state_dict(), "lora.pt")
    ckpt = torch.load("lora.pt")
    breakpoint()
    model.load_state_dict(ckpt, strict=False)
    out3 = model(x)
    breakpoint()

image

需要在gpu上运行才会有quant_state。也就是说你需要model = model.cuda()x = x.cuda()

并且注意model.cuda只能调用一次,不然会出错(这是bitsandbytes的实现,我也控制不了,他们重载了.cuda()函数)

需要在gpu上运行才会有quant_state。也就是说你需要model = model.cuda()x = x.cuda()

并且注意model.cuda只能调用一次,不然会出错(这是bitsandbytes的实现,我也控制不了,他们重载了.cuda()函数)

确实只能.cuda()一次,给LoraLinear提前.cuda()就会报维度错误。
调试好了,非常感谢耐心回复