-
Notifications
You must be signed in to change notification settings - Fork 109
feat: Proposed SIMBAUQ Sampling Strategy #785
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
radum2275
wants to merge
23
commits into
generative-computing:main
Choose a base branch
from
radum2275:feat/simba_uq
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 15 commits
Commits
Show all changes
23 commits
Select commit
Hold shift + click to select a range
c5236f0
feat: initial commit for the SIMBAUQSamplingStrategy
ea51043
chore: added a separate filed to mot.meta for the similarity matrix
5c23a58
chore: added a second aggregation by classification CE algorithm
d7f3b6a
refactor: revised and moved the SIMBAUQSamplingStrategy in docs/examples
908258c
Update test/stdlib/sampling/test_simbauq.py
radum2275 8b8c336
Update docs/examples/simbauq/simbauq_example.py
radum2275 865e85f
Update .gitignore
radum2275 a6b356a
Update docs/examples/simbauq/README.md
radum2275 cbae30c
Update docs/examples/simbauq/README.md
radum2275 a3c51a8
Update mellea/stdlib/sampling/simbauq.py
radum2275 e9b05f1
Update mellea/stdlib/sampling/simbauq.py
radum2275 372046a
Update mellea/stdlib/sampling/simbauq.py
radum2275 af55899
refactor: refactored the simbauq sampling strategy
da1440d
fix: added the ollama backend in simbauq example
11b180f
chore: set aggregation by mean in simbauq example
6c6c099
chore: fixed a typo in the simbauq README.md file
78fe6c7
chore: added scikit-learn as required dependency for simbauq strategy
65a1268
Update test/stdlib/sampling/test_simbauq.py
radum2275 41728a5
Update test/stdlib/sampling/test_simbauq.py
radum2275 f90a466
Update mellea/stdlib/sampling/simbauq.py
radum2275 1cd588c
Update mellea/stdlib/sampling/simbauq.py
radum2275 c8bd228
Update mellea/stdlib/sampling/simbauq.py
radum2275 e0b5952
chore: revised the dependencies for simbauq strategy
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,153 @@ | ||
| # SIMBA-UQ Sampling Strategy | ||
|
|
||
| Confidence-aware sample selection using the SIMBA-UQ framework | ||
| (Bhattacharjya et al., 2025). Generates multiple samples across a range of | ||
| temperatures and selects the one with the highest estimated confidence. | ||
|
|
||
| **Paper:** [SIMBA UQ: Similarity-Based Aggregation for Uncertainty Quantification in Large Language Models](https://arxiv.org/abs/2510.13836) | ||
|
|
||
| ## Files | ||
|
|
||
| ### simbauq_example.py | ||
|
|
||
| Complete example demonstrating both confidence estimation methods with | ||
| a RITSBackend and granite-4.0-micro. | ||
|
|
||
| ## Architecture | ||
|
|
||
| ``` | ||
| User Query | ||
| | | ||
| v | ||
| Generate N samples (across temperatures) | ||
| | | ||
| v | ||
| Compute pairwise similarity matrix (N x N) | ||
| | | ||
| +---> [Aggregation] Aggregate similarities per sample -> confidence | ||
| | | ||
| +---> [Classifier] Extract features per sample -> RF predicts P(correct) | ||
| | | ||
| v | ||
| Select sample with highest confidence | ||
| | | ||
| v | ||
| Result (with confidence metadata in mot.meta["simba_uq"]) | ||
| ``` | ||
|
|
||
| ## Confidence Methods | ||
|
|
||
| ### 1. Aggregation (data-free) | ||
|
|
||
| No training data required. For each sample, computes its similarity to every | ||
| other sample, then aggregates those values into a confidence score. Samples | ||
| that are more similar to the majority get higher confidence. | ||
|
|
||
| ```python | ||
| from mellea.stdlib.sampling.simbauq import SIMBAUQSamplingStrategy | ||
|
|
||
| strategy = SIMBAUQSamplingStrategy( | ||
| temperatures=[0.3, 0.5, 0.7, 1.0], | ||
| n_per_temp=3, | ||
| similarity_metric="rouge", | ||
| confidence_method="aggregation", | ||
| aggregation="mean", | ||
| ) | ||
|
|
||
| result = m.instruct("Your query here", strategy=strategy, return_sampling_results=True) | ||
| ``` | ||
|
|
||
| ### 2. Classifier (trained) | ||
|
|
||
| Uses a random forest classifier trained on labeled examples. The classifier | ||
| learns to predict P(correct) from pairwise similarity features. Provide | ||
| either training data or a pre-trained sklearn classifier. | ||
|
|
||
| **With training data:** | ||
|
|
||
| ```python | ||
| strategy = SIMBAUQSamplingStrategy( | ||
| temperatures=[0.3, 0.5, 0.7, 1.0], | ||
| n_per_temp=3, | ||
| similarity_metric="rouge", | ||
| confidence_method="classifier", | ||
| training_samples=[ | ||
| ["correct answer 1", "correct answer 2", ..., "wrong answer"], # group 1 | ||
| ["correct answer 1", "correct answer 2", ..., "wrong answer"], # group 2 | ||
| ], | ||
| training_labels=[ | ||
| [1, 1, ..., 0], # labels for group 1 | ||
| [1, 1, ..., 0], # labels for group 2 | ||
| ], | ||
| ) | ||
| ``` | ||
|
|
||
| Each training group must have exactly `len(temperatures) * n_per_temp` samples | ||
| so the feature vectors match at inference time. | ||
|
|
||
| **With pre-trained classifier:** | ||
|
|
||
| ```python | ||
| strategy = SIMBAUQSamplingStrategy( | ||
| temperatures=[0.3, 0.5, 0.7, 1.0], | ||
| n_per_temp=3, | ||
| confidence_method="classifier", | ||
| classifier=my_pretrained_sklearn_clf, | ||
| ) | ||
| ``` | ||
|
|
||
| ## Constructor Parameters | ||
|
|
||
| | Parameter | Type | Default | Description | | ||
| |-----------|------|---------|-------------| | ||
| | `temperatures` | `list[float]` | `[0.3, 0.5, 0.7, 1.0]` | Temperature values to sample at | | ||
| | `n_per_temp` | `int` | `4` | Number of samples per temperature | | ||
| | `similarity_metric` | `"rouge"`, `"jaccard"`, `"sbert"` | `"rouge"` | Pairwise similarity metric | | ||
| | `confidence_method` | `"aggregation"`, `"classifier"` | `"aggregation"` | Confidence estimation method | | ||
| | `aggregation` | `"mean"`, `"geometric_mean"`, `"harmonic_mean"`, `"median"`, `"max"`, `"min"` | `"mean"` | Aggregation function (for `aggregation` method) | | ||
| | `classifier` | sklearn classifier | `None` | Pre-trained classifier with `predict_proba` | | ||
| | `training_samples` | `list[list[str]]` | `None` | Training data for classifier | | ||
| | `training_labels` | `list[list[int]]` | `None` | Binary correctness labels (0/1) | | ||
| | `clf_max_depth` | `int` | `4` | Max tree depth for random forest | | ||
| | `rouge_type` | `str` | `"rougeL"` | Rouge variant | | ||
| | `sbert_model` | `str` | `"all-MiniLM-L6-v2"` | Sentence-BERT model name | | ||
| | `requirements` | `list[Requirement]` | `None` | Requirements to validate the selected sample | | ||
|
|
||
| ## Similarity Metrics | ||
|
|
||
| - **rouge** (default): RougeL F-measure. Good general-purpose text similarity. | ||
| No extra dependencies beyond `rouge-score` (already in mellea). | ||
| - **jaccard**: Word-level set overlap (intersection / union). Fast, no | ||
| external dependencies, works well for short structured answers. | ||
| - **sbert**: Cosine similarity of Sentence-BERT embeddings. Best semantic | ||
| similarity but requires `sentence-transformers` (`pip install | ||
| mellea[granite_retriever]`). | ||
|
|
||
| ## Inspecting Results | ||
|
|
||
| The selected sample's `ModelOutputThunk` stores confidence metadata: | ||
|
|
||
| ```python | ||
| result = m.instruct(..., strategy=strategy, return_sampling_results=True) | ||
|
|
||
| # Best sample | ||
| best_mot = result.result | ||
| meta = best_mot._meta["simba_uq"] | ||
|
|
||
| meta["confidence"] # float: confidence of the selected sample | ||
| meta["all_confidences"] # list[float]: confidence for every sample | ||
| meta["similarity_matrix"] # list[list[float]]: N x N pairwise similarity matrix | ||
| meta["temperatures_used"] # list[float]: temperature used for each sample | ||
| meta["confidence_method"] # "aggregation" or "classifier" | ||
| meta["similarity_metric"] # "rouge", "jaccard", or "sbert" | ||
| meta["aggregation"] # aggregation function name | ||
|
|
||
| # All generated samples | ||
| for i, mot in enumerate(result.sample_generations): | ||
| print(f"Sample {i}: {mot.value}") | ||
| ``` | ||
|
|
||
| ## Related Files | ||
|
|
||
| - `mellea/stdlib/sampling/simbauq.py` -- Strategy implementation | ||
| - `test/stdlib/sampling/test_simbauq.py` -- Unit and integration tests | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,202 @@ | ||
| # pytest: ollama, llm, qualitative | ||
|
|
||
| """SIMBA-UQ Sampling Strategy Example. | ||
|
|
||
| This example demonstrates the SIMBAUQSamplingStrategy using both confidence | ||
| estimation methods: | ||
|
|
||
| 1. **Aggregation** (data-free) - Computes pairwise similarity between all | ||
| generated samples and aggregates them into per-sample confidence scores. | ||
| The sample with the highest confidence is selected. | ||
|
|
||
| 2. **Classifier** (trained) - Uses a random forest classifier trained on | ||
| labeled examples to predict P(correct) for each sample based on its | ||
| pairwise similarity features. | ||
|
|
||
| Both methods generate multiple samples across different temperature values, | ||
| compute a similarity matrix, and select the most confident response. | ||
|
|
||
| The example uses OllamaModelBackend with granite4:micro. To run: | ||
|
|
||
| ollama serve | ||
| uv run python docs/examples/simbauq/simbauq_example.py | ||
| """ | ||
|
|
||
| import numpy as np | ||
|
|
||
| from mellea import MelleaSession | ||
| from mellea.backends import ModelOption | ||
| from mellea.backends.ollama import OllamaModelBackend | ||
| from mellea.core import SamplingResult | ||
| from mellea.stdlib.context import ChatContext | ||
| from mellea.stdlib.sampling.simbauq import SIMBAUQSamplingStrategy | ||
|
|
||
|
|
||
| def make_session() -> MelleaSession: | ||
| """Create a MelleaSession with OllamaModelBackend.""" | ||
| backend = OllamaModelBackend(model_options={ModelOption.MAX_NEW_TOKENS: 100}) | ||
| return MelleaSession(backend, ctx=ChatContext()) | ||
|
|
||
|
|
||
| def print_results(result: SamplingResult) -> None: | ||
| """Print detailed results from a SIMBA-UQ sampling run.""" | ||
| meta = result.result._meta["simba_uq"] | ||
| confidences = meta["all_confidences"] | ||
| temperatures = meta["temperatures_used"] | ||
| sim_matrix = np.array(meta["similarity_matrix"]) | ||
|
|
||
| # --- Best response --- | ||
| print("=" * 70) | ||
| print("BEST RESPONSE") | ||
| print("=" * 70) | ||
| print(f" Index: {result.result_index}") | ||
| print(f" Confidence: {meta['confidence']:.4f}") | ||
| print(f" Method: {meta['confidence_method']}") | ||
| print(f" Metric: {meta['similarity_metric']}") | ||
| print(f" Aggregation: {meta['aggregation']}") | ||
| print(f" Text:\n {result.result!s}") | ||
| print() | ||
|
|
||
| # --- All samples --- | ||
| print("=" * 70) | ||
| print("ALL SAMPLES") | ||
| print("=" * 70) | ||
| print(f"{'Idx':>4} {'Temp':>5} {'Conf':>8} {'Text'}") | ||
| print("-" * 70) | ||
| for i, mot in enumerate(result.sample_generations): | ||
| text = str(mot).replace("\n", " ") | ||
| truncated = (text[:100] + "...") if len(text) > 100 else text | ||
| marker = " <-- best" if i == result.result_index else "" | ||
| print( | ||
| f"{i:>4} {temperatures[i]:>5.2f} {confidences[i]:>8.4f} " | ||
| f"{truncated}{marker}" | ||
| ) | ||
| print() | ||
|
|
||
| # --- Similarity matrix --- | ||
| n = sim_matrix.shape[0] | ||
| print("=" * 70) | ||
| print("SIMILARITY MATRIX") | ||
| print("=" * 70) | ||
| header = " " + "".join(f" [{i:>2}] " for i in range(n)) | ||
| print(header) | ||
| for i in range(n): | ||
| row = f"[{i:>2}] " + "".join(f" {sim_matrix[i, j]:.3f} " for j in range(n)) | ||
| print(row) | ||
| print() | ||
|
|
||
|
|
||
| def run_aggregation_example() -> None: | ||
| """Run SIMBA-UQ with data-free similarity aggregation.""" | ||
| print("\n>>> AGGREGATION CONFIDENCE METHOD <<<\n") | ||
|
|
||
| m = make_session() | ||
|
|
||
| strategy = SIMBAUQSamplingStrategy( | ||
| temperatures=[0.3, 0.5, 0.7, 1.0], | ||
| n_per_temp=3, | ||
| similarity_metric="rouge", | ||
| confidence_method="aggregation", | ||
| aggregation="mean", | ||
| ) | ||
|
|
||
| result: SamplingResult = m.instruct( | ||
| "Which magazine was started first Arthur's Magazine or First for Women?", | ||
| strategy=strategy, | ||
| return_sampling_results=True, | ||
| ) | ||
|
|
||
| print(f"Total samples generated: {len(result.sample_generations)}") | ||
| print_results(result) | ||
|
|
||
| del m | ||
|
|
||
|
|
||
| def run_classifier_example() -> None: | ||
| """Run SIMBA-UQ with a trained random forest classifier.""" | ||
| print("\n>>> CLASSIFIER CONFIDENCE METHOD <<<\n") | ||
|
|
||
| m = make_session() | ||
|
|
||
| # Synthetic training data: 3 groups of 12 samples (4 temps * 3 per temp). | ||
| # Each group has mostly "correct" similar answers and a few outliers. | ||
| training_samples = [ | ||
| [ | ||
| "Paris is the capital of France.", | ||
| "The capital of France is Paris.", | ||
| "France's capital city is Paris.", | ||
| "Paris, the capital of France.", | ||
| "The capital city of France is Paris.", | ||
| "France has Paris as its capital.", | ||
| "Paris serves as France's capital.", | ||
| "In France, Paris is the capital.", | ||
| "The French capital is Paris.", | ||
| "Bananas are a yellow fruit.", | ||
| "Dogs are loyal pets.", | ||
| "The ocean is very deep.", | ||
| ], | ||
| [ | ||
| "Water boils at 100 degrees Celsius.", | ||
| "At 100C water reaches boiling point.", | ||
| "The boiling point of water is 100 degrees.", | ||
| "Water boils when heated to 100C.", | ||
| "100 degrees Celsius is water's boiling point.", | ||
| "Boiling occurs at 100C for water.", | ||
| "Water starts boiling at one hundred degrees.", | ||
| "At 100 degrees water boils.", | ||
| "The temperature for boiling water is 100C.", | ||
| "Cats like to sleep a lot.", | ||
| "Mountains can be very high.", | ||
| "Stars shine in the night sky.", | ||
| ], | ||
| [ | ||
| "Python is a programming language.", | ||
| "Python is a popular programming language.", | ||
| "The Python programming language is widely used.", | ||
| "Python is used for programming.", | ||
| "Programming in Python is common.", | ||
| "Python is a well-known language for coding.", | ||
| "Many developers use Python.", | ||
| "Python is a general-purpose language.", | ||
| "The language Python is popular.", | ||
| "Pizza originated in Italy.", | ||
| "Rain falls from clouds.", | ||
| "Books contain many pages.", | ||
| ], | ||
| ] | ||
| training_labels = [ | ||
| [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], | ||
| [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], | ||
| [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], | ||
| ] | ||
|
|
||
| strategy = SIMBAUQSamplingStrategy( | ||
| temperatures=[0.3, 0.5, 0.7, 1.0], | ||
| n_per_temp=3, | ||
| similarity_metric="rouge", | ||
| confidence_method="classifier", | ||
| training_samples=training_samples, | ||
| training_labels=training_labels, | ||
| ) | ||
|
|
||
| result: SamplingResult = m.instruct( | ||
| "Which magazine was started first Arthur's Magazine or First for Women?", | ||
| strategy=strategy, | ||
| return_sampling_results=True, | ||
| ) | ||
|
|
||
| print(f"Total samples generated: {len(result.sample_generations)}") | ||
| print_results(result) | ||
|
|
||
| del m | ||
|
|
||
|
|
||
| def main(): | ||
| """Run both SIMBA-UQ confidence estimation examples.""" | ||
| run_aggregation_example() | ||
| print("\n" + "=" * 70 + "\n") | ||
| run_classifier_example() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.