Extreme Classification
This demo aims to label new texts automatically when the number of possible labels is enormous. This scenario is known as extreme classification, a supervised learning variant that deals with multi-class and multi-label problems involving many choices.
Examples for applying extreme classification are labeling a new article with Wikipedia's topical labels, matching web content with a set of relevant advertisements, classifying product descriptions with catalog labels, and classifying a resume into a collection of pertinent job titles.
Here's how we'll perform extreme classification:
- We'll transform 250,000 labels into vector embeddings using a publicly available embedding model and upload them into a managed vector index.
- Then we'll take an article that requires labeling and transform it into a vector embedding using the same model.
- We'll use that article's vector embedding as the query to search the vector index. In effect, this will retrieve the most similar labels to the article's semantic content.
- With the most relevant labels retrieved, we can automatically apply them to the article.
Let's get started!
Dependencies
!pip install -qU pinecone-client ipywidgets setuptools>=36.2.1 wikitextparser unidecode
!pip install -qU sentence-transformers --no-cache-dir
import os
import re
import gzip
import json
import pandas as pd
import numpy as np
from wikitextparser import remove_markup, parse
from sentence_transformers import SentenceTransformer
from unidecode import unidecode
Setting up Pinecone's Similarity Search Service
Here we set up our similarity search service. We assume you are familiar with Pinecone's quick start tutorial.
import pinecone
# Load Pinecone API key
api_key = os.getenv("PINECONE_API_KEY") or "YOUR_API_KEY"
pinecone.init(api_key=api_key, environment='YOUR_ENVIRONMENT')
# List all existing indices for you API key
pinecone.list_indexes()
[]
Get a Pinecone API key if you don’t have one. You can find your environment in the Pinecone console under API Keys.
# Pick a name for the new index
index_name = 'extreme-ml'
# Check whether the index with the same name already exists
if index_name in pinecone.list_indexes():
pinecone.delete_index(index_name)
# Create a new vector index
pinecone.create_index(name=index_name, dimension=300)
# Connect to the created index
index = pinecone.Index(index_name)
# Print index statistics
index.describe_index_stats()
{'dimension': 300, 'namespaces': {'': {'vector_count': 139500}}}
Data Preparation
In this demo, we classify Wikipedia articles using a standard dataset from an extreme classification benchmarking resource. The data used in this example is Wikipedia-500k which contains around 500,000 labels. Here, we will download the raw data and prepare it for the classification task.
# Download train dataset
!wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=10RBSf6nC9C38wUMwqWur2Yd8mCwtup5K' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=10RBSf6nC9C38wUMwqWur2Yd8mCwtup5K" -O 'trn.raw.json.gz' && rm -rf /tmp/cookies.txt
# Download test dataset
!wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1pEyKXtkwHhinuRxmARhtwEQ39VIughDf' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1pEyKXtkwHhinuRxmARhtwEQ39VIughDf" -O 'tst.raw.json.gz' && rm -rf /tmp/cookies.txt
# Download categories labels file
!wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1ZYTZPlnkPBCMcNqRRO-gNx8EPgtV-GL3' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1ZYTZPlnkPBCMcNqRRO-gNx8EPgtV-GL3" -O 'Yf.txt' && rm -rf /tmp/cookies.txt
# Create and move downloaded files to data folder
!mkdir data
!mv 'trn.raw.json.gz' 'tst.raw.json.gz' 'Yf.txt' data
--2022-02-09 17:13:45-- https://docs.google.com/uc?export=download&confirm=arfw&id=10RBSf6nC9C38wUMwqWur2Yd8mCwtup5K
Resolving docs.google.com (docs.google.com)... 142.251.107.101, 142.251.107.113, 142.251.107.100, ...
Connecting to docs.google.com (docs.google.com)|142.251.107.101|:443... connected.
HTTP request sent, awaiting response... 302 Moved Temporarily
Location: https://doc-14-5k-docs.googleusercontent.com/docs/securesc/mi75es4ss9f8cbmlkasp0714ekfl64em/oiufcg4j2mku0ucr6a89me2tql3td1v1/1644426825000/06283569454216238406/01276505903269316155Z/10RBSf6nC9C38wUMwqWur2Yd8mCwtup5K?e=download [following]
--2022-02-09 17:13:45-- https://doc-14-5k-docs.googleusercontent.com/docs/securesc/mi75es4ss9f8cbmlkasp0714ekfl64em/oiufcg4j2mku0ucr6a89me2tql3td1v1/1644426825000/06283569454216238406/01276505903269316155Z/10RBSf6nC9C38wUMwqWur2Yd8mCwtup5K?e=download
Resolving doc-14-5k-docs.googleusercontent.com (doc-14-5k-docs.googleusercontent.com)... 173.194.210.132, 2607:f8b0:400c:c0f::84
Connecting to doc-14-5k-docs.googleusercontent.com (doc-14-5k-docs.googleusercontent.com)|173.194.210.132|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://docs.google.com/nonceSigner?nonce=l2fmr07iln170&continue=https://doc-14-5k-docs.googleusercontent.com/docs/securesc/mi75es4ss9f8cbmlkasp0714ekfl64em/oiufcg4j2mku0ucr6a89me2tql3td1v1/1644426825000/06283569454216238406/01276505903269316155Z/10RBSf6nC9C38wUMwqWur2Yd8mCwtup5K?e%3Ddownload&hash=46dcmopdem1mp2anvp98snh4203mij96 [following]
--2022-02-09 17:13:45-- https://docs.google.com/nonceSigner?nonce=l2fmr07iln170&continue=https://doc-14-5k-docs.googleusercontent.com/docs/securesc/mi75es4ss9f8cbmlkasp0714ekfl64em/oiufcg4j2mku0ucr6a89me2tql3td1v1/1644426825000/06283569454216238406/01276505903269316155Z/10RBSf6nC9C38wUMwqWur2Yd8mCwtup5K?e%3Ddownload&hash=46dcmopdem1mp2anvp98snh4203mij96
Connecting to docs.google.com (docs.google.com)|142.251.107.101|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://doc-14-5k-docs.googleusercontent.com/docs/securesc/mi75es4ss9f8cbmlkasp0714ekfl64em/oiufcg4j2mku0ucr6a89me2tql3td1v1/1644426825000/06283569454216238406/01276505903269316155Z/10RBSf6nC9C38wUMwqWur2Yd8mCwtup5K?e=download&nonce=l2fmr07iln170&user=01276505903269316155Z&hash=jgf078p8n3lll6ulirsoqj13g645cf9q [following]
--2022-02-09 17:13:45-- https://doc-14-5k-docs.googleusercontent.com/docs/securesc/mi75es4ss9f8cbmlkasp0714ekfl64em/oiufcg4j2mku0ucr6a89me2tql3td1v1/1644426825000/06283569454216238406/01276505903269316155Z/10RBSf6nC9C38wUMwqWur2Yd8mCwtup5K?e=download&nonce=l2fmr07iln170&user=01276505903269316155Z&hash=jgf078p8n3lll6ulirsoqj13g645cf9q
Connecting to doc-14-5k-docs.googleusercontent.com (doc-14-5k-docs.googleusercontent.com)|173.194.210.132|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 5292805889 (4.9G) [application/x-gzip]
Saving to: ‘trn.raw.json.gz’
trn.raw.json.gz 100%[===================>] 4.93G 120MB/s in 57s
2022-02-09 17:14:42 (89.0 MB/s) - ‘trn.raw.json.gz’ saved [5292805889/5292805889]
--2022-02-09 17:14:43-- https://docs.google.com/uc?export=download&confirm=uE_3&id=1pEyKXtkwHhinuRxmARhtwEQ39VIughDf
Resolving docs.google.com (docs.google.com)... 142.251.107.101, 142.251.107.113, 142.251.107.100, ...
Connecting to docs.google.com (docs.google.com)|142.251.107.101|:443... connected.
HTTP request sent, awaiting response... 302 Moved Temporarily
Location: https://doc-0c-2o-docs.googleusercontent.com/docs/securesc/5m4vneh8ah734pahp3qbjncg6mmp89ta/d5qkmrodksuck3tqaeh82prkj3v26vfe/1644426825000/06283569454216238406/08808106369581203619Z/1pEyKXtkwHhinuRxmARhtwEQ39VIughDf?e=download [following]
--2022-02-09 17:14:43-- https://doc-0c-2o-docs.googleusercontent.com/docs/securesc/5m4vneh8ah734pahp3qbjncg6mmp89ta/d5qkmrodksuck3tqaeh82prkj3v26vfe/1644426825000/06283569454216238406/08808106369581203619Z/1pEyKXtkwHhinuRxmARhtwEQ39VIughDf?e=download
Resolving doc-0c-2o-docs.googleusercontent.com (doc-0c-2o-docs.googleusercontent.com)... 173.194.210.132, 2607:f8b0:400c:c0f::84
Connecting to doc-0c-2o-docs.googleusercontent.com (doc-0c-2o-docs.googleusercontent.com)|173.194.210.132|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://docs.google.com/nonceSigner?nonce=fv89ild8hgp3u&continue=https://doc-0c-2o-docs.googleusercontent.com/docs/securesc/5m4vneh8ah734pahp3qbjncg6mmp89ta/d5qkmrodksuck3tqaeh82prkj3v26vfe/1644426825000/06283569454216238406/08808106369581203619Z/1pEyKXtkwHhinuRxmARhtwEQ39VIughDf?e%3Ddownload&hash=algva8fi1m74v18nhdhve6o38458h8bo [following]
--2022-02-09 17:14:43-- https://docs.google.com/nonceSigner?nonce=fv89ild8hgp3u&continue=https://doc-0c-2o-docs.googleusercontent.com/docs/securesc/5m4vneh8ah734pahp3qbjncg6mmp89ta/d5qkmrodksuck3tqaeh82prkj3v26vfe/1644426825000/06283569454216238406/08808106369581203619Z/1pEyKXtkwHhinuRxmARhtwEQ39VIughDf?e%3Ddownload&hash=algva8fi1m74v18nhdhve6o38458h8bo
Connecting to docs.google.com (docs.google.com)|142.251.107.101|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://doc-0c-2o-docs.googleusercontent.com/docs/securesc/5m4vneh8ah734pahp3qbjncg6mmp89ta/d5qkmrodksuck3tqaeh82prkj3v26vfe/1644426825000/06283569454216238406/08808106369581203619Z/1pEyKXtkwHhinuRxmARhtwEQ39VIughDf?e=download&nonce=fv89ild8hgp3u&user=08808106369581203619Z&hash=d3m02ho8p665cjtl094bjkqk6g1qftj1 [following]
--2022-02-09 17:14:43-- https://doc-0c-2o-docs.googleusercontent.com/docs/securesc/5m4vneh8ah734pahp3qbjncg6mmp89ta/d5qkmrodksuck3tqaeh82prkj3v26vfe/1644426825000/06283569454216238406/08808106369581203619Z/1pEyKXtkwHhinuRxmARhtwEQ39VIughDf?e=download&nonce=fv89ild8hgp3u&user=08808106369581203619Z&hash=d3m02ho8p665cjtl094bjkqk6g1qftj1
Connecting to doc-0c-2o-docs.googleusercontent.com (doc-0c-2o-docs.googleusercontent.com)|173.194.210.132|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2297151115 (2.1G) [application/x-gzip]
Saving to: ‘tst.raw.json.gz’
tst.raw.json.gz 100%[===================>] 2.14G 130MB/s in 15s
2022-02-09 17:14:59 (141 MB/s) - ‘tst.raw.json.gz’ saved [2297151115/2297151115]
--2022-02-09 17:15:01-- https://docs.google.com/uc?export=download&confirm=&id=1ZYTZPlnkPBCMcNqRRO-gNx8EPgtV-GL3
Resolving docs.google.com (docs.google.com)... 173.194.210.139, 173.194.210.138, 173.194.210.102, ...
Connecting to docs.google.com (docs.google.com)|173.194.210.139|:443... connected.
HTTP request sent, awaiting response... 302 Moved Temporarily
Location: https://doc-04-7s-docs.googleusercontent.com/docs/securesc/55p9v9r892nth323knj0kpu3c0fh68iu/jne4sjlttqu450ikt8fph58j505oapok/1644426900000/06283569454216238406/07409829577848409351Z/1ZYTZPlnkPBCMcNqRRO-gNx8EPgtV-GL3?e=download [following]
--2022-02-09 17:15:02-- https://doc-04-7s-docs.googleusercontent.com/docs/securesc/55p9v9r892nth323knj0kpu3c0fh68iu/jne4sjlttqu450ikt8fph58j505oapok/1644426900000/06283569454216238406/07409829577848409351Z/1ZYTZPlnkPBCMcNqRRO-gNx8EPgtV-GL3?e=download
Resolving doc-04-7s-docs.googleusercontent.com (doc-04-7s-docs.googleusercontent.com)... 173.194.210.132, 2607:f8b0:400c:c0f::84
Connecting to doc-04-7s-docs.googleusercontent.com (doc-04-7s-docs.googleusercontent.com)|173.194.210.132|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://docs.google.com/nonceSigner?nonce=erf8t6vb6o29s&continue=https://doc-04-7s-docs.googleusercontent.com/docs/securesc/55p9v9r892nth323knj0kpu3c0fh68iu/jne4sjlttqu450ikt8fph58j505oapok/1644426900000/06283569454216238406/07409829577848409351Z/1ZYTZPlnkPBCMcNqRRO-gNx8EPgtV-GL3?e%3Ddownload&hash=tjpqbk2dp469l934sb9627cs6d9dq9ht [following]
--2022-02-09 17:15:02-- https://docs.google.com/nonceSigner?nonce=erf8t6vb6o29s&continue=https://doc-04-7s-docs.googleusercontent.com/docs/securesc/55p9v9r892nth323knj0kpu3c0fh68iu/jne4sjlttqu450ikt8fph58j505oapok/1644426900000/06283569454216238406/07409829577848409351Z/1ZYTZPlnkPBCMcNqRRO-gNx8EPgtV-GL3?e%3Ddownload&hash=tjpqbk2dp469l934sb9627cs6d9dq9ht
Connecting to docs.google.com (docs.google.com)|173.194.210.139|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://doc-04-7s-docs.googleusercontent.com/docs/securesc/55p9v9r892nth323knj0kpu3c0fh68iu/jne4sjlttqu450ikt8fph58j505oapok/1644426900000/06283569454216238406/07409829577848409351Z/1ZYTZPlnkPBCMcNqRRO-gNx8EPgtV-GL3?e=download&nonce=erf8t6vb6o29s&user=07409829577848409351Z&hash=m0ja3hdvfkrvsfr6kol8k4bktej41mje [following]
--2022-02-09 17:15:02-- https://doc-04-7s-docs.googleusercontent.com/docs/securesc/55p9v9r892nth323knj0kpu3c0fh68iu/jne4sjlttqu450ikt8fph58j505oapok/1644426900000/06283569454216238406/07409829577848409351Z/1ZYTZPlnkPBCMcNqRRO-gNx8EPgtV-GL3?e=download&nonce=erf8t6vb6o29s&user=07409829577848409351Z&hash=m0ja3hdvfkrvsfr6kol8k4bktej41mje
Connecting to doc-04-7s-docs.googleusercontent.com (doc-04-7s-docs.googleusercontent.com)|173.194.210.132|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 33740692 (32M) [text/plain]
Saving to: ‘Yf.txt’
Yf.txt 100%[===================>] 32.18M --.-KB/s in 0.1s
2022-02-09 17:15:02 (248 MB/s) - ‘Yf.txt’ saved [33740692/33740692]
# Define paths
ROOT_PATH = os.getcwd()
TRAIN_DATA_PATH = (os.path.join(ROOT_PATH, 'data/trn.raw.json.gz'))
TEST_DATA_PATH = (os.path.join(ROOT_PATH, 'data/tst.raw.json.gz'))
# Load categories
with open('./data/Yf.txt', encoding='utf-8') as f:
categories = f.readlines()
# Clean values
categories = [cat.split('->')[1].strip('\n') for cat in categories]
# Show frist few categories
categories[:3]
['!!!_albums', '+/-_(band)_albums', '+44_(band)_songs']
Using a Subset of the Data
For this example, we will select and use a subset of wikipedia articles. This will save time for processing and consume much less memory than the complete dataset.
We will select a sample of 200,000 articles that contains around 250,000 different labels.
Feel free to run the notebook with more data.
WIKI_ARTICLES_INDEX = range(0, 1000000, 5)
lines = []
with gzip.open(TRAIN_DATA_PATH) as f:
for e, line in enumerate(f):
if e >= 1000000:
break
if e in WIKI_ARTICLES_INDEX:
lines.append(json.loads(line))
df = pd.DataFrame.from_dict(lines)
df = df[['title', 'content', 'target_ind']]
df.head()
title | content | target_ind | |
---|---|---|---|
0 | Anarchism | {{redirect2|anarchist|anarchists|the fictional... | [81199, 83757, 83805, 193030, 368811, 368937, ... |
1 | Academy_Awards | {{redirect2|oscars|the oscar|the film|the osca... | [19080, 65864, 78208, 96051] |
2 | Anthropology | {{about|the social science}} {{use dmy dates|d... | [83605, 423943] |
3 | American_Football_Conference | {{refimprove|date=september 2014}} {{use dmy d... | [76725, 314198, 334093] |
4 | Analysis_of_variance | {{use dmy dates|date=june 2013}} '''analysis o... | [81170, 168516, 338198, 441529] |
<svg xmlns="http://www.w3.org/2000/svg" height="24px"viewBox="0 0 24 24"
width="24px">
<script>
const buttonEl =
document.querySelector('#df-c158e1ee-532a-41cb-8be0-bc56f09f22e1 button.colab-df-convert');
buttonEl.style.display =
google.colab.kernel.accessAllowed ? 'block' : 'none';
async function convertToInteractive(key) {
const element = document.querySelector('#df-c158e1ee-532a-41cb-8be0-bc56f09f22e1');
const dataTable =
await google.colab.kernel.invokeFunction('convertToInteractive',
[key], {});
if (!dataTable) return;
const docLinkHtml = 'Like what you see? Visit the ' +
'<a target="_blank" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'
+ ' to learn more about interactive tables.';
element.innerHTML = '';
dataTable['output_type'] = 'display_data';
await google.colab.output.renderOutput(dataTable, element);
const docLink = document.createElement('div');
docLink.innerHTML = docLinkHtml;
element.appendChild(docLink);
}
</script>
</div>
print(df.shape)
(200000, 3)
Remove Wikipedia Markup Format
We are going to use only the first part of the articles to make them comparable in terms of length. Also, Wikipedia articles have a certain format that is not so readable, so we will remove the markup to make the content as clean as possible.
# Reduce content to first 3000 characters
df['content_short'] = df.content.apply(lambda x: x[:3000])
# Remove wiki articles markup
df['content_cleaned'] = df.content_short.apply(lambda x: remove_markup(x))
# Keep only certain columns
df = df[['title', 'content_cleaned', 'target_ind']]
# Show data
df.head()
title | content_cleaned | target_ind | |
---|---|---|---|
0 | Anarchism | anarchism is a political philosophy that a... | [81199, 83757, 83805, 193030, 368811, 368937, ... |
1 | Academy_Awards | the academy awards or the oscars (the offi... | [19080, 65864, 78208, 96051] |
2 | Anthropology | anthropology is the scientific study of hu... | [83605, 423943] |
3 | American_Football_Conference | the american football conference (afc) is o... | [76725, 314198, 334093] |
4 | Analysis_of_variance | analysis of variance (anova) is a collection ... | [81170, 168516, 338198, 441529] |
<svg xmlns="http://www.w3.org/2000/svg" height="24px"viewBox="0 0 24 24"
width="24px">
<script>
const buttonEl =
document.querySelector('#df-5f3fce25-2f7d-40d1-a109-335c1649ab71 button.colab-df-convert');
buttonEl.style.display =
google.colab.kernel.accessAllowed ? 'block' : 'none';
async function convertToInteractive(key) {
const element = document.querySelector('#df-5f3fce25-2f7d-40d1-a109-335c1649ab71');
const dataTable =
await google.colab.kernel.invokeFunction('convertToInteractive',
[key], {});
if (!dataTable) return;
const docLinkHtml = 'Like what you see? Visit the ' +
'<a target="_blank" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'
+ ' to learn more about interactive tables.';
element.innerHTML = '';
dataTable['output_type'] = 'display_data';
await google.colab.output.renderOutput(dataTable, element);
const docLink = document.createElement('div');
docLink.innerHTML = docLinkHtml;
element.appendChild(docLink);
}
</script>
</div>
# Keep all labels in a single list
all_categories = []
for i, row in df.iterrows():
all_categories.extend(row.target_ind)
print('Number of labels: ',len(list(set(all_categories))))
Number of labels: 256899
Create Article Vector Embeddings
Recall, we want to index and search all possible (250,000) labels. We do that by averaging, for each label, the corresponding article vector embeddings that contain that label.
Let's first create the article vector embeddings. Here we use the Average Word Embeddings Models. In the next section, we will aggregate these vectors to make the final label embeddings.
# Load the model
model = SentenceTransformer('average_word_embeddings_komninos')
# Create embeddings
encoded_articles = model.encode(df['content_cleaned'], show_progress_bar=True)
df['content_vector'] = pd.Series(encoded_articles.tolist())
Downloading: 0%| | 0.00/690 [00:00<?, ?B/s]
Downloading: 0%| | 0.00/2.13k [00:00<?, ?B/s]
Downloading: 0%| | 0.00/122 [00:00<?, ?B/s]
Downloading: 0%| | 0.00/248 [00:00<?, ?B/s]
Downloading: 0%| | 0.00/267M [00:00<?, ?B/s]
Downloading: 0%| | 0.00/2.59M [00:00<?, ?B/s]
Downloading: 0%| | 0.00/164 [00:00<?, ?B/s]
Downloading: 0%| | 0.00/190 [00:00<?, ?B/s]
Batches: 0%| | 0/6250 [00:00<?, ?it/s]
Upload articles
It appears that using the article embeddings per se doesn't provide good enough accuracies. Therefore, we chose to index and search the labels directly.
The label embedding is simply the average of all its corresponding article embeddings.
# Explode the target indicator column
df_explode = df.explode('target_ind')
# Group by label and define a unique vector for each label
result = df_explode.groupby('target_ind').agg(mean=('content_vector', lambda x: np.vstack(x).mean(axis=0).tolist()))
result['target_ind'] = result.index
result.columns = ['content_vector', 'ind']
result.head()
content_vector | ind | |
---|---|---|
target_ind | ||
2 | [0.0704750344157219, -0.007719345390796661, 0.... | 2 |
3 | [0.05894148722290993, -0.03119848482310772, 0.... | 3 |
5 | [0.18302207440137863, 0.061663837544620036, 0.... | 5 |
6 | [0.1543595753610134, 0.03904660418629646, 0.03... | 6 |
9 | [0.22310754656791687, 0.1524289846420288, 0.09... | 9 |
<svg xmlns="http://www.w3.org/2000/svg" height="24px"viewBox="0 0 24 24"
width="24px">
<script>
const buttonEl =
document.querySelector('#df-eda4ccdc-de8e-4dac-9282-b989392d0727 button.colab-df-convert');
buttonEl.style.display =
google.colab.kernel.accessAllowed ? 'block' : 'none';
async function convertToInteractive(key) {
const element = document.querySelector('#df-eda4ccdc-de8e-4dac-9282-b989392d0727');
const dataTable =
await google.colab.kernel.invokeFunction('convertToInteractive',
[key], {});
if (!dataTable) return;
const docLinkHtml = 'Like what you see? Visit the ' +
'<a target="_blank" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'
+ ' to learn more about interactive tables.';
element.innerHTML = '';
dataTable['output_type'] = 'display_data';
await google.colab.output.renderOutput(dataTable, element);
const docLink = document.createElement('div');
docLink.innerHTML = docLinkHtml;
element.appendChild(docLink);
}
</script>
</div>
# Create a list of items to upsert
items_to_upsert = [(unidecode(categories[int(row.ind)])[:64], row.content_vector) for i, row in result.iterrows()]
import itertools
def chunks(iterable, batch_size=100):
it = iter(iterable)
chunk = tuple(itertools.islice(it, batch_size))
while chunk:
yield chunk
chunk = tuple(itertools.islice(it, batch_size))
# Upsert data
for batch in chunks(items_to_upsert, 250):
index.upsert(vectors=batch)
Let's validate the number of indexed labels.
index.describe_index_stats()
{'dimension': 300, 'namespaces': {'': {'vector_count': 256899}}}
Query
Now, let's test the vector index and examine the classifier results. Observe that here we retrieve a fixed number of labels. Naturally, in an actual application, you might want to calculate the size of the retrieved label set dynamically.
NUM_OF_WIKI_ARTICLES = 3
WIKI_ARTICLES_INDEX = range(1111, 100000, 57)[:NUM_OF_WIKI_ARTICLES]
lines = []
with gzip.open(TEST_DATA_PATH) as f:
for e, line in enumerate(f):
if e in WIKI_ARTICLES_INDEX:
lines.append(json.loads(line))
if e > max(WIKI_ARTICLES_INDEX):
break
df_test = pd.DataFrame.from_dict(lines)
df_test = df_test[['title', 'content', 'target_ind']]
df_test.head()
title | content | target_ind | |
---|---|---|---|
0 | Discrimination | {{otheruses}} {{discrimination sidebar}} '''di... | [170479, 423902] |
1 | Erfurt | {{refimprove|date=june 2014}} {{use dmy dates|... | [142638, 187156, 219262, 294479, 329185, 38243... |
2 | ETA | {{about|the basque organization|other uses|eta... | [83681, 100838, 100849, 100868, 176034, 188979... |
<svg xmlns="http://www.w3.org/2000/svg" height="24px"viewBox="0 0 24 24"
width="24px">
<script>
const buttonEl =
document.querySelector('#df-05a3b113-37f2-403c-b0e2-f5c4dd0dac99 button.colab-df-convert');
buttonEl.style.display =
google.colab.kernel.accessAllowed ? 'block' : 'none';
async function convertToInteractive(key) {
const element = document.querySelector('#df-05a3b113-37f2-403c-b0e2-f5c4dd0dac99');
const dataTable =
await google.colab.kernel.invokeFunction('convertToInteractive',
[key], {});
if (!dataTable) return;
const docLinkHtml = 'Like what you see? Visit the ' +
'<a target="_blank" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'
+ ' to learn more about interactive tables.';
element.innerHTML = '';
dataTable['output_type'] = 'display_data';
await google.colab.output.renderOutput(dataTable, element);
const docLink = document.createElement('div');
docLink.innerHTML = docLinkHtml;
element.appendChild(docLink);
}
</script>
</div>
# Reduce content to first 3000 characters
df_test['content_short'] = df_test.content.apply(lambda x: x[:3000])
# Remove wiki articles markup
df_test['content_cleaned'] = df_test.content_short.apply(lambda x: remove_markup(x))
# Keep only certain columns
df_test = df_test[['title', 'content_cleaned', 'target_ind']]
# Show data
df_test.head()
title | content_cleaned | target_ind | |
---|---|---|---|
0 | Discrimination | discrimination is action that denies social ... | [170479, 423902] |
1 | Erfurt | erfurt () is the capital city of thuringia ... | [142638, 187156, 219262, 294479, 329185, 38243... |
2 | ETA | eta (, ), an acronym for euskadi ta askatas... | [83681, 100838, 100849, 100868, 176034, 188979... |
<svg xmlns="http://www.w3.org/2000/svg" height="24px"viewBox="0 0 24 24"
width="24px">
<script>
const buttonEl =
document.querySelector('#df-bf29f1a6-0986-4c74-bb33-d82361095999 button.colab-df-convert');
buttonEl.style.display =
google.colab.kernel.accessAllowed ? 'block' : 'none';
async function convertToInteractive(key) {
const element = document.querySelector('#df-bf29f1a6-0986-4c74-bb33-d82361095999');
const dataTable =
await google.colab.kernel.invokeFunction('convertToInteractive',
[key], {});
if (!dataTable) return;
const docLinkHtml = 'Like what you see? Visit the ' +
'<a target="_blank" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'
+ ' to learn more about interactive tables.';
element.innerHTML = '';
dataTable['output_type'] = 'display_data';
await google.colab.output.renderOutput(dataTable, element);
const docLink = document.createElement('div');
docLink.innerHTML = docLinkHtml;
element.appendChild(docLink);
}
</script>
</div>
# Create embeddings for test articles
test_vectors = model.encode(df_test['content_cleaned'], show_progress_bar=True).tolist()
Batches: 0%| | 0/1 [00:00<?, ?it/s]
# Query the vector index
query_results = []
for xq in test_vectors:
query_res = index.query(xq, top_k=10)
query_results.append(query_res)
# Show results
for term, labs, res in zip(df_test.title.tolist(), df_test.target_ind.tolist(), query_results):
print()
print('Term queried: ',term)
print('Original labels: ')
for l in labs:
if l in all_categories:
print('\t', categories[l])
print('Predicted: ')
df_result = pd.DataFrame({
'id': [res.id for res in res.matches],
'score': [res.score for res in res.matches],})
display(df_result)
Term queried: Discrimination
Original labels:
Discrimination
Social_justice
Predicted:
id | score | |
---|---|---|
0 | Discrimination | 0.972958 |
1 | Sociological_terminology | 0.971606 |
2 | Identity_politics | 0.970097 |
3 | Social_concepts | 0.967534 |
4 | Sexism | 0.967476 |
5 | Affirmative_action | 0.967288 |
6 | Political_correctness | 0.966926 |
7 | Human_behavior | 0.966475 |
8 | Persecution | 0.965421 |
9 | Social_movements | 0.964394 |
<svg xmlns="http://www.w3.org/2000/svg" height="24px"viewBox="0 0 24 24"
width="24px">
<script>
const buttonEl =
document.querySelector('#df-98bd35a3-d003-446e-beec-909fd26b91ac button.colab-df-convert');
buttonEl.style.display =
google.colab.kernel.accessAllowed ? 'block' : 'none';
async function convertToInteractive(key) {
const element = document.querySelector('#df-98bd35a3-d003-446e-beec-909fd26b91ac');
const dataTable =
await google.colab.kernel.invokeFunction('convertToInteractive',
[key], {});
if (!dataTable) return;
const docLinkHtml = 'Like what you see? Visit the ' +
'<a target="_blank" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'
+ ' to learn more about interactive tables.';
element.innerHTML = '';
dataTable['output_type'] = 'display_data';
await google.colab.output.renderOutput(dataTable, element);
const docLink = document.createElement('div');
docLink.innerHTML = docLinkHtml;
element.appendChild(docLink);
}
</script>
</div>
Term queried: Erfurt
Original labels:
Erfurt
German_state_capitals
Members_of_the_Hanseatic_League
Oil_Campaign_of_World_War_II
Province_of_Saxony
University_towns_in_Germany
Predicted:
id | score | |
---|---|---|
0 | University_towns_in_Germany | 0.966058 |
1 | Province_of_Saxony | 0.959731 |
2 | Populated_places_on_the_Rhine | 0.958737 |
3 | Imperial_free_cities | 0.957159 |
4 | Hildesheim_(district) | 0.956927 |
5 | History_of_the_Electoral_Palatinate | 0.956800 |
6 | Towns_in_Saxony-Anhalt | 0.956501 |
7 | Towns_in_Lower_Saxony | 0.955259 |
8 | Halle_(Saale) | 0.954934 |
9 | Cities_in_Saxony-Anhalt | 0.954934 |
<svg xmlns="http://www.w3.org/2000/svg" height="24px"viewBox="0 0 24 24"
width="24px">
<script>
const buttonEl =
document.querySelector('#df-6849b581-67be-451f-87c2-c297429607dc button.colab-df-convert');
buttonEl.style.display =
google.colab.kernel.accessAllowed ? 'block' : 'none';
async function convertToInteractive(key) {
const element = document.querySelector('#df-6849b581-67be-451f-87c2-c297429607dc');
const dataTable =
await google.colab.kernel.invokeFunction('convertToInteractive',
[key], {});
if (!dataTable) return;
const docLinkHtml = 'Like what you see? Visit the ' +
'<a target="_blank" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'
+ ' to learn more about interactive tables.';
element.innerHTML = '';
dataTable['output_type'] = 'display_data';
await google.colab.output.renderOutput(dataTable, element);
const docLink = document.createElement('div');
docLink.innerHTML = docLinkHtml;
element.appendChild(docLink);
}
</script>
</div>
Term queried: ETA
Original labels:
Anti-Francoism
Basque_conflict
Basque_history
Basque_politics
ETA
European_Union_designated_terrorist_organizations
Far-left_politics
Francoist_Spain
Government_of_Canada_designated_terrorist_organizations
Irregular_military
Military_wings_of_political_parties
National_liberation_movements
Nationalist_terrorism
Organizations_designated_as_terrorist_by_the_United_States_government
Organizations_designated_as_terrorist_in_Europe
Organizations_established_in_1959
Politics_of_Spain
Resistance_movements
Secession_in_Spain
Secessionist_organizations_in_Europe
Terrorism_in_Spain
United_Kingdom_Home_Office_designated_terrorist_groups
Predicted:
id | score | |
---|---|---|
0 | Organizations_designated_as_terrorist_in_Europe | 0.948875 |
1 | Terrorism_in_Spain | 0.948431 |
2 | Basque_politics | 0.942670 |
3 | Politics_of_Spain | 0.941830 |
4 | European_Union_designated_terrorist_organizations | 0.940194 |
5 | Irregular_military | 0.938163 |
6 | Political_parties_disestablished_in_1977 | 0.936437 |
7 | Algerian_Civil_War | 0.936311 |
8 | Republicanism_in_Spain | 0.935577 |
9 | Guerrilla_organizations | 0.935506 |
<svg xmlns="http://www.w3.org/2000/svg" height="24px"viewBox="0 0 24 24"
width="24px">
<script>
const buttonEl =
document.querySelector('#df-02b4760b-a840-4c3d-b778-64b3de6f35bf button.colab-df-convert');
buttonEl.style.display =
google.colab.kernel.accessAllowed ? 'block' : 'none';
async function convertToInteractive(key) {
const element = document.querySelector('#df-02b4760b-a840-4c3d-b778-64b3de6f35bf');
const dataTable =
await google.colab.kernel.invokeFunction('convertToInteractive',
[key], {});
if (!dataTable) return;
const docLinkHtml = 'Like what you see? Visit the ' +
'<a target="_blank" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'
+ ' to learn more about interactive tables.';
element.innerHTML = '';
dataTable['output_type'] = 'display_data';
await google.colab.output.renderOutput(dataTable, element);
const docLink = document.createElement('div');
docLink.innerHTML = docLinkHtml;
element.appendChild(docLink);
}
</script>
</div>
Summary
We demonstrated a similarity search approach for performing extreme classification of texts. We took a simple approach representing labels as the average of their corresponding texts' vector embeddings. In classification time, we match between a new article embedding and its nearest label embeddings. Our result examples indicate the usefulness of this approach.
You can take this forward by exploring advanced ideas. For example, you can utilize the hierarchical relationship between labels or improve the label representations. Just have fun, and feel free to share your thoughts.
Delete the index
Delete the index once you do not want to use it anymore. Once the index is deleted, you cannot use it again.
pinecone.delete_index(index_name)
Updated 4 months ago