文本预处理

Introduction

在做NLP的深度学习任务时,一个关键的问题是如何构建输入。本文介绍如何利用有限内存进行大规模数据处理,主要包括:

  • 建立词典
  • 将单词转换为id
  • 训练集验证集切分

How To Do IT

原始数据集

1
2
3
4
11@成都 高新技术 产业 开发区 人民 检察院 指控 , 201532923时 许 , 被告人 刘某 某 饮 酒后 驾驶川 A * * * 84 北京 现代牌 小型 轿车 , 从 成都市 桐梓林 附近 出发 上 人民 南 路 出 城 , 当 车 行驶 至 成都 高新区 天府 大道 与 府城 大道 交叉 路口处 时 ,  公诉 机关 认为 , 被告人 刘 某某 在 道路 上 醉 酒 驾驶 机动车 , 危害 公共 安全 , 其 行为 应当 以 ×× 追究 其 刑事 责任 。

11@黑龙江省 尚志市 人民 检察院 指控 : ×× 201492220时 许 , 被告人 矫 2 某 在 尚志市 苇河镇 阿里郎歌厅 对面 停放 的 货车 的 副 驾驶 座位 上 , 将 被害人 李某 甲 的 蓝色 女式 拎 包 盗 走 , 包 内 有 人民币 57000 元 , 红色 钱包 一个 , 农业 银行卡 一 张 , 身份证 一 张 、 驾驶证 一 本 、 账本 一 册 。 案 发 前 , 被告人 矫 2 某 将 盗走 的 财物 返还 被害人 。 在 ×× 到 五 年 幅度 内 量刑 , 并 处 罚金 ; 对 所 犯 的 ×× 在 ×× 到 六 个 月 幅度 内 量刑 , 并 处 罚金 。 针对 上述 指控 , 公诉 机关 提供 了 相应 的 证据 。
...

这里以法研杯比赛的文本数据集为例。格式为 标签@文本

其中,文本已经过分词处理,使用空格分隔。

建立词典

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def sent_label_split(line):
"""
句子处理成单词
:param line: 原始行
:return: 单词, 标签
"""
line = line.strip('\n').split('@')
label = line[0]
sent = line[1].split(' ')
return sent, label

def word_to_id(word, word2id):
"""
单词-->ID
:param word: 单词
:param word2id: word2id @type: dict
:return:
"""
return word2id[word] if word in word2id else word2id['unk']


def bulid_vocab(vocab_size, min_freq=3, stop_word_list=None,
is_debug=False):
"""
建立词典
:param vocab_size: 词典大小
:param min_freq: 最小词频限制
:param stop_list: 停用词 @type:file_path
:param is_debug: 是否测试模式 @type: bool True:使用很小的数据集进行代码测试
:return: word2id
"""
size = 0
count = Counter()

with open(os.path.join(config.ROOT_DIR, config.RAW_DATA), 'r') as fr:
logger.info('Building vocab')
for line in tqdm(fr, desc='Build vocab'):
words, label = sent_label_split(line)
count.update(words)
size += 1
if is_debug:
limit_train_size = 10000
if size > limit_train_size:
break
if stop_word_list:
stop_list = {}
with open(os.path.join(config.ROOT_DIR, config.STOP_WORD_LIST), 'r') as fr:
for i, line in enumerate(fr):
word = line.strip('\n')
if stop_list.get(word) is None:
stop_list[word] = i
count = {k: v for k, v in count.items() if k not in stop_list}
count = sorted(count.items(), key=operator.itemgetter(1))
# 词典
vocab = [w[0] for w in count if w[1] >= min_freq]
if vocab_size < len(vocab):
vocab = vocab[:vocab_size]
vocab = config.flag_words + vocab
logger.info('vocab_size is %d'%len(vocab))
# 词典到编号的映射
word2id = {k: v for k, v in zip(vocab, range(0, len(vocab)))}
assert word2id['<pad>'] == 0, "ValueError: '<pad>' id is not 0"
print(word2id)
with open(config.WORD2ID_FILE, 'wb') as fw:
pickle.dump(word2id, fw)
return word2id

