NLP - 基于bert预训练模型的文本多分类示例

项目说明

项目名称

基于DistilBERT的标题多分类任务

项目概述

本项目旨在使用DistilBERT模型对给定的标题文本进行多分类任务。项目包括从数据处理、模型训练、模型评估到最终的API部署。该项目采用模块化设计,以便于理解和维护。

项目结构

.
├── bert_data
│   ├── train.txt
│   ├── dev.txt
│   └── test.txt
├── saved_model
├── results
├── logs
├── data_processing.py
├── dataset.py
├── training.py
├── app.py
└── main.py

文件说明

  1. bert_data/:存放训练集、验证集和测试集的数据文件。

    • train.txt
    • dev.txt
    • test.txt
  2. saved_model/:存放训练好的模型和tokenizer。

  3. results/:存放训练结果。

  4. logs/:存放训练日志。

  5. data_processing.py:数据处理模块,负责读取和预处理数据。

  6. dataset.py:数据集类模块,定义了用于训练和评估的数据集类。

  7. training.py:模型训练模块,定义了训练和评估模型的过程。

  8. app.py:模型部署模块,使用FastAPI创建API服务。

  9. main.py:主脚本,运行整个流程,包括数据处理、模型训练和部署。

数据集数据规范

为了确保数据处理和模型训练的顺利进行,请按照以下规范准备数据集文件。每个文件包含的标题和标签分别使用制表符(\t)分隔。以下是一个示例数据集的格式。

数据文件格式

数据文件应为纯文本文件,扩展名为.txt,文件内容的每一行应包含一个文本标题和一个对应的分类标签,用制表符分隔。数据文件不应包含表头。

数据示例
探索神秘的海底世界    7
如何在家中制作美味披萨    2
全球气候变化的原因和影响    1
最新的智能手机评测    8
健康饮食:如何搭配均衡的膳食    5
最受欢迎的电影和电视剧推荐    3
了解宇宙的奥秘:天文学入门    0
如何种植和照顾多肉植物    9
时尚潮流:今年夏天的必备单品    6
如何有效管理个人财务    4

注意事项

  • 标签规范:确保每个标题文本的标签是一个整数,表示类别。
  • 文本编码:确保数据文件使用UTF-8编码,避免中文字符乱码。
  • 数据一致性:确保训练、验证和测试数据格式一致,便于数据加载和处理。

通过以上规范和示例数据文件创建方法,可以确保数据文件符合项目需求,并顺利进行数据处理和模型训练。

模块说明

1. 数据处理模块 (data_processing.py)

功能:读取数据文件并进行预处理。

  • load_data(file_path): 读取指定路径的数据文件,并返回一个包含文本和标签的数据框。
  • tokenize_data(data, tokenizer, max_length=128): 使用BERT的tokenizer对数据进行tokenize处理。
  • main(): 加载数据、tokenize数据并返回处理后的数据。
2. 数据集类模块 (dataset.py)

功能:定义数据集类,便于模型训练。

  • TextDataset: 将tokenized数据和标签封装成PyTorch的数据集格式,便于Trainer进行训练和评估。
3. 模型训练模块 (training.py)

功能:定义训练和评估模型的过程。

  • train_model(): 加载数据和tokenizer,创建数据集,加载模型,设置训练参数,定义Trainer,训练和评估模型,保存训练好的模型和tokenizer。
4. 模型部署模块 (app.py)

功能:使用FastAPI进行模型部署。

  • predict(item: Item): 接收POST请求的文本输入,使用训练好的模型进行预测并返回分类结果。
  • FastAPI应用启动配置。
5. 主脚本 (main.py)

功能:运行整个流程,包括数据处理、模型训练和部署。

  • main(): 运行模型训练流程,并输出训练完成的提示。

运行步骤

  1. 安装依赖
pip install pandas torch transformers fastapi uvicorn scikit-learn
  1. 数据处理

确保bert_data文件夹下包含train.txtdev.txttest.txt文件,每个文件包含文本和标签,使用制表符分隔。

  1. 训练模型

运行main.py脚本,进行数据处理和模型训练:

python main.py

训练完成后,模型和tokenizer将保存在saved_model文件夹中。

  1. 部署模型

运行app.py脚本,启动API服务:

uvicorn app:app --reload

服务启动后,可以通过POST请求访问预测接口,进行文本分类预测。

示例请求

