Overview
An instruction-finetuned text embedding model that can generate text embeddings tailored to any task (e.g., classification, retrieval, clustering, text evaluation, etc.) or domain (e.g., science, finance, etc.) by simply providing the task instruction in natural language.
Takes customized text units (e.g. paragraph, sentence, document). Better performance than instructor-base, but worse than instructor-xl. Medium-sized.
Learn about how to best use Instructor for specific tasks here.
Using the model
Installation:
!pip install transformers==4.20.0 InstructorEmbedding pinecone sentence-transformers
Create Index
from pinecone import Pinecone, ServerlessSpec
pc = Pinecone(api_key="API_KEY")
index_name = "instructor-large"
if not pc.has_index(index_name):
pc.create_index(
name=index_name,
dimension=768,
metric="cosine",
spec=ServerlessSpec(
cloud='aws',
region='us-east-1'
)
)
index = pc.Index(index_name)
Embed & Upsert
data = [
{"id": "vec1", "text": "Apple is a popular fruit known for its sweetness and crisp texture."},
{"id": "vec2", "text": "The tech company Apple is known for its innovative products like the iPhone."},
{"id": "vec3", "text": "Many people enjoy eating apples as a healthy snack."},
{"id": "vec4", "text": "Apple Inc. has revolutionized the tech industry with its sleek designs and user-friendly interfaces."},
{"id": "vec5", "text": "An apple a day keeps the doctor away, as the saying goes."},
]
instruction = "Represent the following document for retrieval: "
from InstructorEmbedding import INSTRUCTOR
model = INSTRUCTOR('hkunlp/instructor-large')
instruction_embedding_pairs = [[instruction, d["text"]] for d in data]
embeddings = model.encode(instruction_embedding_pairs)
vectors = []
for d, e in zip(data, embeddings):
vectors.append({
"id": d['id'],
"values": e,
"metadata": {'text': d['text']}
})
index.upsert(
vectors=vectors,
namespace="ns1"
)
Query
query_instruction = "Represent this query for retrieving supporting documents: "
query = "Tell me about the tech company known as Apple"
x = model.encode([[query_instruction, query]])
results = index.query(
namespace="ns1",
vector=x[0].tolist(),
top_k=3,
include_values=False,
include_metadata=True
)
print(results)