-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdataParser.py
More file actions
196 lines (169 loc) · 7.28 KB
/
Copy pathdataParser.py
File metadata and controls
196 lines (169 loc) · 7.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
import collections
import torch
from tqdm import tqdm
import numpy as np
import os
def read_tokens(path:str):
"""读取数据
返回一个二维列表,每个元素为token字符
同一行的属于同一个对话,同一对话内用'sep'分隔说话者
"""
with open(path, 'r') as f:
data = f.read()
tokens = data.split("\n\n")
for i, d in enumerate(tokens):
d = list(d)
for j, k in enumerate(d):
if k == '\n':
d[j] = '<sep>'
tokens[i]=d
return tokens
def split_train_test(tokens, train_percent):
total_dialog = len(tokens)
train_dialog = int(total_dialog * train_percent)
train_tokens = tokens[:train_dialog]
test_tokens = tokens[train_dialog:]
return train_tokens, test_tokens
def read_tokens_idx(tokens, vocab, seq_len=None, show_bar=True):
"""
获取标号数据集
返回一个2维tensor,每个元素为token字符对应的标号
ret.shape = (len(tokens), seq_len)
"""
if seq_len is None:
seq_len = max([len(token) for token in tokens])
ret = np.zeros((len(tokens), seq_len), dtype=np.int16)
pad_idx = vocab['<pad>'] # 大多数token为<pad>,避免重复查询vocab
# 在分布式训练中,只有主进程显示进度条
def is_main_process():
# 检查是否在分布式环境中
if 'RANK' in os.environ:
return int(os.environ['RANK']) == 0
elif 'LOCAL_RANK' in os.environ:
return int(os.environ['LOCAL_RANK']) == 0
else:
# 尝试使用torch.distributed
try:
import torch.distributed as dist
if dist.is_initialized():
return dist.get_rank() == 0
except:
pass
return True # 非分布式环境,显示进度条
# 只有主进程或非分布式环境才显示进度条
should_show_bar = show_bar and is_main_process()
if should_show_bar:
iterator = tqdm(enumerate(tokens), desc='Loading tokens idx', total=len(tokens))
else:
iterator = enumerate(tokens)
for i, sentence in iterator:
for j in range(seq_len):
if j < len(sentence):
ret[i][j] = vocab[sentence[j]]
ret[i][len(sentence):] = pad_idx
return ret
def build_vocab(tokens):
"""构建词表"""
reserved_tokens = ['<sep>', '<pad>']
vocab = Vocab(tokens, min_freq=0, reserved_tokens=reserved_tokens)
return vocab
class Vocab:
"""文本词表"""
def __init__(self, tokens, min_freq=0, reserved_tokens=None):
if reserved_tokens is None:
reserved_tokens = []
# 按出现频率排序
counter = count_corpus(tokens)
self._token_freqs = sorted(counter.items(), key=lambda x: x[1],
reverse=True)
# 未知词元的索引为0
self.idx_to_token = ['<unk>'] + reserved_tokens
self.token_to_idx = {token: idx
for idx, token in enumerate(self.idx_to_token)}
for token, freq in self._token_freqs:
if freq < min_freq:
break
if token not in self.token_to_idx:
self.idx_to_token.append(token)
self.token_to_idx[token] = len(self.idx_to_token) - 1
def __len__(self):
return len(self.idx_to_token)
def __getitem__(self, tokens):
if not isinstance(tokens, (list, tuple)):
return self.token_to_idx.get(tokens, self.unk)
return [self.__getitem__(token) for token in tokens]
def to_tokens(self, indices):
if not isinstance(indices, (list, tuple)):
return self.idx_to_token[indices]
return [self.idx_to_token[index] for index in indices]
def insert_token(self, token):
if isinstance(token, (list, tuple)):
for t in token:
self.insert_token(t)
else:
self.idx_to_token.append(token)
self.token_to_idx[token] = len(self.idx_to_token) - 1
@property
def unk(self): # 未知词元的索引为0
return 0
@property
def token_freqs(self):
return self._token_freqs
def count_corpus(tokens):
"""统计词元的频率"""
# 这里的tokens是1D列表或2D列表
# 将词元列表展平成一个列表
tokens = [token for line in tokens for token in line]
return collections.Counter(tokens)
def tokens_dataloader(tokens_idx, batch_size, pad_idx, shuffle=True, sampler=None):
"""构造一个PyTorch数据迭代器"""
from torch.utils import data
add = torch.full((tokens_idx.shape[0], 1), pad_idx, dtype=torch.int64)
y = torch.cat([tokens_idx[:, 1:], add], dim=-1)
dataset = data.TensorDataset(tokens_idx, y)
# 如果提供了sampler,则不能同时使用shuffle
if sampler is not None:
return data.DataLoader(dataset, batch_size, sampler=sampler)
else:
return data.DataLoader(dataset, batch_size, shuffle=shuffle)
if __name__ == "__main__":
path = 'data/train.txt'
print('读取tokens')
tokens = read_tokens(path)
print(tokens[:2])
v = build_vocab(tokens)
print('根据标号获取字符')
for i in range(10):
print(v.to_tokens(i), end=' ')
print()
print('根据字符获取标号')
print(v['乐'])
print('获取token出现频率')
print(v.token_freqs[0:20])
print('获取总token数')
print(len(v))
print('获取标号数据集')
tokens_idx = read_tokens_idx(tokens, v, 1024)
print(tokens_idx[:2])
print(tokens_idx.shape)
print('最大句子长度')
print(max([len(sentence) for sentence in tokens_idx]))
"""
读取tokens
[['谢', '谢', '你', '所', '做', '的', '一', '切', '<sep>', '你', '开', '心', '就', '好', '<sep>', '开', '心', '<sep>', '嗯', '因', '为', '你', '的', '心', '里', '只', '有', '学', '习', '<sep>', '某', '某', '某', ',', '还', '有', '你', '<sep>', '这', '个', '某', '某', '某', '用', '的', '好'], ['你', '们', '宿', '舍', '都', '是', '这', '么', '厉', '害', '的', '人', '吗', '<sep>', '眼', '睛', '特', '别', '搞', '笑', '这', '土', '也', '不', '好', '捏', '但', '就', '是', '觉', '得', '挺', '可', '爱', '<sep>', '特', '别', '可', '爱', '啊']]
根据标号获取字符
<unk> <sep> <pad> 我 的 , 你 了 哈 是
根据字符获取标号
219
获取token出现频率
[('<sep>', 1548664), ('我', 653547), ('的', 611458), (',', 597859), ('你', 554873), ('了', 520155), ('哈', 493861), ('是', 472772), ('不', 459928), ('好', 300303), ('一', 256255), ('有', 215029), ('这', 196945), ('么', 186024), ('!', 180269), ('个', 179892), ('就', 178807), ('啊', 175506), ('看', 172787), ('没', 163021)]
获取总token数
7587
获取标号数据集
Loading tokens idx: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 500001/500001 [00:16<00:00, 31116.49it/s]
[[ 72 72 6 ... 2 2 2]
[ 6 53 863 ... 2 2 2]]
(500001, 1024)
最大句子长度
1024
"""