Skip to content

SFT Dataset Creation

Overview

The SFT extraction module provides tools for creating Supervised Fine-Tuning (SFT) datasets from generated completions. It filters, ranks, and formats model completions into structured conversation samples suitable for training, with support for reward-based filtering, diversity-aware selection, and customizable prompt templates.

Key Concepts

Extraction Pipeline

The SFT extraction pipeline follows these steps:

  1. Correctness filtering — Remove completions that are invalid (e.g., malformed SMILES, failed extractions)
  2. Grouping — Group completions by their originating prompt ID
  3. Diversity-aware filtering (optional) — Deduplicate chemically similar completions using molecular fingerprints (see Diversity-Aware Top-k)
  4. Reward threshold filtering (optional) — Keep only completions whose reward exceeds a minimum threshold
  5. Post-processing — Optionally inject reward and source information into the prompt messages

Reward & Source Templating

After building conversation samples, the extractor can enrich prompts with reward and source information using configurable templates. For example, with the default reward_info_template:

# Original system message:
"Generate a molecule with high binding affinity."

# After reward templating (reward=0.85):
"Generate a molecule with high binding affinity.\nPropose an answer whose reward is: 0.85"

Usage Examples

Basic SFT Extraction

from mol_gen_docking.evaluation.sft_extraction import (
    SFTExtractionConfig,
    SFTExtractor,
    Completion,
)
from mol_gen_docking.data.pydantic_dataset import Sample, Conversation, Message

# Define the extraction configuration
config = SFTExtractionConfig(
    min_reward_threshold=0.5,
    div_threshold=None,  # No diversity filtering
    reward_info_template={},  # No reward injection
    source_info_template={},  # No source injection
)

extractor = SFTExtractor(config)

# Build a prompt sample
prompt = Sample(
    identifier="prompt_0",
    conversations=[
        Conversation(
            messages=[
                Message(role="system", content="You are a molecular generation assistant."),
                Message(role="user", content="Generate a molecule with high docking score."),
            ]
        )
    ],
)

# Build completions (e.g., from a generation run)
completions = [
    Completion(
        output="<answer>CCO</answer>",
        reward=0.8,
        metadata={"prompt_id": "prompt_0"},
        reward_meta={
            "generation_verifier_metadata": {"all_smi": ["CCO"]}
        },
        source="my_model_v1",
    ),
    Completion(
        output="<answer>invalid</answer>",
        reward=0.1,
        metadata={"prompt_id": "prompt_0"},
        reward_meta={
            "generation_verifier_metadata": {"all_smi": ["invalid"]}
        },
        source="my_model_v1",
    ),
]

# Extract SFT samples
samples = extractor.extract(completions, [prompt])
print(len(samples))
# Output:
# >>> 1

With Diversity-Aware Filtering

config = SFTExtractionConfig(
    min_reward_threshold=0.3,
    div_threshold=0.7,            # Tanimoto similarity threshold
    fingerprint_name="ecfp4-1024",
    reward_info_template={},
    source_info_template={},
)

extractor = SFTExtractor(config)
samples = extractor.extract(completions, [prompt])

With Reward-Conditioned Prompts

config = SFTExtractionConfig(
    min_reward_threshold=0.5,
    div_threshold=None,
    reward_info_template={
        "system": "{content}\nPropose an answer whose reward is: {reward:.2f}"
    },
    source_info_template={
        "system": "{content}\nThe source of this conversation is: {source}"
    },
)

extractor = SFTExtractor(config)
samples = extractor.extract(completions, [prompt])

# The system message of each conversation now includes reward and source info
print(samples[0].conversations[0].messages[0].content)
# Output:
# >>> "You are a molecular generation assistant.
# >>> Propose an answer whose reward is: 0.80
# >>> The source of this conversation is: my_model_v1"

Using a Custom System Prompt File

config = SFTExtractionConfig(
    system_prompt_path="system_prompts/vanilla.json",
    min_reward_threshold=0.5,
    div_threshold=None,
    reward_info_template={},
    source_info_template={},
)

extractor = SFTExtractor(config)
# The system prompt from the JSON file will replace the existing system message
samples = extractor.extract(completions, [prompt])

Inspecting Extraction Metadata

After extraction, the SFTExtractor tracks metadata about the retained completions:

extractor = SFTExtractor(config)
samples = extractor.extract(completions, [prompt])

# Prompt IDs of all retained completions
print(extractor.metadata.prompt_ids)

# Rewards of retained completions
print(extractor.metadata.rewards)

# Estimated token counts
print(extractor.metadata.n_tokens)

Class & Function Reference

Completion

Bases: BaseModel

A single model completion with associated reward and metadata.

Represents one generated output for a given prompt, together with the reward score assigned by the reward model and any verifier metadata needed for correctness filtering and diversity-aware selection.

Attributes:

Name Type Description
output str

The raw generated text (e.g., containing <answer>...</answer> tags).

reward Optional[float]

Reward score assigned to this completion, or None if not yet scored.

metadata Dict[str, Any]

Arbitrary metadata dict; must contain a "prompt_id" key that links the completion back to its originating prompt.

reward_meta Dict[str, Any]

Verifier metadata produced by the reward pipeline. Expected to contain one of generation_verifier_metadata, mol_prop_verifier_metadata, or reaction_verifier_metadata.

source Optional[str]

Optional label identifying the model or method that produced the completion (e.g., "Qwen3-0.6B-sft-v2").

