Word2Vec (6):Pytorch 實作 Skipgram with Negative Sampling

Skipgram with Negative Sampling

skipgram 的思想是用中心詞 center word 去預測兩側的 context words

  • $V$: the vocabulary size
  • $N$ : the embedding dimension
  • $W$: the input side matrix which is $V \times N$
    • each row is the $N$ dimension vector
    • $\text{v}_{w_i}$ is the representation of the input word $w_i$
  • $W’$: the output side matrix which is $N \times V$
    • each column is the $N$ dimension vector
    • $\text{v}^{‘}_{w_j}$ is the j-th column of the matrix $W’$ representing $w_j$

Objective Function

令 $w_I$ 表 input 的 center word ; $w_{O,j}$ 表 target 的 第 $j$ 個 context word。

則 Negative Sampling 下的 objective function

  • $\tilde{w}_i$ 為從 distribution $Q$ sample 出的 word
  • M 為 從 $Q$ sample 出的 $\tilde{w}$ 數量

第一項為 input center word $w_I$ 與 target context word $w_{O,j}$ 產生的 loss

第二項為 negative sample 產生的 loss ,共 sample 出 $M$ 個 word

有興趣看從 softmax 推導到 NEG的,參閱 Word2Vec (3):Negative Sampling 背後的數學

Negative Sample (NEG)

目標是從一個分佈 $Q$ sample 出 word $\tilde{w}$

實作上從 vocabulary $V$ sample 出 ${w}_i$ 的 probability $P(w_i)$ 為

  • $f(w_i)$ 為 $w_i$ 在 corpus 的 frequency count
  • $\alpha$ 為 factor, 通常設為 $0.75$,其作用是 increase the probability for less frequency words and decrease the probability for more frequent words

每個 word $w_i$ 都有個被 sample 出的 probability $P(w_i)$, 目的是從 $P(w)$ sample 出 $M$ 個 word 做為 negative 項

網路上常見的實現方法是調用

1
np.random.multinomial(sample_size, pvals)

此法應該是透過 inverse CDF 來 sample word,每筆 training data 都調用一次的話運算效率不高

Word2Vec 作者 Tomas Mikolov 在他的 c code 中,採用了一種近似方式,其思想是在極大的抽樣次數下 $M = 1e8$,word 的 probability 越高代表其 frequency 越大,也就是在 M 中所占份額 shares 越多。

例如 yellow 的 probability 最大,理應在 M=30 中佔據較多的份額。

  • $P(\text{blue}) = \frac{2}{30}$
  • $P(\text{green}) = \frac{6}{30}$
  • $P(\text{yellow}) = \frac{10}{30}$
  • $P(\text{red}) = \frac{5}{30}$
  • $P(\text{gray}) = \frac{7}{30}$

所以事先準備好一張 size 夠大的 table ($M = 1e8$),根據 word frequency 給予相應的 shares ,真正要 sample word 的時候,只要從 $M$ 中 uniform random 出一個 index $m$ , index $m$ 對應到的 word 就是被 sample 出的 word $\tilde{w}$,是個以空間換取時間的做法。

Seeing is Believing

做了一下測試 ,10000 次迭代,每次取 6 個 negatvie sample 的情景下,Tomas Mikolov 的近似思路比較有效率,而且是碾壓性的

但在 一次 sample 較多 word 的時候,multinomial 較有效率,可能 numpy 內部有做平行化的關係

Pytorch Skipgram with Negative Sampling

Negative Sample

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class NegativeSampler:
def __init__(self, corpus, sample_ratio=0.75):
self.sample_ratio = sample_ratio
self.sample_table = self.__build_sample_table(corpus)
self.table_size = len(self.sample_table)

def __build_sample_table(self, corpus):
counter = dict(Counter(list(itertools.chain.from_iterable(corpus))))
words = np.array(list(counter.keys()))
probs = np.power(np.array(list(counter.values())), self.sample_ratio)
normalizing_factor = probs.sum()
probs = np.divide(probs, normalizing_factor)

sample_table = []

table_size = 1e8
word_share_list = np.round(probs * table_size)
'''
the higher prob, the more shares in sample_table
'''
for w_idx, w_fre in enumerate(word_share_list):
sample_table += [words[w_idx]] * int(w_fre)

# sample_table = np.array(sample_table) // too slow
return sample_table

def generate(self, sample_size=6):
negatvie_samples = [self.sample_table[idx] for idx in np.random.randint(0, self.table_size, sample_size)]
return np.array(negatvie_samples)

In:

1
2
3
sampler = NegativeSampler(corpus)
sampler.generate()

Out:

