make brushnet work
This commit is contained in:
@@ -3,6 +3,8 @@ from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Optional, Literal, List
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from iopaint.const import (
|
||||
INSTRUCT_PIX2PIX_NAME,
|
||||
KANDINSKY22_NAME,
|
||||
@@ -11,9 +13,9 @@ from iopaint.const import (
|
||||
SDXL_CONTROLNET_CHOICES,
|
||||
SD2_CONTROLNET_CHOICES,
|
||||
SD_CONTROLNET_CHOICES,
|
||||
SD_BRUSHNET_CHOICES,
|
||||
)
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field, field_validator, computed_field
|
||||
from pydantic import BaseModel, Field, computed_field, model_validator
|
||||
|
||||
|
||||
class ModelType(str, Enum):
|
||||
@@ -63,6 +65,13 @@ class ModelInfo(BaseModel):
|
||||
return SD_CONTROLNET_CHOICES
|
||||
return []
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def brushnets(self) -> List[str]:
|
||||
if self.model_type in [ModelType.DIFFUSERS_SD]:
|
||||
return SD_BRUSHNET_CHOICES
|
||||
return []
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def support_strength(self) -> bool:
|
||||
@@ -103,6 +112,13 @@ class ModelInfo(BaseModel):
|
||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||
]
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def support_brushnet(self) -> bool:
|
||||
return self.model_type in [
|
||||
ModelType.DIFFUSERS_SD,
|
||||
]
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def support_freeu(self) -> bool:
|
||||
@@ -369,6 +385,13 @@ class InpaintRequest(BaseModel):
|
||||
"lllyasviel/control_v11p_sd15_canny", description="Controlnet method"
|
||||
)
|
||||
|
||||
# BrushNet
|
||||
enable_brushnet: bool = Field(False, description="Enable brushnet")
|
||||
brushnet_method: str = Field(
|
||||
SD_BRUSHNET_CHOICES[0], description="Brushnet method"
|
||||
)
|
||||
brushnet_conditioning_scale: float = Field(1.0, description="brushnet conditioning scale", ge=0.0, le=1.0)
|
||||
|
||||
# PowerPaint
|
||||
powerpaint_task: PowerPaintTask = Field(
|
||||
PowerPaintTask.text_guided, description="PowerPaint task"
|
||||
@@ -380,31 +403,37 @@ class InpaintRequest(BaseModel):
|
||||
le=1.0,
|
||||
)
|
||||
|
||||
@field_validator("sd_seed")
|
||||
@classmethod
|
||||
def sd_seed_validator(cls, v: int) -> int:
|
||||
if v == -1:
|
||||
return random.randint(1, 99999999)
|
||||
return v
|
||||
@model_validator(mode='after')
|
||||
def validate_field(cls, values: 'InpaintRequest'):
|
||||
if values.sd_seed == -1:
|
||||
values.sd_seed = random.randint(1, 99999999)
|
||||
logger.info(f"Generate random seed: {values.sd_seed}")
|
||||
|
||||
@field_validator("controlnet_conditioning_scale")
|
||||
@classmethod
|
||||
def validate_field(cls, v: float, values):
|
||||
use_extender = values.data["use_extender"]
|
||||
enable_controlnet = values.data["enable_controlnet"]
|
||||
if use_extender and enable_controlnet:
|
||||
logger.info(f"Extender is enabled, set controlnet_conditioning_scale=0")
|
||||
return 0
|
||||
return v
|
||||
if values.use_extender and values.enable_controlnet:
|
||||
logger.info("Extender is enabled, set controlnet_conditioning_scale=0")
|
||||
values.controlnet_conditioning_scale = 0
|
||||
|
||||
@field_validator("sd_strength")
|
||||
@classmethod
|
||||
def validate_sd_strength(cls, v: float, values):
|
||||
use_extender = values.data["use_extender"]
|
||||
if use_extender:
|
||||
logger.info(f"Extender is enabled, set sd_strength=1")
|
||||
return 1.0
|
||||
return v
|
||||
if values.use_extender:
|
||||
logger.info("Extender is enabled, set sd_strength=1")
|
||||
values.sd_strength = 1.0
|
||||
|
||||
if values.enable_brushnet:
|
||||
logger.info("BrushNet is enabled, set enable_controlnet=False")
|
||||
if values.enable_controlnet:
|
||||
values.enable_controlnet = False
|
||||
if values.sd_lcm_lora:
|
||||
logger.info("BrushNet is enabled, set sd_lcm_lora=False")
|
||||
values.sd_lcm_lora = False
|
||||
if values.sd_freeu:
|
||||
logger.info("BrushNet is enabled, set sd_freeu=False")
|
||||
values.sd_freeu = False
|
||||
|
||||
if values.enable_controlnet:
|
||||
logger.info("ControlNet is enabled, set enable_brushnet=False")
|
||||
if values.enable_brushnet:
|
||||
values.enable_brushnet = False
|
||||
|
||||
return values
|
||||
|
||||
|
||||
class RunPluginRequest(BaseModel):
|
||||
|
||||
Reference in New Issue
Block a user