Example
comp = Completion(
    output="<answer>CCO</answer>",
    reward=0.85,
    metadata={"prompt_id": "prompt_42"},
    reward_meta={
        "generation_verifier_metadata": {"all_smi": ["CCO"]}
    },
    source="my_model_v1",
)
Source code in mol_gen_docking/evaluation/sft_extraction.py
class Completion(BaseModel):
    """A single model completion with associated reward and metadata.

    Represents one generated output for a given prompt, together with the
    reward score assigned by the reward model and any verifier metadata
    needed for correctness filtering and diversity-aware selection.

    Attributes:
        output: The raw generated text (e.g., containing ``<answer>...</answer>`` tags).
        reward: Reward score assigned to this completion, or ``None`` if not yet scored.
        metadata: Arbitrary metadata dict; must contain a ``"prompt_id"`` key that
            links the completion back to its originating prompt.
        reward_meta: Verifier metadata produced by the reward pipeline. Expected to
            contain one of ``generation_verifier_metadata``,
            ``mol_prop_verifier_metadata``, or ``reaction_verifier_metadata``.
        source: Optional label identifying the model or method that produced the
            completion (e.g., ``"Qwen3-0.6B-sft-v2"``).

    Example:
        ```python
        comp = Completion(
            output="<answer>CCO</answer>",
            reward=0.85,
            metadata={"prompt_id": "prompt_42"},
            reward_meta={
                "generation_verifier_metadata": {"all_smi": ["CCO"]}
            },
            source="my_model_v1",
        )
        ```
    """

    output: str = Field(
        ...,
        description="The generated completion text.",
    )
    reward: Optional[float] = Field(
        None,
        description="The reward score associated with the completion.",
    )
    metadata: Dict[str, Any] = Field(
        ...,
        description="Metadata associated with the completion, including generation verifier metadata and other relevant information.",
    )
    reward_meta: Dict[str, Any] = Field(
        default_factory=dict,
        description="Additional metadata related to the reward computation.",
    )
    source: Optional[str] = Field(
        None,
        description="The source of the completion (e.g., model name, generation method).",
    )

SFTExtractionConfig

Bases: BaseModel

Configuration for SFT dataset extraction.

Controls how completions are filtered, selected, and formatted into training samples. Supports reward thresholding, diversity-aware deduplication, and prompt enrichment via reward/source templates.

Attributes:

Name Type Description
system_prompt_path Optional[str]

Path to a JSON file containing a system prompt template. If specified, the content replaces the system message in all conversations.

min_reward_threshold Optional[float]

Minimum reward score a completion must achieve to be retained. Set to None to disable reward filtering.

div_threshold Optional[float]

Tanimoto similarity threshold for diversity-aware filtering. Completions more similar than this threshold to an already-selected completion are discarded. Set to None to disable diversity filtering.

fingerprint_name Optional[str]

Molecular fingerprint type used for diversity computation (e.g., "ecfp6-2048"). Only relevant when div_threshold is set.

reward_info_template Dict[str, str]

Per-role templates for injecting reward information into prompt messages. Keys are message roles (e.g., "system"), values are format strings with {content} and {reward} placeholders.

source_info_template Dict[str, str]

Per-role templates for injecting source information into prompt messages. Keys are message roles, values are format strings with {content} and {source} placeholders. Applied after reward_info_template.

Example
config = SFTExtractionConfig(
    min_reward_threshold=0.5,
    div_threshold=0.7,
    fingerprint_name="ecfp4-1024",
    reward_info_template={
        "system": "{content}\nPropose an answer whose reward is: {reward:.2f}"
    },
    source_info_template={},
)
Source code in mol_gen_docking/evaluation/sft_extraction.py
class SFTExtractionConfig(BaseModel):
    """Configuration for SFT dataset extraction.

    Controls how completions are filtered, selected, and formatted into
    training samples. Supports reward thresholding, diversity-aware deduplication,
    and prompt enrichment via reward/source templates.

    Attributes:
        system_prompt_path: Path to a JSON file containing a system prompt template.
            If specified, the content replaces the system message in all conversations.
        min_reward_threshold: Minimum reward score a completion must achieve to be
            retained. Set to ``None`` to disable reward filtering.
        div_threshold: Tanimoto similarity threshold for diversity-aware filtering.
            Completions more similar than this threshold to an already-selected
            completion are discarded. Set to ``None`` to disable diversity filtering.
        fingerprint_name: Molecular fingerprint type used for diversity computation
            (e.g., ``"ecfp6-2048"``). Only relevant when ``div_threshold`` is set.
        reward_info_template: Per-role templates for injecting reward information into
            prompt messages. Keys are message roles (e.g., ``"system"``), values are
            format strings with ``{content}`` and ``{reward}`` placeholders.
        source_info_template: Per-role templates for injecting source information into
            prompt messages. Keys are message roles, values are format strings with
            ``{content}`` and ``{source}`` placeholders. Applied after
            ``reward_info_template``.

    Example:
        ```python
        config = SFTExtractionConfig(
            min_reward_threshold=0.5,
            div_threshold=0.7,
            fingerprint_name="ecfp4-1024",
            reward_info_template={
                "system": "{content}\\nPropose an answer whose reward is: {reward:.2f}"
            },
            source_info_template={},
        )
        ```
    """

    system_prompt_path: Optional[str] = Field(
        None,
        description="Path to a text file containing the system prompt template. \
            If specified, the content of this file will be used as the system prompt template for all \
            conversations during SFT extraction. The template should include a placeholder for the content (e.g., {content}).",
    )
    min_reward_threshold: Optional[float] = Field(
        ...,
        description="Minimum reward threshold for filtering completions during SFT extraction.",
    )
    div_threshold: Optional[float] = Field(
        ...,
        description="Diversity threshold for diversity-aware top-k selection during SFT extraction.",
    )
    fingerprint_name: Optional[str] = Field(
        "ecfp6-2048",
        description="Fingerprint type to use for diversity-aware top-k selection. \
            This should be a valid fingerprint type supported by the underlying implementation.",
    )
    reward_info_template: Dict[str, str] = Field(
        default_factory=lambda x: dict(
            user="{content}\n(Propose an answer whose reward is: {reward:.2f})"
        ),
        description="Template used to add the reward information to the prompt (specified by the key). \
            The template should include placeholders for the content and reward values.",
    )
    source_info_template: Dict[str, str] = Field(
        default_factory=lambda x: dict(user="{content}\n(Source model: {source})"),
        description="Template used to add the source information to the prompt (specified by the key). \
            The template should include placeholders for the content and source values. \
            This pattern is always applied after the reward_info_template if both are specified.",
    )
    boxed: bool = Field(True, description="Whether to force the answer to be boxed")

