Source code for noisemaker.composer

"""Extremely high-level interface for composable noise presets. See `detailed docs <composer.html>`_."""

from collections import UserDict
from enum import Enum, EnumMeta
from functools import partial

import random

import tensorflow as tf

from noisemaker.effects_registry import EFFECTS
from noisemaker.generators import multires
from noisemaker.util import logger, save

DEFAULT_SHAPE = [1024, 1024, 3]

SETTINGS_KEY = "settings"

ALLOWED_KEYS = ["layers", SETTINGS_KEY, "generator", "octaves", "post", "final", "ai", "unique"]

# These correspond to https://platform.stability.ai/docs/api-reference#tag/v1generation/operation/textToImage
ALLOWED_AI_KEYS = ["prompt", "image_strength", "cfg_scale", "style_preset", "model"]

# Don't raise an exception if the following keys are unused in settings
UNUSED_OKAY = ["ai", "angle", "palette_alpha", "palette_name", "speed"]

# Populated by reload_presets() after setting random seed
GENERATOR_PRESETS = {}
EFFECT_PRESETS = {}

_STASH = {}


[docs]class Preset: def __init__(self, preset_name, presets, settings=None): """ """ self.layers = presets[preset_name].get("layers", []) self.name = preset_name prototype = presets.get(preset_name) if prototype is None: raise ValueError(f"Preset \"{preset_name}\" was not found among the available presets.") if not isinstance(prototype, dict): raise ValueError(f"Preset \"{preset_name}\" should be a dict, not \"{type(prototype)}\"") # To avoid mistakes in presets, unknown top-level keys are disallowed. for key in prototype: if key not in ALLOWED_KEYS: raise ValueError(f"Preset \"{preset_name}\": Key \"{key}\" is not permitted. " \ f"Allowed keys are: {ALLOWED_KEYS}") # Build a flat list of parent preset names, in topological order. self.flattened_layers = [] _flatten_ancestors(self.name, presets, {}, self.flattened_layers) # self.settings provides overridable args which can be consumed by generator, octaves, post, ai, and final. # SettingsDict is a custom dict class that enforces no unused extra keys, to minimize human error. self.settings = SettingsDict(_flatten_ancestor_metadata(self, None, SETTINGS_KEY, {}, presets)) if settings: # Inline overrides from caller (such as from CLI) self.settings.update(settings) # These args will be sent to generators.multires() to create the noise basis self.generator_kwargs = _flatten_ancestor_metadata(self, self.settings, "generator", {}, presets) # A list of callable effects functions, to be applied per-octave, in order self.octave_effects = _flatten_ancestor_metadata(self, self.settings, "octaves", [], presets) # A list of callable effects functions, to be applied post-reduce, in order self.post_effects = _flatten_ancestor_metadata(self, self.settings, "post", [], presets) # A list of callable effects functions, to be applied in order after everything else self.final_effects = _flatten_ancestor_metadata(self, self.settings, "final", [], presets) try: # To avoid mistakes in presets, unused keys are disallowed. self.settings.raise_if_unaccessed(unused_okay=UNUSED_OKAY) except Exception as e: raise ValueError(f"Preset \"{preset_name}\": {e}") # AI post-processing settings are not currently inherited, but are specified on a per-preset basis. _ai_settings = prototype.get("ai", {}) for k in _ai_settings: if k not in ALLOWED_AI_KEYS: raise ValueError(f"Preset \"{preset_name}\": Disallowed key in \"ai\" section: " \ f"\"{k}\"") self.ai_settings = { "prompt": _ai_settings.get("prompt", self.name.replace('-', ' ') + ", psychedelic fractal imagery"), "image_strength": _ai_settings.get("image_strength", 0.5), "cfg_scale": _ai_settings.get("cfg_scale", 15), "style_preset": _ai_settings.get("style_preset", "digital-art"), "model": _ai_settings.get("model", "stable-diffusion-v1-5") } self.ai_settings.update(prototype.get("ai_settings", {})) # This will be set to True if the call to Stable Diffusion succeeds self.ai_success = False def __str__(self): return f"<Preset \"{self.name}\">"
[docs] def is_generator(self): return self.generator_kwargs
[docs] def is_effect(self): return self.post_effects or self.final_effects
[docs] def render(self, seed, tensor=None, shape=DEFAULT_SHAPE, time=0.0, speed=1.0, filename="art.png", with_alpha=False, with_supersample=False, with_ai=False, with_upscale=False, stability_model=None): """Render the preset to an image file.""" try: tensor = multires(self, seed, tensor=tensor, shape=shape, with_supersample=with_supersample, octave_effects=self.octave_effects, post_effects=self.post_effects, with_ai=with_ai, final_effects=self.final_effects, with_alpha=with_alpha, with_upscale=with_upscale, stability_model=stability_model, time=time, speed=speed, **self.generator_kwargs) save(tensor, filename) except Exception as e: logger.error(f"Error rendering preset named {self.name}: {e}") raise
[docs]def Effect(effect_name, **kwargs): """Return a partial effects function. Invoke the wrapped function with params "tensor", "shape", "time", and "speed." """ if effect_name not in EFFECTS: raise ValueError(f'"{effect_name}" is not a registered effect name.') for k in kwargs: if k not in EFFECTS[effect_name]: raise ValueError(f'Effect "{effect_name}" does not accept a parameter named "{k}"') return partial(EFFECTS[effect_name]["func"], **kwargs)
def _flatten_ancestors(preset_name, presets, unique, ancestors): for ancestor_name in presets[preset_name].get("layers", []): if ancestor_name not in presets: raise ValueError(f"\"{ancestor_name}\" was not found among the available presets.") # "unique" layers may only be inherited once if ancestor_name in unique: continue if presets[ancestor_name].get("unique"): unique[ancestor_name] = True _flatten_ancestors(ancestor_name, presets, unique, ancestors) ancestors.append(preset_name) def _flatten_ancestor_metadata(preset, settings, key, default, presets): """Flatten ancestor preset metadata""" if isinstance(default, dict): flattened_metadata = {} else: flattened_metadata = [] for ancestor_name in preset.flattened_layers: if key == SETTINGS_KEY: ancestor = presets[ancestor_name].get(key, lambda: default) else: ancestor = presets[ancestor_name].get(key, lambda _: default) if callable(ancestor): try: if key == SETTINGS_KEY: ancestor = ancestor() else: ancestor = ancestor(settings) except Exception as e: if preset.name == ancestor_name: raise else: raise ValueError(f"In ancestor \"{ancestor_name}\": {e}") else: raise ValueError(f"{ancestor_name}: Key \"{key}\" wasn't wrapped in a lambda. " + "This can cause unexpected results for the given seed.") if not isinstance(ancestor, type(default)): raise ValueError(f"{ancestor_name}: Key \"{key}\" should be {type(default)}, not {type(data)}.") if isinstance(ancestor, dict): flattened_metadata.update(ancestor) else: flattened_metadata += ancestor return flattened_metadata
[docs]def random_member(*collections): """Return a random member from a collection, enum list, or enum. Ensures deterministic ordering.""" collection = [] for c in collections: if not hasattr(c, "__iter__"): raise ValueError(f"random_member(arg) should be iterable (collection, enum list, or enum)") if isinstance(collection, EnumMeta): collection += list(c) # maybe it's a list of enum members elif isinstance(next(iter(c), None), Enum): collection += [s[1] for s in sorted([(m.name if m is not None else "", m) for m in c])] else: # make sure order is deterministic collection += sorted(c) return collection[random.randint(0, len(collection) - 1)]
[docs]def coin_flip(): return bool(random.randint(0, 1))
[docs]def enum_range(a, b): """Return a list of enum members within the specified inclusive numeric value range.""" enum_class = type(a) members = [] for i in range(a.value, b.value + 1): members.append(enum_class(i)) return members
[docs]def reload_presets(presets): """Re-evaluate presets after changing the interpreter's random seed.""" GENERATOR_PRESETS.clear() EFFECT_PRESETS.clear() presets = presets() for preset_name in presets: try: preset = Preset(preset_name, presets) if preset.is_generator(): GENERATOR_PRESETS[preset_name] = preset if preset.is_effect(): EFFECT_PRESETS[preset_name] = preset except Exception as e: if f"Preset \"{preset_name}\"" in str(e): raise else: raise ValueError(f"Preset \"{preset_name}\": {e}")
[docs]def stash(key, value=None): """Hold on to a variable for reference within the same lambda. Returns the stashed value.""" global _STASH if value is not None: _STASH[key] = value return _STASH[key]
[docs]class UnusedKeys(Exception): """Exception raised when a preset has keys that aren't being used.""" pass
[docs]class SettingsDict(UserDict): """dict, but it makes sure the caller eats everything on their plate.""" def __init__(self, *args, **kwargs): self.__accessed__ = {} super().__init__(*args, **kwargs) def __getitem__(self, key): self.__accessed__[key] = True return super().__getitem__(key)
[docs] def was_accessed(self, key): return key in self.__accessed__
[docs] def raise_if_unaccessed(self, unused_okay=None): keys = [] for key in self: if not self.was_accessed(key) and (unused_okay is None or key not in unused_okay): keys.append(key) if keys: if len(keys) == 1: raise UnusedKeys(f"Settings key \"{keys[0]}\" (value: {self[keys[0]]}) is unused. This is usually human error.") else: raise UnusedKeys(f"Settings keys {keys} are unused. This is usually human error.")