# -*- coding:utf-8 -*-
from __future__ import division, print_function, absolute_import
import os
import re
import sys
import nagisa_utils as utils
import nagisa.model as model
base = os.path.dirname(os.path.abspath(__file__))
sys.path.append(base)
[docs]class Tagger(object):
"""
This class has a word segmentation function and a POS-tagging function for Japanese.
"""
def __init__(self, vocabs=None, params=None, hp=None, single_word_list=None):
if vocabs is None:
vocabs = base + '/data/nagisa_v001.dict'
if params is None:
params = base + '/data/nagisa_v001.model'
if hp is None:
hp = base + '/data/nagisa_v001.hp'
# Load vocaburary files
vocabs = utils.load_data(vocabs)
self._uni2id, self._bi2id, self._word2id, self._pos2id, self._word2postags = vocabs
self._id2pos = {v:k for k, v in self._pos2id.items()}
self.id2pos = self._id2pos
self.postags = [postag for postag in self._pos2id.keys()]
# Load a hyper-parameter file
self._hp = utils.load_data(hp)
# Construct a word segmentation model and a pos tagging model
self._model = model.Model(self._hp, params)
# If a word is included in the single_word_list,
# it is recognized as a single word forcibly.
self.pattern = None
if single_word_list:
single_word_list = [utils.preprocess(w) for w in single_word_list if len(w) > 1]
single_word_list = [w.replace('(', '\(').replace(')', '\)')
for w in single_word_list]
single_word_list = sorted(single_word_list, key=lambda x:-len(x))
if len(single_word_list) > 0:
self.pattern = re.compile('|'.join(single_word_list))
# If use_noun_heuristic is True, nouns are more lilely to appear.
if u'名詞' in self._pos2id:
self.use_noun_heuristic = True
else:
self.use_noun_heuristic = False
def wakati(self, text, lower=False):
"""Word segmentation function. Return the segmented words.
args:
- text (str): An input sentence.
- lower (bool): If lower is True, all uppercase characters in a list \
of the words are converted into lowercase characters.
return:
- words (list): A list of the words.
"""
text = utils.preprocess(text)
lower_text = text.lower()
feats = utils.feature_extraction(text=lower_text,
uni2id=self._uni2id,
bi2id=self._bi2id,
dictionary=self._word2id,
window_size=self._hp['WINDOW_SIZE'])
obs = self._model.encode_ws(feats)
obs = [ob.npvalue() for ob in obs]
tags = utils.np_viterbi(self._model.trans_array, obs)
# A word can be recognized as a single word forcibly.
if self.pattern:
for match in self.pattern.finditer(text):
span = match.span()
span_s = span[0]
span_e = span[1]
if (span_e - span_s) == 1:
tags[span_s:span_e] = [3]
else:
tags[span_s:span_e] = [0]+[1]*((span_e-span_s)-2)+[2]
if span_s != 0:
previous_tag = tags[span_s-1]
if previous_tag == 0: # 0 is BEGIN tag
tags[span_s-1] = 3 # 3 is SINGLE tag
elif previous_tag == 1: # 1 is MIDDEL tag
tags[span_s-1] = 2 # 2 is END tag
if span_e != len(text):
next_tag = tags[span_e]
if next_tag == 1: # 1 is MIDDEL tag
tags[span_e] = 0 # 0 is BEGIN tag
elif next_tag == 2: # 2 is END tag
tags[span_e] = 3 # 3 is SINGLE tag
if lower is True:
words = utils.segmenter_for_bmes(lower_text, tags)
else:
words = utils.segmenter_for_bmes(text, tags)
return words
def _postagging(self, words, lower=False):
if lower is True:
words = [w.lower() for w in words]
wids = utils.conv_tokens_to_ids(words, self._word2id)
cids = [utils.conv_tokens_to_ids([c for c in w], self._uni2id) for w in words]
# Improve the bottleneck in part-of-speech tagging.
# No changes made to output results by this change.
tids = []
for w in words:
w2p = set(self._word2postags.get(w, [0]))
if self.use_noun_heuristic and w.isalnum():
if 0 in w2p:
w2p.remove(0)
w2p.add(self._pos2id[u'名詞'])
tids.append(list(w2p))
X = [cids, wids, tids]
postags = [self._id2pos[pid] for pid in self._model.POStagging(X)]
return postags
def postagging(self, words, lower=False):
""" Return the words with POS-tags of the given words.
args:
- words (list): Input words.
- lower (bool): If lower is True, all uppercase characters in a list \
of the words are converted into lowercase characters.
return:
- object : The object of the words with POS-tags.
"""
return self.decode(words, lower)
def decode(self, words, lower=False):
""" Return the words with tags of the given words.
args:
- words (list): Input words.
- lower (bool, optional): If lower is True, all uppercase characters in a list \
of the words are converted into lowercase characters.
return:
- object : The object of the words with tags.
"""
if not type(words) == list:
raise AssertionError("Please input a list of words.")
words = [utils.preprocess_without_rstrip(w) if w == " " or w == " "
else utils.preprocess(w) for w in words]
postags = self._postagging(words, lower)
return postags
def tagging(self, text, lower=False):
""" Return the words with POS-tags of the given sentence.
args:
- text (str): An input sentence.
- lower (bool): If lower is True, all uppercase characters in a list \
of the words are converted into lowercase characters.
return:
- object : The object of the words with POS-tags.
"""
return self._Token(text, lower, self.wakati, self._postagging)
def filter(self, text, lower=False, filter_postags=None):
"""Return the filtered words with POS-tags of the given sentence.
args:
- text (str): An input sentence.
- lower (bool): If lower is True, all uppercase characters in a list \
of the words are converted into lowercase characters.
- filter_postags (list): Filtering the word with the POS-tag in \
filter_postags from a text.
return:
- object : The object of the words with POS-tags.
"""
if filter_postags is None:
filter_postags = []
words = []
postags = []
tokens = self.tagging(text, lower)
for word, postag in zip(tokens.words, tokens.postags):
if not postag in filter_postags:
words.append(word)
postags.append(postag)
return self._Token(text, lower, self.wakati, self._postagging,
_words=words, _postags=postags)
def extract(self, text, lower=False, extract_postags=None):
"""Return the extracted words with POS-tags of the given sentence.
args:
- text (str): An input sentence.
- lower (bool): If lower is True, all uppercase characters in a list \
of the words are converted into lowercase characters.
- filter_postags (list): Extracting the word with the POS-tag in \
filter_postags from a text.
return:
- object : The object of the words with POS-tags.
"""
if extract_postags is None:
extract_postags = []
words = []
postags = []
tokens = self.tagging(text, lower)
for word, postag in zip(tokens.words, tokens.postags):
if postag in extract_postags:
words.append(word)
postags.append(postag)
return self._Token(text, lower, self.wakati, self._postagging,
_words=words, _postags=postags)
class _Token(object):
def __init__(self, text, lower, wakati, postagging, _words=None, _postags=None):
self.text = text
self.__lower = lower
self.__words = _words
self.__postags = _postags
self.__wakati = wakati
self.__postagging = postagging
@property
def words(self):
if self.__words is None:
self.__words = self.__wakati(self.text, self.__lower)
return self.__words
@property
def postags(self):
if self.__postags is None:
self.__postags = self.__postagging(self.words, self.__lower)
return self.__postags
def __str__(self):
return ' '.join([w+'/'+p for w, p in zip(self.words, self.postags)])