SFTExtractionMetadata

Bases: BaseModel

Metadata accumulated during SFT extraction.

Tracks per-completion statistics (prompt IDs, rewards, estimated token counts) for all completions that passed filtering and were included in the final dataset.

Attributes:

Name Type Description
prompt_ids List[str]

Prompt identifiers for each retained completion, preserving the order in which completions were processed.

rewards List[float]

Reward scores for each retained completion. Defaults to 0.0 when the original reward is None.

n_tokens List[int]

Rough token-count estimates for each retained completion, computed as len(output) // 4.

Source code in mol_gen_docking/evaluation/sft_extraction.py
class SFTExtractionMetadata(BaseModel):
    """Metadata accumulated during SFT extraction.

    Tracks per-completion statistics (prompt IDs, rewards, estimated token counts)
    for all completions that passed filtering and were included in the final dataset.

    Attributes:
        prompt_ids: Prompt identifiers for each retained completion, preserving
            the order in which completions were processed.
        rewards: Reward scores for each retained completion. Defaults to ``0.0``
            when the original reward is ``None``.
        n_tokens: Rough token-count estimates for each retained completion,
            computed as ``len(output) // 4``.
    """

    prompt_ids: List[str] = Field(
        default_factory=list,
        description="List of prompt IDs associated with the completions. \
            Used to group completions that were generated from the same prompt for filtering during SFT extraction.",
    )
    rewards: List[float] = Field(
        default_factory=list, description="List of rewards extracted for each prompt."
    )
    n_tokens: List[int] = Field(
        default_factory=list,
        description="List of token counts for each completion (number of characters divided by 4 as a rough estimate).",
    )

SFTExtractor

Extracts SFT training samples from a collection of completions and prompts.

Orchestrates the full extraction pipeline: correctness filtering, per-prompt grouping, diversity-aware selection, reward thresholding, conversation formatting, and post-processing (reward/source injection).

Attributes:

Name Type Description
config

The SFTExtractionConfig controlling extraction behaviour.

metadata

An SFTExtractionMetadata instance accumulating statistics about retained completions across calls to :meth:extract.

Example
from mol_gen_docking.evaluation.sft_extraction import (
    SFTExtractionConfig, SFTExtractor, Completion,
)

