-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathfile_shard_spliter.py
50 lines (38 loc) · 1.56 KB
/
file_shard_spliter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from pathlib import Path
from tqdm import tqdm
def read_lines_from_input(input_files):
for fname in tqdm(input_files):
with open(fname) as infile:
yield [i.strip() for i in infile]
def split(input_files, file_path_prefix,
file_name_suffix='',
append_trailing_newlines=True,
mini_line_length=100000):
shard_id = -1
input_iterator = read_lines_from_input(input_files)
while True:
line_counter = 0
line_buffer = []
shard_id += 1
output_file_name = '{}-{}{}'.format(file_path_prefix, shard_id,
file_name_suffix)
with open(output_file_name, 'wt') as outfile:
try:
while True:
file_lines = next(input_iterator)
line_counter += len(file_lines)
line_buffer.extend(file_lines)
if line_counter >= mini_line_length:
outfile.write('\n'.join(line_buffer))
if append_trailing_newlines:
outfile.write('\n')
break
except StopIteration:
if line_buffer:
outfile.write('\n'.join(line_buffer))
if append_trailing_newlines:
outfile.write('\n')
break
if __name__ == "__main__":
input_files = [str(i.absolute()) for i in Path('token_cleaned_plain_files').glob('*wiki*') if i.is_file()]
split(input_files, './sharded_files/data.txt')