curl -X POST "http://localhost:8000/predict" -H "Content-Type: application/json" -d '{"text": "你的文本"}'

返回示例:

{
    "prediction": 3
}

注意事项

  • 确保数据文件格式正确,每行包含一个文本和对应的标签,使用制表符分隔。
  • 调整训练参数(如batch size和训练轮数)以适应不同的GPU配置。
  • 使用nvidia-smi监控显存使用,避免显存溢出。

项目代码

1. 数据处理模块

功能:读取数据文件并进行预处理。

# data_processing.py
import pandas as pd
from transformers import DistilBertTokenizer

def load_data(file_path):
    data = pd.read_csv(file_path, delimiter='\t', header=None)
    data.columns = ['text', 'label']
    return data

def tokenize_data(data, tokenizer, max_length=128):
    encodings = tokenizer(list(data['text']), truncation=True, padding=True, max_length=max_length)
    return encodings

def main():
    # 加载Tokenizer
    tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-chinese')
    
    # 加载数据
    train_data = load_data('./bert_data/train.txt')
    dev_data = load_data('./bert_data/dev.txt')
    test_data = load_data('./bert_data/test.txt')
    
    # Tokenize数据
    train_encodings = tokenize_data(train_data, tokenizer)
    dev_encodings = tokenize_data(dev_data, tokenizer)
    test_encodings = tokenize_data(test_data, tokenizer)
    
    return train_encodings, dev_encodings, test_encodings, train_data['label'], dev_data['label'], test_data['label']

if __name__ == "__main__":
    main()

2. 数据集类模块

功能:定义数据集类,便于模型训练。

# dataset.py
import torch

class TextDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

3. 模型训练模块

功能:定义训练和评估模型的过程。

# training.py
import torch
from transformers import DistilBertForSequenceClassification, Trainer, TrainingArguments
from dataset import TextDataset
import data_processing

def train_model():
    # 加载数据和tokenizer
    train_encodings, dev_encodings, test_encodings, train_labels, dev_labels, test_labels = data_processing.main()

    # 创建数据集
    train_dataset = TextDataset(train_encodings, train_labels)
    dev_dataset = TextDataset(dev_encodings, dev_labels)
    test_dataset = TextDataset(test_encodings, test_labels)

    # 加载DistilBERT模型
    model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-chinese', num_labels=10)
    model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))

    # 设置训练参数
    training_args = TrainingArguments(
        output_dir='./results',          # 输出结果目录
        num_train_epochs=3,              # 训练轮数
        per_device_train_batch_size=16,  # 训练时每个设备的批量大小
        per_device_eval_batch_size=64,   # 验证时每个设备的批量大小
        warmup_steps=500,                # 训练步数
        weight_decay=0.01,               # 权重衰减
        logging_dir='./logs',            # 日志目录
        fp16=True,                       # 启用混合精度训练
    )

    # 定义Trainer
    trainer = Trainer(
        model=model,                         # 预训练模型
        args=training_args,                  # 训练参数
        train_dataset=train_dataset,         # 训练数据集
        eval_dataset=dev_dataset             # 验证数据集
    )

    # 训练模型
    trainer.train()

    # 评估模型
    eval_results = trainer.evaluate()
    print(eval_results)

    # 保存模型
    trainer.save_model('./saved_model')
    tokenizer = trainer.tokenizer
    tokenizer.save_pretrained('./saved_model')

if __name__ == "__main__":
    train_model()

4. 模型部署模块

功能:使用FastAPI进行模型部署。

# app.py
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
import torch

app = FastAPI()

# 加载模型和tokenizer
model = DistilBertForSequenceClassification.from_pretrained('./saved_model')
tokenizer = DistilBertTokenizer.from_pretrained('./saved_model')

class Item(BaseModel):
    text: str

@app.post("/predict")
def predict(item: Item):
    inputs = tokenizer(item.text, return_tensors="pt", max_length=128, padding='max_length', truncation=True)
    outputs = model(**inputs)
    prediction = torch.argmax(outputs.logits, dim=1)
    return {"prediction": prediction.item()}

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)

5. 主脚本

功能:运行整个流程,包括数据处理、模型训练和部署。

# main.py
import training

def main():
    # 训练模型
    training.train_model()
    print("模型训练完成并保存。")

if __name__ == "__main__":
    main()

其他:客户端调用案例

# client.py
import requests

