基于imdb的LoRA微调分类模型 Xlnet[简单例子,小白可上手]

2024-03-18 1072阅读

温馨提示:这篇文章已超过370天没有更新,请注意相关的内容是否还可用!

LORA微调简单例子

  • 什么是LORA微调?
    • 为什么我会接触LoRA?
      • 1. 安装库
      • 2.开始LoRA微调
        • Train with LoRA
        • inference
        • 说明下target_modules这个参数(可选)
        • 谈一下LoRA的当下的应用

          什么是LORA微调?

          LoRA 全称为 low-rank adaptation, 低秩自适应微调方法。

          paper:LORA: LOW-RANK ADAPTATION OF LARGE LANGUAGE MODELS

          大概思路:借鉴知乎 https://www.zhihu.com/tardis/zm/art/623543497?source_id=1005 的这个图说明一下:

          • LoRA微调什么? 预训练模型是 左边, Lora微调得到的模型是 右边。

            LoRA相当于训练一个新的额外的参数,来学习原参数的知识。

          • LoRA优点:相比于全量微调(full finetune)方法的优点是 LoRA通常只需要更新不到1%的参数就能达到与全量微调相当的效果。

            full finetune:更新所有的预训练网络参数。

            基于imdb的LoRA微调分类模型 Xlnet[简单例子,小白可上手]

            例如在下面的例子中,peftmodel 显示lora微调训练的参数量是原来的0.62倍,十分小。

            model.print_trainable_parameters()
            

            基于imdb的LoRA微调分类模型 Xlnet[简单例子,小白可上手]

            为什么我会接触LoRA?

            最近在2张40G的GPU上,想要实现基于imdb数据集微调Xlnet-large模型,发现直接load pre-trained model 全量微调压根跑不起来,内存直接爆炸。

            • 内存/存储空间不够

              因此,想通过一些高效低参的微调的方法来实现,如LoRA, bitfit(只微调网络中的bias 参数,比lora效果差得多)。

              接下来直接上代码介绍怎么基于LoRA微调xlnet-base-cased吧。很简单的过程,但是你要能连接Huggingface,不能连接的话就把数据和模型下到本地吧。

              依赖库提供方手册
              transformerHuggingfacehttps://huggingface.co/docs/transformers/pipeline_tutorial
              datasetsHuggingfacehttps://huggingface.co/docs/datasets/load_hub
              evaluateHuggingfacehttps://huggingface.co/docs/evaluate/transformers_integrations
              peftHuggingfacehttps://huggingface.co/docs/peft/quicktour

              建议在conda 虚拟环境里跑哈 使用的环境如下:

              Framework versions
              PEFT 0.9.0
              Transformers 4.38.2
              Pytorch 2.2.1
              Datasets 2.18.0
              Tokenizers 0.15.2
              python 3.10
              

              1. 安装库

              执行以下代码安装相关的库函数,需要翻墙,能够连接上Huggingface :

              pip install transformers datasets evaluate peft
              

              2.开始LoRA微调

              下载xlnet-base-cased模型-- 用peft封装模型- - 训练参数 — 训练模型 – 上传模型到hub

              Train with LoRA
              def train_lora():
              	# 如果需要将模型推送到hub上的话,需要登录一下hub
                  login(token='hf_XXX')  ##hf_XXX 替换成你自己的hub账号的token
                  peft_config = LoraConfig(task_type=TaskType.SEQ_CLS,
                                           target_modules=['layer_1','layer_2'],
                                            inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1)
                 #下载xlnet-base-cased模型
                  model_base = 'xlnet-base-cased'
                  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
                  model  = AutoModelForSequenceClassification.from_pretrained(model_base, num_labels=2)
                  print(model)
              	#用peft封装模型
                  model = get_peft_model(model,peft_config)
                  model.print_trainable_parameters()
                  model = model.to(device)
                
                  tokenizer = AutoTokenizer.from_pretrained(model_base,max_length=350)
                  def preocess_tokenize(exam):
                      return tokenizer(exam['text'],truncation=True,padding=True)
                  #下载数据集
                  imdb = load_dataset('imdb')
                  tokenized_imdb = imdb.map(preocess_tokenize,batched=True)
                  #设置评估指标
                  accuracy = evaluate.load("accuracy")
                  def compute_metrics(eval_pred):
                      predictions, labels = eval_pred
                      predictions = np.argmax(predictions, axis=1)
                      return accuracy.compute(predictions=predictions, references=labels)
                 #设置训练参数
                  bts = 1
                  accumulated_step = 2
                  training_args = TrainingArguments(
                      output_dir=f"imdb_{model_base.replace('-','_')}",
                      learning_rate=2e-5,
                      per_device_train_batch_size=bts,
                      per_device_eval_batch_size=bts,
                      num_train_epochs=1,
                      weight_decay=0.01,
                      evaluation_strategy="epoch",
                      save_strategy="epoch",
                      load_best_model_at_end=True,
                      push_to_hub=True,
                      gradient_accumulation_steps=accumulated_step,
                  )
                 #配置训练器
                  trainer = Trainer(
                  model=model,
                  args=training_args,
                  train_dataset=tokenized_imdb["train"],
                  eval_dataset=tokenized_imdb["test"],
                  tokenizer=tokenizer,
                  compute_metrics=compute_metrics,
                  )
                  #训练并上传
                  print(f'Start to Train {model_base}')
                  trainer.train()
                  print('Success!')
                  trainer.push_to_hub()
                  print('Success to push the model to huggingface')
              if __name__ =='__main__':
                  train_lora()
              
              inference

              以下是我训练号一个peftmodel = Siki-77/imdb_xlnet_base_cased,使用它去判断其他样本的过程如下

              from peft import PeftModel, PeftConfig
              from transformers import AutoModelForSequenceClassification
              config = PeftConfig.from_pretrained("Siki-77/imdb_xlnet_base_cased")
              model = AutoModelForSequenceClassification.from_pretrained("xlnet-base-cased")
              model = PeftModel.from_pretrained(model, "Siki-77/imdb_xlnet_base_cased")
              tokenizer = AutoTokenizer.from_pretrained(model_base,max_length=350)
              inputs = tokenizer("Preheat the oven to 350 degrees and place the cookie dough", return_tensors="pt")
              with torch.no_grad():
                  logits = model(**inputs).logits
              label = logits.argmax().item()  # 这里是2分类,因此 label=0 或1 
              

              以上例子结束,代码完全跑的通,很简单的代码流程。

              说明下target_modules这个参数(可选)

              这一小节回答以下问题:

              • 任意的模型进行LoRA微调的话,需要怎么设置target modules 这个参数呢??
              • 省流回答:target module 只能选择以下几类层的名称

                -torch.nn.Linear, torch.nn.Embedding, torch.nn.Conv2d, transformers.pytorch_utils.Conv1D

                本节主要补充一下 peft_config两个关键参数说明:

                  peft_config = LoraConfig(
                      task_type=TaskType.SEQ_CLS,
                      target_modules=['layer_1','layer_2'],
                      inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1)
                
                • task_type表示要微调的模型类型,TaskType.SEQ_CLS表示文本分类模型,
                • target_modules 表示用更少参数代替的层(只允许 linear
                • 其他参数默认

                  基于imdb的LoRA微调分类模型 Xlnet[简单例子,小白可上手]

                  如,我用的pre-trained model = xlnet-base-cased,它的模型结构如上图所示,一般lora压缩的是堆叠好几层的模块的 linear/conv层,这里只有名为“layer-1”“layer-2”的线性层,因此

                  target_modules=['layer_1','layer_2'],
                  

                  如果我设置了unsupport layer的名称为 target modules,将会报错,如下

                  peft_config = LoraConfig(
                                 r=16,
                                 lora_alpha=32,
                                 lora_dropout=0.05,
                                 target_modules=["layer_norm"],## "layer_norm" is non-linear layer
                                 task_type=TaskType.SEQ_CLS, # this is necessary
                                 )
                  

                  报错信息如下

                  Currently, only the following modules are supported: torch.nn.Linear, torch.nn.Embedding, torch.nn.Conv2d, transformers.pytorch_utils.Conv1D

                  谈一下LoRA的当下的应用

                  随着GPT-3.5以来,越来越多大模型发布,甚至开源。如果要在本地微调这些大模型的话,基本全量微调是希望渺茫的(除非你的GPU内存真的超足),这种情况下,LoRA简直是天选之子,微调的首选方式。

                  真的十分有效,当然LoRA也不一定能解决存储不足的问题。

                  还有一个流行的“微调” promt,不微调参数,而是直接引导模型为我所用,很玄也很难定义怎么prompt才能“为我所用”。

                  相比于prompt微调,个人感觉LoRA还是可控性比较高的微调。

VPS购买请点击我

免责声明:我们致力于保护作者版权,注重分享,被刊用文章因无法核实真实出处,未能及时与作者取得联系,或有版权异议的,请联系管理员,我们会立即处理! 部分文章是来自自研大数据AI进行生成,内容摘自(百度百科,百度知道,头条百科,中国民法典,刑法,牛津词典,新华词典,汉语词典,国家院校,科普平台)等数据,内容仅供学习参考,不准确地方联系删除处理! 图片声明:本站部分配图来自人工智能系统AI生成,觅知网授权图片,PxHere摄影无版权图库和百度,360,搜狗等多加搜索引擎自动关键词搜索配图,如有侵权的图片,请第一时间联系我们,邮箱:ciyunidc@ciyunshuju.com。本站只作为美观性配图使用,无任何非法侵犯第三方意图,一切解释权归图片著作权方,本站不承担任何责任。如有恶意碰瓷者,必当奉陪到底严惩不贷!

目录[+]