Word2Vec (5):Pytorch 實作 CBOW with Hierarchical Softmax

CBOW with Hierarchical Softmax

CBOW 的思想是用兩側 context words 去預測中間的 center word

  • $V$: the vocabulary size
  • $N$ : the embedding dimension
  • $W$: the input side matrix which is $V \times N$
    • each row is the $N$ dimension vector
    • $\text{v}_{w_i}$ is the representation of the input word $w_i$
  • $W’$: the output side matrix which is $N \times V$
    • $\text{v}’_j$ 表 $W’$ 中 j-th columns vector
    • 在 Hierarchical softmax 中, $W’$ each column 表 huffman tree 的 non leaf node 的 vector 而不是 leaf node ,跟 column vector $\text{v}’_j$ 與 word $w_i$ 不是直接對應的關係

Objective Function

Huffman Tree

令 $w_{I,j}$ 表 input 的 第 $j$ 個 context word; $w_O$ 表 target 的 center word

則 Hierarchical Softmax 下的 objective function

  • $L(w_i) -1$ 表 huffman tree 中從 root node 到 leaf node of $w_i$ 的 node number
  • $[\cdot]$表 huffman tree 的分岔判斷
    • $[\cdot] = 1$ 表 turn left
    • $[\cdot ] = -1$ 表 turn right
  • $h = \frac {1}{C} \sum^{C}_{j=1}\text{v}_{w_{I,j}}$ average of all context word vector $w_{I,j}$

詳細推導請見 Word2Vec (2):Hierarchical Softmax 背後的數學

透過 Hierarchical Softmax,因爲 huffman tree 為 full binary tree, time complexity 降成 $\log_2|V|$

Pytorch CBOW with Hierarchical Softmax

Building Huffman Tree

Huffman Tree 建樹過程

HuffmanTree >folded
1
2

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124

class HuffmanTree:
def __init__(self, fre_dict):
self.root = None
freq_dict = sorted(fre_dict.items(), key=lambda x:x[1], reverse=True)
self.vocab_size = len(freq_dict)
self.node_dict = {}
self._build_tree(freq_dict)

def _build_tree(self, freq_dict):
'''
freq_dict is in decent order
node_list: two part: [leaf node :: internal node]
leaf node is sorting by frequency in decent order;
'''

node_list = [HuffmanNode(is_leaf=True, value=w, fre=fre) for w, fre in freq_dict] # create leaf node
node_list += [HuffmanNode(is_leaf=False, fre=1e10) for i in range(self.vocab_size)] # create non-leaf node

parentNode = [0] * (self.vocab_size * 2) # only 2 * vocab_size - 2 be used
binary = [0] * (self.vocab_size * 2) # recording turning left or turning right

'''
pos1 points to currently processing leaf node at left side of node_list
pos2 points to currently processing non-leaf node at right side of node_list
'''

pos1 = self.vocab_size - 1
pos2 = self.vocab_size

'''
each iteration picks two node from node_list
the first pick assigns to min1i
the second pick assigns to min2i

min2i's frequency is always larger than min1i
'''
min1i = 0
min2i = 0
'''
the main process of building huffman tree
'''
for a in range(self.vocab_size - 1):
'''
first pick assigns to min1i
'''
if pos1 >= 0:
if node_list[pos1].fre < node_list[pos2].fre:
min1i = pos1
pos1 -= 1
else:
min1i = pos2
pos2 += 1
else:
min1i = pos2
pos2 += 1

'''
second pick assigns to min2i
'''
if pos1 >= 0:
if node_list[pos1].fre < node_list[pos2].fre:
min2i = pos1
pos1 -= 1
else:
min2i = pos2
pos2 += 1
else:
min2i = pos2
pos2 += 1

''' fill information of non leaf node '''
node_list[self.vocab_size + a].fre = node_list[min1i].fre + node_list[min2i].fre
node_list[self.vocab_size + a].left = node_list[min1i]
node_list[self.vocab_size + a].right = node_list[min2i]

'''
the parent node always is non leaf node
assigen lead child (min2i) and right child (min1i) to parent node
'''
parentNode[min1i] = self.vocab_size + a # max index = 2 * vocab_size - 2
parentNode[min2i] = self.vocab_size + a
binary[min2i] = 1

'''generate huffman code of each leaf node '''
for a in range(self.vocab_size):
b = a
i = 0
code = []
point = []

'''
backtrace path from current node until root node. (bottom up)
'root node index' in node_list is 2 * vocab_size - 2
'''
while b != self.vocab_size * 2 - 2:
code.append(binary[b])
b = parentNode[b]
# point recording the path index from leaf node to root, the length of point is less 1 than the length of code
point.append(b)

'''
huffman code should be top down, so we reverse it.
'''
node_list[a].code_len = len(code)
node_list[a].code = list(reversed(code))


'''
1. Recording the path from root to leaf node (top down).

2.The actual index value should be shifted by self.vocab_size,
because we need the index starting from zero to mapping non-leaf node

3. In case of full binary tree, the number of non leaf node always equals to vocab_size - 1.
The index of BST root node in node_list is 2 * vocab_size - 2,
and we shift vocab_size to get the actual index of root node: vocab_size - 2
'''
node_list[a].node_path = list(reversed([p - self.vocab_size for p in point]))

self.node_dict[node_list[a].value] = node_list[a]

self.root = node_list[2 * vocab_size - 2]


