Skip to content

Reaction Verifier

The ReactionVerifier computes rewards for chemical reaction and retro-synthesis tasks, validating synthesis paths, SMARTS predictions, and reaction product verification.

Overview

The Reaction Verifier supports various reaction-related tasks:

  • Retro-synthesis Planning: Validate multi-step synthesis routes
  • SMARTS Prediction: Evaluate predicted reaction SMARTS patterns
  • Product/Reactant Prediction: Compare predicted molecules to ground truth
  • Analog Generation: Generate molecular analogs via synthesis
Supported Task Types

The following task types are supported by the Reaction Verifier:

Task Type Description
final_product Predict the final product of a reaction
reactant Predict a single reactant
all_reactants Predict all reactants for a reaction
smarts Predict the SMARTS pattern for a reaction
full_path Provide a complete retro-synthesis path
full_path_bb_ref Synthesis path with building block constraints
full_path_smarts_ref Synthesis path with SMARTS constraints
analog_gen Generate molecular analogs

ReactionVerifierConfigModel

Bases: BaseModel

Pydantic model for molecular verifier configuration.

This model defines the configuration parameters for the MolecularVerifier class, providing validation and documentation for all configuration options.

Attributes:

Name Type Description
path_to_mappings

Optional path to property mappings and docking targets configuration directory.

rescale

Whether to rescale the rewards to a normalized range.

reaction_matrix_path str

Path to the reaction matrix pickle file used for reaction verification.

oracle_kwargs str

Dictionary of keyword arguments to pass to the docking oracle. Can include: - exhaustiveness: Docking exhaustiveness parameter - n_cpu: Number of CPUs for docking - docking_oracle: Type of docking oracle ("pyscreener" or "autodock_gpu") - vina_mode: Command mode for AutoDock GPU

docking_concurrency_per_gpu str

Number of concurrent docking runs to allow per GPU. Default is 2 (uses ~1GB per run on 80GB GPU).

Source code in mol_gen_docking/reward/verifiers/reaction_reward/reaction_verifier_pydantic_model.py
class ReactionVerifierConfigModel(BaseModel):
    """Pydantic model for molecular verifier configuration.

    This model defines the configuration parameters for the MolecularVerifier class,
    providing validation and documentation for all configuration options.

    Attributes:
        path_to_mappings: Optional path to property mappings and docking targets configuration directory.
        rescale: Whether to rescale the rewards to a normalized range.
        reaction_matrix_path: Path to the reaction matrix pickle file used for reaction verification.
        oracle_kwargs: Dictionary of keyword arguments to pass to the docking oracle. Can include:
                       - exhaustiveness: Docking exhaustiveness parameter
                       - n_cpu: Number of CPUs for docking
                       - docking_oracle: Type of docking oracle ("pyscreener" or "autodock_gpu")
                       - vina_mode: Command mode for AutoDock GPU
        docking_concurrency_per_gpu: Number of concurrent docking runs to allow per GPU.
                                     Default is 2 (uses ~1GB per run on 80GB GPU).
    """

    parsing_method: Literal["none", "answer_tags", "boxed"] = Field(
        default="answer_tags",
        description="Method to parse model completions for SMILES or property values.",
    )

    reward: Literal["property", "valid_smiles"] = Field(
        default="property",
        description='Reward type: "property" for property-based or "valid_smiles" for validity-based rewards',
    )

    reaction_matrix_path: str = Field(
        default="data/rxn_matrix.pkl",
        description="Path to the reaction matrix pickle file for reaction verification",
    )

    reaction_reward_type: Literal["binary", "tanimoto"] = Field(
        default="tanimoto",
        description="For retro-synthesis, assign reward based on the exact match (binary) or Tanimoto similarity of the last product",
    )

    class Config:
        """Pydantic configuration."""

        arbitrary_types_allowed = True
        json_schema_extra = {
            "example": {
                "reward": "property",
                "reaction_matrix_path": "data/rxn_matrix.pkl",
                "reaction_reward_type": "tanimoto",
            }
        }

    @model_validator(mode="after")
    def check_reaction_matrix_path(self) -> "ReactionVerifierConfigModel":
        """Validate that the reaction matrix path exists."""
        if not os.path.exists(self.reaction_matrix_path):
            raise ValueError(
                f"Reaction matrix path {self.reaction_matrix_path} does not exist."
            )
        return self

Config

Pydantic configuration.

Source code in mol_gen_docking/reward/verifiers/reaction_reward/reaction_verifier_pydantic_model.py
class Config:
    """Pydantic configuration."""

    arbitrary_types_allowed = True
    json_schema_extra = {
        "example": {
            "reward": "property",
            "reaction_matrix_path": "data/rxn_matrix.pkl",
            "reaction_reward_type": "tanimoto",
        }
    }

check_reaction_matrix_path()

Validate that the reaction matrix path exists.

Source code in mol_gen_docking/reward/verifiers/reaction_reward/reaction_verifier_pydantic_model.py
@model_validator(mode="after")
def check_reaction_matrix_path(self) -> "ReactionVerifierConfigModel":
    """Validate that the reaction matrix path exists."""
    if not os.path.exists(self.reaction_matrix_path):
        raise ValueError(
            f"Reaction matrix path {self.reaction_matrix_path} does not exist."
        )
    return self

ReactionVerifierInputMetadataModel

Bases: BaseModel

Input metadata model for reaction verifier.

Defines the verification criteria for chemical reaction and retro-synthesis tasks, including objectives, target molecules, reactants, products, and validation constraints.

Attributes:

Name Type Description
objectives List[ReactionObjT]

List of objective types for the reaction verification. Valid values:

  • "final_product": Verify only the final product matches the target
  • "reactant": Verify a single reactant is valid
  • "all_reactants": Verify all reactants are chemically valid
  • "all_reactants_bb_ref": Verify all reactants are in the building blocks list
  • "smarts": Verify reaction follows the given SMARTS pattern
  • "full_path": Verify complete synthesis path to target molecule
  • "full_path_bb_ref": Verify synthesis path with building block constraints
  • "full_path_smarts_ref": Verify synthesis path matches SMARTS patterns
  • "full_path_smarts_bb_ref": Verify synthesis path with both SMARTS and building block constraints
  • "analog_gen": Verify analog generation task
target List[str]

List of target molecules (SMILES strings) for verification. For synthesis tasks: The desired final product molecule For SMARTS tasks: The expected product after reaction Empty list if not applicable to the objective type.

reactants List[List[str]]

List of reactant lists for each reaction step. Each inner list contains SMILES strings for reactants in one reaction. For ground truth verification in SMARTS prediction tasks. Empty list if not applicable to the objective type.

products List[str]

List of product molecules (SMILES strings) for each reaction step. For ground truth verification in multi-step synthesis. Empty list if not applicable to the objective type.

building_blocks List[str] | None

List of valid building block SMILES strings. Optional constraint for synthesis tasks requiring specific building blocks. None if no building block constraints apply.

smarts List[str]

Reference SMARTS strings for the reaction steps. Used to verify that reactions follow specific reaction templates. Empty list if not applicable to the objective type.

