Commit 3b2e470a authored by Brian McMahan's avatar Brian McMahan
Browse files

small fixes; seed settings and readme instructions

parent c8ff14f7
%% Cell type:code id: tags:
``` python
from annoy import AnnoyIndex
import numpy as np
import torch
from tqdm import tqdm_notebook
from argparse import Namespace
```
%% Cell type:code id: tags:
``` python
args = Namespace(
glove_filename='../data/glove.6B.100d.txt'
)
```
%% Cell type:code id: tags:
``` python
def load_word_vectors(filename):
word_to_index = {}
word_vectors = []
with open(filename) as fp:
for line in tqdm_notebook(fp.readlines(), leave=False):
line = line.split(" ")
word = line[0]
word_to_index[word] = len(word_to_index)
vec = np.array([float(x) for x in line[1:]])
word_vectors.append(vec)
return word_to_index, word_vectors
```
%% Cell type:code id: tags:
``` python
class PreTrainedEmbeddings(object):
def __init__(self, glove_filename):
self.word_to_index, self.word_vectors = load_word_vectors(glove_filename)
self.word_vector_size = len(self.word_vectors[0])
self.index_to_word = {v: k for k, v in self.word_to_index.items()}
self.index = AnnoyIndex(self.word_vector_size, metric='euclidean')
print('Building Index')
for _, i in tqdm_notebook(self.word_to_index.items(), leave=False):
self.index.add_item(i, self.word_vectors[i])
self.index.build(50)
print('Finished!')
def get_embedding(self, word):
return self.word_vectors[self.word_to_index[word]]
def closest(self, word, n=1):
vector = self.get_embedding(word)
nn_indices = self.index.get_nns_by_vector(vector, n)
return [self.index_to_word[neighbor] for neighbor in nn_indices]
def closest_v(self, vector, n=1):
nn_indices = self.index.get_nns_by_vector(vector, n)
return [self.index_to_word[neighbor] for neighbor in nn_indices]
def sim(self, w1, w2):
return np.dot(self.get_embedding(w1), self.get_embedding(w2))
```
%% Cell type:code id: tags:
``` python
glove = PreTrainedEmbeddings(args.glove_filename)
```
%% Output
Building Index
Finished!
%% Cell type:code id: tags:
``` python
glove.closest('apple', n=5)
```
%% Output
['apple', 'microsoft', 'dell', 'pc', 'compaq']
%% Cell type:code id: tags:
``` python
glove.closest('plane', n=5)
```
%% Output
['plane', 'airplane', 'jet', 'flight', 'crashed']
%% Cell type:code id: tags:
``` python
glove.sim('beer', 'wine'), glove.sim('beer', 'gasoline')
```
%% Output
(26.873448266652, 16.501491855324)
%% Cell type:markdown id: tags:
** Lexical relationships uncovered by word embeddings **
%% Cell type:code id: tags:
``` python
def SAT_analogy(w1, w2, w3):
'''
Solves problems of the type:
w1 : w2 :: w3 : __