文本映射到Id

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def text2id(word2id, maxlen=None, valid_size=0.3, random_state=2018, shuffle=True, is_debug=False):
"""
训练集文本转ID
:param valid_size: 验证集大小
"""
print(os.path.join(config.ROOT_DIR, config.TRAIN_FILE))
if len(glob(os.path.join(config.ROOT_DIR, config.TRAIN_FILE))) > 0:
logger.info('Text to id file existed')
return
logger.info('Text to id')
sentences, labels, lengths = [], [], []
size = 0
with open(os.path.join(config.ROOT_DIR, config.RAW_DATA), 'r') as fr:
for line in tqdm(fr, desc='text_to_id'):
words, label = sent_label_split(line)
sent = [word_to_id(word=word, word2id=word2id) for word in words]
if maxlen:
sent = sent[:maxlen]
length = len(sent)
sentences.append(sent)
labels.append(label)
lengths.append(length)
size += 1
if is_debug:
limit_train_size = 10000
if size > limit_train_size:
break

train, valid = train_val_split(sentences, labels,
valid_size=valid_size,
random_state=random_state,
shuffle=shuffle)
del sentences, labels, lengths


with open(config.TRAIN_FILE, 'w') as fw:
for sent, label in train:
sent = [str(s) for s in sent]
line = "\t".join[str(label), " ".join(sent)]
fw.write(line + '\n')
logger.info('Writing train to file done')

with open(config.VALID_FILE, 'w') as fw:
for sent, label in train:
sent = [str(s) for s in sent]
line = "\t".join[str(label), " ".join(sent)]
fw.write(line + '\n')
logger.info('Writing valid to file done')

训练集验证集分割

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def train_val_split(X, y, valid_size=0.3, random_state=2018, shuffle=True):
"""
训练集验证集分割
:param X: sentences
:param y: labels
:param random_state: 随机种子
"""
logger.info('train val split')
data = [(data_x, data_y) for data_x, data_y in zip(X, y)]
N = len(data)
test_size = int(N * valid_size)

if shuffle:
random.seed(random_state)
random.shuffle(data)

valid = data[:test_size]
train = data[test_size:]
return train, valid

完整代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import os
import random
import pickle
import operator
from glob import glob
from tqdm import tqdm
from collections import Counter

import config
from Logginger import init_logger

logger = init_logger("torch", logging_path=config.LOG_PATH)


def sent_label_split(line):
"""
句子处理成单词
:param line: 原始行
:return: 单词, 标签
"""
line = line.strip('\n').split('@')
label = line[0]
sent = line[1].split(' ')
return sent, label

def word_to_id(word, word2id):
"""
单词-->ID
:param word: 单词
:param word2id: word2id @type: dict
:return:
"""
return word2id[word] if word in word2id else word2id['unk']


def bulid_vocab(vocab_size, min_freq=3, stop_word_list=None,
is_debug=False):
"""
建立词典
:param vocab_size: 词典大小
:param min_freq: 最小词频限制
:param stop_list: 停用词 @type:file_path
:param is_debug: 是否测试模式 @type: bool True:使用很小的数据集进行代码测试
:return: word2id
"""
size = 0
count = Counter()

with open(os.path.join(config.ROOT_DIR, config.RAW_DATA), 'r') as fr:
logger.info('Building vocab')
for line in tqdm(fr, desc='Build vocab'):
words, label = sent_label_split(line)
count.update(words)
size += 1
if is_debug:
limit_train_size = 10000
if size > limit_train_size:
break
if stop_word_list:
stop_list = {}
with open(os.path.join(config.ROOT_DIR, config.STOP_WORD_LIST), 'r') as fr:
for i, line in enumerate(fr):
word = line.strip('\n')
if stop_list.get(word) is None:
stop_list[word] = i
count = {k: v for k, v in count.items() if k not in stop_list}
count = sorted(count.items(), key=operator.itemgetter(1))
# 词典
vocab = [w[0] for w in count if w[1] >= min_freq]
if vocab_size < len(vocab):
vocab = vocab[:vocab_size]
vocab = config.flag_words + vocab
logger.info('vocab_size is %d'%len(vocab))
# 词典到编号的映射
word2id = {k: v for k, v in zip(vocab, range(0, len(vocab)))}
assert word2id['<pad>'] == 0, "ValueError: '<pad>' id is not 0"
print(word2id)
with open(config.WORD2ID_FILE, 'wb') as fw:
pickle.dump(word2id, fw)
return word2id


def train_val_split(X, y, valid_size=0.3, random_state=2018, shuffle=True):
"""
训练集验证集分割
:param X: sentences
:param y: labels
:param random_state: 随机种子
"""
logger.info('train val split')
data = [(data_x, data_y) for data_x, data_y in zip(X, y)]
N = len(data)
test_size = int(N * valid_size)

if shuffle:
random.seed(random_state)
random.shuffle(data)

valid = data[:test_size]
train = data[test_size:]
return train, valid


