Overview
Ideal model for good performance while keeping with open source and minimal hardware requirements. Works well on messy data. Good for short queries expected to return medium-length passages of text (1-2 paragraphs).
See here for an example.
Using the model
Installation:
!pip install -qU transformers==4.35.0 pinecone
Create Index
from pinecone import Pinecone, ServerlessSpec
pc = Pinecone(api_key="API KEY")
# Create Index
index_name = "gte-base"
if not pc.has_index(index_name):
pc.create_index(
name=index_name,
dimension=768,
metric="cosine",
spec=ServerlessSpec(
cloud='aws',
region='us-east-1'
)
)
index = pc.Index(index_name)
Embed & Upsert
# Embed data
data = [
{"id": "vec1", "text": "Apple is a popular fruit known for its sweetness and crisp texture."},
{"id": "vec2", "text": "The tech company Apple is known for its innovative products like the iPhone."},
{"id": "vec3", "text": "Many people enjoy eating apples as a healthy snack."},
{"id": "vec4", "text": "Apple Inc. has revolutionized the tech industry with its sleek designs and user-friendly interfaces."},
{"id": "vec5", "text": "An apple a day keeps the doctor away, as the saying goes."},
]
import torch
from torch.nn.functional import normalize
from transformers import AutoModel, AutoTokenizer
device = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "thenlper/gte-base"
# initialize tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModel.from_pretrained(model_id).to(device)
model.eval()
def embed(docs: list[str]) -> list[list[float]]:
# tokenize
tokens = tokenizer(
docs, padding=True, max_length=512, truncation=True, return_tensors="pt"
).to(device)
with torch.no_grad():
# process with model for token-level embeddings
out = model(**tokens)
# mask padding tokens
last_hidden = out.last_hidden_state.masked_fill(
~tokens["attention_mask"][..., None].bool(), 0.0
)
# create mean pooled embeddings
doc_embeds = last_hidden.sum(dim=1) / tokens["attention_mask"].sum(dim=1)[..., None]
return doc_embeds.cpu().numpy().tolist()
embeddings = embed([d["text"] for d in data])
vectors = []
for d, e in zip(data, embeddings):
vectors.append({
"id": d['id'],
"values": e,
"metadata": {'text': d['text']}
})
index.upsert(
vectors=vectors,
namespace="ns1"
)
Query
query = "Tell me about the tech company known as Apple"
x = embed([query])
results = index.query(
namespace="ns1",
vector=x[0],
top_k=3,
include_values=False,
include_metadata=True
)
print(results)