config = SFTExtractionConfig(
    min_reward_threshold=0.5,
    div_threshold=None,
    reward_info_template={},
    source_info_template={},
)
extractor = SFTExtractor(config)
samples = extractor.extract(completions, prompts)
Source code in mol_gen_docking/evaluation/sft_extraction.py
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
class SFTExtractor:
    """Extracts SFT training samples from a collection of completions and prompts.

    Orchestrates the full extraction pipeline: correctness filtering, per-prompt
    grouping, diversity-aware selection, reward thresholding, conversation
    formatting, and post-processing (reward/source injection).

    Attributes:
        config: The `SFTExtractionConfig` controlling extraction behaviour.
        metadata: An `SFTExtractionMetadata` instance accumulating statistics
            about retained completions across calls to :meth:`extract`.

    Example:
        ```python
        from mol_gen_docking.evaluation.sft_extraction import (
            SFTExtractionConfig, SFTExtractor, Completion,
        )

        config = SFTExtractionConfig(
            min_reward_threshold=0.5,
            div_threshold=None,
            reward_info_template={},
            source_info_template={},
        )
        extractor = SFTExtractor(config)
        samples = extractor.extract(completions, prompts)
        ```
    """

    def __init__(self, config: SFTExtractionConfig) -> None:
        """Initialise the SFT extractor.

        Args:
            config: Extraction configuration controlling filtering thresholds,
                diversity parameters, and prompt templates.
        """
        self.config = config
        self.metadata = SFTExtractionMetadata()

    def filter_is_correct(self, completions: List[Completion]) -> List[Completion]:
        """Filter completions to keep only those that are semantically correct.

        Automatically detects the task type from the verifier metadata and applies
        the appropriate correctness check:

        - **Molecular generation** (``generation_verifier_metadata``): keeps
          completions with exactly one extracted SMILES.
        - **Property prediction** (``mol_prop_verifier_metadata``): keeps
          completions where value extraction succeeded.
        - **Retro-synthesis** (``reaction_verifier_metadata``): keeps completions
          whose validity score is strictly positive.

        If any completion has empty ``reward_meta``, all completions are returned
        unfiltered (graceful fallback).

        Args:
            completions: List of completions to filter.

        Returns:
            A new list containing only the completions that passed the
            task-specific correctness check.

        Raises:
            ValueError: If none of the known verifier metadata keys are found
                in the completions.
        """
        if any(len(c.reward_meta) == 0 for c in completions):
            return completions
        # First if molecular generation:
        if all(
            c.reward_meta.get("generation_verifier_metadata", None) is not None
            for c in completions
        ):
            return [
                c
                for c in completions
                if len(c.reward_meta["generation_verifier_metadata"]["all_smi"]) == 1
            ]
        # Molecular property prediction
        if all(
            c.reward_meta.get("mol_prop_verifier_metadata", None) is not None
            for c in completions
        ):
            return [
                c
                for c in completions
                if c.reward_meta["mol_prop_verifier_metadata"]["extraction_success"]
            ]
        if all(
            c.reward_meta.get("reaction_verifier_metadata", None) is not None
            for c in completions
        ):
            return [
                c
                for c in completions
                if c.reward_meta["reaction_verifier_metadata"]["valid"] > 0.0
            ]
        raise ValueError("Unknown verifier metadata format for filtering completions.")

    def filter_completions_single_id(
        self, completions: List[Completion]
    ) -> List[Completion]:
        """Filter and rank completions that share a single prompt ID.

        Applies the following filters in order:

        1. **Diversity-aware selection** — if ``config.div_threshold`` is set, uses
           :func:`~mol_gen_docking.evaluation.diversity_aware_top_k.diversity_aware_top_k`
           to remove chemically redundant completions.
        2. **Reward threshold** — if ``config.min_reward_threshold`` is set, discards
           completions whose reward is below the threshold.

        Side-effects: appends per-completion statistics (reward, token count,
        prompt ID) to ``self.metadata``.

        Args:
            completions: List of completions that must all share the same
                ``metadata["prompt_id"]``.

        Returns:
            The filtered list of completions.

        Raises:
            AssertionError: If completions do not share the same prompt ID, or if
                diversity filtering is requested but reward metadata is missing.
        """
        prompt_id = completions[0].metadata["prompt_id"]
        assert all(c.metadata["prompt_id"] == prompt_id for c in completions), (
            "All completions must have the same prompt_id for filtering."
        )

        # Filter completions based on diversity-aware top-k selection
        if self.config.div_threshold is not None:
            assert all(
                len(c.reward_meta) > 0 and c.reward is not None for c in completions
            ), (
                "Reward metadata must be available for all completions when min_reward_threshold is set."
            )
            assert all(
                c.reward_meta["generation_verifier_metadata"] is not None
                for c in completions
            ), (
                "Diversity aware sampling can only be performed on molecular generation tasks"
            )

            if len(completions) > 1:
                _, keep_indices = diversity_aware_top_k(
                    mols=[
                        comp.reward_meta["generation_verifier_metadata"]["all_smi"][0]
                        for comp in completions
                    ],
                    scores=[c.reward for c in completions],  # type: ignore
                    k=len(completions),
                    t=self.config.div_threshold,
                    fingerprint_name=self.config.fingerprint_name,
                    return_idxs=True,
                )
                completions = [completions[i] for i in keep_indices]

        # Filter completions based on minimum reward threshold
        if self.config.min_reward_threshold is not None:
            assert all(c.reward is not None for c in completions)
            completions = [
                c
                for c in completions
                if c.reward is not None and c.reward >= self.config.min_reward_threshold
            ]

        # Update metadata
        for c in completions:
            if c.reward is not None:
                self.metadata.rewards.append(c.reward)
            else:
                self.metadata.rewards.append(0.0)
            self.metadata.n_tokens.append(
                len(c.output) // 4
            )  # Rough estimate of token count
            self.metadata.prompt_ids.append(prompt_id)

        return completions

    def completion_to_conv(
        self,
        completions: List[Completion],
        prompt: Sample,
    ) -> Sample | None:
        """Convert filtered completions into a multi-conversation Sample.

        For each completion, creates a new `Conversation` by copying the
        prompt messages and appending the completion text as an assistant message.
        Metadata such as ``source``, ``rating``, and ``training_masks_strategy``
        are carried over from the original conversation / completion.

        Args:
            completions: Filtered completions to convert. All must have
                ``metadata["prompt_id"]`` equal to ``prompt.identifier``.
            prompt: The original prompt `Sample` whose first conversation
                provides the base messages.

        Returns:
            A new `Sample` containing one `Conversation` per
            completion, or ``None`` if an error occurs.

        Raises:
            AssertionError: If any completion's prompt ID does not match
                ``prompt.identifier``.
        """
        assert all(c.metadata["prompt_id"] == prompt.identifier for c in completions), (
            "All completions must have the same prompt_id as the prompt sample for conversion."
        )

        conversation = prompt.conversations[0]
        new_conversation_messages = []
        for completion in completions:
            # stop the completion after the last </answer> tag
            assistant_content = completion.output
            assistant_content = assistant_content.split("</answer>")[0] + "</answer>"
            if self.config.boxed:
                # If there is no \boxed{...} in the content, we add it around the answer
                parsed_answer = re.search(
                    r"<answer>(.*?)</answer>", assistant_content, re.DOTALL
                )
                if parsed_answer is not None:
                    if not re.search(r"\\boxed{.*?}", parsed_answer.group(1)):
                        # Get the extracted answer from the metadata
                        if (
                            completion.reward_meta.get("generation_verifier_metadata")
                            is not None
                        ):
                            extracted_answer = completion.reward_meta[
                                "generation_verifier_metadata"
                            ]["all_smi"][0]
                        elif (
                            completion.reward_meta.get("mol_prop_verifier_metadata")
                            is not None
                        ):
                            # For property prediction, we can try to extract the value from the metadata
                            extracted_answer = completion.reward_meta[
                                "mol_prop_verifier_metadata"
                            ]["extracted_value"]
                        else:
                            raise NotImplementedError(
                                "Boxing is only implemented for molecular generation and property prediction tasks."
                            )
                        new_parsed_answer = f"\\boxed{{{extracted_answer}}}"
                        assistant_content = assistant_content.replace(
                            f"<answer>{parsed_answer.group(1)}</answer>",
                            f"<answer>{new_parsed_answer}</answer>",
                        )

            assistant_message = Message(role="assistant", content=assistant_content)
            prompt_message = deepcopy(conversation.messages)
            prompt_message.append(assistant_message)
            new_conversation = Conversation(
                messages=prompt_message,
                source=completion.source,
                rating=completion.reward,
                meta=conversation.meta,
                identifier=conversation.identifier,
                training_masks_strategy=conversation.training_masks_strategy,
            )
            new_conversation_messages.append(new_conversation)
        new_sample = Sample(
            conversations=new_conversation_messages,
            identifier=conversation.identifier,
        )
        return new_sample

    def post_process_sample(self, sample: Sample) -> Sample:
        """Enrich a sample's messages with reward and source information.

        Iterates over every conversation and message in *sample* and applies:

        1. ``config.reward_info_template`` — injects the conversation's reward
           into matching messages (keyed by role).
        2. ``config.source_info_template`` — injects the conversation's source
           into matching messages (keyed by role). Always applied **after** the
           reward template.

        Args:
            sample: The `Sample` to post-process. Modified **in-place**.

        Returns:
            The same `Sample` instance, with message contents updated.
        """
        # Add reward information to the system prompt if a template is provided
        if self.config.reward_info_template:
            for conv in sample.conversations:
                for msg in conv.messages:
                    if msg.role in self.config.reward_info_template:
                        reward = conv.rating if conv.rating is not None else 0.0
                        msg.content = self.config.reward_info_template[msg.role].format(
                            content=msg.content,
                            reward=reward,
                        )
        # Add source information to the system prompt if a template is provided
        if self.config.source_info_template:
            for conv in sample.conversations:
                for msg in conv.messages:
                    if msg.role in self.config.source_info_template:
                        source = conv.source if conv.source is not None else "unknown"
                        msg.content = self.config.source_info_template[msg.role].format(
                            content=msg.content,
                            source=source,
                        )
        return sample

    def get_sample(
        self, completions: List[Completion], prompt: Sample
    ) -> Sample | None:
        """Build a single SFT sample from completions and their originating prompt.

        Convenience method that chains :meth:`filter_completions_single_id`,
        :meth:`completion_to_conv`, and :meth:`post_process_sample`.

        Args:
            completions: Completions for one prompt ID.
            prompt: The corresponding prompt `Sample`.

        Returns:
            A fully processed `Sample` ready for SFT training, or ``None``
            if no completions survive filtering.
        """
        filtered_completions = self.filter_completions_single_id(completions)
        if len(filtered_completions) == 0:
            return None
        sample = self.completion_to_conv(filtered_completions, prompt)
        if sample is not None:
            sample = self.post_process_sample(sample)
        return sample

    def extract(
        self, completions: List[Completion], prompts: List[Sample]
    ) -> List[Sample]:
        """Run the full SFT extraction pipeline.

        Executes the following steps:

        1. **Correctness filtering** — removes invalid completions via
           :meth:`filter_is_correct`.
        2. **Grouping** — groups surviving completions by ``prompt_id``.
        3. **System prompt override** — if ``config.system_prompt_path`` is set,
           replaces or inserts the system message in every prompt conversation.
        4. **Per-prompt extraction** — calls :meth:`get_sample` for each prompt
           that has at least one completion.

        Args:
            completions: All completions across all prompts.
            prompts: The original prompt `Sample` objects. Only prompts
                whose ``identifier`` appears in the completions will produce
                output samples.

        Returns:
            List of `Sample` objects suitable for SFT training. Each
            sample may contain multiple conversations (one per retained
            completion).
        """
        samples = []
        completions = self.filter_is_correct(completions)

        id_to_completions: Dict[str, List[Completion]] = {}
        for c in completions:
            prompt_id = c.metadata["prompt_id"]
            if prompt_id not in id_to_completions:
                id_to_completions[prompt_id] = []
            id_to_completions[prompt_id].append(c)
        for prompt in prompts:
            # Add the system prompt template to the prompt conversations if specified in the config
            if self.config.system_prompt_path is not None:
                with open(self.config.system_prompt_path) as f:
                    system_prompt = json.load(f).copy()
                for conv in prompt.conversations:
                    if conv.messages[0].role == "system":
                        conv.messages[0].content = system_prompt["content"]
                    else:
                        conv.messages.insert(
                            0, Message(role="system", content=system_prompt["content"])
                        )

            prompt_id = prompt.identifier
            if prompt_id not in id_to_completions:
                continue
            sample = self.get_sample(id_to_completions[prompt_id], prompt)
            if sample is not None:
                samples.append(sample)

        return samples