or_smarts List[str]

Original reference SMARTS strings for the reaction steps. Alternative SMARTS patterns that can also be valid. Empty list if not applicable to the objective type.

n_steps_max int

Maximum number of reaction steps allowed in the synthesis route. Default is 5. Only applies to full_path objectives.

idx_chosen int

Index of the chosen reaction for multi-reaction tasks. Default is 0. Used for tracking in batch processing.

Source code in mol_gen_docking/reward/verifiers/reaction_reward/input_metadata.py
class ReactionVerifierInputMetadataModel(BaseModel):
    """Input metadata model for reaction verifier.

    Defines the verification criteria for chemical reaction and retro-synthesis tasks,
    including objectives, target molecules, reactants, products, and validation constraints.

    Attributes:
        objectives: List of objective types for the reaction verification.
            Valid values:

            - "final_product": Verify only the final product matches the target
            - "reactant": Verify a single reactant is valid
            - "all_reactants": Verify all reactants are chemically valid
            - "all_reactants_bb_ref": Verify all reactants are in the building blocks list
            - "smarts": Verify reaction follows the given SMARTS pattern
            - "full_path": Verify complete synthesis path to target molecule
            - "full_path_bb_ref": Verify synthesis path with building block constraints
            - "full_path_smarts_ref": Verify synthesis path matches SMARTS patterns
            - "full_path_smarts_bb_ref": Verify synthesis path with both SMARTS and building block constraints
            - "analog_gen": Verify analog generation task

        target: List of target molecules (SMILES strings) for verification.
            For synthesis tasks: The desired final product molecule
            For SMARTS tasks: The expected product after reaction
            Empty list if not applicable to the objective type.

        reactants: List of reactant lists for each reaction step.
            Each inner list contains SMILES strings for reactants in one reaction.
            For ground truth verification in SMARTS prediction tasks.
            Empty list if not applicable to the objective type.

        products: List of product molecules (SMILES strings) for each reaction step.
            For ground truth verification in multi-step synthesis.
            Empty list if not applicable to the objective type.

        building_blocks: List of valid building block SMILES strings.
            Optional constraint for synthesis tasks requiring specific building blocks.
            None if no building block constraints apply.

        smarts: Reference SMARTS strings for the reaction steps.
            Used to verify that reactions follow specific reaction templates.
            Empty list if not applicable to the objective type.

        or_smarts: Original reference SMARTS strings for the reaction steps.
            Alternative SMARTS patterns that can also be valid.
            Empty list if not applicable to the objective type.

        n_steps_max: Maximum number of reaction steps allowed in the synthesis route.
            Default is 5. Only applies to full_path objectives.

        idx_chosen: Index of the chosen reaction for multi-reaction tasks.
            Default is 0. Used for tracking in batch processing.
    """

    objectives: List[ReactionObjT] = Field(
        ...,
        description="The type of objective for the reaction verification.",
    )
    target: List[str] = Field(
        default_factory=list,
        description="The target molecule or SMARTS string for verification.",
    )
    reactants: List[List[str]] = Field(
        default_factory=list,
        description="List of reactants in a reaction.",
    )
    intermediate_products: List[str] = Field(
        default_factory=list,
        description="The intermediate product molecules of the reaction steps.",
    )
    products: List[str] = Field(
        default_factory=list,
        description="The product molecule of the reaction.",
    )
    building_blocks: List[str] | None = Field(
        None,
        description="List of valid building blocks for the reaction.",
    )
    smarts: List[str] = Field(
        default_factory=list,
        description="Reference SMARTS strings for the reaction steps.",
    )
    or_smarts: List[str] = Field(
        default_factory=list,
        description="Original Reference SMARTS strings for the reaction steps.",
    )
    n_steps_max: int = Field(
        default=5,
        gt=0,
        description="Maximum number of reaction steps allowed in the synthesis route.",
    )
    idx_chosen: int = Field(
        0,
        description="Index of the chosen reaction.",
    )

ReactionVerifierOutputModel

Bases: VerifierOutputModel

Output model for reaction verifier results.

Attributes:

Name Type Description
reward float

The computed reward for the reaction verification.

parsed_answer str

The parsed answer extracted from the model completion.

verifier_metadata ReactionVerifierMetadataModel

Metadata related to the reaction verification process.

Source code in mol_gen_docking/reward/verifiers/reaction_reward/reaction_verifier_pydantic_model.py
class ReactionVerifierOutputModel(VerifierOutputModel):
    """Output model for reaction verifier results.

    Attributes:
        reward: The computed reward for the reaction verification.
        parsed_answer: The parsed answer extracted from the model completion.
        verifier_metadata: Metadata related to the reaction verification process.
    """

    reward: float = Field(
        ...,
        description="The computed reward for the reaction verification.",
    )
    parsed_answer: str = Field(
        ..., description="The parsed answer extracted from the model completion."
    )
    verifier_metadata: ReactionVerifierMetadataModel = Field(
        ...,
        description="Metadata related to the reaction verification process.",
    )

ReactionVerifierMetadataModel

Bases: BaseModel

Metadata model for reaction verifier results.

Contains detailed information about the reaction verification process, including validity, product correctness, and reactant validation.

Attributes:

Name Type Description
valid float

Proportion of valid reaction steps (0.0 to 1.0). For single reaction tasks: 1.0 if valid, 0.0 if invalid. For synthesis route tasks (full_path): proportion of reaction steps that are chemically valid.

correct_product float

Whether the product is correct or similarity to the target molecule. For SMARTS prediction tasks: 1.0 if reaction produces the correct product, 0.0 otherwise. For synthesis tasks with tanimoto similarity: Tanimoto similarity score (0.0 to 1.0) between the final product and the target molecule. For exact match tasks: 1.0 if exact match, 0.0 otherwise.

correct_reactant bool

Whether all reactants are correct. For building block constrained tasks: True if all reactants are in the allowed building blocks list. For unconstrained tasks: True if all reactants are chemically valid. False if any reactant is invalid or not allowed.

Source code in mol_gen_docking/reward/verifiers/reaction_reward/reaction_verifier_pydantic_model.py
class ReactionVerifierMetadataModel(BaseModel):
    """Metadata model for reaction verifier results.

    Contains detailed information about the reaction verification process,
    including validity, product correctness, and reactant validation.

    Attributes:
        valid: Proportion of valid reaction steps (0.0 to 1.0).
            For single reaction tasks: 1.0 if valid, 0.0 if invalid.
            For synthesis route tasks (full_path): proportion of reaction steps that are chemically valid.

        correct_product: Whether the product is correct or similarity to the target molecule.
            For SMARTS prediction tasks: 1.0 if reaction produces the correct product, 0.0 otherwise.
            For synthesis tasks with tanimoto similarity: Tanimoto similarity score (0.0 to 1.0)
            between the final product and the target molecule.
            For exact match tasks: 1.0 if exact match, 0.0 otherwise.

        correct_reactant: Whether all reactants are correct.
            For building block constrained tasks: True if all reactants are in the allowed building blocks list.
            For unconstrained tasks: True if all reactants are chemically valid.
            False if any reactant is invalid or not allowed.
    """

    valid: float = Field(
        default=0.0,
        description="Is the answer valid. If the task is to propose a synthesis route, this is the proportion of valid reaction steps (0.0 to 1.0).",
    )
    correct_product: float = Field(
        default=0.0,
        description="Whether the product is correct. For synthesis tasks, if we use tanimoto similarity, similarity to the target molecule, for SMARTS prediction, do both of the chemical reactions lead to the correct product.",
    )
    correct_reactant: bool = Field(
        default=False,
        description="Whether all reactants are correct.",
    )

