1 Star 1 Fork 2

hyzycczz/masr

forked from 东方佑/masr 
加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
decoder.py 4.96 KB
一键复制 编辑 原始数据 按行查看 历史
me 提交于 2020-05-18 15:29 . final
import Levenshtein as Lev
import torch
from six.moves import xrange
class Decoder(object):
"""
Basic decoder class from which all other decoders inherit. Implements several
helper functions. Subclasses should implement the decode() method.
Arguments:
labels (string): mapping from integers to characters.
blank_index (int, optional): index for the blank '_' character. Defaults to 0.
space_index (int, optional): index for the space ' ' character. Defaults to 28.
"""
def __init__(self, labels, blank_index=0):
# e.g. labels = "_'ABCDEFGHIJKLMNOPQRSTUVWXYZ#"
self.labels = labels
self.int_to_char = dict([(i, c) for (i, c) in enumerate(labels)])
self.blank_index = blank_index
"""
space_index = len(labels) # To prevent errors in decode, we add an out of bounds index for the space
if ' ' in labels:
space_index = labels.index(' ')
self.space_index = space_index
"""
def wer(self, s1, s2):
"""
Computes the Word Error Rate, defined as the edit distance between the
two provided sentences after tokenizing to words.
Arguments:
s1 (string): space-separated sentence
s2 (string): space-separated sentence
"""
# build mapping of words to integers
b = set(s1.split() + s2.split())
word2char = dict(zip(b, range(len(b))))
# map the words to a char array (Levenshtein packages only accepts
# strings)
w1 = [chr(word2char[w]) for w in s1.split()]
w2 = [chr(word2char[w]) for w in s2.split()]
return Lev.distance("".join(w1), "".join(w2))
def cer(self, s1, s2):
"""
Computes the Character Error Rate, defined as the edit distance.
Arguments:
s1 (string): space-separated sentence
s2 (string): space-separated sentence
"""
s1, s2, = s1.replace(" ", ""), s2.replace(" ", "")
return Lev.distance(s1, s2)
def decode(self, probs, sizes=None):
"""
Given a matrix of character probabilities, returns the decoder's
best guess of the transcription
Arguments:
probs: Tensor of character probabilities, where probs[c,t]
is the probability of character c at time t
sizes(optional): Size of each sequence in the mini-batch
Returns:
string: sequence of the model's best guess for the transcription
"""
raise NotImplementedError
class GreedyDecoder(Decoder):
def __init__(self, labels, blank_index=0):
super(GreedyDecoder, self).__init__(labels, blank_index)
def convert_to_strings(
self, sequences, sizes=None, remove_repetitions=False, return_offsets=False
):
"""Given a list of numeric sequences, returns the corresponding strings"""
strings = []
offsets = [] if return_offsets else None
for x in xrange(len(sequences)):
seq_len = sizes[x] if sizes is not None else len(sequences[x])
string, string_offsets = self.process_string(
sequences[x], seq_len, remove_repetitions
)
strings.append([string]) # We only return one path
if return_offsets:
offsets.append([string_offsets])
if return_offsets:
return strings, offsets
else:
return strings
def process_string(self, sequence, size, remove_repetitions=False):
string = ""
offsets = []
for i in range(size):
char = self.int_to_char[sequence[i].item()]
if char != self.int_to_char[self.blank_index]:
# if this char is a repetition and remove_repetitions=true, then skip
if (
remove_repetitions
and i != 0
and char == self.int_to_char[sequence[i - 1].item()]
):
pass
else:
string = string + char
offsets.append(i)
return string, torch.tensor(offsets, dtype=torch.int)
def decode(self, probs, sizes=None):
"""
Returns the argmax decoding given the probability matrix. Removes
repeated elements in the sequence, as well as blanks.
Arguments:
probs: Tensor of character probabilities from the network. Expected shape of batch x seq_length x output_dim
sizes(optional): Size of each sequence in the mini-batch
Returns:
strings: sequences of the model's best guess for the transcription on inputs
offsets: time step per character predicted
"""
_, max_probs = torch.max(probs, 2)
strings, offsets = self.convert_to_strings(
max_probs.view(max_probs.size(0), max_probs.size(1)),
sizes,
remove_repetitions=True,
return_offsets=True,
)
return strings, offsets
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/731228516/masr.git
[email protected]:731228516/masr.git
731228516
masr
masr
master

搜索帮助