Building a Semantic Highlighter: Understanding Search Result Presentation Through Machine Learning - Part 2

Part 2: Building a Semantic Highlighter from Scratch
Chapter 1: The Foundation - Understanding Our Tools
Let's build our semantic highlighter piece by piece. I'll be your guide, and we'll construct each component while understanding why it exists.
First, let's understand what we're building. Imagine you're a teacher grading essays, but instead of grading the whole essay, you're highlighting which sentences actually answer the student's question. That's our task.
import torch
import torch.nn as nn
from transformers import BertModel, BertPreTrainedModel
We're standing on the shoulders of giants here. BERT gives us the ability to understand language. PyTorch gives us the tools to build neural networks. Let's start with our core model class:
class BertTaggerForSentenceExtractionWithBackoff(BertPreTrainedModel):
"""Our sentence highlighter - a teacher that reads and grades each sentence"""
def __init__(self, config):
super().__init__(config)
self.num_labels = 2 # Binary: relevant or not relevant
Why binary classification? We could score relevance from 0 to 1, but binary decisions are cleaner: either a sentence helps answer the query or it doesn't. This simplicity makes training more stable.
Chapter 2: The Architecture - Building Our Reader
Now let's add BERT, our reading comprehension engine:
self.bert = BertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, self.num_labels)
self.init_weights()
Let me explain each component:
BERT: Reads and understands the text
Dropout: Like occasionally closing your eyes while studying - forces the model to not rely on any single feature
Classifier: Makes the final decision - relevant or not?
The hidden size (typically 768) is like the bandwidth of understanding - how much nuance BERT can capture about each token.
Chapter 3: The Forward Pass - How Reading Happens
Here's where the magic happens. Let's build our forward method:
def forward(self, input_ids, attention_mask, token_type_ids, sentence_ids):
# Step 1: Read the entire document with BERT
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
)
sequence_output = self.dropout(outputs[0])
Think of this like reading a document with a highlighter in hand, but not highlighting yet - just understanding everything first.
The sentence_ids parameter is crucial - it tells us which tokens belong to which sentence. It's like having paragraph markers in a text.
Chapter 4: The Aggregation Problem - From Words to Sentences
Here's our challenge: BERT understands individual tokens, but we need to classify entire sentences. It's like having opinions about individual words but needing to judge complete thoughts.
def _get_agg_output(ids, seq_out):
"""Transform token-level understanding to sentence-level understanding"""
max_sentences = torch.max(ids) + 1
d_model = seq_out.size(-1)
agg_out, global_offsets, num_sents = [], [], []
We're preparing to group tokens by sentence. Imagine you have a bag of colored marbles (tokens) and you need to sort them into cups (sentences).
for i, sen_ids in enumerate(ids):
out, local_ids = [], sen_ids.clone()
mask = local_ids != -100 # -100 marks "ignore" tokens
# Normalize sentence IDs to start from 0
offset = local_ids[mask].min()
global_offsets.append(offset)
local_ids[mask] -= offset
n_sent = local_ids.max() + 1
num_sents.append(n_sent)
The normalization step is subtle but important. Different documents might number their sentences differently. We standardize them to always start from 0, like resetting a counter.
Chapter 5: The Pooling Strategy - Summarizing Sentences
Now comes a crucial decision - how do we summarize all tokens in a sentence into one representation?
for j in range(int(n_sent)):
# Find all tokens in sentence j
sentence_mask = local_ids == j
# Average their representations
sentence_repr = seq_out[i, sentence_mask].mean(dim=-2, keepdim=True)
out.append(sentence_repr)
Why averaging (mean pooling)? Let me illustrate with an analogy:
Max pooling: Like judging a choir by its loudest singer
First token: Like judging a book by its first word
Mean pooling: Like listening to the whole choir's harmony
Mean pooling ensures every token contributes to our understanding of the sentence.
Chapter 6: The Classification - Making Decisions
With sentence representations ready, we can now classify:
agg_output, offsets, num_sents_item = _get_agg_output(sentence_ids, sequence_output)
# Classify each sentence
logits = self.classifier(agg_output)
# Convert to probabilities
probs = torch.softmax(logits, dim=-1)[:, :, 1]
The softmax function turns raw scores into probabilities. It's like converting "this sentence scored 2.5" into "this sentence is 92% likely to be relevant."
Chapter 7: The Backoff Strategy - Handling Uncertainty
Here's where we add intelligence to handle edge cases:
def _get_preds(pp, offs, num_s, threshold=0.5, alpha=0.05):
"""Smart prediction with fallback logic"""
preds = []
for p, off, ns in zip(pp, offs, num_s):
rel_probs = p[:ns]
hits = (rel_probs >= threshold).int()
# The clever bit: if nothing is clearly relevant,
# but something shows promise, highlight it
if hits.sum() == 0 and rel_probs.max().item() >= alpha:
hits[rel_probs.argmax()] = 1
preds.append(torch.where(hits == 1)[0] + off)
return preds
This is like a teacher who, when no answer is perfect, still marks the best attempt. The alpha=0.05 threshold says "if any sentence has even 5% chance of being relevant, and nothing else is better, show it."
Chapter 8: Data Preparation - Teaching the Model to Read
Now let's look at how we prepare data for our model. This is like teaching someone to read by showing them properly formatted examples:
def prepare_input_features(tokenizer, examples, max_seq_length=510, stride=128):
"""Prepare text for BERT processing"""
tokenized_examples = tokenizer(
examples["question"],
examples["context"],
truncation="only_second",
max_length=max_seq_length,
stride=stride,
return_overflowing_tokens=True,
padding=padding,
is_split_into_words=True,
)
Key insights:
max_seq_length=510: Leaves room for special tokens [CLS] and [SEP]stride=128: Overlapping windows ensure we don't break sentences awkwardlytruncation="only_second": Keep the full query, truncate the document if needed
Chapter 9: Sentence Tracking - The Clever Bookkeeping
The trickiest part is tracking which tokens belong to which sentences after tokenization:
for i, sample_index in enumerate(sample_mapping):
word_ids = tokenized_examples.word_ids(i)
word_level_sentence_ids = examples["word_level_sentence_ids"][sample_index]
# Find where document starts (after query)
token_start_index = 0
while sequence_ids[token_start_index] != 1:
token_start_index += 1
# Mark query tokens as -100 (ignore)
sentences_ids = [-100] * token_start_index
# Map document tokens to sentences
for word_idx in word_ids[token_start_index:]:
if word_idx is not None:
sentences_ids.append(word_level_sentence_ids[word_idx])
else:
sentences_ids.append(-100)
This bookkeeping ensures that when BERT splits "don't" into ["don", "'", "t"], we still know all three pieces belong to the same sentence.
Chapter 10: Putting It All Together - The Complete System
Let's see our highlighter in action:
# Real example
query = "When does OpenSearch use text reanalysis for highlighting?"
document = "..." # Your document here
# Prepare the data
doc_sents = nltk.sent_tokenize(document)
sentence_ids, context = [], []
for sid, sent in enumerate(doc_sents):
words = sent.split()
context.extend(words)
sentence_ids.extend([sid] * len(words))
# Create model and process
model = BertTaggerForSentenceExtractionWithBackoff.from_pretrained(
"opensearch-project/opensearch-semantic-highlighter-v1"
)
# Get highlights
highlights = model(batch["input_ids"], batch["attention_mask"],
batch["token_type_ids"], batch["sentence_ids"])
highlighted_sentences = [doc_sents[i] for i in highlights[0]]
The Beautiful Simplicity
What we've built is remarkably elegant:
One forward pass processes the entire document
Context-aware understanding of every sentence
Intelligent fallbacks for edge cases
Efficient batching for long documents
The model doesn't just match keywords - it understands the relationship between questions and answers. When you search for "symptoms of dehydration," it knows that "feeling dizzy and tired" is relevant even without the word "symptom."
Final Thoughts: The Power of Hierarchical Understanding
Our semantic highlighter represents a fundamental principle in modern NLP: hierarchical processing mirrors human understanding. We don't read words in isolation, and neither should our models.
By building understanding from tokens → sentences → relevance decisions, we've created a system that can truly comprehend what parts of a document answer a user's question. It's not just highlighting - it's understanding made visible.
The next time you see highlighted search results that actually show you what you're looking for, remember: behind that simple yellow highlight is a sophisticated system that learned to read, understand, and identify meaning just like a human would.
This implementation is based on OpenSearch's semantic highlighter. The code examples are simplified for clarity while maintaining technical accuracy. For production use, refer to the official OpenSearch documentation.
Appendix: Full Annotated Code:
import nltk
import torch
import numpy as np
from datasets import Dataset
from functools import partial
from torch.utils.data import DataLoader
from dataclasses import dataclass, field
from typing import Any, Dict, List, Union
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoTokenizer, BertModel, BertPreTrainedModel
import torch.nn as nn
class BertTaggerForSentenceExtractionWithBackoff(BertPreTrainedModel):
"""
ML Architecture: Hierarchical Classification Model
This model performs sentence-level classification by:
1. Encoding query-document pairs at token level (using BERT)
2. Aggregating token representations to sentence representations
3. Classifying each sentence as relevant/irrelevant
4. Applying confidence-based backoff to ensure at least one highlight
Key ML insight: Rather than treating this as N independent classifications,
it processes the entire document with query context, preserving
cross-sentence attention patterns.
"""
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels # Typically 2: [not_relevant, relevant]
# Pre-trained BERT encoder - provides contextualized token embeddings
# Design choice: Using full BERT allows tokens to attend to the entire
# query-document context, not just within-sentence context
self.bert = BertModel(config)
# Dropout for regularization - prevents overfitting to training data
# Applied before classification to add noise during training
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# Linear projection: hidden_size (768 for BERT-base) → num_labels (2)
# This is the actual "classification head" that learns to map
# sentence representations to relevance scores
self.classifier = nn.Linear(config.hidden_size, self.num_labels)
# Initialize weights using BERT's initialization scheme
# (typically normal distribution with std=0.02)
self.init_weights()
def forward(
self,
input_ids=None, # Token IDs: [batch_size, seq_len]
attention_mask=None, # Attention mask: 1 for real tokens, 0 for padding
token_type_ids=None, # Segment IDs: 0 for query, 1 for document
sentence_ids=None, # Custom input: maps each token to its sentence ID
):
"""
Forward pass implements the hierarchical classification:
Token embeddings → Sentence embeddings → Sentence classifications
"""
# Step 1: Get contextualized token embeddings from BERT
# outputs[0] shape: [batch_size, seq_len, hidden_size]
# Each token has a 768-dim representation that considers all other tokens
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
)
# Extract the last hidden states and apply dropout
# ML insight: Dropout during training adds noise, improving generalization
sequence_output = self.dropout(outputs[0])
def _get_agg_output(ids, seq_out):
"""
Aggregation Strategy: Token → Sentence Representations
ML Rationale: Sentences are variable-length sequences. We need a
fixed-size representation for each sentence for classification.
Mean pooling is chosen over alternatives like:
- Max pooling: Would lose information, focus only on most activated token
- First/Last token: Would be position-dependent
- Attention-weighted: Would add complexity, mean pooling often sufficient
"""
# Find max sentence ID to determine tensor dimensions
max_sentences = torch.max(ids) + 1
d_model = seq_out.size(-1) # Hidden dimension (768 for BERT-base)
agg_out, global_offsets, num_sents = [], [], []
# Process each example in the batch
for i, sen_ids in enumerate(ids):
out, local_ids = [], sen_ids.clone()
# -100 is a special padding value (from HuggingFace convention)
# Used to mask tokens that shouldn't contribute to loss
mask = local_ids != -100
# Normalize sentence IDs to start from 0 for this example
# This handles cases where sentence IDs might start from arbitrary values
offset = local_ids[mask].min()
global_offsets.append(offset)
local_ids[mask] -= offset
# Count actual sentences in this example
n_sent = local_ids.max() + 1
num_sents.append(n_sent)
# Aggregate tokens by sentence using mean pooling
# ML insight: Mean pooling preserves information from all tokens
# while providing a fixed-size representation
for j in range(int(n_sent)):
# Get all tokens belonging to sentence j
sentence_mask = local_ids == j
# Mean pool over the token dimension (dim=-2)
# keepdim=True maintains shape for concatenation
sentence_repr = seq_out[i, sentence_mask].mean(dim=-2, keepdim=True)
out.append(sentence_repr)
# Padding: Ensures all batches have same number of sentences
# Necessary for batch processing in GPUs
# Padded sentences will be masked later
if max_sentences - n_sent:
padding = torch.zeros(
(int(max_sentences - n_sent), d_model), device=seq_out.device
)
out.append(padding)
# Concatenate all sentence representations
agg_out.append(torch.cat(out, dim=0))
# Stack into a batch tensor: [batch_size, max_sentences, hidden_size]
return torch.stack(agg_out), global_offsets, num_sents
# Perform the aggregation
agg_output, offsets, num_sents_item = _get_agg_output(sentence_ids, sequence_output)
# Step 2: Classify each sentence
# logits shape: [batch_size, max_sentences, num_labels]
logits = self.classifier(agg_output)
# Step 3: Convert logits to probabilities
# Softmax ensures probabilities sum to 1 across classes
# We extract [:, :, 1] to get P(relevant|sentence)
probs = torch.softmax(logits, dim=-1)[:, :, 1]
def _get_preds(pp, offs, num_s, threshold=0.5, alpha=0.05):
"""
Confidence-Based Backoff Strategy
ML Rationale: Pure threshold-based classification might result in
no sentences being highlighted, leading to poor user experience.
The backoff rule ensures at least one sentence is highlighted if
ANY sentence shows reasonable confidence (≥ alpha).
Parameters:
- threshold: Primary classification threshold (default 0.5)
- alpha: Minimum confidence for backoff (default 0.05)
This is a form of confidence calibration and post-processing
that balances precision with recall.
"""
preds = []
for p, off, ns in zip(pp, offs, num_s):
# Get probabilities for actual sentences (exclude padding)
rel_probs = p[:ns]
# Primary classification: threshold at 0.5
hits = (rel_probs >= threshold).int()
# Backoff mechanism: If no sentences pass threshold
# but at least one has confidence ≥ alpha (0.05),
# select the highest confidence sentence
if hits.sum() == 0 and rel_probs.max().item() >= alpha:
hits[rel_probs.argmax()] = 1
# Convert local sentence indices back to global indices
# and return positions of relevant sentences
preds.append(torch.where(hits == 1)[0] + off)
return preds
# Apply the prediction logic with backoff
return tuple(_get_preds(probs, offsets, num_sents_item))
# ==================== DATA PROCESSING COMPONENTS ====================
@dataclass
class DataCollatorWithPadding:
"""
Custom Data Collator for Batch Processing
ML Purpose: Neural networks require fixed-size tensors for batch processing.
This collator handles variable-length sequences by padding them to the same length.
Design Choice: Using -100 for sentence_ids padding aligns with HuggingFace's
convention for ignored tokens in loss computation.
"""
pad_kvs: Dict[str, Union[int, float]] = field(default_factory=dict)
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
first = features[0]
batch = {}
# Pad sequences to max length in batch
# pad_sequence is efficient for GPU processing
for key, pad_value in self.pad_kvs.items():
if key in first and first[key] is not None:
batch[key] = pad_sequence(
[torch.tensor(f[key]) for f in features],
batch_first=True, # Standard for transformer models
padding_value=pad_value,
)
# Stack other tensors (assumes they're already same size)
for k, v in first.items():
if k not in self.pad_kvs and v is not None and isinstance(v, torch.Tensor):
batch[k] = torch.stack([f[k] for f in features])
return batch
def prepare_input_features(
tokenizer, examples, max_seq_length=510, stride=128, padding=False
):
"""
Feature Engineering for Question-Document Pairs
ML Insights:
1. max_seq_length=510: Leaves room for [CLS] and [SEP] tokens (BERT max is 512)
2. stride=128: Sliding window approach for long documents
- Overlapping windows preserve context at boundaries
- Trade-off: More computation vs. better coverage
3. Sentence ID mapping: Critical for aggregation layer
"""
# Tokenize with sliding window for long documents
# truncation="only_second": Keep full query, truncate document if needed
tokenized_examples = tokenizer(
examples["question"],
examples["context"],
truncation="only_second",
max_length=max_seq_length,
stride=stride, # Overlap between chunks
return_overflowing_tokens=True, # Create multiple examples if needed
padding=padding,
is_split_into_words=True, # Input is already word-tokenized
)
# Track which original example each tokenized example came from
sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
tokenized_examples["example_id"] = []
tokenized_examples["word_ids"] = []
tokenized_examples["sentence_ids"] = []
# Map tokens back to sentences
for i, sample_index in enumerate(sample_mapping):
word_ids = tokenized_examples.word_ids(i)
word_level_sentence_ids = examples["word_level_sentence_ids"][sample_index]
sequence_ids = tokenized_examples.sequence_ids(i)
# Find where document tokens start (after query tokens)
token_start_index = 0
while sequence_ids[token_start_index] != 1: # 1 indicates document tokens
token_start_index += 1
# -100 for query tokens (won't be aggregated into sentences)
sentences_ids = [-100] * token_start_index
# Map each document token to its sentence
for word_idx in word_ids[token_start_index:]:
if word_idx is not None:
sentences_ids.append(word_level_sentence_ids[word_idx])
else:
# Sub-word tokens or special tokens
sentences_ids.append(-100)
tokenized_examples["sentence_ids"].append(sentences_ids)
tokenized_examples["example_id"].append(examples["id"][sample_index])
tokenized_examples["word_ids"].append(word_ids)
# Ensure we don't exceed BERT's positional embedding limit
# ML Note: BERT uses learned positional embeddings up to position 512
for key in ("input_ids", "token_type_ids", "attention_mask", "sentence_ids"):
tokenized_examples[key] = [seq[:max_seq_length] for seq in tokenized_examples[key]]
return tokenized_examples
# ==================== USAGE EXAMPLE WITH ML ANNOTATIONS ====================
# Example demonstrates the full ML pipeline
query = "When does OpenSearch use text reanalysis for highlighting?"
document = """To highlight the search terms, the highlighter needs the start and end character offsets of each term. The offsets mark the term's position in the original text. The highlighter can obtain the offsets from the following sources: Postings: When documents are indexed, OpenSearch creates an inverted search index—a core data structure used to search for documents. Postings represent the inverted search index and store the mapping of each analyzed term to the list of documents in which it occurs. If you set the index_options parameter to offsets when mapping a text field, OpenSearch adds each term's start and end character offsets to the inverted index. During highlighting, the highlighter reruns the original query directly on the postings to locate each term. Thus, storing offsets makes highlighting more efficient for large fields because it does not require reanalyzing the text. Storing term offsets requires additional disk space, but uses less disk space than storing term vectors. Text reanalysis: In the absence of both postings and term vectors, the highlighter reanalyzes text in order to highlight it. For every document and every field that needs highlighting, the highlighter creates a small in-memory index and reruns the original query through Lucene's query execution planner to access low-level match information for the current document. Reanalyzing the text works well in most use cases. However, this method is more memory and time intensive for large fields."""
# Sentence segmentation - Critical for defining classification targets
doc_sents = nltk.sent_tokenize(document)
# Create word-level sentence mappings
# ML Note: Word-level processing allows sub-word tokenization flexibility
sentence_ids, context = [], []
for sid, sent in enumerate(doc_sents):
words = sent.split()
context.extend(words)
sentence_ids.extend([sid] * len(words))
# Create dataset in HuggingFace format
example_dataset = Dataset.from_dict(
{
"question": [[query]], # Nested list for batch processing
"context": [context],
"word_level_sentence_ids": [sentence_ids],
"id": [0],
}
)
# Initialize tokenizer - BERT uses WordPiece tokenization
# ML insight: Subword tokenization handles OOV words better than word-level
base_model_id = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(base_model_id)
# Configure padding strategy
collator = DataCollatorWithPadding(
pad_kvs={
"input_ids": 0, # [PAD] token ID
"token_type_ids": 0, # Padding belongs to first segment
"attention_mask": 0, # Don't attend to padding
"sentence_ids": -100, # Ignore in sentence aggregation
"sentence_labels": -100, # Ignore in loss computation
}
)
# Feature extraction function
preprocess_fn = partial(prepare_input_features, tokenizer)
# Apply feature engineering
example_dataset = example_dataset.map(
preprocess_fn,
batched=True, # Process multiple examples at once
remove_columns=example_dataset.column_names, # Remove raw features
desc="Preparing model inputs",
)
# Create DataLoader for batch processing
# ML Note: Even with batch_size=1, using DataLoader ensures consistent tensor formatting
loader = DataLoader(example_dataset, batch_size=1, collate_fn=collator)
# Get single batch for inference
batch = next(iter(loader))
# Load pre-trained model
# ML Note: This model has been fine-tuned on query-document-relevance data
model = BertTaggerForSentenceExtractionWithBackoff.from_pretrained(
"opensearch-project/opensearch-semantic-highlighter-v1"
)
# Handle variable sequence lengths - Critical for BERT's positional embeddings
max_len = model.config.max_position_embeddings # 512 for BERT
for key in ("input_ids", "token_type_ids", "attention_mask", "sentence_ids"):
batch[key] = batch[key][:, :max_len]
# Run inference
# Returns indices of relevant sentences
highlights = model(
batch["input_ids"],
batch["attention_mask"],
batch["token_type_ids"],
batch["sentence_ids"],
)
# Extract highlighted sentences
highlighted_sentences = [doc_sents[i] for i in highlights[0]]
print(highlighted_sentences)