Reaction verifier for chemical reaction and retro-synthesis tasks.

This module provides the ReactionVerifier class which computes rewards for chemical reaction tasks including retro-synthesis planning, SMARTS prediction, and reaction product verification.

ReactionVerifier

Bases: Verifier

Verifier for chemical reaction and retro-synthesis tasks.

This verifier computes rewards for various reaction-related tasks including: - Final product prediction - Reactant identification - SMARTS pattern prediction - Full retro-synthesis path validation

The verifier uses a reaction matrix to validate synthesis steps and supports both binary and Tanimoto-based reward computation.

Attributes:

Name Type Description
verifier_config ReactionVerifierConfigModel

Configuration for the reaction verifier.

rxn_matrix ReactantReactionMatrix

Pre-loaded reaction matrix for validation.

check_ground_truth_tasks

List of task types requiring ground truth comparison.

run_validation_tasks

List of task types requiring path validation.

logger

Logger instance for the verifier.

Source code in mol_gen_docking/reward/verifiers/reaction_reward/reaction_verifier.py
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
class ReactionVerifier(Verifier):
    """Verifier for chemical reaction and retro-synthesis tasks.

    This verifier computes rewards for various reaction-related tasks including:
    - Final product prediction
    - Reactant identification
    - SMARTS pattern prediction
    - Full retro-synthesis path validation

    The verifier uses a reaction matrix to validate synthesis steps and supports
    both binary and Tanimoto-based reward computation.

    Attributes:
        verifier_config: Configuration for the reaction verifier.
        rxn_matrix: Pre-loaded reaction matrix for validation.
        check_ground_truth_tasks: List of task types requiring ground truth comparison.
        run_validation_tasks: List of task types requiring path validation.
        logger: Logger instance for the verifier.
    """

    def __init__(
        self,
        verifier_config: ReactionVerifierConfigModel,
    ):
        """Initialize the ReactionVerifier.

        Args:
            verifier_config: Configuration containing reaction matrix path
                and reward type settings.
        """
        super().__init__(verifier_config)
        self.verifier_config: ReactionVerifierConfigModel = verifier_config

        self.rxn_matrix: ReactantReactionMatrix
        with open(verifier_config.reaction_matrix_path, "rb") as f:
            self.rxn_matrix = pickle.load(f)

        self.reactants_csi: List[str] = [m.csmiles for m in self.rxn_matrix.reactants]
        self.check_ground_truth_tasks = [
            "final_product",
            "reactant",
            "all_reactants",
            "all_reactants_bb_ref",
        ]
        self.run_validation_tasks = [
            "full_path",
            "full_path_bb_ref",
            "full_path_smarts_ref",
            "full_path_smarts_bb_ref",
            "full_path_intermediates_gt_reactants",
            "full_path_intermediates",  # We treat this task as a normal path validation
        ]
        self.logger = logging.getLogger("ReactionVerifier")

    def r_ground_truth_mols(
        self,
        mol_y: List[Molecule],
        mol_label: List[Molecule],
        reactants: List[str],
        product: str,
        smarts: str,
        objective: str,
    ) -> float:
        """Compute reward for molecule prediction against ground truth."""
        self.logger.info(
            f"Computed molecules: {[mol.smiles for mol in mol_y]} vs labels: {[mol.smiles for mol in mol_label]}"
        )
        smi_y = {mol.csmiles for mol in mol_y}
        smi_y_true = {mol.csmiles for mol in mol_label}
        if smi_y == smi_y_true:
            return 1.0

        else:
            # Check if the reaction still works
            rxn = Reaction(smarts)
            if objective == "final_product":
                if len(mol_y) != 1:
                    return 0.0
                react_mols = [Molecule(smi) for smi in reactants]
                possible_products = []
                for r_ in itertools.permutations(react_mols):
                    possible_products.extend(rxn(list(r_)))
                # Change to csmiles to avoid unexpected behavior with python in
                possible_products_csi = {m.csmiles for m in possible_products}
                if mol_y[0].csmiles in possible_products_csi:  # Can be only one product
                    return 1.0
                else:
                    return 0.0
            elif objective == "reactant":
                react_mols = [Molecule(smi) for smi in reactants]
                react_mols = [m for m in react_mols if not m == mol_label[0]]
                react_mols += mol_y
            elif objective in ["all_reactants", "all_reactants_bb_ref"]:
                react_mols = mol_y
            # Check that the number of reactants matches
            if len(react_mols) != rxn.num_reactants:
                return 0.0
            # Check if all reactants are in the building blocks
            if any(m.csmiles not in self.reactants_csi for m in react_mols):
                return 0.0
            possible_products = []
            for r_ in itertools.permutations(react_mols):
                possible_products.extend(rxn(list(r_)))
            # Change to csmiles to avoid unexpected behavior with python in
            possible_products_csi = {m.csmiles for m in possible_products}
            prod_mol = Molecule(product)
            if prod_mol.csmiles in possible_products_csi:
                return 1.0
            return 0.0

    def ground_truth_reward_mol(
        self,
        answer: Dict[str, Any],
        labels: List[str],
        reactants: List[str],
        product: str,
        smarts: str,
        objective: str,
    ) -> float:
        """Compute reward for molecule prediction tasks.

        Notes
        The answer must be contained in a JSON object with an "answer" key, which can be either a single SMILES string or a list of SMILES strings.

        Args:
            answer: Model completion containing the answer.
            labels: List of ground truth SMILES strings.
            reactants: List of reactant SMILES strings.
            product: Expected product SMILES string.
            objective: The type of objective for the reaction verification.

        Returns:
            Reward value between 0.0 and 1.0.
        """
        if answer == {}:
            return 0.0
        mol_label = [Molecule(smi) for smi in labels]
        if not all([m.is_valid for m in mol_label]):
            self.logger.error("Invalid ground truth molecule")
            return 0.0
        smiles = answer["answer"]
        if isinstance(smiles, str):
            smiles_list = [smiles]
        elif isinstance(smiles, list):
            smiles_list = smiles
        else:
            return 0.0
        mols = [Molecule(smi) for smi in smiles_list]

        if any([not m.is_valid for m in mols]):
            self.logger.info("Invalid molecule found in prediction")
            return 0.0

        return self.r_ground_truth_mols(
            mols,
            mol_label,
            reactants=reactants,
            product=product,
            smarts=smarts,
            objective=objective,
        )

    def reward_smarts(
        self,
        answer: Dict[str, Any],
        labels: List[str],
        reactants: List[str],
        product: str,
    ) -> Tuple[float, Dict[str, Any]]:
        """Compute reward for SMARTS prediction tasks.

        Notes
        The answer must be contained in a JSON object with an "answer" key,
        which should be a SMARTS string representing the reaction. The reward is computed based
        on whether the proposed SMARTS can produce the expected product from the given reactants.
        A reward of 1.0 is given for an exact match with the ground truth SMARTS, 0.1 if the SMARTS
        is valid and produces the correct product, and 0.0 otherwise.

        Args:
            answer: Model completion containing the SMARTS answer.
            labels: List containing the ground truth SMARTS string.
            reactants: List of reactant SMILES strings.
            product: Expected product SMILES string.

        Returns:
            Tuple of (reward, metadata_dict) where metadata contains
            'Reactants_contained' and 'Products_contained' flags.
        """
        if answer == {}:
            return 0.0, {"Reactants_contained": False, "Products_contained": False}
        gt_smarts = labels[0]
        smarts_pred = answer["answer"]
        if not isinstance(smarts_pred, str):
            return 0.0, {"Reactants_contained": False, "Products_contained": False}

        if smarts_pred.strip() == gt_smarts:
            return 1.0, {"Reactants_contained": True, "Products_contained": True}
        self.logger.info(
            f"Proposed SMARTS: {smarts_pred.strip()} | GT SMARTS: {gt_smarts}, checking reaction..."
        )
        try:
            rxnB = Reaction(smarts_pred.strip())
            if rxnB.num_reactants != len(reactants):
                return 0.0, {"Reactants_contained": False, "Products_contained": False}
            p = rxnB([Molecule(r) for r in reactants])
            reward = 0.0
            if Molecule(product).csmiles in {prod.csmiles for prod in p}:
                reward = 0.1
            return reward, {
                "Reactants_contained": True,
                "Products_contained": reward == 0.1,
            }
        except Exception as e:
            self.logger.info(
                f"Error in reaction SMARTS parsing: {e} (proposed: {smarts_pred} | gt: {gt_smarts})"
            )
            return 0.0, {"Reactants_contained": False, "Products_contained": False}

    def _find_reaction_smarts(
        self,
        reactants_step: List[Molecule],
        products_step: List[Molecule],
        allowed_smarts: ReactionContainer,
    ) -> List[Reaction]:
        """Find valid reaction SMARTS that can produce products from reactants.

        Args:
            reactants_step: List of reactant molecules for this step.
            products_step: List of expected product molecules.
            allowed_smarts: Container of allowed reaction SMARTS patterns.

        Returns:
            List of Reaction objects that successfully produce the expected products.
        """
        found_reactions: List[Reaction] = []
        id_poss_smarts: List[Dict[int, tuple[int, ...]]] = []
        for r in reactants_step:
            id_poss_smarts.append(allowed_smarts.match_reactions(r))

        for id_reaction in id_poss_smarts[0]:
            # Check if reaction can take the correct number of reactants
            if allowed_smarts[id_reaction].num_reactants != len(reactants_step):
                continue
            if any(
                id_reaction not in id_poss_smarts[i]
                for i in range(1, len(reactants_step))
            ):
                continue

            # Generate all permutations of reactants and test the reaction
            possible_products = []
            for reactants_ord in itertools.permutations(reactants_step):
                possible_products.extend(
                    allowed_smarts[id_reaction](list(reactants_ord))
                )
            possible_products_csi = {m.csmiles for m in possible_products}
            all_found: bool = all(
                p.csmiles in possible_products_csi for p in products_step
            )
            if all_found:
                found_reactions.append(allowed_smarts[id_reaction])
        return found_reactions

    def _check_valid_step(
        self,
        reactants_step: List[Molecule],
        products_step: List[Molecule],
        possible_reactants: List[str],
        allowed_smarts: ReactionContainer,
    ) -> Tuple[bool, str]:
        """Check if a synthesis step is valid.
        Used in the reward_run_path method to validate each step of the proposed synthesis path.

        Validates that:
        1. All reactants and products are valid molecules
        2. All reactants are in building blocks or previous products
        3. At least one reaction can produce the products from the reactants

        Args:
            reactants_step: List of reactant molecules for this step.
            products_step: List of expected product molecules.
            possible_reactants: List of valid starting materials (building blocks + previous products).
            allowed_smarts: Container of allowed reaction SMARTS patterns.

        Returns:
            Tuple of (is_valid, fail_reason) where fail_reason is empty string if valid,
            or one of "reactants", "products", "reaction" indicating what failed.
        """
        # 1. Check that all reactants and products are valid molecules
        if not all([r.is_valid for r in reactants_step]):
            self.logger.info(
                "Reactants not valid in {}".format([r.smiles for r in reactants_step])
            )
            return False, "reactants"
        if not all([p.is_valid for p in products_step]):
            self.logger.info(
                "Products not valid in {}".format([p.smiles for p in products_step])
            )
            return False, "products"
        # 2. Check that all reactants are in building blocks or previous products
        for r in reactants_step:
            if r.csmiles not in possible_reactants:
                self.logger.info(
                    "Reactant {} not in building blocks or previous products".format(
                        r.smiles
                    )
                )
                return False, "reactants"
        if products_step == []:
            self.logger.info("No products in step")
            return False, "products"

        # 3. Check that there is at least one reaction that can produce the products from the reactants
        found_reactions = self._find_reaction_smarts(
            reactants_step, products_step, allowed_smarts
        )
        if len(found_reactions) == 0:
            self.logger.info(
                "No reaction found for step: {} -> {}".format(
                    [r.smiles for r in reactants_step],
                    [p.smiles for p in products_step],
                )
            )
            return False, "reaction"
        # Log success
        self.logger.info(
            "Found valid reaction for step: {} -> {}".format(
                [r.smiles for r in reactants_step],
                [p.smiles for p in products_step],
            )
        )
        return True, ""

    def reward_run_path(
        self,
        answer: Dict[str, Any],
        label: str,
        building_blocks: List[str],
        smarts: List[str],
        n_steps_max: int,
        reward_type: Literal["binary", "tanimoto"] = "binary",
    ) -> Tuple[float, Dict[str, Any]]:
        """Compute reward for retro-synthesis path validation.

        Validates a multi-step synthesis path by checking:
        1. All reactants are valid building blocks or previous products
        2. Each reaction step has a valid SMARTS pattern
        3. The final product matches the target (exactly or by Tanimoto similarity)

        Notes
        The answer must be contained in a JSON object with an "answer" key,
        which should contain a list of steps: Dictionaries containing the keys:

            - reactants: List of reactant SMILES strings for this step
            - product: List of product SMILES strings for this step

        Args:
            answer: Model completion containing the synthesis path.
            label: Target product SMILES string.
            building_blocks: List of valid starting building block SMILES.
            smarts: List of allowed SMARTS patterns (empty = use reaction matrix).
            n_steps_max: Maximum allowed number of synthesis steps.
            reward_type: "binary" for exact match or "tanimoto" for similarity-based.

        Returns:
            Tuple of (reward, metadata_dict) containing validation results.
        """
        if answer == {}:
            self.logger.info("No synthesis path found in completion")
            return 0.0, {
                "valid": 0.0,
                "correct_product": 0.0,
                "correct_reactant": False,
            }
        steps: List[Dict[str, Any]] = answer["answer"]
        if not isinstance(steps, list) or len(steps) == 0 or len(steps) > n_steps_max:
            self.logger.info("Synthesis path answer is not a list or is too long/short")
            return 0.0, {
                "valid": 0.0,
                "correct_product": 0.0,
                "correct_reactant": False,
            }
        # Validate each step structure
        for step in steps:
            if "reactants" not in step or "product" not in step:
                self.logger.info("Synthesis step missing reactants or product")
                return 0.0, {
                    "valid": 0.0,
                    "correct_product": 0.0,
                    "correct_reactant": False,
                }
            if step["reactants"] == [] or step["product"] == []:
                self.logger.info("Synthesis step has empty reactants or product")
                return 0.0, {
                    "valid": 0.0,
                    "correct_product": 0.0,
                    "correct_reactant": False,
                }
            if not isinstance(step["reactants"], list) or not all(
                isinstance(r, str) for r in step["reactants"]
            ):
                self.logger.info("One or more reactants in step are not strings")
                return 0.0, {
                    "valid": 0.0,
                    "correct_product": 0.0,
                    "correct_reactant": False,
                }
            if not isinstance(step["product"], list) or not all(
                isinstance(p, str) for p in step["product"]
            ):
                self.logger.info("One or more products in step are not strings")
                return 0.0, {
                    "valid": 0.0,
                    "correct_product": 0.0,
                    "correct_reactant": False,
                }

        reactants = [
            [Molecule(smi.strip()) for smi in step["reactants"]] for step in steps
        ]
        products = [
            [Molecule(smi.strip()) for smi in step["product"]] for step in steps
        ]
        n_steps = len(reactants)
        label_mol = Molecule(label)
        reward_mult: List[float] = [1.0 for _ in products]
        if reward_type == "binary" and not any(
            [label_mol == last_p for last_p in products[-1]]
        ):
            self.logger.info("Product not found")
            return 0.0, {
                "valid": 0.0,
                "correct_last_product": 0.0,
                "correct_reactant": False,
            }
        elif reward_type == "tanimoto":
            # Compute the tanimoto similarity between the label and products at each step
            for i, product in enumerate(products):
                all_sims = label_mol.tanimoto_similarity(product)
                reward_mult[i] = max(all_sims) ** 3

        reactions: ReactionContainer
        if smarts == []:
            reactions = self.rxn_matrix.reactions
        else:
            reactions = ReactionContainer([Reaction(sma) for sma in smarts])

        building_blocks_csi = (
            self.reactants_csi
        )  # For the moment, the building blocks passed in the metadata
        # are only used to  help the model but not proper constraints.

        n_valid = 0
        fail_reason = ""
        for i_reac, (reactant, product) in enumerate(zip(reactants, products)):
            is_valid, fail_reason = self._check_valid_step(
                reactant,
                product,
                building_blocks_csi
                + [p.csmiles for step in products[:i_reac] for p in step],
                reactions,
            )
            if not is_valid:
                self.logger.info(f"Invalid step at index {i_reac} due to {fail_reason}")
                break
            else:
                n_valid += 1
        if n_valid < n_steps:
            return reward_mult[n_valid - 1] * (n_valid / n_steps) ** 2, {
                "valid": n_valid / n_steps,
                "correct_product": reward_mult[n_valid - 1],
                "correct_reactant": fail_reason != "reactants",
            }

        return reward_mult[n_valid - 1], {
            "valid": 1.0,
            "correct_product": reward_mult[n_valid - 1],
            "correct_reactant": True,
        }

    def parse_json_content(self, content: str) -> Dict[str, Any]:
        """Parse JSON content from the model completion.

        Args:
            content: The extracted answer content from the model completion.
        Returns:
            Parsed JSON as a dictionary.
        """
        # Find the first and last curly braces to extract JSON
        # as a regex with { "answer": ... }
        possible_json = re.search(r"\{.*\}", content, re.DOTALL)
        parsed: Dict[str, Any]
        if possible_json is None:
            return {}
        try:
            parsed = json.loads(possible_json.group(0))
        except json.JSONDecodeError as e:
            self.logger.info(f"JSON decode error: {e}")
            return {}
        if "answer" not in parsed:
            return {}
        return parsed

    def get_score(
        self, inputs: BatchVerifiersInputModel
    ) -> List[ReactionVerifierOutputModel]:
        """Compute reaction rewards for a batch of completions.

        This method routes each completion to the appropriate reward function
        based on the objective type specified in the metadata.

        Args:
            inputs: Batch of completions and metadata for verification.

        Returns:
            List of ReactionVerifierOutputModel containing rewards and metadata.

        Notes:
            - Ground truth tasks: final_product, reactant, all_reactants
            - SMARTS tasks: smarts prediction with reaction validation
            - Path tasks: full_path with step-by-step validation
        """
        completions = inputs.completions
        assert all(
            isinstance(meta, ReactionVerifierInputMetadataModel)
            for meta in inputs.metadatas
        )
        metadatas: List[ReactionVerifierInputMetadataModel] = inputs.metadatas  # type: ignore

        output_models = []
        for answer, meta in zip(completions, metadatas):
            objective = meta.objectives[0]
            reward = 0.0
            reward_metadata = {
                "valid": 0.0,
                "correct_product": 0.0,
                "correct_reactant": False,
            }
            extracted_answer = self.parse_answer(answer)
            json_answer = self.parse_json_content(extracted_answer)

            if objective in self.check_ground_truth_tasks:
                reward = self.ground_truth_reward_mol(
                    json_answer,
                    meta.target,
                    reactants=meta.reactants[meta.idx_chosen],
                    product=meta.products[meta.idx_chosen],
                    smarts=meta.or_smarts[meta.idx_chosen],
                    objective=objective,
                )
                reward_metadata = {
                    "valid": float(reward > 0.0),
                    "correct_product": reward > 0.0,
                    "correct_reactant": reward > 0.0,
                }
            elif objective == "smarts":
                assert len(meta.reactants) > 0, (
                    "Reactants must be provided for SMARTS objective"
                )
                assert len(meta.products) > 0, (
                    "Product must be provided for SMARTS objective"
                )
                reward, raw_metadata = self.reward_smarts(
                    json_answer,
                    meta.target,
                    meta.reactants[0],
                    meta.products[0],
                )
                reward_metadata = {
                    "valid": reward,
                    "correct_product": raw_metadata.get("Products_contained", False),
                    "correct_reactant": raw_metadata.get("Reactants_contained", False),
                }
            elif objective in self.run_validation_tasks:
                assert len(meta.target) > 0, (
                    "Target must be provided for run validation tasks"
                )
                # If smarts constraint, add it
                if "smarts" in objective:
                    assert len(meta.smarts) > 0, (
                        "SMARTS must be provided for smarts reference path validation"
                    )
                    smarts = meta.smarts
                else:
                    smarts = []
                reward, raw_metadata = self.reward_run_path(
                    json_answer,
                    meta.target[0],
                    meta.building_blocks if meta.building_blocks else [],
                    smarts=smarts,
                    n_steps_max=meta.n_steps_max,
                    reward_type=self.verifier_config.reaction_reward_type,
                )
                reward_metadata = raw_metadata
            else:
                raise ValueError(
                    "Unknown objective {} type for reaction verifier".format(objective)
                )

            if self.verifier_config.reward == "valid_smiles":
                reward = float(reward > 0.0)

            # Create the output model
            output_model = ReactionVerifierOutputModel(
                reward=reward,
                parsed_answer=f"{json_answer}",
                verifier_metadata=ReactionVerifierMetadataModel(
                    valid=reward_metadata["valid"],
                    correct_product=reward_metadata["correct_product"],
                    correct_reactant=reward_metadata["correct_reactant"],
                ),
            )
            output_models.append(output_model)

        return output_models

