Paddle2.0实现中⽂新闻⽂本标题分类
Paddle2.0实现中⽂新闻⽂本标题分类
项⽬说明,本项⽬是李宏毅⽼师在飞桨授权课程的作业解析
课程
该项⽬AiStudio项⽬
数据集
本项⽬仅⽤于参考,提供思路和想法并⾮标准答案!请谨慎抄袭!
中⽂新闻⽂本标题分类Paddle2.0版本基线(⾮官⽅)
⾮官⽅,三岁出品!(虽⽔必精)
调优⼩建议
本项⽬基线的值不会很⾼,需要⾃⾏调参来提⾼效果。
优化建议:
修改模型 现在是线性模型可以尝试修改更为复杂的
对于nlp项⽬更加友好的(具体的我也不是很清楚)
调整学习率来调整我们最好效果的查
可以通过对已有模型进⼀步训练得到较好的效果
……
数据集地址
任务描述
基于THUCNews数据集的⽂本分类, THUCNews是根据新浪新闻RSS订阅频道2005~2011年间的历史数据筛选过滤⽣成,包含74万篇新闻⽂档,参赛者需要根据新闻标题的内容⽤算法来判断该新闻属于哪⼀类别
数据说明
THUCNews是根据新浪新闻RSS订阅频道2005~2011年间的历史数据筛选过滤⽣成,包含74万篇新闻⽂档(2.19 GB),均为UTF-8纯⽂本格式。在原始新浪新闻分类体系的基础上,重新整合划分出14个候选分类类别:财经、、房产、股票、家居、教育、科技、社会、时尚、时政、体育、星座、游戏、娱乐。
已将训练集按照“标签ID+\t+标签+\t+原⽂标题”的格式抽取出来,可以直接根据新闻标题进⾏⽂本分类任务,希望答题者能够给出⾃⼰的解决⽅案。
训练集格式 标签ID+\t+标签+\t+原⽂标题 测试集格式 原⽂标题
提交答案
考试提交,需要提交模型代码项⽬版本和结果⽂件。结果⽂件为TXT⽂件格式,命名为,⽂件内的字段需要按照指定格式写⼊。
1.每个类别的⾏数和测试集原始数据⾏数应⼀⼀对应,不可乱序
2.输出结果应检查是否为83599⾏数据,否则成绩⽆效
3.输出结果⽂件命名为,⼀⾏⼀个类别,样例如下:
···
游戏
财经
时政
股票
家居
科技
社会
房产
教育
星座
科技
股票
游戏
财经
时政
股票
家居
科技
社会python官方文档中文版
房产
教育
·
··
代码思路说明
根据题⽬可以知道这个是⼀个经典的nlp任务。
根据nlp任务处理的⼀般流程,我们需要进⾏以下⼏个步骤:数据处理并转换成词向量
模型的搭建
数据的训练
模型读取并推理数据得到结果
那么话不多说我们开始!
数据集解压
! pip install -U paddlepaddle==2.0.1
! unzip -oq /home/aistudio/data/data75812/新闻⽂本标签分类.zip
import paddle
import numpy as np
import matplotlib.pyplot as plt
as nn
import os
import numpy as np
print(paddle.__version__)# 查看当前版本
# cpu/gpu环境选择,在 paddle.set_device() 输⼊对应运⾏设备。
# device = paddle.set_device('gpu')
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init__.py:107: DeprecationWarning: Using or importing the ABCs from 'c ollections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
from collections import MutableMapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'col lections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
from collections import Iterable, Mapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'coll ections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
from collections import Sized
2021-03-27 12:21:25,020 - INFO - font search path ['/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/ttf', '/o pt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/afm', '/opt/conda/envs/python35-paddle120-env/lib/python3.7 /site-packages/matplotlib/mpl-data/fonts/pdfcorefonts']
2021-03-27 12:21:25,357 - INFO - generated new fontManager
2.0.1
数据处理
⾸先我们考虑词向量的书写⽅式。
我们先制作词典(此处词典已经制作完成,我们直接读取就好了,词典制作过程会放在留⾔中)
我们把词典和我们的数据集进⾏对应,制作完成⼀个纯数字的对应码
得到对应码以后进⾏输出测试是否正确。
数据⽆误进⾏填充,把数据码⽤特殊标签进⾏替代完成数据长度相同的内容
检验数据长度
数据读取(字典、数据集)
# 字典读取
def get_dict_len(d_path):
with open(d_path,'r', encoding='utf-8')as f:
line =adlines()[0])
return line
word_dict = get_dict_len('新闻⽂本标签分类/')
# 训练集和验证集读取
set=[]
def dataset(datapath):# 数据集读取代码
with open(datapath)as f:
for i adlines():
data =[]
dataset = i[:i.rfind('\t')].split(',')# 获取⽂字内容
dataset = np.array(dataset)
data.append(dataset)
label = np.array(i[i.rfind('\t')+1:-1])# 获取标签
data.append(label)
set.append(data)
return set
train_dataset = dataset('新闻⽂本标签分类/')
val_dataset = dataset('新闻⽂本标签分类/')
数据初始化
定义⼀些需要的值
# 初始数据准备
vocab_size =len(word_dict)+1# 字典长度加1
print(vocab_size)
emb_size =256# 神经⽹络长度
seq_len =30# 数据集长度(需要扩充的长度)
batch_size =32# 批处理⼤⼩
epochs =2# 训练轮数
pad_id = word_dict['<unk>']# 空的填充内容值
nu=["财经","","房产","股票","家居","教育","科技","社会","时尚","时政","体育","星座","游戏","娱乐"] # ⽣成句⼦列表(数据码⽣成⽂本)
def ids_to_str(ids):
# print(ids)
words =[]
for k in ids:
w =list(word_dict)[eval(k)]
words.append(w if isinstance(w,str)else w.decode('ASCII'))
return" ".join(words)
5308
数据查看
查看数据是否正确如有异常及时修改
# 查看数据内容
for i in  train_dataset:
sent = i[0]
label =int(i[1])
print('sentence list id is:', sent)# 数据内容
print('sentence label id is:', label)# 对应标签
print('--------------------------')# 分隔线
print('sentence list is: ', ids_to_str(sent))# 转换后的数据
print('sentence label is: ', nu[label])# 转换后的标签
break
sentence list id is: ['2976' '385' '2050' '3757' '1147' '3296' '1585' '688' '1180' '2608'
'4280' '1887']
sentence label id is: 0
--------------------------
sentence list is:  上证 5 0 E T F 净申购突增
sentence label is:  财经
数据扩充
把数据扩充成⼀样的长度
# 数据扩充并查看
def create_padded_dataset(dataset):
padded_sents =[]
labels =[]
for batch_id, data in enumerate(dataset):# 读取数据
sent, label = data[0], data[1]# 标签和数据拆分
padded_sent = np.concatenate([sent[:seq_len],[pad_id]*(seq_len -len(sent))]).astype('int32')# 数据拼接
# print(padded_sent)
padded_sents.append(padded_sent)# 写⼊数据
labels.append(label)# 写⼊标签
# print(padded_sents)
return np.array(padded_sents), np.array(labels).astype('int64')# 转换成数组并返回
# 对train、val数据进⾏实例化
train_sents, train_labels = create_padded_dataset(train_dataset)# 实例化训练集
val_sents, val_labels = create_padded_dataset(val_dataset)# 实例化测试集
train_labels = shape(832475,1)# 标签数据⼤⼩转换
val_labels = shape(832475,1)
# 查看数据⼤⼩及举例内容
print(train_sents.shape)
print(train_labels.shape)
print(val_sents.shape)
print(val_labels.shape)
(832475, 30)
(832475, 1)
(832475, 30)
(832475, 1)
数据封装
通过继承paddle.io.Dataset类,把数据封装然后⽣成可以训练的数据格式

版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系QQ:729038198,我们将在24小时内删除。