make brushnet work

This commit is contained in:
Qing
2024-04-12 11:07:41 +08:00
parent 35f12d5b9b
commit 0a262fa811
14 changed files with 3408 additions and 56 deletions

View File

@@ -3,6 +3,8 @@ from enum import Enum
from pathlib import Path
from typing import Optional, Literal, List
from loguru import logger
from iopaint.const import (
INSTRUCT_PIX2PIX_NAME,
KANDINSKY22_NAME,
@@ -11,9 +13,9 @@ from iopaint.const import (
SDXL_CONTROLNET_CHOICES,
SD2_CONTROLNET_CHOICES,
SD_CONTROLNET_CHOICES,
SD_BRUSHNET_CHOICES,
)
from loguru import logger
from pydantic import BaseModel, Field, field_validator, computed_field
from pydantic import BaseModel, Field, computed_field, model_validator
class ModelType(str, Enum):
@@ -63,6 +65,13 @@ class ModelInfo(BaseModel):
return SD_CONTROLNET_CHOICES
return []
@computed_field
@property
def brushnets(self) -> List[str]:
if self.model_type in [ModelType.DIFFUSERS_SD]:
return SD_BRUSHNET_CHOICES
return []
@computed_field
@property
def support_strength(self) -> bool:
@@ -103,6 +112,13 @@ class ModelInfo(BaseModel):
ModelType.DIFFUSERS_SDXL_INPAINT,
]
@computed_field
@property
def support_brushnet(self) -> bool:
return self.model_type in [
ModelType.DIFFUSERS_SD,
]
@computed_field
@property
def support_freeu(self) -> bool:
@@ -369,6 +385,13 @@ class InpaintRequest(BaseModel):
"lllyasviel/control_v11p_sd15_canny", description="Controlnet method"
)
# BrushNet
enable_brushnet: bool = Field(False, description="Enable brushnet")
brushnet_method: str = Field(
SD_BRUSHNET_CHOICES[0], description="Brushnet method"
)
brushnet_conditioning_scale: float = Field(1.0, description="brushnet conditioning scale", ge=0.0, le=1.0)
# PowerPaint
powerpaint_task: PowerPaintTask = Field(
PowerPaintTask.text_guided, description="PowerPaint task"
@@ -380,31 +403,37 @@ class InpaintRequest(BaseModel):
le=1.0,
)
@field_validator("sd_seed")
@classmethod
def sd_seed_validator(cls, v: int) -> int:
if v == -1:
return random.randint(1, 99999999)
return v
@model_validator(mode='after')
def validate_field(cls, values: 'InpaintRequest'):
if values.sd_seed == -1:
values.sd_seed = random.randint(1, 99999999)
logger.info(f"Generate random seed: {values.sd_seed}")
@field_validator("controlnet_conditioning_scale")
@classmethod
def validate_field(cls, v: float, values):
use_extender = values.data["use_extender"]
enable_controlnet = values.data["enable_controlnet"]
if use_extender and enable_controlnet:
logger.info(f"Extender is enabled, set controlnet_conditioning_scale=0")
return 0
return v
if values.use_extender and values.enable_controlnet:
logger.info("Extender is enabled, set controlnet_conditioning_scale=0")
values.controlnet_conditioning_scale = 0
@field_validator("sd_strength")
@classmethod
def validate_sd_strength(cls, v: float, values):
use_extender = values.data["use_extender"]
if use_extender:
logger.info(f"Extender is enabled, set sd_strength=1")
return 1.0
return v
if values.use_extender:
logger.info("Extender is enabled, set sd_strength=1")
values.sd_strength = 1.0
if values.enable_brushnet:
logger.info("BrushNet is enabled, set enable_controlnet=False")
if values.enable_controlnet:
values.enable_controlnet = False
if values.sd_lcm_lora:
logger.info("BrushNet is enabled, set sd_lcm_lora=False")
values.sd_lcm_lora = False
if values.sd_freeu:
logger.info("BrushNet is enabled, set sd_freeu=False")
values.sd_freeu = False
if values.enable_controlnet:
logger.info("ControlNet is enabled, set enable_brushnet=False")
if values.enable_brushnet:
values.enable_brushnet = False
return values
class RunPluginRequest(BaseModel):