class BPE(object):
- def __init__(self, codes, separator='@@', vocab=None, glossaries=None):
+ def __init__(self, codes, merges=-1, separator='@@', vocab=None, glossaries=None):
# check version information
firstline = codes.readline()
self.version = (0, 1)
codes.seek(0)
- self.bpe_codes = [tuple(item.split()) for item in codes]
+ self.bpe_codes = [tuple(item.split()) for (n, item) in enumerate(codes) if (n < merges or merges == -1)]
# some hacking to deal with duplicates (only consider first instance)
self.bpe_codes = dict([(code,i) for (i,code) in reversed(list(enumerate(self.bpe_codes)))])
'--codes', '-c', type=argparse.FileType('r'), metavar='PATH',
required=True,
help="File with BPE codes (created by learn_bpe.py).")
+ parser.add_argument(
+ '--merges', '-m', type=int, default=-1,
+ metavar='INT',
+ help="Use this many BPE operations (<= number of learned symbols)"+
+ "default: Apply all the learned merge operations")
parser.add_argument(
'--output', '-o', type=argparse.FileType('w'), default=sys.stdout,
metavar='PATH',
else:
vocabulary = None
- bpe = BPE(args.codes, args.separator, vocabulary, args.glossaries)
+ bpe = BPE(args.codes, args.merges, args.separator, vocabulary, args.glossaries)
for line in args.input:
args.output.write(bpe.segment(line).strip())