Commit 3b2e470a authored by Brian McMahan's avatar Brian McMahan
Browse files

small fixes; seed settings and readme instructions

parent c8ff14f7
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from annoy import AnnoyIndex\n",
"import numpy as np\n",
"import torch\n",
"from tqdm import tqdm_notebook\n",
"from argparse import Namespace"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"args = Namespace(\n",
" glove_filename='../data/glove.6B.100d.txt'\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def load_word_vectors(filename):\n",
" word_to_index = {}\n",
" word_vectors = []\n",
" \n",
" with open(filename) as fp:\n",
" for line in tqdm_notebook(fp.readlines(), leave=False):\n",
" line = line.split(\" \")\n",
" \n",
" word = line[0]\n",
" word_to_index[word] = len(word_to_index)\n",
" \n",
" vec = np.array([float(x) for x in line[1:]])\n",
" word_vectors.append(vec)\n",
" \n",
" return word_to_index, word_vectors"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"class PreTrainedEmbeddings(object):\n",
" def __init__(self, glove_filename):\n",
" self.word_to_index, self.word_vectors = load_word_vectors(glove_filename)\n",
" self.word_vector_size = len(self.word_vectors[0])\n",
" \n",
" self.index_to_word = {v: k for k, v in self.word_to_index.items()}\n",
" self.index = AnnoyIndex(self.word_vector_size, metric='euclidean')\n",
" print('Building Index')\n",
" for _, i in tqdm_notebook(self.word_to_index.items(), leave=False):\n",
" self.index.add_item(i, self.word_vectors[i])\n",
" self.index.build(50)\n",
" print('Finished!')\n",
" \n",
" def get_embedding(self, word):\n",
" return self.word_vectors[self.word_to_index[word]]\n",
" \n",
" def closest(self, word, n=1):\n",
" vector = self.get_embedding(word)\n",
" nn_indices = self.index.get_nns_by_vector(vector, n)\n",
" return [self.index_to_word[neighbor] for neighbor in nn_indices]\n",
" \n",
" def closest_v(self, vector, n=1):\n",
" nn_indices = self.index.get_nns_by_vector(vector, n)\n",
" return [self.index_to_word[neighbor] for neighbor in nn_indices]\n",
" \n",
" def sim(self, w1, w2):\n",
" return np.dot(self.get_embedding(w1), self.get_embedding(w2))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, max=400000), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"Building Index\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, max=400000), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Finished!\n"
]
}
],
"source": [
"glove = PreTrainedEmbeddings(args.glove_filename)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"['apple', 'microsoft', 'dell', 'pc', 'compaq']"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"glove.closest('apple', n=5)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"['plane', 'airplane', 'jet', 'flight', 'crashed']"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"glove.closest('plane', n=5)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(26.873448266652, 16.501491855324)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"glove.sim('beer', 'wine'), glove.sim('beer', 'gasoline')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"** Lexical relationships uncovered by word embeddings **"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"def SAT_analogy(w1, w2, w3):\n",
" '''\n",
" Solves problems of the type:\n",
" w1 : w2 :: w3 : __\n",
" '''\n",
" closest_words = []\n",
" try:\n",
" w1v = glove.get_embedding(w1)\n",
" w2v = glove.get_embedding(w2)\n",
" w3v = glove.get_embedding(w3)\n",
" w4v = w3v + (w2v - w1v)\n",
" closest_words = glove.closest_v(w4v, n=5)\n",
" closest_words = [w for w in closest_words if w not in [w1, w2, w3]]\n",
" except:\n",
" pass\n",
" if len(closest_words) == 0:\n",
" print(':-(')\n",
" else:\n",
" the_closest_word = closest_words[0]\n",
" print('{} : {} :: {} : {}'.format(w1, w2, w3, the_closest_word))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Pronouns**"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"man : he :: woman : she\n"
]
}
],
"source": [
"SAT_analogy('man', 'he', 'woman')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"** Verb-Noun relationships **"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"fly : plane :: sail : ship\n"
]
}
],
"source": [
"SAT_analogy('fly', 'plane', 'sail')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Noun-Noun relationships**"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"cat : kitten :: dog : pug\n"
]
}
],
"source": [
"SAT_analogy('cat', 'kitten', 'dog')"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"human : baby :: dog : puppy\n"
]
}
],
"source": [
"SAT_analogy('human', 'baby', 'dog')"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"human : babies :: dog : puppies\n"
]
}
],
"source": [
"SAT_analogy('human', 'babies', 'dog')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Hypernymy**"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"blue : color :: dog : animal\n"
]
}
],
"source": [
"SAT_analogy('blue', 'color', 'dog')"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": true
},
"source": [
"**Meronymy**"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"leg : legs :: hand : hands\n"
]
}
],
"source": [
"SAT_analogy('leg', 'legs', 'hand')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Troponymy**"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"talk : communicate :: read : correctly\n"
]
}
],
"source": [
"SAT_analogy('talk', 'communicate', 'read')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Metonymy**"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"blue : democrat :: red : republican\n"
]
}
],
"source": [
"SAT_analogy('blue', 'democrat', 'red')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Misc**"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"man : doctor :: woman : nurse\n"
]
}
],
"source": [
"SAT_analogy('man', 'doctor', 'woman')"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"man : leader :: woman : opposition\n"
]
}
],
"source": [
"SAT_analogy('man', 'leader', 'woman')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "magis",
"language": "python",
"name": "magis"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
from collections import Counter
import numpy as np
from torch.utils.data import Dataset
import six
import json
class Vocabulary(object):
"""
An implementation that manages the interface between a token dataset and the
machine learning algorithm.
"""
def __init__(self, use_unks=False, unk_token="<UNK>",
use_mask=False, mask_token="<MASK>", use_start_end=False,
start_token="<START>", end_token="<END>"):
"""
Args:
use_unks (bool): The vocabulary will output UNK tokens for out of
vocabulary items.
[default=False]
unk_token (str): The token used for unknown tokens.
If `use_unks` is True, this will be added to the vocabulary.
[default='<UNK>']
use_mask (bool): The vocabulary will reserve the 0th index for a mask token.
This is used to handle variable lengths in sequence models.
[default=False]
mask_token (str): The token used for the mask.
Note: mostly a placeholder; it's unlikely the token will be seen.
[default='<MASK>']
use_start_end (bool): The vocabulary will reserve indices for two tokens
that represent the start and end of a sequence.
[default=False]
start_token: The token used to indicate the start of a sequence.
If `use_start_end` is True, this will be added to the vocabulary.
[default='<START>']
end_token: The token used to indicate the end of a sequence
If `use_start_end` is True, this will be added to the vocabulary.
[default='<END>']
"""
self._mapping = {} # str -> int
self._flip = {} # int -> str;
self._i = 0
self._frozen = False
# mask token for use in masked recurrent networks
# usually need to be the 0th index
self.use_mask = use_mask
self.mask_token = mask_token
if self.use_mask:
self.add(self.mask_token)
# unk token for out of vocabulary tokens
self.use_unks = use_unks
self.unk_token = unk_token
if self.use_unks:
self.add(self.unk_token)
# start token for sequence models
self.use_start_end = use_start_end
self.start_token = start_token
self.end_token = end_token
if self.use_start_end:
self.add(self.start_token)
self.add(self.end_token)
def iterkeys(self):
for k in self._mapping.keys():
if k == self.unk_token or k == self.mask_token:
continue
else:
yield k
def keys(self):
return list(self.iterkeys())
def iteritems(self):
for key, value in self._mapping.items():
if key == self.unk_token or key == self.mask_token:
continue
yield key, value
def items(self):
return list(self.iteritems())
def values(self):
return [value for _, value in self.iteritems()]
def __getitem__(self, k):
if self._frozen:
if k in self._mapping:
out_index = self._mapping[k]
elif self.use_unks:
out_index = self.unk_index
else: # case: frozen, don't want unks, raise exception