__init__(verifier_config)

Initialize the ReactionVerifier.

Parameters:

Name Type Description Default
verifier_config ReactionVerifierConfigModel

Configuration containing reaction matrix path and reward type settings.

required
Source code in mol_gen_docking/reward/verifiers/reaction_reward/reaction_verifier.py
def __init__(
    self,
    verifier_config: ReactionVerifierConfigModel,
):
    """Initialize the ReactionVerifier.

    Args:
        verifier_config: Configuration containing reaction matrix path
            and reward type settings.
    """
    super().__init__(verifier_config)
    self.verifier_config: ReactionVerifierConfigModel = verifier_config

    self.rxn_matrix: ReactantReactionMatrix
    with open(verifier_config.reaction_matrix_path, "rb") as f:
        self.rxn_matrix = pickle.load(f)

    self.reactants_csi: List[str] = [m.csmiles for m in self.rxn_matrix.reactants]
    self.check_ground_truth_tasks = [
        "final_product",
        "reactant",
        "all_reactants",
        "all_reactants_bb_ref",
    ]
    self.run_validation_tasks = [
        "full_path",
        "full_path_bb_ref",
        "full_path_smarts_ref",
        "full_path_smarts_bb_ref",
        "full_path_intermediates_gt_reactants",
        "full_path_intermediates",  # We treat this task as a normal path validation
    ]
    self.logger = logging.getLogger("ReactionVerifier")

