import warnings
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
from labelbox.schema.tool_building.tool_type import ToolType
[docs]@dataclass
class StepReasoningVariant:
id: int
name: str
actions: List[str] = field(default_factory=list)
def asdict(self) -> Dict[str, Any]:
return {"id": self.id, "name": self.name, "actions": self.actions}
[docs]@dataclass
class IncorrectStepReasoningVariant:
id: int
name: str
regenerate_steps: Optional[bool] = True
generate_and_rate_alternative_steps: Optional[bool] = True
rewrite_step: Optional[bool] = True
justification: Optional[bool] = True
def asdict(self) -> Dict[str, Any]:
actions = []
if self.regenerate_steps:
actions.append("regenerateSteps")
if self.generate_and_rate_alternative_steps:
actions.append("generateAndRateAlternativeSteps")
if self.rewrite_step:
actions.append("rewriteStep")
if self.justification:
actions.append("justification")
return {"id": self.id, "name": self.name, "actions": actions}
@classmethod
def from_dict(
cls, dictionary: Dict[str, Any]
) -> "IncorrectStepReasoningVariant":
return cls(
id=dictionary["id"],
name=dictionary["name"],
regenerate_steps="regenerateSteps" in dictionary.get("actions", []),
generate_and_rate_alternative_steps="generateAndRateAlternativeSteps"
in dictionary.get("actions", []),
rewrite_step="rewriteStep" in dictionary.get("actions", []),
justification="justification" in dictionary.get("actions", []),
)
def _create_correct_step() -> StepReasoningVariant:
return StepReasoningVariant(
id=StepReasoningVariants.CORRECT_STEP_ID, name="Correct"
)
def _create_neutral_step() -> StepReasoningVariant:
return StepReasoningVariant(
id=StepReasoningVariants.NEUTRAL_STEP_ID, name="Neutral"
)
def _create_incorrect_step() -> IncorrectStepReasoningVariant:
return IncorrectStepReasoningVariant(
id=StepReasoningVariants.INCORRECT_STEP_ID, name="Incorrect"
)
[docs]@dataclass
class StepReasoningVariants:
"""
This class is used to define the possible options for evaluating a step
Currently the options are correct, neutral, and incorrect
"""
CORRECT_STEP_ID = 0
NEUTRAL_STEP_ID = 1
INCORRECT_STEP_ID = 2
correct_step: StepReasoningVariant = field(
default_factory=_create_correct_step
)
neutral_step: StepReasoningVariant = field(
default_factory=_create_neutral_step
)
incorrect_step: IncorrectStepReasoningVariant = field(
default_factory=_create_incorrect_step
)
def asdict(self):
return [
self.correct_step.asdict(),
self.neutral_step.asdict(),
self.incorrect_step.asdict(),
]
@classmethod
def from_dict(cls, dictionary: List[Dict[str, Any]]):
correct_step = None
neutral_step = None
incorrect_step = None
for variant in dictionary:
if variant["id"] == cls.CORRECT_STEP_ID:
correct_step = StepReasoningVariant(**variant)
elif variant["id"] == cls.NEUTRAL_STEP_ID:
neutral_step = StepReasoningVariant(**variant)
elif variant["id"] == cls.INCORRECT_STEP_ID:
incorrect_step = IncorrectStepReasoningVariant.from_dict(
variant
)
if not all([correct_step, neutral_step, incorrect_step]):
raise ValueError("Invalid step reasoning variants")
return cls(
correct_step=correct_step, # type: ignore
neutral_step=neutral_step, # type: ignore
incorrect_step=incorrect_step, # type: ignore
)
[docs]@dataclass
class StepReasoningDefinition:
variants: StepReasoningVariants = field(
default_factory=StepReasoningVariants
)
version: int = field(default=1)
title: Optional[str] = None
value: Optional[str] = None
def asdict(self) -> Dict[str, Any]:
result = {"variants": self.variants.asdict(), "version": self.version}
if self.title is not None:
result["title"] = self.title
if self.value is not None:
result["value"] = self.value
return result
@classmethod
def from_dict(cls, dictionary: Dict[str, Any]) -> "StepReasoningDefinition":
variants = StepReasoningVariants.from_dict(dictionary["variants"])
title = dictionary.get("title", None)
value = dictionary.get("value", None)
return cls(variants=variants, title=title, value=value)