Tutorial

Train a Japanese word segmentation and POS tagging model for Universal Dependencies

This tutorial provides an example of training a joint word segmentation and POS tagging model by using Japanese universal dependencies treebank. You get to know how to build the original sequence labeling model through this tutorial.

Download the dataset

Before we get started, please run the command $ pip install nagisa to install the nagisa library. After installing it, download the Japanese UD treebank from UD_Japanese-GDS.

mkdir work
cd work
pip install nagisa
git clone https://github.com/UniversalDependencies/UD_Japanese-GSD

Preprocess the dataset and train a model

First, convert the downloaded data to the input data format for nagisa. The input data format of the train/dev/test files is tsv. The Each line is word and tag and one line is represented by word \t tag. Note that you put EOS between sentences. Refer to the tiny sample datasets.

Next, you train a joint word segmentation and POS-tagging model by using the nagisa.fit() function. After finish training the model, save the three model files (ja_gsd_ud.vocabs, ja_gsd_ud.params, ja_gsd_ud.hp) in the current directory.

tutorial_train_ud.py
 1import nagisa
 2
 3def write_file(fn_in, fn_out):
 4    with open(fn_in, "r") as f:
 5        data = []
 6        words = []
 7        postags = []
 8        for line in f:
 9            line = line.strip()
10
11            if len(line) > 0:
12                prefix = line[0]
13                if prefix != "#":
14                    tokens = line.split("\t")
15                    word = tokens[1]
16                    postag = tokens[3]
17                    words.append(word)
18                    postags.append(postag)
19
20            else:
21                if (len(words) > 0) and (len(postags) > 0):
22                    data.append([words, postags])
23                    words = []
24                    postags = []
25
26    with open(fn_out, "w") as f:
27        for words, postags in data:
28            for word, postag in zip(words, postags):
29                f.write("\t".join([word, postag])+"\n")
30            f.write("EOS\n")
31
32
33if __name__ == "__main__":
34    # files
35    fn_in_train = "UD_Japanese-GSD/ja_gsd-ud-train.conllu"
36    fn_in_dev = "UD_Japanese-GSD/ja_gsd-ud-dev.conllu"
37    fn_in_test = "UD_Japanese-GSD/ja_gsd-ud-test.conllu"
38
39    fn_out_train = "ja_gsd_ud.train"
40    fn_out_dev = "ja_gsd_ud.dev"
41    fn_out_test = "ja_gsd_ud.test"
42
43    fn_out_model = "ja_gsd_ud"
44
45    # write files for nagisa
46    write_file(fn_in_train, fn_out_train)
47    write_file(fn_in_dev, fn_out_dev)
48    write_file(fn_in_test, fn_out_test)
49
50    # start training
51    nagisa.fit(train_file=fn_out_train, dev_file=fn_out_dev,
52               test_file=fn_out_test, model_name=fn_out_model)

This is a log of the training process.

[dynet] random seed: 1234
[dynet] allocating memory: 32MB
[dynet] memory allocation done.
[nagisa] LAYERS: 1
[nagisa] THRESHOLD: 3
[nagisa] DECAY: 1
[nagisa] EPOCH: 10
[nagisa] WINDOW_SIZE: 3
[nagisa] DIM_UNI: 32
[nagisa] DIM_BI: 16
[nagisa] DIM_WORD: 16
[nagisa] DIM_CTYPE: 8
[nagisa] DIM_TAGEMB: 16
[nagisa] DIM_HIDDEN: 100
[nagisa] LEARNING_RATE: 0.1
[nagisa] DROPOUT_RATE: 0.3
[nagisa] SEED: 1234
[nagisa] TRAINSET: ja_gsd_ud.train
[nagisa] TESTSET: ja_gsd_ud.test
[nagisa] DEVSET: ja_gsd_ud.dev
[nagisa] DICTIONARY: None
[nagisa] EMBEDDING: None
[nagisa] HYPERPARAMS: ja_gsd_ud.hp
[nagisa] MODEL: ja_gsd_ud.params
[nagisa] VOCAB: ja_gsd_ud.vocabs
[nagisa] EPOCH_MODEL: ja_gsd_ud_epoch.params
[nagisa] NUM_TRAIN: 7133
[nagisa] NUM_TEST: 551
[nagisa] NUM_DEV: 511
[nagisa] VOCAB_SIZE_UNI: 2352
[nagisa] VOCAB_SIZE_BI: 25108
[nagisa] VOCAB_SIZE_WORD: 9143
[nagisa] VOCAB_SIZE_POSTAG: 17
Epoch       LR      Loss    Time_m  DevWS_f1        DevPOS_f1       TestWS_f1       TestPOS_f1
1           0.100   13.37   1.462   91.84           87.75           91.63           87.35
2           0.100   6.280   1.473   92.57           89.67           92.44           89.15
3           0.100   4.961   1.535   93.54           90.98           93.62           90.18
4           0.050   4.256   1.430   92.52           90.19           93.62           90.18
5           0.025   3.200   1.443   93.46           91.06           93.62           90.18
6           0.025   2.581   1.512   93.56           91.49           93.88           91.29
7           0.025   2.379   1.475   93.58           91.50           93.73           91.15
8           0.025   2.218   1.476   93.63           91.57           93.92           91.31
9           0.025   2.122   1.475   93.78           91.63           94.09           91.40
10          0.012   1.985   1.434   93.55           91.39           94.09           91.40