def predict(text):
    url = "http://localhost:8000/predict"
    payload = {"text": text}
    headers = {"Content-Type": "application/json"}

    response = requests.post(url, json=payload, headers=headers)
    
    if response.status_code == 200:
        prediction = response.json()
        return prediction
    else:
        print(f"Error: {response.status_code}")
        print(response.text)
        return None

if __name__ == "__main__":
    text_to_predict = "探索神秘的海底世界"
    prediction = predict(text_to_predict)
    if prediction:
        print(f"Prediction: {prediction['prediction']}")

详细说明

  1. 数据处理模块

    • 读取训练集、验证集和测试集的数据文件。
    • 使用BERT的Tokenizer对数据进行tokenize处理,生成模型可接受的输入格式。
    • 提供主要的数据处理函数,包括加载数据和tokenize数据。
  2. 数据集类模块

    • 定义一个TextDataset类,用于将tokenized数据和标签封装成PyTorch的数据集格式,便于Trainer进行训练和评估。
  3. 模型训练模块

    • 使用数据处理模块加载和tokenize数据。
    • 创建训练和验证数据集。
    • 加载DistilBERT模型,并设置训练参数(包括启用混合精度训练)。
    • 使用Trainer进行模型训练和评估,并保存训练好的模型。
  4. 模型部署模块

    • 使用FastAPI创建一个简单的API服务。
    • 加载保存的模型和tokenizer。
    • 定义一个预测接口,通过POST请求接收文本输入并返回分类预测结果。
  5. 主脚本

    • 运行模型训练流程,并输出训练完成的提示。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/780782.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

盘点8款国内顶尖局域网监控软件(2024年国产局域网监控软件排名)

局域网监控软件对于企业网络管理至关重要,它们可以帮助IT部门维护网络安全,优化网络性能,同时监控和控制内部员工的网络使用行为。以下是八款备受推崇的局域网监控软件,每一款都有其独特的优势和适用场景。 1.安企神软件 试用版领…

【数据结构】(C语言):二叉搜索树(不使用递归)

二叉搜索树: 非线性的,树是层级结构。基本单位是节点,每个节点最多2个子节点。有序。每个节点,其左子节点都比它小,其右子节点都比它大。每个子树都是一个二叉搜索树。每个节点及其所有子节点形成子树。可以是空树。 …

Report Design Analysis报告之logic level详解

目录 一、前言 二、Logic Level distribution 2.1 logic level配置 2.2 Logic Level Distribution报告 2.3 Logic Level 报告详情查看 2.4 Route Distributions 报告详情查看 2.5 示例代码 一、前言 ​在工程设计中,如果需要了解路径的逻辑级数,可…

赚钱小思路,送给没有背景的辛辛苦苦努力的我们!

我是一个没有背景的普通人,主要靠勤奋和一股钻劲,这十几年来我的日常作息铁打不变,除了睡觉,不是在搞钱,就是在琢磨怎么搞钱。 ​ 可以说打拼了十几年,各种小生意都做过,以前一直是很乐观的&…

绘唐3最新版本哪里下载

绘唐3最新版本哪里下载 绘唐最新版本下载地址 推文视频创作设计是一种通过视频和文字的形式来进行推广的方式,可以通过一些专业的工具来进行制作。 以下是一些常用的小说推文视频创作设计工具: 视频剪辑软件:如Adobe Premiere Pro、Fina…

人工智能系列-Pandas基础

🌈个人主页:羽晨同学 💫个人格言:“成为自己未来的主人~” Pandas简介 Pandas是Python语言的拓展程序库,用于数据分析。 Pandas是一个开放源码,BSD许可的库,提供高性能,易于使用的数据结…

SSM养老院管理系统-计算机毕业设计源码02221

摘要 本篇论文旨在设计和实现一个基于SSM的养老院管理系统,旨在提供高效、便捷的养老院管理服务。该系统将包括老人档案信息管理、护工人员管理、房间信息管理、费用管理等功能模块,以满足养老院管理者和居民的不同需求。 通过引入SSM框架&#x…

ChatGPT对话:Scratch编程中一个单词,如balloon,每个字母行为一致,如何优化编程

【编者按】balloon 7个字母具有相同的行为,根据ChatGPT提供的方法,优化了代码,方便代码维护与复用。初学者可以使用7个字母精灵,复制代码到不同精灵,也能完成这个功能,但不是优化方法,也没有提高…

微信小程序毕业设计-医院挂号预约系统项目开发实战(附源码+论文)