建樹過程參考 Word2Vec 作者 Tomas Mikolov 的 c code,思路如下:

  1. 建一個 Array,左半邊放 leaf node ,右半邊放 non leaf node
    • leaf node 按照 frequency 降序排列
  2. bottom up building tree
    • 從 Array 中間位置向右半邊填 non leaf node
    • each iteration 都從 leaf node 跟 已填完的 non leaf node 找兩個 frequency 最小的 node,做為 child node 填入當下 non leaf node

Hierarchical Softmax

用 huffman tree 實作 Hierarchical Softmax

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
class HierarchicalSoftmaxLayer(nn.Module):
def __init__(self, vocab_size, embedding_dim, freq_dict):
super().__init__()
## in w2v c implement, syn1 initial with all zero
self.vocab_size = vocab_size
self.embedding_dim = embedding_dim
self.syn1 = nn.Embedding(
num_embeddings=vocab_size + 1,
embedding_dim=embedding_dim,
padding_idx=vocab_size

)
torch.nn.init.constant_(self.syn1.weight.data, val=0)
self.huffman_tree = HuffmanTree(freq_dict)

def forward(self, neu1, target):
# neu1: [b_size, embedding_dim]
# target: [b_size, 1]

# turns:[b_size, max_code_len_in_batch]
# paths: [b_size, max_code_len_in_batch]
turns, paths = self._get_turns_and_paths(target)
paths_emb = self.syn1(paths) # [b_size, max_code_len_in_batch, embedding_dim]

loss = -F.logsigmoid(
(turns.unsqueeze(2) * paths_emb * neu1.unsqueeze(1)).sum(2)).sum(1).mean()
return loss

def _get_turns_and_paths(self, target):
turns = [] # turn right(1) or turn left(-1) in huffman tree
paths = []
max_len = 0

for n in target:
n = n.item()
node = self.huffman_tree.node_dict[n]

code = target.new_tensor(node.code).int() # in code, left node is 0; right node is 1
turn = torch.where(code == 1, code, -torch.ones_like(code))

turns.append(turn)
paths.append(target.new_tensor(node.node_path))

if node.code_len > max_len:
max_len = node.code_len


turns = [F.pad(t, pad=(0, max_len - len(t)), mode='constant', value=0) for t in turns]
paths = [F.pad(p, pad=(0, max_len - p.shape[0]), mode='constant', value=net.hs.vocab_size) for p in paths]
return torch.stack(turns).int(), torch.stack(paths).long()
  • syn1 表 $W’$ 裡面的 vector 對應到 huffman tree non leaf node 的 vector
    • 實作上 $W’$ row vector 才有意義
  • neu1 即 $\text{h}$ 為 hidden layer 的輸出
  • target 為 center word $w_O$
  • function _get_turns_and_paths 中
    • 實作時 -1 表 turn left ; 1 表 turn right ,其實兩者只要相反就好,因爲對於 binary classification
      • $p(\text{true}) = \sigma(x)$ ⇒ $p(\text{false}) = 1- \sigma(x) = \sigma(-x)$
      • 只是 $\sigma$ 裡的正負號對換而已

CBOW + Hierarchical Softmax

1
2
3
4
5
6
7
8
9
10
11
12
13
class CBOWHierarchicalSoftmax(nn.Module):
def __init__(self, vocab_size, embedding_dim, fre_dict):
super().__init__()
self.syn0 = nn.Embedding(vocab_size, embedding_dim)
self.hs = HierarchicalSoftmaxLayer(vocab_size, embedding_dim, fre_dict)
torch.nn.init.xavier_uniform_(self.syn0.weight.data)

def forward(self, context, target):
# context: [b_size, 2 * window_size]
# target: [b_size]
neu1 = self.syn0(context.long()).mean(dim=1) # [b_size, embedding_dim]
loss = self.hs(neu1, target.long())
return loss
  • neu1 為 average of context words’ vector

Training

訓練過程省略,有興趣請見 notebook

seed9D/hands-on-machine-learning

Evaluation

訓練語料是聖經,看看 jesus 跟 christ 的相近詞

In :

1
2
cosinSim = CosineSimilarity(w2v_embedding, idx_to_word, word_to_idx)
cosinSim.get_synonym('christ')

out:

1
2
3
4
5
6
7
8
9
10
[('christ', 1.0),
('hope', 0.78780156),
('gospel', 0.7656436),
('jesus', 0.74575657),
('faith', 0.7190881),
('godliness', 0.7005944),
('offences', 0.70045626),
('grace', 0.6946964),
('dear', 0.666232),
('willing', 0.66131693)]

In

1
cosinSim.get_synonym('jesus')

Out

1
2
3
4
5
6
7
8
9
10
[('jesus', 0.9999999),
('gospel', 0.8051339),
('grace', 0.75879383),
('church', 0.7542972),
('christ', 0.74575657),
('manifest', 0.7415799),
('believed', 0.7215627),
('faith', 0.7198993),
('godliness', 0.7091305),
('john', 0.7015951)]

In

1
cosinSim.get_synonym('god')

Out

1
2
3
4
5
6
7
8
9
10
[('jesus', 0.9999999),
('gospel', 0.8051339),
('grace', 0.75879383),
('church', 0.7542972),
('christ', 0.74575657),
('manifest', 0.7415799),
('believed', 0.7215627),
('faith', 0.7198993),
('godliness', 0.7091305),
('john', 0.7015951)]

Reference

Word2Vec (5):Pytorch 實作 CBOW with Hierarchical Softmax

https://seed9d.github.io/Pytorch-Implement-CBOW-with-Hierarchical-Softmax/

Author

seed9D

Posted on

2021-01-31

Updated on

2021-02-10

Licensed under


Comments