__init__(config)

Initialise the SFT extractor.

Parameters:

Name Type Description Default
config SFTExtractionConfig

Extraction configuration controlling filtering thresholds, diversity parameters, and prompt templates.

required
Source code in mol_gen_docking/evaluation/sft_extraction.py
def __init__(self, config: SFTExtractionConfig) -> None:
    """Initialise the SFT extractor.

    Args:
        config: Extraction configuration controlling filtering thresholds,
            diversity parameters, and prompt templates.
    """
    self.config = config
    self.metadata = SFTExtractionMetadata()

completion_to_conv(completions, prompt)

Convert filtered completions into a multi-conversation Sample.

For each completion, creates a new Conversation by copying the prompt messages and appending the completion text as an assistant message. Metadata such as source, rating, and training_masks_strategy are carried over from the original conversation / completion.

Parameters:

Name Type Description Default
completions List[Completion]

Filtered completions to convert. All must have metadata["prompt_id"] equal to prompt.identifier.

required
prompt Sample

The original prompt Sample whose first conversation provides the base messages.

required

Returns:

Type Description
Sample | None

A new Sample containing one Conversation per

Sample | None

completion, or None if an error occurs.

Raises:

Type Description
AssertionError

If any completion's prompt ID does not match prompt.identifier.

Source code in mol_gen_docking/evaluation/sft_extraction.py
def completion_to_conv(
    self,
    completions: List[Completion],
    prompt: Sample,
) -> Sample | None:
    """Convert filtered completions into a multi-conversation Sample.

    For each completion, creates a new `Conversation` by copying the
    prompt messages and appending the completion text as an assistant message.
    Metadata such as ``source``, ``rating``, and ``training_masks_strategy``
    are carried over from the original conversation / completion.

    Args:
        completions: Filtered completions to convert. All must have
            ``metadata["prompt_id"]`` equal to ``prompt.identifier``.
        prompt: The original prompt `Sample` whose first conversation
            provides the base messages.

    Returns:
        A new `Sample` containing one `Conversation` per
        completion, or ``None`` if an error occurs.

    Raises:
        AssertionError: If any completion's prompt ID does not match
            ``prompt.identifier``.
    """
    assert all(c.metadata["prompt_id"] == prompt.identifier for c in completions), (
        "All completions must have the same prompt_id as the prompt sample for conversion."
    )

    conversation = prompt.conversations[0]
    new_conversation_messages = []
    for completion in completions:
        # stop the completion after the last </answer> tag
        assistant_content = completion.output
        assistant_content = assistant_content.split("</answer>")[0] + "</answer>"
        if self.config.boxed:
            # If there is no \boxed{...} in the content, we add it around the answer
            parsed_answer = re.search(
                r"<answer>(.*?)</answer>", assistant_content, re.DOTALL
            )
            if parsed_answer is not None:
                if not re.search(r"\\boxed{.*?}", parsed_answer.group(1)):
                    # Get the extracted answer from the metadata
                    if (
                        completion.reward_meta.get("generation_verifier_metadata")
                        is not None
                    ):
                        extracted_answer = completion.reward_meta[
                            "generation_verifier_metadata"
                        ]["all_smi"][0]
                    elif (
                        completion.reward_meta.get("mol_prop_verifier_metadata")
                        is not None
                    ):
                        # For property prediction, we can try to extract the value from the metadata
                        extracted_answer = completion.reward_meta[
                            "mol_prop_verifier_metadata"
                        ]["extracted_value"]
                    else:
                        raise NotImplementedError(
                            "Boxing is only implemented for molecular generation and property prediction tasks."
                        )
                    new_parsed_answer = f"\\boxed{{{extracted_answer}}}"
                    assistant_content = assistant_content.replace(
                        f"<answer>{parsed_answer.group(1)}</answer>",
                        f"<answer>{new_parsed_answer}</answer>",
                    )

        assistant_message = Message(role="assistant", content=assistant_content)
        prompt_message = deepcopy(conversation.messages)
        prompt_message.append(assistant_message)
        new_conversation = Conversation(
            messages=prompt_message,
            source=completion.source,
            rating=completion.reward,
            meta=conversation.meta,
            identifier=conversation.identifier,
            training_masks_strategy=conversation.training_masks_strategy,
        )
        new_conversation_messages.append(new_conversation)
    new_sample = Sample(
        conversations=new_conversation_messages,
        identifier=conversation.identifier,
    )
    return new_sample