Predict

You can build the tagger only by loading the three trained model files (ja_gsd_ud.vocabs, ja_gsd_ud.params, ja_gsd_ud.hp) to set arguments in nagisa.Tagger().

tutorial_predict_ud.py
 1import nagisa
 2
 3if __name__ == "__main__":
 4    # Build the tagger by loading the trained model files.
 5    ud_tagger = nagisa.Tagger(vocabs='ja_gsd_ud.vocabs',
 6                              params='ja_gsd_ud.params',
 7                              hp='ja_gsd_ud.hp')
 8
 9    text = "福岡・博多の観光情報"
10    words = ud_tagger.tagging(text)
11    print(words)
12    #> 福岡/PROPN ・/SYM 博多/PROPN の/ADP 観光/NOUN 情報/NOUN

Error analysis

By checking a confusion matrix, you can see what the model is wrong with. The code shows how to create a confusion matrix by comparing the predicted tags with the gold-standard tags.

tutorial_error_analysis_ud.py
 1import nagisa
 2import pandas as pd
 3
 4from sklearn.metrics import confusion_matrix
 5
 6
 7def load_file(filename):
 8    X = []
 9    Y = []
10    words = []
11    tags = []
12    with open(filename, "r") as f:
13        for line in f:
14            line = line.rstrip()
15            if line == "EOS":
16                assert(len(words) == len(tags))
17                X.append(words)
18                Y.append(tags)
19                words = []
20                tags = []
21            else:
22                line = line.split("\t")
23                word = " ".join(line[:-1])
24                tag = line[-1]
25                words.append(word)
26                tags.append(tag)
27    return X, Y
28
29
30def create_confusion_matrix(tagger, X, Y):
31    true_cm = []
32    pred_cm = []
33    label2id = {}
34    for i in range(len(X)):
35        words = X[i]
36        true_tags = Y[i]
37        pred_tags = tagger.decode(words) # decoding
38
39        if true_tags != pred_tags:
40            for true_tag, pred_tag in zip(true_tags, pred_tags):
41                if true_tag != pred_tag:
42                    if true_tag not in label2id:
43                        label2id[true_tag] = len(label2id)
44
45                    if pred_tag not in label2id:
46                        label2id[pred_tag] = len(label2id)
47
48                    true_cm.append(label2id[true_tag])
49                    pred_cm.append(label2id[pred_tag])
50
51    cm = confusion_matrix(true_cm, pred_cm)
52    labels = list(label2id.keys())
53    cm_labeled = pd.DataFrame(cm, columns=labels, index=labels)
54    return cm_labeled
55
56
57if __name__ == "__main__":
58    # load the testset
59    test_X, test_Y = load_file("ja_gsd_ud.test")
60
61    # build the tagger for UD
62    ud_tagger = nagisa.Tagger(vocabs='ja_gsd_ud.vocabs',
63                              params='ja_gsd_ud.params',
64                              hp='ja_gsd_ud.hp')
65
66    # create a confusion matrix if tagger make a mistake in prediction.
67    cm_labeled = create_confusion_matrix(ud_tagger, test_X, test_Y)
68    print(cm_labeled)

This is a confusion matrix if tagger make a mistake in prediction. This confusion matrix shows that the tagger often mistakes “NOUN” for “PROPN” in this UD_Japanese-GDS dataset.

        AUX  VERB  NOUN  ADV  PRON  PART  PUNCT  SYM  ADJ  PROPN  CCONJ  SCONJ  ADP  NUM  INTJ
AUX      0    16     2    0     0     0      0    0    2      0      0      1   25    0     0
VERB    14     0    23    0     1     0      0    0    2      0      0      1    0    0     0
NOUN     0    12     0    5     1     0      1    0   16    101      0      1    1    2     0
ADV      0     2     8    0     0     1      0    0    2      1      2      0    0    0     0
PRON     0     3     6    1     0     0      0    0    1      0      0      0    0    0     0
PART     1     0     4    0     0     0      0    0    0      0      0      0    0    0     0
PUNCT    0     0     2    0     0     0      0    2    0      0      0      0    0    0     0
SYM      0     0     0    0     0     0      0    0    0      0      0      0    0    1     0
ADJ      8     6    41    3     0     1      0    0    0      4      0      0    0    0     0
PROPN    0     2    65    0     0     0      0    0    0      0      1      0    0    1     0
CCONJ    0     0     1    2     0     0      0    0    0      0      0      0    0    0     0
SCONJ    1     0     1    0     0     0      0    0    0      0      0      0    2    0     0
ADP      4     0     0    0     0     0      0    0    0      0      0      7    0    0     0
NUM      0     0     1    0     0     0      0    0    0      0      0      0    0    0     0
INTJ     0     0     0    1     0     0      0    0    0      0      0      0    0    0     0