get_score(inputs)

Compute reaction rewards for a batch of completions.

This method routes each completion to the appropriate reward function based on the objective type specified in the metadata.

Parameters:

Name Type Description Default
inputs BatchVerifiersInputModel

Batch of completions and metadata for verification.

required

Returns:

Type Description
List[ReactionVerifierOutputModel]

List of ReactionVerifierOutputModel containing rewards and metadata.

Notes
  • Ground truth tasks: final_product, reactant, all_reactants
  • SMARTS tasks: smarts prediction with reaction validation
  • Path tasks: full_path with step-by-step validation
Source code in mol_gen_docking/reward/verifiers/reaction_reward/reaction_verifier.py
def get_score(
    self, inputs: BatchVerifiersInputModel
) -> List[ReactionVerifierOutputModel]:
    """Compute reaction rewards for a batch of completions.

    This method routes each completion to the appropriate reward function
    based on the objective type specified in the metadata.

    Args:
        inputs: Batch of completions and metadata for verification.

    Returns:
        List of ReactionVerifierOutputModel containing rewards and metadata.

    Notes:
        - Ground truth tasks: final_product, reactant, all_reactants
        - SMARTS tasks: smarts prediction with reaction validation
        - Path tasks: full_path with step-by-step validation
    """
    completions = inputs.completions
    assert all(
        isinstance(meta, ReactionVerifierInputMetadataModel)
        for meta in inputs.metadatas
    )
    metadatas: List[ReactionVerifierInputMetadataModel] = inputs.metadatas  # type: ignore

    output_models = []
    for answer, meta in zip(completions, metadatas):
        objective = meta.objectives[0]
        reward = 0.0
        reward_metadata = {
            "valid": 0.0,
            "correct_product": 0.0,
            "correct_reactant": False,
        }
        extracted_answer = self.parse_answer(answer)
        json_answer = self.parse_json_content(extracted_answer)

        if objective in self.check_ground_truth_tasks:
            reward = self.ground_truth_reward_mol(
                json_answer,
                meta.target,
                reactants=meta.reactants[meta.idx_chosen],
                product=meta.products[meta.idx_chosen],
                smarts=meta.or_smarts[meta.idx_chosen],
                objective=objective,
            )
            reward_metadata = {
                "valid": float(reward > 0.0),
                "correct_product": reward > 0.0,
                "correct_reactant": reward > 0.0,
            }
        elif objective == "smarts":
            assert len(meta.reactants) > 0, (
                "Reactants must be provided for SMARTS objective"
            )
            assert len(meta.products) > 0, (
                "Product must be provided for SMARTS objective"
            )
            reward, raw_metadata = self.reward_smarts(
                json_answer,
                meta.target,
                meta.reactants[0],
                meta.products[0],
            )
            reward_metadata = {
                "valid": reward,
                "correct_product": raw_metadata.get("Products_contained", False),
                "correct_reactant": raw_metadata.get("Reactants_contained", False),
            }
        elif objective in self.run_validation_tasks:
            assert len(meta.target) > 0, (
                "Target must be provided for run validation tasks"
            )
            # If smarts constraint, add it
            if "smarts" in objective:
                assert len(meta.smarts) > 0, (
                    "SMARTS must be provided for smarts reference path validation"
                )
                smarts = meta.smarts
            else:
                smarts = []
            reward, raw_metadata = self.reward_run_path(
                json_answer,
                meta.target[0],
                meta.building_blocks if meta.building_blocks else [],
                smarts=smarts,
                n_steps_max=meta.n_steps_max,
                reward_type=self.verifier_config.reaction_reward_type,
            )
            reward_metadata = raw_metadata
        else:
            raise ValueError(
                "Unknown objective {} type for reaction verifier".format(objective)
            )

        if self.verifier_config.reward == "valid_smiles":
            reward = float(reward > 0.0)

        # Create the output model
        output_model = ReactionVerifierOutputModel(
            reward=reward,
            parsed_answer=f"{json_answer}",
            verifier_metadata=ReactionVerifierMetadataModel(
                valid=reward_metadata["valid"],
                correct_product=reward_metadata["correct_product"],
                correct_reactant=reward_metadata["correct_reactant"],
            ),
        )
        output_models.append(output_model)

    return output_models