extract(completions, prompts)

Run the full SFT extraction pipeline.

Executes the following steps:

  1. Correctness filtering — removes invalid completions via :meth:filter_is_correct.
  2. Grouping — groups surviving completions by prompt_id.
  3. System prompt override — if config.system_prompt_path is set, replaces or inserts the system message in every prompt conversation.
  4. Per-prompt extraction — calls :meth:get_sample for each prompt that has at least one completion.

Parameters:

Name Type Description Default
completions List[Completion]

All completions across all prompts.

required
prompts List[Sample]

The original prompt Sample objects. Only prompts whose identifier appears in the completions will produce output samples.

required

Returns:

Type Description
List[Sample]

List of Sample objects suitable for SFT training. Each

List[Sample]

sample may contain multiple conversations (one per retained

List[Sample]

completion).

Source code in mol_gen_docking/evaluation/sft_extraction.py
def extract(
    self, completions: List[Completion], prompts: List[Sample]
) -> List[Sample]:
    """Run the full SFT extraction pipeline.

    Executes the following steps:

    1. **Correctness filtering** — removes invalid completions via
       :meth:`filter_is_correct`.
    2. **Grouping** — groups surviving completions by ``prompt_id``.
    3. **System prompt override** — if ``config.system_prompt_path`` is set,
       replaces or inserts the system message in every prompt conversation.
    4. **Per-prompt extraction** — calls :meth:`get_sample` for each prompt
       that has at least one completion.

    Args:
        completions: All completions across all prompts.
        prompts: The original prompt `Sample` objects. Only prompts
            whose ``identifier`` appears in the completions will produce
            output samples.

    Returns:
        List of `Sample` objects suitable for SFT training. Each
        sample may contain multiple conversations (one per retained
        completion).
    """
    samples = []
    completions = self.filter_is_correct(completions)

    id_to_completions: Dict[str, List[Completion]] = {}
    for c in completions:
        prompt_id = c.metadata["prompt_id"]
        if prompt_id not in id_to_completions:
            id_to_completions[prompt_id] = []
        id_to_completions[prompt_id].append(c)
    for prompt in prompts:
        # Add the system prompt template to the prompt conversations if specified in the config
        if self.config.system_prompt_path is not None:
            with open(self.config.system_prompt_path) as f:
                system_prompt = json.load(f).copy()
            for conv in prompt.conversations:
                if conv.messages[0].role == "system":
                    conv.messages[0].content = system_prompt["content"]
                else:
                    conv.messages.insert(
                        0, Message(role="system", content=system_prompt["content"])
                    )

        prompt_id = prompt.identifier
        if prompt_id not in id_to_completions:
            continue
        sample = self.get_sample(id_to_completions[prompt_id], prompt)
        if sample is not None:
            samples.append(sample)

    return samples