大家好!我是程序猿老A,感谢您阅读本文,欢迎一键三连哦。 💞当前专栏:微信小程序毕业设计 精彩专栏推荐👇🏻👇🏻👇🏻 🎀 Python毕业设计…

高考志愿填报,选专业是看兴趣还是看就业?

对于结束高考的学生来说,选择专业的确是一个非常让人头疼的事情。因为很多人都不知道,选专业的时候究竟是应该看一下个人兴趣,还是看未来的就业方向,这也是让不少人都相当纠结的问题。这里分析一下关于专业选择的问题,…

动手学深度学习(Pytorch版)代码实践 -循环神经网络-54循环神经网络概述

54循环神经网络概述 1.潜变量自回归模型 使用潜变量h_t总结过去信息 2.循环神经网络概述 ​ 循环神经网络(recurrent neural network,简称RNN)源自于1982年由Saratha Sathasivam 提出的霍普菲尔德网络。循环神经网络,是指在全…

字符串——string类的常用接口

一、string类对象的常见构造 二、string类对象的容量操作 三、string类对象的访问及遍历操作 四、string类对象的修改操作 一、string类对象的常见构造 1.string() ——构造空的string类对象,也就是空字符串 2.string(const char* s) ——用字符串来初始化stri…

【微服务】springboot对接Prometheus指标监控使用详解

目录 一、前言 二、微服务监控概述 2.1 微服务常用监控指标 2.2 微服务常用指标监控工具 2.3 微服务使用Prometheus监控优势 三、环境准备 3.1 部署Prometheus服务 3.2 部署Grafana 服务 3.3 提前搭建springboot工程 3.3.1 引入基础依赖 3.3.2 配置Actuator 端点 3.…

代码随想录算法训练营第14天|226.翻转二叉树、101. 对称二叉树、104. 二叉树的最大深度、111.二叉树的最小深度

打卡Day14 1.226.翻转二叉树2.101. 对称二叉树扩展100. 相同的树572. 另一棵树的子树 3.104. 二叉树的最大深度扩展559.n叉树的最大深度 3.111.二叉树的最小深度 1.226.翻转二叉树 题目链接:226.翻转二叉树 文档讲解: 代码随想录 (1&#x…

博客的部署方法论

有了域名后,可以方便其他人记住并访问,历史上不乏大企业花大价钱购买域名的: 京东域名换成 JD.com,并且说是为了防止百度吸引流量,为什么?唯品会买下域名 VIP.COM 或花费千万 ‍ 域名提供商 如果想要域…

Xilinx FPGA:vivado关于真双端口的串口传输数据的实验

一、实验内容 用一个真双端RAM,端口A和端口B同时向RAM里写入数据0-99,A端口读出单数并存入单端口RAM1中,B端口读出双数并存入但端口RAM2中,当检测到按键1到来时将RAM1中的单数读出显示到PC端,当检测到按键2到来时&…

普通Java工程如何在代码中引用docker-compose.yml中的environment值

文章目录 一、概述二、常规做法1. 数据库配置分离2. 代码引用配置3. 编写启动类4. 支持打包成可执行包5. 支持可执行包打包成docker镜像6. docker运行 三、存在问题分析四、改进措施1. 包含environment 变量的编排文件2. 修改读取配置文件方式3. 为什么可以这样做 五、运行效果…

股票Level-2行情是什么,应该怎么使用,从哪里获取数据

行情接入方法 level2行情websocket接入方法-CSDN博客 相比传统的股票行情,Level-2行情为投资者打开了更广阔的视野,不仅限于买一卖一的表面数据,而是深入到市场的核心,提供了十档乃至千档的行情信息(沪市十档&#…

JavaWeb-【1】HTML

笔记系列持续更新,真正做到详细!!本次系列重点讲解后端,那么第一阶段先讲解前端 目录 1、Javaweb技术体系 2、BS架构说明 3、官方文档 4、网页组成 5、HTML 6、HTML快速入门 7、HTML基本结构 8、HTML标签 ​9、HTML标签使用细节 ①、font标签 ②、字符实体 ③、标…

图神经网络dgl和torch-geometric安装

文章目录 搭建环境dgl的安装torch-geometric安装 在跑论文代码过程中,许多小伙伴们可能会遇到一些和我一样的问题,就是文章所需要的一些库的版本比较老,而新版的环境跑代码会报错,这就需要我们手动的下载whl格式的文件来安装相应的…