ground_truth_reward_mol(answer, labels, reactants, product, smarts, objective)

Compute reward for molecule prediction tasks.

Notes The answer must be contained in a JSON object with an "answer" key, which can be either a single SMILES string or a list of SMILES strings.

Parameters:

Name Type Description Default
answer Dict[str, Any]

Model completion containing the answer.

required
labels List[str]

List of ground truth SMILES strings.

required
reactants List[str]

List of reactant SMILES strings.

required
product str

Expected product SMILES string.

required
objective str

The type of objective for the reaction verification.

required

Returns:

Type Description
float

Reward value between 0.0 and 1.0.

Source code in mol_gen_docking/reward/verifiers/reaction_reward/reaction_verifier.py
def ground_truth_reward_mol(
    self,
    answer: Dict[str, Any],
    labels: List[str],
    reactants: List[str],
    product: str,
    smarts: str,
    objective: str,
) -> float:
    """Compute reward for molecule prediction tasks.

    Notes
    The answer must be contained in a JSON object with an "answer" key, which can be either a single SMILES string or a list of SMILES strings.

    Args:
        answer: Model completion containing the answer.
        labels: List of ground truth SMILES strings.
        reactants: List of reactant SMILES strings.
        product: Expected product SMILES string.
        objective: The type of objective for the reaction verification.

    Returns:
        Reward value between 0.0 and 1.0.
    """
    if answer == {}:
        return 0.0
    mol_label = [Molecule(smi) for smi in labels]
    if not all([m.is_valid for m in mol_label]):
        self.logger.error("Invalid ground truth molecule")
        return 0.0
    smiles = answer["answer"]
    if isinstance(smiles, str):
        smiles_list = [smiles]
    elif isinstance(smiles, list):
        smiles_list = smiles
    else:
        return 0.0
    mols = [Molecule(smi) for smi in smiles_list]

    if any([not m.is_valid for m in mols]):
        self.logger.info("Invalid molecule found in prediction")
        return 0.0

    return self.r_ground_truth_mols(
        mols,
        mol_label,
        reactants=reactants,
        product=product,
        smarts=smarts,
        objective=objective,
    )

