]>
wolfpit.net Git - experiments/OpenNMT-py/.git/blob - preprocess.py
2 # -*- coding: utf-8 -*-
4 Pre-process Data / features files and build vocabulary
11 from functools
import partial
12 from collections
import Counter
, defaultdict
14 from onmt
.utils
.logging
import init_logger
, logger
15 from onmt
.utils
.misc
import split_corpus
16 import onmt
.inputters
as inputters
17 import onmt
.opts
as opts
18 from onmt
.utils
.parse
import ArgumentParser
19 from onmt
.inputters
.inputter
import _build_fields_vocab
,\
23 def check_existing_pt_files(opt
):
24 """ Check if there are existing .pt files to avoid overwriting them """
25 pattern
= opt
.save_data
+ '.{}*.pt'
26 for t
in ['train', 'valid']:
27 path
= pattern
.format(t
)
29 sys
.stderr
.write("Please backup existing pt files: %s, "
30 "to avoid overwriting them!\n" % path
)
34 def build_save_dataset(corpus_type
, fields
, src_reader
, tgt_reader
, opt
):
35 assert corpus_type
in ['train', 'valid']
37 if corpus_type
== 'train':
38 counters
= defaultdict(Counter
)
43 srcs
= [opt
.valid_src
]
44 tgts
= [opt
.valid_tgt
]
47 for src
, tgt
, maybe_id
in zip(srcs
, tgts
, ids
):
48 logger
.info("Reading source and target files: %s %s." % (src
, tgt
))
50 src_shards
= split_corpus(src
, opt
.shard_size
)
51 tgt_shards
= split_corpus(tgt
, opt
.shard_size
)
52 shard_pairs
= zip(src_shards
, tgt_shards
)
54 if (corpus_type
== "train" or opt
.filter_valid
) and tgt
is not None:
55 filter_pred
= partial(
56 inputters
.filter_example
, use_src_len
=opt
.data_type
== "text",
57 max_src_len
=opt
.src_seq_length
, max_tgt_len
=opt
.tgt_seq_length
)
61 if corpus_type
== "train":
62 existing_fields
= None
63 if opt
.src_vocab
!= "":
65 logger
.info("Using existing vocabulary...")
66 existing_fields
= torch
.load(opt
.src_vocab
)
67 except torch
.serialization
.pickle
.UnpicklingError
:
68 logger
.info("Building vocab from text file...")
69 src_vocab
, src_vocab_size
= _load_vocab(
70 opt
.src_vocab
, "src", counters
,
71 opt
.src_words_min_frequency
)
75 if opt
.tgt_vocab
!= "":
76 tgt_vocab
, tgt_vocab_size
= _load_vocab(
77 opt
.tgt_vocab
, "tgt", counters
,
78 opt
.tgt_words_min_frequency
)
82 for i
, (src_shard
, tgt_shard
) in enumerate(shard_pairs
):
83 assert len(src_shard
) == len(tgt_shard
)
84 logger
.info("Building shard %d." % i
)
85 dataset
= inputters
.Dataset(
87 readers
=([src_reader
, tgt_reader
]
88 if tgt_reader
else [src_reader
]),
89 data
=([("src", src_shard
), ("tgt", tgt_shard
)]
90 if tgt_reader
else [("src", src_shard
)]),
91 dirs
=([opt
.src_dir
, None]
92 if tgt_reader
else [opt
.src_dir
]),
93 sort_key
=inputters
.str2sortkey
[opt
.data_type
],
94 filter_pred
=filter_pred
96 if corpus_type
== "train" and existing_fields
is None:
97 for ex
in dataset
.examples
:
98 for name
, field
in fields
.items():
102 f_iter
= [(name
, field
)]
103 all_data
= [getattr(ex
, name
, None)]
105 all_data
= getattr(ex
, name
)
106 for (sub_n
, sub_f
), fd
in zip(
108 has_vocab
= (sub_n
== 'src' and src_vocab
) or \
109 (sub_n
== 'tgt' and tgt_vocab
)
110 if (hasattr(sub_f
, 'sequential')
111 and sub_f
.sequential
and not has_vocab
):
113 counters
[sub_n
].update(val
)
115 shard_base
= corpus_type
+ "_" + maybe_id
117 shard_base
= corpus_type
118 data_path
= "{:s}.{:s}.{:d}.pt".\
119 format(opt
.save_data
, shard_base
, i
)
120 dataset_paths
.append(data_path
)
122 logger
.info(" * saving %sth %s data shard to %s."
123 % (i
, shard_base
, data_path
))
125 dataset
.save(data_path
)
132 if corpus_type
== "train":
133 vocab_path
= opt
.save_data
+ '.vocab.pt'
134 if existing_fields
is None:
135 fields
= _build_fields_vocab(
136 fields
, counters
, opt
.data_type
,
137 opt
.share_vocab
, opt
.vocab_size_multiple
,
138 opt
.src_vocab_size
, opt
.src_words_min_frequency
,
139 opt
.tgt_vocab_size
, opt
.tgt_words_min_frequency
)
141 fields
= existing_fields
142 torch
.save(fields
, vocab_path
)
145 def build_save_vocab(train_dataset
, fields
, opt
):
146 fields
= inputters
.build_vocab(
147 train_dataset
, fields
, opt
.data_type
, opt
.share_vocab
,
148 opt
.src_vocab
, opt
.src_vocab_size
, opt
.src_words_min_frequency
,
149 opt
.tgt_vocab
, opt
.tgt_vocab_size
, opt
.tgt_words_min_frequency
,
150 vocab_size_multiple
=opt
.vocab_size_multiple
152 vocab_path
= opt
.save_data
+ '.vocab.pt'
153 torch
.save(fields
, vocab_path
)
156 def count_features(path
):
158 path: location of a corpus file with whitespace-delimited tokens and
159 │-delimited features within the token
160 returns: the number of features in the dataset
162 with codecs
.open(path
, "r", "utf-8") as f
:
163 first_tok
= f
.readline().split(None, 1)[0]
164 return len(first_tok
.split(u
"│")) - 1
168 ArgumentParser
.validate_preprocess_args(opt
)
169 torch
.manual_seed(opt
.seed
)
170 if not(opt
.overwrite
):
171 check_existing_pt_files(opt
)
173 init_logger(opt
.log_file
)
174 logger
.info("Extracting features...")
178 for src
, tgt
in zip(opt
.train_src
, opt
.train_tgt
):
179 src_nfeats
+= count_features(src
) if opt
.data_type
== 'text' \
181 tgt_nfeats
+= count_features(tgt
) # tgt always text so far
182 logger
.info(" * number of source features: %d." % src_nfeats
)
183 logger
.info(" * number of target features: %d." % tgt_nfeats
)
185 logger
.info("Building `Fields` object...")
186 fields
= inputters
.get_fields(
190 dynamic_dict
=opt
.dynamic_dict
,
191 src_truncate
=opt
.src_seq_length_trunc
,
192 tgt_truncate
=opt
.tgt_seq_length_trunc
)
194 src_reader
= inputters
.str2reader
[opt
.data_type
].from_opt(opt
)
195 tgt_reader
= inputters
.str2reader
["text"].from_opt(opt
)
197 logger
.info("Building & saving training data...")
199 'train', fields
, src_reader
, tgt_reader
, opt
)
201 if opt
.valid_src
and opt
.valid_tgt
:
202 logger
.info("Building & saving validation data...")
203 build_save_dataset('valid', fields
, src_reader
, tgt_reader
, opt
)
207 parser
= ArgumentParser(description
='preprocess.py')
209 opts
.config_opts(parser
)
210 opts
.preprocess_opts(parser
)
214 if __name__
== "__main__":
215 parser
= _get_parser()
217 opt
= parser
.parse_args()