filter_completions_single_id(completions)

Filter and rank completions that share a single prompt ID.

Applies the following filters in order:

  1. Diversity-aware selection — if config.div_threshold is set, uses :func:~mol_gen_docking.evaluation.diversity_aware_top_k.diversity_aware_top_k to remove chemically redundant completions.
  2. Reward threshold — if config.min_reward_threshold is set, discards completions whose reward is below the threshold.

Side-effects: appends per-completion statistics (reward, token count, prompt ID) to self.metadata.

Parameters:

Name Type Description Default
completions List[Completion]

List of completions that must all share the same metadata["prompt_id"].

required

Returns:

Type Description
List[Completion]

The filtered list of completions.

Raises:

Type Description
AssertionError

If completions do not share the same prompt ID, or if diversity filtering is requested but reward metadata is missing.

Source code in mol_gen_docking/evaluation/sft_extraction.py
def filter_completions_single_id(
    self, completions: List[Completion]
) -> List[Completion]:
    """Filter and rank completions that share a single prompt ID.

    Applies the following filters in order:

    1. **Diversity-aware selection** — if ``config.div_threshold`` is set, uses
       :func:`~mol_gen_docking.evaluation.diversity_aware_top_k.diversity_aware_top_k`
       to remove chemically redundant completions.
    2. **Reward threshold** — if ``config.min_reward_threshold`` is set, discards
       completions whose reward is below the threshold.

    Side-effects: appends per-completion statistics (reward, token count,
    prompt ID) to ``self.metadata``.

    Args:
        completions: List of completions that must all share the same
            ``metadata["prompt_id"]``.

    Returns:
        The filtered list of completions.

    Raises:
        AssertionError: If completions do not share the same prompt ID, or if
            diversity filtering is requested but reward metadata is missing.
    """
    prompt_id = completions[0].metadata["prompt_id"]
    assert all(c.metadata["prompt_id"] == prompt_id for c in completions), (
        "All completions must have the same prompt_id for filtering."
    )

    # Filter completions based on diversity-aware top-k selection
    if self.config.div_threshold is not None:
        assert all(
            len(c.reward_meta) > 0 and c.reward is not None for c in completions
        ), (
            "Reward metadata must be available for all completions when min_reward_threshold is set."
        )
        assert all(
            c.reward_meta["generation_verifier_metadata"] is not None
            for c in completions
        ), (
            "Diversity aware sampling can only be performed on molecular generation tasks"
        )

        if len(completions) > 1:
            _, keep_indices = diversity_aware_top_k(
                mols=[
                    comp.reward_meta["generation_verifier_metadata"]["all_smi"][0]
                    for comp in completions
                ],
                scores=[c.reward for c in completions],  # type: ignore
                k=len(completions),
                t=self.config.div_threshold,
                fingerprint_name=self.config.fingerprint_name,
                return_idxs=True,
            )
            completions = [completions[i] for i in keep_indices]

    # Filter completions based on minimum reward threshold
    if self.config.min_reward_threshold is not None:
        assert all(c.reward is not None for c in completions)
        completions = [
            c
            for c in completions
            if c.reward is not None and c.reward >= self.config.min_reward_threshold
        ]

    # Update metadata
    for c in completions:
        if c.reward is not None:
            self.metadata.rewards.append(c.reward)
        else:
            self.metadata.rewards.append(0.0)
        self.metadata.n_tokens.append(
            len(c.output) // 4
        )  # Rough estimate of token count
        self.metadata.prompt_ids.append(prompt_id)

    return completions

filter_is_correct(completions)

Filter completions to keep only those that are semantically correct.

Automatically detects the task type from the verifier metadata and applies the appropriate correctness check:

  • Molecular generation (generation_verifier_metadata): keeps completions with exactly one extracted SMILES.
  • Property prediction (mol_prop_verifier_metadata): keeps completions where value extraction succeeded.
  • Retro-synthesis (reaction_verifier_metadata): keeps completions whose validity score is strictly positive.