1
2
array(['visiting', 'defiled', 'thieves', 'beyond', 'lord', 'fill'],
dtype='<U18')

Skipgram + NEG

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class SkipGramNEG(nn.Module):
def __init__(self, vocab_size, embedding_dim):
super().__init__()
self.vocab_size = vocab_size
self.embedding_dim = embedding_dim
self.syn0 = nn.Embedding(vocab_size, embedding_dim) # |V| x |K|
self.neg_syn1 = nn.Embedding(vocab_size, embedding_dim) # |V| x |K|
torch.nn.init.constant_(self.neg_syn1.weight.data, val=0)

def forward(self, center: torch.Tensor, context: torch.Tensor, negative_samples: torch.Tensor):
# center : [b_size, 1]
# context: [b_size, 1]
# negative_sample: [b_size, negative_sample_num]
embd_center = self.syn0(center) # [b_size, 1, embedding_dim]
embd_context = self.neg_syn1(context) # [b_size, 1, embedding_dim]
embd_negative_sample = self.neg_syn1(negative_samples) # [b_size, negative_sample_num, embedding_dim]

prod_p = (embd_center * embd_context).sum(dim=1).squeeze() # [b_size]
loss_p = F.logsigmoid(prod_p).mean() # 1


prod_n = (embd_center * embd_negative_sample).sum(dim=2) # [b_size, negative_sample_num]
loss_n = F.logsigmoid(-prod_n).sum(dim=1).mean() # 1
return -(loss_p + loss_n)
  • syn0 對應到 input side 的 matrix $W$
  • neg_syn1 對應到 output side 的 matrix $W’$
    • Tomas Mikolov 在 WordVec c code 初始化為 0
  • loss function
    • loss_p 對應到 $\log \sigma(\text{v}^\top_{w_I} \text{v}’_{w_{O,j}})$
    • loos_n 對應到 $\sum^M_{\substack{i=1 \\ \tilde{w}_i \sim Q}}\exp(\text{v}^\top_{w_I} \text{v}’_{\tilde{w}_i})$

Training Skipgram + Negative Sampling

訓練過程省略,參閱 notebook

seed9D/hands-on-machine-learning

Evaluation

取回 embedding

簡單的把 syn0 跟 neg_syn1 平均

1
2
3
4
5
6
7
syn0 = model.syn0.weight.data
neg_syn1 = model.neg_syn1.weight.data

w2v_embedding = (syn0 + neg_syn1) / 2
w2v_embedding = w2v_embedding.numpy()
l2norm = np.linalg.norm(w2v_embedding, 2, axis=1, keepdims=True)
w2v_embedding = w2v_embedding / l2norm

Cosine similarity

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
class CosineSimilarity:
def __init__(self, word_embedding, idx_to_word_dict, word_to_idx_dict):
self.word_embedding = word_embedding # normed already
self.idx_to_word_dict = idx_to_word_dict
self.word_to_idx_dict = word_to_idx_dict

def get_synonym(self, word, topK=10):
idx = self.word_to_idx_dict[word]
embed = self.word_embedding[idx]

cos_similairty = w2v_embedding @ embed

topK_index = np.argsort(-cos_similairty)[:topK]
pairs = []
for i in topK_index:
w = self.idx_to_word_dict[i]
pairs.append((w, cos_similairty[i]))
return pairs

訓練語料是聖經,看看 jesus 跟 christ 的相近詞

In:

1
2
cosinSim = CosineSimilarity(w2v_embedding, idx_to_word, word_to_idx)
cosinSim.get_synonym('christ')

Out:

1
2
3
4
5
6
7
8
9
10
[('christ', 1.0),
('jesus', 0.7170907),
('gospel', 0.4621805),
('peter', 0.39412546),
('disciples', 0.3873747),
('noise', 0.28152165),
('asleep', 0.26372147),
('taught', 0.2422184),
('zarhites', 0.24168596),
('nobles', 0.23950878)]

In:

1
cosinSim.get_synonym('jesus')

out:

1
2
3
4
5
6
7
8
9
10
[('jesus', 1.0),
('christ', 0.7170907),
('gospel', 0.5360588),
('peter', 0.3603956),
('disciples', 0.3460646),
('church', 0.2755898),
('passed', 0.24744174),
('noise', 0.23768528),
('preach', 0.23454829),
('send', 0.2337867)]

Reference

Word2Vec (6):Pytorch 實作 Skipgram with Negative Sampling

https://seed9d.github.io/Pytorch-Implement-Skipgram-with-Negative-Sampling/

Author

seed9D

Posted on

2021-01-31

Updated on

2021-02-10

Licensed under


Comments