"""Classes for generating random recipes/selections"""
import mrich
import json
from pathlib import Path
from .tools import dt_hash
from .recipe import Recipe
from .cset import CompoundSet, IngredientSet
class RRGMixin:
"""Mixin class for shared properties"""
@property
def db(self) -> "Database":
"""Get the linked HIPPO Database object"""
return self._db
@property
def db_path(self) -> str:
"""Get the path of the linked Database"""
return self._db_path
@property
def starting_recipe(self):
"""Get the starting recipe used in all generations"""
return self._starting_recipe
@property
def suppliers_str(self) -> str:
"""SQL formatted tuple of suppliers"""
return str(tuple(self.suppliers)).replace(",)", ")")
@property
def suppliers(self) -> list[str]:
"""List of suppliers"""
return self._suppliers
@property
def max_lead_time(self) -> float:
"""Maximum lead-time constraint"""
return self._max_lead_time
@property
def data_path(self):
"""File path for the JSON data export"""
return self._data_path
@property
def recipe_dir(self):
"""File path for the JSON recipe export"""
return self._recipe_dir
def __repr__(self) -> str:
"""ANSI Formatted string representation"""
import mcol
return f"{mcol.bold}{mcol.underline}{self}{mcol.unbold}{mcol.ununderline}"
def __call__(self, *args, **kwargs) -> "Recipe":
"""Generate Recipe"""
return self.generate(*args, **kwargs)
def __rich__(self) -> str:
"""Rich Formatted string representation"""
return f"[bold underline]{self}"
[docs]
class RandomRecipeGenerator(RRGMixin):
"""Class to create randomly sampled Recipe from a HIPPO Database"""
def __init__(
self,
db,
*,
max_lead_time=None,
suppliers: list | None = None,
start_with: Recipe | CompoundSet | IngredientSet | None = None,
route_pool: "RouteSet | None" = None,
out_key: str | None = None,
):
"""RandomRecipeGenerator initialisation"""
mrich.debug("RandomRecipeGenerator.__init__()")
if not start_with:
start_with = Recipe(db)
# Static parameters
self._db_path = db.path
self._max_lead_time = max_lead_time
self._suppliers = suppliers
self._starting_recipe = start_with
mrich.var("database", self.db_path)
mrich.var("max_lead_time", self.max_lead_time)
mrich.var("suppliers", self.suppliers)
# Database set up
self._db = db
if not out_key:
out_key = str(self.db_path.name).removesuffix(".sqlite")
mrich.var("out_key", out_key)
parent_dir = Path(out_key).parent
if not parent_dir.exists():
parent_dir.mkdir(parents=True)
# JSON I/O set up
self._data_path = Path(f"{out_key}_rgen.json")
if self.data_path.exists():
mrich.warning(f"Will overwrite existing rgen data file: {self.data_path}")
# Recipe I/O set up
path = Path(f"{out_key}_recipes")
if not path.exists():
mrich.writing(f"{path}/")
path.mkdir()
self._recipe_dir = path
# Route pool
if route_pool:
route_pool = route_pool.prune_unavailable(suppliers=suppliers)
self._route_pool = route_pool
else:
mrich.debug("Solving route pool...")
self._route_pool = self.get_route_pool()
assert len(self._route_pool), "Route pool is empty!"
# dump data
self.dump_data()
### FACTORIES
[docs]
@classmethod
def from_json(cls, db: "Database", path: "Path | str"):
"""Construct the RandomRecipeGenerator from a JSON file"""
data = json.load(open(path, "rt"))
self = cls.__new__(cls)
self._db_path = Path(data["db_path"])
self._recipe_dir = Path(data["recipe_dir"])
self._max_lead_time = data["max_lead_time"]
self._suppliers = data["suppliers"]
self._starting_recipe = Recipe.from_json(
db=db,
path=None,
data=data["starting_recipe"],
allow_db_mismatch=True,
)
mrich.var("database", self.db_path)
mrich.var("max_lead_time", self.max_lead_time)
mrich.var("suppliers", self.suppliers)
self._db = db
# JSON I/O set up
self._data_path = Path(path)
# Route pool
from .recipe import RouteSet
self._route_pool = RouteSet.from_json(path=None, data=data["route_pool"], db=db)
return self
### PROPERTIES
@property
def route_pool(self):
"""Get the RouteSet of all product reaction routes considered by this generator"""
return self._route_pool
### POOL METHODS
[docs]
def get_route_pool(self, mini_test=False):
"""Construct the pool of routes that will be randomly sampled from
:param mini_test: (Default value = False)
"""
"""
Explainer for SQL query:
- get table of quoted compounds with a count of the valid suppliers
- join routes, components, and the new table together and grouped by route count the unavailable reactants
- return route ids where no reactants are unavailable
"""
if "route" not in self.db.table_names:
mrich.error("route table not in Database")
raise NotImplementedError
assert self.suppliers_str
if self.max_lead_time:
raise NotImplementedError
### EXCLUDE PRODUCTS OF ROUTES IN STARTING RECIPE!!!
sql = f"""
WITH possible_reactants AS (
SELECT quote_compound, COUNT(CASE WHEN quote_supplier IN {self.suppliers_str} THEN 1 END) AS [count_valid]
FROM quote
GROUP BY quote_compound
),
route_reactants AS (
SELECT route_id, route_product,
COUNT(
CASE
WHEN count_valid = 0 THEN 1
WHEN count_valid IS NULL THEN 1
END)
AS [count_unavailable] FROM route
INNER JOIN component ON component_route = route_id
LEFT JOIN possible_reactants ON quote_compound = component_ref
WHERE component_type = 2
GROUP BY route_id
)
SELECT route_id FROM route_reactants
WHERE count_unavailable = 0
"""
route_ids = self.db.execute(sql).fetchall()
route_ids = [i for i, in route_ids]
if mini_test:
route_ids = route_ids[:100]
from .recipe import RouteSet
return RouteSet.from_ids(self.db, route_ids)
### FILE I/O METHODS
[docs]
def dump_data(self):
"""Dump data to JSON"""
data = {}
data["db_path"] = str(self.db_path.resolve())
data["recipe_dir"] = str(self.recipe_dir.resolve())
data["max_lead_time"] = self.max_lead_time
data["suppliers"] = self.suppliers
data["starting_recipe"] = self.starting_recipe.get_dict(serialise_price=True)
data["route_pool"] = self.route_pool.get_dict()
mrich.writing(self.data_path)
json.dump(data, open(self.data_path, "wt"), indent=4)
[docs]
def generate(
self,
budget: float = 10000,
currency: str = "EUR",
max_products: int = 1000,
max_reactions: int = 1000,
debug: bool = False,
max_iter: int | None = None,
shuffle: bool = True,
balance_clusters: bool = False,
permitted_clusters: None | set = None,
):
"""Generate random recipe
:param budget: maximum budget (Default value = 10000)
:param currency: currency (Default value = 'EUR')
:param max_products: maximum number of products (Default value = 1000)
:param max_reactions: maximum number of reactions (Default value = 1000)
:param debug: increase verbosity for debugging (Default value = True)
:param max_iter: maximum number of iterations (Default value = None)
:param shuffle: randomly shuffle recipe pool (Default value = True)
:param balance_clusters: balance selection across scaffold clusters (Default value = False)
:param permitted_clusters: restrict selection to provided set of clusters (Default value = False)
"""
# construct filename
out_file = self.recipe_dir / f"Recipe_{dt_hash()}.json"
from .price import Price
if not max_iter:
max_iter = max_products + max_reactions
max_iter = min(max_iter, len(self.route_pool))
budget = Price(budget, currency)
recipe = self.starting_recipe.copy()
recipe.reactants._supplier = self.suppliers
# get the RouteSet
pool = self.route_pool.copy()
assert len(pool), "Route pool is empty!"
if shuffle:
mrich.debug("Shuffling Route pool")
pool.shuffle()
old_recipe = recipe.copy()
mrich.var("route pool", len(pool))
mrich.var("max_iter", max_iter)
for i in mrich.track(range(max_iter), prefix="Generating Recipe..."):
if debug:
mrich.title(f"Iteration {i}")
price = recipe.price
mrich.set_progress_field("price", str(price))
mrich.set_progress_field("#products", len(recipe.products))
if debug:
mrich.var("price", price)
# pop a route
if balance_clusters:
candidate_route = pool.balanced_pop(
permitted_clusters=permitted_clusters
)
else:
candidate_route = pool.pop()
if debug:
mrich.var("candidate_route", candidate_route)
if debug:
mrich.var("candidate_route.reactants", candidate_route.reactants.ids)
if candidate_route.product in recipe.products:
continue
# add the route to the recipe
if debug:
mrich.var("#recipe.reactants", len(recipe.reactants))
recipe += candidate_route
if debug:
mrich.var("#recipe.reactants", len(recipe.reactants))
# calculate the new price
try:
new_price = recipe.price
except AssertionError:
mrich.error(
f"Something went wrong while calculating the price after adding {candidate_route=} to recipe"
)
raise
if debug:
mrich.var("new price", new_price)
# Break if product pool depleted
if not len(pool):
stop_reason = "Product pool depleted"
mrich.success(stop_reason)
break
# check breaking conditions
if new_price > budget:
recipe = old_recipe.copy()
continue
if len(recipe.reactions) > max_reactions:
stop_reason = "Max #reactions exceeded"
mrich.success(stop_reason)
break
if len(recipe.products) > max_products:
stop_reason = "Max #products exceeded"
mrich.success(stop_reason)
break
# accept change
old_recipe = recipe.copy()
else:
stop_reason = "Max #iterations reached"
mrich.warning(stop_reason)
### recalculate the products to see if any extra can be had for free?
mrich.success(f"Completed after {i} iterations")
metadict = {
"rgen_data_path": str(self.data_path.resolve()),
"rgen_db_path": str(self.db_path.resolve()),
"rgen_recipe_dir": str(self.recipe_dir.resolve()),
"rgen_max_lead_time": self.max_lead_time,
"rgen_suppliers": self.suppliers,
"gen_budget": budget.amount,
"gen_currency": budget.currency,
"gen_max_products": max_products,
"gen_max_reactions": max_reactions,
"gen_max_iter": max_iter,
"gen_shuffle": shuffle,
"gen_iterations": i,
"gen_stop_reason": stop_reason,
"gen_recipe_path": str(out_file.resolve()),
}
# write the Recipe JSON
recipe.write_json(out_file, extra=metadict)
return recipe
### DUNDERS
[docs]
def __str__(self) -> str:
"""Unformatted string representation"""
return f"RandomRecipeGenerator(recipe_dir={self.recipe_dir})"
[docs]
class RandomSelectionGenerator(RRGMixin):
"""Class to create randomly sampled (no-chemistry) Recipe from a HIPPO Database"""
def __init__(
self,
db,
*,
# max_lead_time=None,
suppliers: list | None = None,
amount: float = 1.0, # in mg
start_with: Recipe | CompoundSet | IngredientSet = None,
compounds: CompoundSet | None = None,
quoted_only: bool = True,
):
"""RandomSelectionGenerator initialisation"""
mrich.debug("RandomRecipeGenerator.__init__()")
# Static parameters
self._db_path = db.path
self._suppliers = suppliers
self._amount = amount
self._quoted_only = quoted_only
self._db = db
mrich.var("database", self.db_path)
mrich.var("suppliers", self.suppliers)
mrich.var("amount per compound", self.amount, unit="mg")
mrich.var("quoted_only", self.quoted_only)
self.get_starting_recipe(start_with)
mrich.var("starting recipe", self.starting_recipe)
# JSON I/O set up
self._data_path = Path(str(self.db_path.name).replace(".sqlite", "_sgen.json"))
if self.data_path.exists():
mrich.warning(f"Will overwrite existing rgen data file: {self.data_path}")
# Recipe I/O set up
path = Path(str(self.db_path.name).replace(".sqlite", "_selections"))
mrich.writing(f"{path}/")
path.mkdir(exist_ok=True)
self._recipe_dir = path
with mrich.spinner("Getting compound pool"):
self.get_compound_pool(compounds)
mrich.var("compound pool", self.compound_pool)
# dump data
self.dump_data()
### FACTORIES
[docs]
@classmethod
def from_json(
cls, db: "Database", path: "Path | str"
) -> "RandomSelectionGenerator":
"""Construct the RandomRecipeGenerator from a JSON file"""
data = json.load(open(path, "rt"))
self = cls.__new__(cls)
self._db_path = Path(data["db_path"])
self._recipe_dir = Path(data["recipe_dir"])
# self._max_lead_time = data["max_lead_time"]
self._suppliers = data["suppliers"]
self._amount = data["amount"]
self._starting_recipe = Recipe.from_json(
db=db,
path=None,
data=data["starting_recipe"],
allow_db_mismatch=True,
)
mrich.var("database", self.db_path)
mrich.var("suppliers", self.suppliers)
mrich.var("amount", self.amount)
mrich.var("starting_recipe", self.starting_recipe)
self._db = db
# JSON I/O set up
self._data_path = Path(path)
# Route pool
self._compound_pool = IngredientSet.from_json(
path=None, data=data["compound_pool"]["data"], db=db
)
mrich.var("compound_pool", self.compound_pool)
return self
### PROPERTIES
@property
def amount(self) -> float:
"""Amount to quote each compound for"""
return self._amount
@property
def quoted_only(self) -> bool:
"""Only consider compounds with quotes"""
return self._quoted_only
@property
def compound_pool(self) -> "CompoundTable | CompoundSet":
"""The pool of compounds that will be chosen from"""
return self._compound_pool
### METHODS
[docs]
def get_starting_recipe(
self, start_with: "Recipe | CompoundSet | IngredientSet"
) -> Recipe:
"""Process start_with into Recipe object"""
if isinstance(start_with, Recipe):
if start_with.type != "NOCHEM":
raise NotImplementedError("Only NOCHEM recipes are supported")
self._starting_recipe = start_with
return self._starting_recipe
from .compound import Compound
self._starting_recipe = Recipe(self.db)
if start_with is not None:
for item in start_with:
if isinstance(item, Compound):
item = item.as_ingredient(amount=self._amount)
self._starting_recipe.compounds.add(item)
return self._starting_recipe
[docs]
def get_compound_pool(
self, compounds: CompoundSet | None
) -> "CompoundTable | CompoundSet":
"""Get pool of compounds to select from"""
if self.suppliers:
raise NotImplementedError
if compounds is None:
# all compounds
if not self.quoted_only:
ids = self.db.select(
table="compound", query="compound_id", multiple=True
)
self._compound_pool = IngredientSet.from_compounds(
db=self.db, ids=[i for i, in ids], amount=self.amount
)
return self._compound_pool
# get all compounds that have a quote
sql = f"""
SELECT quote_id, quote_compound, quote_amount, quote_supplier, MIN(quote_price)
FROM quote
WHERE quote_amount >= {self.amount}
GROUP BY quote_compound
"""
records = self.db.execute(sql).fetchall()
ingredients = [
dict(
quote_id=i,
compound_id=c,
amount=self.amount,
quoted_amount=a,
supplier=None,
max_lead_time=None,
)
for i, c, a, s, p in records
]
self._compound_pool = IngredientSet.from_ingredient_dicts(
self.db, ingredients
)
else:
# ignore quoting
if not self.quoted_only:
self._compound_pool = IngredientSet.from_compounds(
db=self.db, ids=compounds.ids, amount=self.amount
)
return self._compound_pool
# get all compounds that have a quote
sql = f"""
SELECT quote_id, quote_compound, quote_amount, quote_supplier, MIN(quote_price)
FROM quote
WHERE quote_amount >= {self.amount}
AND quote_compound IN {compounds.str_ids}
GROUP BY quote_compound
"""
records = self.db.execute(sql).fetchall()
ingredients = [
dict(
quote_id=i,
compound_id=c,
amount=self.amount,
quoted_amount=a,
supplier=None,
max_lead_time=None,
)
for i, c, a, s, p in records
]
self._compound_pool = IngredientSet.from_ingredient_dicts(
self.db, ingredients
)
[docs]
def dump_data(self):
"""Dump data to JSON"""
data = {}
data["db_path"] = str(self.db_path.resolve())
data["recipe_dir"] = str(self.recipe_dir.resolve())
# data["max_lead_time"] = self.max_lead_time
data["amount"] = self.amount
data["suppliers"] = self.suppliers
data["starting_recipe"] = self.starting_recipe.get_dict(serialise_price=True)
data["compound_pool"] = self.compound_pool.get_dict()
mrich.writing(self.data_path)
json.dump(data, open(self.data_path, "wt"), indent=4)
[docs]
def generate(
self,
budget: float = 10000,
currency: str = "EUR",
max_iter: int | None = None,
max_compounds: int = 1000,
debug: bool = False,
shuffle: bool = True,
):
"""Generate random selection
:param budget: maximum budget
:param currency: currency
:param max_iter: maximum number of iterations
:param max_compounds: maximum number of compounds
:param debug: Increase verbosity for debugging
:param shuffle: Randomise order of compound pool
"""
# construct filename
out_file = self.recipe_dir / f"Recipe_{dt_hash()}.json"
from .price import Price
budget = Price(budget, currency)
recipe = self.starting_recipe.copy()
recipe.compounds._supplier = self.suppliers
# get the RouteSet
pool = self.compound_pool.copy()
assert len(pool), "Route pool is empty!"
if shuffle:
mrich.debug("Shuffling Route pool")
pool.shuffle()
old_recipe = recipe.copy()
if not max_iter:
max_iter = max_compounds * 3
mrich.var("compound pool", pool)
mrich.var("max_compounds", max_compounds)
mrich.var("max_iter", max_iter)
for i in mrich.track(range(max_iter), prefix="Generating Recipe..."):
if debug:
mrich.title(f"Iteration {i}")
price = recipe.price
mrich.set_progress_field("price", str(price))
mrich.set_progress_field("#compounds", len(recipe.compounds))
if debug:
mrich.var("price", price)
# # pop a route
# if balance_clusters:
# candidate_route = pool.balanced_pop(
# permitted_clusters=permitted_clusters
# )
# else:
candidate = pool.pop()
if debug:
mrich.var("candidate", candidate)
if candidate in recipe.compounds:
continue
# add the route to the recipe
recipe.compounds.add(candidate)
# calculate the new price
try:
new_price = recipe.price
except AssertionError:
mrich.error(
f"Something went wrong while calculating the price after adding {candidate_route=} to recipe"
)
raise
if debug:
mrich.var("#compounds", recipe.num_compounds)
mrich.var("new price", new_price)
# Break if product pool depleted
if not len(pool):
stop_reason = "Compound pool depleted"
mrich.success(stop_reason)
break
# check breaking conditions
if new_price > budget:
recipe = old_recipe.copy()
continue
if len(recipe.compounds) > max_compounds:
stop_reason = "Max #compounds exceeded"
mrich.success(stop_reason)
break
# accept change
old_recipe = recipe.copy()
else:
stop_reason = "Max #iterations reached"
mrich.warning(stop_reason)
### recalculate the products to see if any extra can be had for free?
mrich.success(f"Completed after {i} iterations")
metadict = {
"rgen_data_path": str(self.data_path.resolve()),
"rgen_db_path": str(self.db_path.resolve()),
"rgen_recipe_dir": str(self.recipe_dir.resolve()),
"rgen_suppliers": self.suppliers,
"rgen_amount": self.amount,
"gen_budget": budget.amount,
"gen_currency": budget.currency,
"gen_max_compounds": max_compounds,
"gen_shuffle": shuffle,
"gen_iterations": i,
"gen_stop_reason": stop_reason,
"gen_recipe_path": str(out_file.resolve()),
}
# write the Recipe JSON
recipe.write_json(out_file, extra=metadict)
return recipe
### DUNDERS
[docs]
def __str__(self) -> str:
"""Unformatted string representation"""
return f"RandomSelectionGenerator(recipe_dir={self.recipe_dir})"