def text2id(word2id, maxlen=None, valid_size=0.3, random_state=2018, shuffle=True, is_debug=False):
"""
训练集文本转ID
:param valid_size: 验证集大小
"""
print(os.path.join(config.ROOT_DIR, config.TRAIN_FILE))
if len(glob(os.path.join(config.ROOT_DIR, config.TRAIN_FILE))) > 0:
logger.info('Text to id file existed')
return
logger.info('Text to id')
sentences, labels, lengths = [], [], []
size = 0
with open(os.path.join(config.ROOT_DIR, config.RAW_DATA), 'r') as fr:
for line in tqdm(fr, desc='text_to_id'):
words, label = sent_label_split(line)
sent = [word_to_id(word=word, word2id=word2id) for word in words]
if maxlen:
sent = sent[:maxlen]
length = len(sent)
sentences.append(sent)
labels.append(label)
lengths.append(length)
size += 1
if is_debug:
limit_train_size = 10000
if size > limit_train_size:
break

train, valid = train_val_split(sentences, labels,
valid_size=valid_size,
random_state=random_state,
shuffle=shuffle)
del sentences, labels, lengths


with open(config.TRAIN_FILE, 'w') as fw:
for sent, label in train:
sent = [str(s) for s in sent]
line = "\t".join[str(label), " ".join(sent)]
fw.write(line + '\n')
logger.info('Writing train to file done')

with open(config.VALID_FILE, 'w') as fw:
for sent, label in train:
sent = [str(s) for s in sent]
line = "\t".join[str(label), " ".join(sent)]
fw.write(line + '\n')
logger.info('Writing valid to file done')


# 功能整合,提供给外部调用的函数接口
def data_helper(vocab_size, min_freq=3, stop_list=None,
valid_size=0.3, random_state=2018, shuffle=True, is_debug=False):
# 判断文件是否已存在
if len(glob(os.path.join(config.ROOT_DIR, config.WORD2ID_FILE))) > 0:
logger.info('Word to id file existed')
with open(os.path.join(config.ROOT_DIR, config.WORD2ID_FILE), 'rb') as fr:
word2id = pickle.load(fr)
else:
word2id = bulid_vocab(vocab_size=vocab_size, min_freq=min_freq, stop_word_list=stop_list,
is_debug=is_debug)
text2id(word2id, valid_size=valid_size, random_state=random_state, shuffle=shuffle, is_debug=is_debug)

config.py

1
2
3
4
5
6
7
8
# ---------PATH------------
ROOT_DIR = '/home/daizelin/pytorch/'
RAW_DATA = 'data/data_for_test.csv'
TRAIN_FILE = 'outputs/intermediate/train.tsv'
VALID_FILE = 'outputs/intermediate/valid.tsv'
LOG_PATH = 'outputs/logs'
is_debug = False
flag_words = ['<pad>', '<unk>']

Logginger.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import logging
from logging import Logger
from logging.handlers import TimedRotatingFileHandler

'''
使用方式
from you_logging_filename.py import init_logger
logger = init_logger("dataset",logging_path='')
def you_function():
logger.info()
logger.error()

'''
'''
日志模块
1. 同时将日志打印到屏幕跟文件中
2. 默认值保留近7天日志文件
'''
def init_logger(logger_name, logging_path):
if logger_name not in Logger.manager.loggerDict:
logger = logging.getLogger(logger_name)
logger.setLevel(logging.DEBUG)
handler = TimedRotatingFileHandler(filename=logging_path+"/all.log",when='D',backupCount = 7)
datefmt = '%Y-%m-%d %H:%M:%S'
format_str = '[%(asctime)s]: %(name)s %(filename)s[line:%(lineno)s] %(levelname)s %(message)s'
formatter = logging.Formatter(format_str,datefmt)
handler.setFormatter(formatter)
handler.setLevel(logging.INFO)
logger.addHandler(handler)
console= logging.StreamHandler()
console.setLevel(logging.INFO)
console.setFormatter(formatter)
logger.addHandler(console)

handler = TimedRotatingFileHandler(filename=logging_path+"/error.log",when='D',backupCount=7)
datefmt = '%Y-%m-%d %H:%M:%S'
format_str = '[%(asctime)s]: %(name)s %(filename)s[line:%(lineno)s] %(levelname)s %(message)s'
formatter = logging.Formatter(format_str,datefmt)
handler.setFormatter(formatter)
handler.setLevel(logging.ERROR)
logger.addHandler(handler)
logger = logging.getLogger(logger_name)
return logger