parse_json_content(content)

Parse JSON content from the model completion.

Parameters:

Name Type Description Default
content str

The extracted answer content from the model completion.

required

Returns: Parsed JSON as a dictionary.

Source code in mol_gen_docking/reward/verifiers/reaction_reward/reaction_verifier.py
def parse_json_content(self, content: str) -> Dict[str, Any]:
    """Parse JSON content from the model completion.

    Args:
        content: The extracted answer content from the model completion.
    Returns:
        Parsed JSON as a dictionary.
    """
    # Find the first and last curly braces to extract JSON
    # as a regex with { "answer": ... }
    possible_json = re.search(r"\{.*\}", content, re.DOTALL)
    parsed: Dict[str, Any]
    if possible_json is None:
        return {}
    try:
        parsed = json.loads(possible_json.group(0))
    except json.JSONDecodeError as e:
        self.logger.info(f"JSON decode error: {e}")
        return {}
    if "answer" not in parsed:
        return {}
    return parsed

r_ground_truth_mols(mol_y, mol_label, reactants, product, smarts, objective)

Compute reward for molecule prediction against ground truth.

Source code in mol_gen_docking/reward/verifiers/reaction_reward/reaction_verifier.py
def r_ground_truth_mols(
    self,
    mol_y: List[Molecule],
    mol_label: List[Molecule],
    reactants: List[str],
    product: str,
    smarts: str,
    objective: str,
) -> float:
    """Compute reward for molecule prediction against ground truth."""
    self.logger.info(
        f"Computed molecules: {[mol.smiles for mol in mol_y]} vs labels: {[mol.smiles for mol in mol_label]}"
    )
    smi_y = {mol.csmiles for mol in mol_y}
    smi_y_true = {mol.csmiles for mol in mol_label}
    if smi_y == smi_y_true:
        return 1.0

    else:
        # Check if the reaction still works
        rxn = Reaction(smarts)
        if objective == "final_product":
            if len(mol_y) != 1:
                return 0.0
            react_mols = [Molecule(smi) for smi in reactants]
            possible_products = []
            for r_ in itertools.permutations(react_mols):
                possible_products.extend(rxn(list(r_)))
            # Change to csmiles to avoid unexpected behavior with python in
            possible_products_csi = {m.csmiles for m in possible_products}
            if mol_y[0].csmiles in possible_products_csi:  # Can be only one product
                return 1.0
            else:
                return 0.0
        elif objective == "reactant":
            react_mols = [Molecule(smi) for smi in reactants]
            react_mols = [m for m in react_mols if not m == mol_label[0]]
            react_mols += mol_y
        elif objective in ["all_reactants", "all_reactants_bb_ref"]:
            react_mols = mol_y
        # Check that the number of reactants matches
        if len(react_mols) != rxn.num_reactants:
            return 0.0
        # Check if all reactants are in the building blocks
        if any(m.csmiles not in self.reactants_csi for m in react_mols):
            return 0.0
        possible_products = []
        for r_ in itertools.permutations(react_mols):
            possible_products.extend(rxn(list(r_)))
        # Change to csmiles to avoid unexpected behavior with python in
        possible_products_csi = {m.csmiles for m in possible_products}
        prod_mol = Molecule(product)
        if prod_mol.csmiles in possible_products_csi:
            return 1.0
        return 0.0

reward_run_path(answer, label, building_blocks, smarts, n_steps_max, reward_type='binary')

Compute reward for retro-synthesis path validation.

Validates a multi-step synthesis path by checking: 1. All reactants are valid building blocks or previous products 2. Each reaction step has a valid SMARTS pattern 3. The final product matches the target (exactly or by Tanimoto similarity)

Notes The answer must be contained in a JSON object with an "answer" key, which should contain a list of steps: Dictionaries containing the keys:

- reactants: List of reactant SMILES strings for this step
- product: List of product SMILES strings for this step

Parameters:

Name Type Description Default
answer Dict[str, Any]

Model completion containing the synthesis path.

required
label str

Target product SMILES string.

required
building_blocks List[str]

List of valid starting building block SMILES.

required
smarts List[str]

List of allowed SMARTS patterns (empty = use reaction matrix).

required
n_steps_max int

Maximum allowed number of synthesis steps.

required
reward_type Literal['binary', 'tanimoto']

"binary" for exact match or "tanimoto" for similarity-based.

'binary'

Returns:

Type Description
Tuple[float, Dict[str, Any]]

Tuple of (reward, metadata_dict) containing validation results.