If any completion has empty reward_meta, all completions are returned unfiltered (graceful fallback).

Parameters:

Name Type Description Default
completions List[Completion]

List of completions to filter.

required

Returns:

Type Description
List[Completion]

A new list containing only the completions that passed the

List[Completion]

task-specific correctness check.

Raises:

Type Description
ValueError

If none of the known verifier metadata keys are found in the completions.

Source code in mol_gen_docking/evaluation/sft_extraction.py
def filter_is_correct(self, completions: List[Completion]) -> List[Completion]:
    """Filter completions to keep only those that are semantically correct.

    Automatically detects the task type from the verifier metadata and applies
    the appropriate correctness check:

    - **Molecular generation** (``generation_verifier_metadata``): keeps
      completions with exactly one extracted SMILES.
    - **Property prediction** (``mol_prop_verifier_metadata``): keeps
      completions where value extraction succeeded.
    - **Retro-synthesis** (``reaction_verifier_metadata``): keeps completions
      whose validity score is strictly positive.

    If any completion has empty ``reward_meta``, all completions are returned
    unfiltered (graceful fallback).

    Args:
        completions: List of completions to filter.

    Returns:
        A new list containing only the completions that passed the
        task-specific correctness check.

    Raises:
        ValueError: If none of the known verifier metadata keys are found
            in the completions.
    """
    if any(len(c.reward_meta) == 0 for c in completions):
        return completions
    # First if molecular generation:
    if all(
        c.reward_meta.get("generation_verifier_metadata", None) is not None
        for c in completions
    ):
        return [
            c
            for c in completions
            if len(c.reward_meta["generation_verifier_metadata"]["all_smi"]) == 1
        ]
    # Molecular property prediction
    if all(
        c.reward_meta.get("mol_prop_verifier_metadata", None) is not None
        for c in completions
    ):
        return [
            c
            for c in completions
            if c.reward_meta["mol_prop_verifier_metadata"]["extraction_success"]
        ]
    if all(
        c.reward_meta.get("reaction_verifier_metadata", None) is not None
        for c in completions
    ):
        return [
            c
            for c in completions
            if c.reward_meta["reaction_verifier_metadata"]["valid"] > 0.0
        ]
    raise ValueError("Unknown verifier metadata format for filtering completions.")

get_sample(completions, prompt)

Build a single SFT sample from completions and their originating prompt.

Convenience method that chains :meth:filter_completions_single_id, :meth:completion_to_conv, and :meth:post_process_sample.

Parameters:

Name Type Description Default
completions List[Completion]

Completions for one prompt ID.

required
prompt Sample

The corresponding prompt Sample.

required

Returns:

Type Description
Sample | None

A fully processed Sample ready for SFT training, or None

Sample | None

if no completions survive filtering.

Source code in mol_gen_docking/evaluation/sft_extraction.py
def get_sample(
    self, completions: List[Completion], prompt: Sample
) -> Sample | None:
    """Build a single SFT sample from completions and their originating prompt.

    Convenience method that chains :meth:`filter_completions_single_id`,
    :meth:`completion_to_conv`, and :meth:`post_process_sample`.

    Args:
        completions: Completions for one prompt ID.
        prompt: The corresponding prompt `Sample`.

    Returns:
        A fully processed `Sample` ready for SFT training, or ``None``
        if no completions survive filtering.
    """
    filtered_completions = self.filter_completions_single_id(completions)
    if len(filtered_completions) == 0:
        return None
    sample = self.completion_to_conv(filtered_completions, prompt)
    if sample is not None:
        sample = self.post_process_sample(sample)
    return sample

post_process_sample(sample)

Enrich a sample's messages with reward and source information.

Iterates over every conversation and message in sample and applies:

  1. config.reward_info_template — injects the conversation's reward into matching messages (keyed by role).
  2. config.source_info_template — injects the conversation's source into matching messages (keyed by role). Always applied after the reward template.

Parameters:

Name Type Description Default
sample Sample

The Sample to post-process. Modified in-place.

required

Returns:

Type Description
Sample

The same Sample instance, with message contents updated.

Source code in mol_gen_docking/evaluation/sft_extraction.py
def post_process_sample(self, sample: Sample) -> Sample:
    """Enrich a sample's messages with reward and source information.

    Iterates over every conversation and message in *sample* and applies:

    1. ``config.reward_info_template`` — injects the conversation's reward
       into matching messages (keyed by role).
    2. ``config.source_info_template`` — injects the conversation's source
       into matching messages (keyed by role). Always applied **after** the
       reward template.

    Args:
        sample: The `Sample` to post-process. Modified **in-place**.

    Returns:
        The same `Sample` instance, with message contents updated.
    """
    # Add reward information to the system prompt if a template is provided
    if self.config.reward_info_template:
        for conv in sample.conversations:
            for msg in conv.messages:
                if msg.role in self.config.reward_info_template:
                    reward = conv.rating if conv.rating is not None else 0.0
                    msg.content = self.config.reward_info_template[msg.role].format(
                        content=msg.content,
                        reward=reward,
                    )
    # Add source information to the system prompt if a template is provided
    if self.config.source_info_template:
        for conv in sample.conversations:
            for msg in conv.messages:
                if msg.role in self.config.source_info_template:
                    source = conv.source if conv.source is not None else "unknown"
                    msg.content = self.config.source_info_template[msg.role].format(
                        content=msg.content,
                        source=source,
                    )
    return sample