Source code in mol_gen_docking/reward/verifiers/reaction_reward/reaction_verifier.py
def reward_run_path(
    self,
    answer: Dict[str, Any],
    label: str,
    building_blocks: List[str],
    smarts: List[str],
    n_steps_max: int,
    reward_type: Literal["binary", "tanimoto"] = "binary",
) -> Tuple[float, Dict[str, Any]]:
    """Compute reward for retro-synthesis path validation.

    Validates a multi-step synthesis path by checking:
    1. All reactants are valid building blocks or previous products
    2. Each reaction step has a valid SMARTS pattern
    3. The final product matches the target (exactly or by Tanimoto similarity)

    Notes
    The answer must be contained in a JSON object with an "answer" key,
    which should contain a list of steps: Dictionaries containing the keys:

        - reactants: List of reactant SMILES strings for this step
        - product: List of product SMILES strings for this step

    Args:
        answer: Model completion containing the synthesis path.
        label: Target product SMILES string.
        building_blocks: List of valid starting building block SMILES.
        smarts: List of allowed SMARTS patterns (empty = use reaction matrix).
        n_steps_max: Maximum allowed number of synthesis steps.
        reward_type: "binary" for exact match or "tanimoto" for similarity-based.

    Returns:
        Tuple of (reward, metadata_dict) containing validation results.
    """
    if answer == {}:
        self.logger.info("No synthesis path found in completion")
        return 0.0, {
            "valid": 0.0,
            "correct_product": 0.0,
            "correct_reactant": False,
        }
    steps: List[Dict[str, Any]] = answer["answer"]
    if not isinstance(steps, list) or len(steps) == 0 or len(steps) > n_steps_max:
        self.logger.info("Synthesis path answer is not a list or is too long/short")
        return 0.0, {
            "valid": 0.0,
            "correct_product": 0.0,
            "correct_reactant": False,
        }
    # Validate each step structure
    for step in steps:
        if "reactants" not in step or "product" not in step:
            self.logger.info("Synthesis step missing reactants or product")
            return 0.0, {
                "valid": 0.0,
                "correct_product": 0.0,
                "correct_reactant": False,
            }
        if step["reactants"] == [] or step["product"] == []:
            self.logger.info("Synthesis step has empty reactants or product")
            return 0.0, {
                "valid": 0.0,
                "correct_product": 0.0,
                "correct_reactant": False,
            }
        if not isinstance(step["reactants"], list) or not all(
            isinstance(r, str) for r in step["reactants"]
        ):
            self.logger.info("One or more reactants in step are not strings")
            return 0.0, {
                "valid": 0.0,
                "correct_product": 0.0,
                "correct_reactant": False,
            }
        if not isinstance(step["product"], list) or not all(
            isinstance(p, str) for p in step["product"]
        ):
            self.logger.info("One or more products in step are not strings")
            return 0.0, {
                "valid": 0.0,
                "correct_product": 0.0,
                "correct_reactant": False,
            }

    reactants = [
        [Molecule(smi.strip()) for smi in step["reactants"]] for step in steps
    ]
    products = [
        [Molecule(smi.strip()) for smi in step["product"]] for step in steps
    ]
    n_steps = len(reactants)
    label_mol = Molecule(label)
    reward_mult: List[float] = [1.0 for _ in products]
    if reward_type == "binary" and not any(
        [label_mol == last_p for last_p in products[-1]]
    ):
        self.logger.info("Product not found")
        return 0.0, {
            "valid": 0.0,
            "correct_last_product": 0.0,
            "correct_reactant": False,
        }
    elif reward_type == "tanimoto":
        # Compute the tanimoto similarity between the label and products at each step
        for i, product in enumerate(products):
            all_sims = label_mol.tanimoto_similarity(product)
            reward_mult[i] = max(all_sims) ** 3

    reactions: ReactionContainer
    if smarts == []:
        reactions = self.rxn_matrix.reactions
    else:
        reactions = ReactionContainer([Reaction(sma) for sma in smarts])

    building_blocks_csi = (
        self.reactants_csi
    )  # For the moment, the building blocks passed in the metadata
    # are only used to  help the model but not proper constraints.

    n_valid = 0
    fail_reason = ""
    for i_reac, (reactant, product) in enumerate(zip(reactants, products)):
        is_valid, fail_reason = self._check_valid_step(
            reactant,
            product,
            building_blocks_csi
            + [p.csmiles for step in products[:i_reac] for p in step],
            reactions,
        )
        if not is_valid:
            self.logger.info(f"Invalid step at index {i_reac} due to {fail_reason}")
            break
        else:
            n_valid += 1
    if n_valid < n_steps:
        return reward_mult[n_valid - 1] * (n_valid / n_steps) ** 2, {
            "valid": n_valid / n_steps,
            "correct_product": reward_mult[n_valid - 1],
            "correct_reactant": fail_reason != "reactants",
        }

    return reward_mult[n_valid - 1], {
        "valid": 1.0,
        "correct_product": reward_mult[n_valid - 1],
        "correct_reactant": True,
    }

reward_smarts(answer, labels, reactants, product)

Compute reward for SMARTS prediction tasks.

Notes The answer must be contained in a JSON object with an "answer" key, which should be a SMARTS string representing the reaction. The reward is computed based on whether the proposed SMARTS can produce the expected product from the given reactants. A reward of 1.0 is given for an exact match with the ground truth SMARTS, 0.1 if the SMARTS is valid and produces the correct product, and 0.0 otherwise.

Parameters:

Name Type Description Default
answer Dict[str, Any]

Model completion containing the SMARTS answer.

required
labels List[str]

List containing the ground truth SMARTS string.

required
reactants List[str]

List of reactant SMILES strings.

required
product str

Expected product SMILES string.

required

Returns:

Type Description
float

Tuple of (reward, metadata_dict) where metadata contains

Dict[str, Any]

'Reactants_contained' and 'Products_contained' flags.

Source code in mol_gen_docking/reward/verifiers/reaction_reward/reaction_verifier.py
def reward_smarts(
    self,
    answer: Dict[str, Any],
    labels: List[str],
    reactants: List[str],
    product: str,
) -> Tuple[float, Dict[str, Any]]:
    """Compute reward for SMARTS prediction tasks.

    Notes
    The answer must be contained in a JSON object with an "answer" key,
    which should be a SMARTS string representing the reaction. The reward is computed based
    on whether the proposed SMARTS can produce the expected product from the given reactants.
    A reward of 1.0 is given for an exact match with the ground truth SMARTS, 0.1 if the SMARTS
    is valid and produces the correct product, and 0.0 otherwise.

    Args:
        answer: Model completion containing the SMARTS answer.
        labels: List containing the ground truth SMARTS string.
        reactants: List of reactant SMILES strings.
        product: Expected product SMILES string.

    Returns:
        Tuple of (reward, metadata_dict) where metadata contains
        'Reactants_contained' and 'Products_contained' flags.
    """
    if answer == {}:
        return 0.0, {"Reactants_contained": False, "Products_contained": False}
    gt_smarts = labels[0]
    smarts_pred = answer["answer"]
    if not isinstance(smarts_pred, str):
        return 0.0, {"Reactants_contained": False, "Products_contained": False}

    if smarts_pred.strip() == gt_smarts:
        return 1.0, {"Reactants_contained": True, "Products_contained": True}
    self.logger.info(
        f"Proposed SMARTS: {smarts_pred.strip()} | GT SMARTS: {gt_smarts}, checking reaction..."
    )
    try:
        rxnB = Reaction(smarts_pred.strip())
        if rxnB.num_reactants != len(reactants):
            return 0.0, {"Reactants_contained": False, "Products_contained": False}
        p = rxnB([Molecule(r) for r in reactants])
        reward = 0.0
        if Molecule(product).csmiles in {prod.csmiles for prod in p}:
            reward = 0.1
        return reward, {
            "Reactants_contained": True,
            "Products_contained": reward == 0.1,
        }
    except Exception as e:
        self.logger.info(
            f"Error in reaction SMARTS parsing: {e} (proposed: {smarts_pred} | gt: {gt_smarts})"
        )
        return 0.0, {"Reactants_contained": False, "Products_contained": False}