From 3e4021ec0dd030d60fdf28019c9914d14455f88c Mon Sep 17 00:00:00 2001 From: Qing Date: Tue, 20 Sep 2022 22:43:20 +0800 Subject: [PATCH] return seed --- lama_cleaner/app/src/adapters/inpainting.ts | 20 ++++++----- .../app/src/components/Editor/Editor.tsx | 8 ++++- .../Settings/NumberInputSetting.tsx | 2 +- .../app/src/components/shared/NumberInput.tsx | 11 +++++- lama_cleaner/app/src/store/Atoms.tsx | 12 +++++++ lama_cleaner/server.py | 35 +++++++++++-------- 6 files changed, 62 insertions(+), 26 deletions(-) diff --git a/lama_cleaner/app/src/adapters/inpainting.ts b/lama_cleaner/app/src/adapters/inpainting.ts index 38bb884..7f48a21 100644 --- a/lama_cleaner/app/src/adapters/inpainting.ts +++ b/lama_cleaner/app/src/adapters/inpainting.ts @@ -50,17 +50,19 @@ export default async function inpaint( fd.append('sizeLimit', sizeLimit) } - const res = await fetch(`${API_ENDPOINT}/inpaint`, { - method: 'POST', - body: fd, - }).then(async r => { - if (r.ok) { - return r.blob() + try { + const res = await fetch(`${API_ENDPOINT}/inpaint`, { + method: 'POST', + body: fd, + }) + if (res.ok) { + const blob = await res.blob() + const seed = res.headers.get('x-seed') + return { blob: URL.createObjectURL(blob), seed } } + } catch { throw new Error('Something went wrong on server side.') - }) - - return URL.createObjectURL(res) + } } export function switchModel(name: string) { diff --git a/lama_cleaner/app/src/components/Editor/Editor.tsx b/lama_cleaner/app/src/components/Editor/Editor.tsx index 154658b..b84c35b 100644 --- a/lama_cleaner/app/src/components/Editor/Editor.tsx +++ b/lama_cleaner/app/src/components/Editor/Editor.tsx @@ -35,6 +35,7 @@ import { isSDState, propmtState, runManuallyState, + seedState, settingState, toastState, } from '../../store/Atoms' @@ -86,6 +87,7 @@ export default function Editor(props: EditorProps) { const { file } = props const promptVal = useRecoilValue(propmtState) const settings = useRecoilValue(settingState) + const [seedVal, setSeed] = useRecoilState(seedState) const croperRect = useRecoilValue(croperState) const [toastVal, setToastState] = useRecoilState(toastState) const [isInpainting, setIsInpainting] = useRecoilState(isInpaintingState) @@ -220,8 +222,12 @@ export default function Editor(props: EditorProps) { if (!res) { throw new Error('empty response') } + const { blob, seed } = res + if (seed) { + setSeed(parseInt(seed, 10)) + } const newRender = new Image() - await loadImage(newRender, res) + await loadImage(newRender, blob) const newRenders = [...renders, newRender] setRenders(newRenders) draw(newRender, []) diff --git a/lama_cleaner/app/src/components/Settings/NumberInputSetting.tsx b/lama_cleaner/app/src/components/Settings/NumberInputSetting.tsx index a71aa8f..554279d 100644 --- a/lama_cleaner/app/src/components/Settings/NumberInputSetting.tsx +++ b/lama_cleaner/app/src/components/Settings/NumberInputSetting.tsx @@ -44,7 +44,7 @@ function NumberInputSetting(props: NumberInputSettingProps) { diff --git a/lama_cleaner/app/src/components/shared/NumberInput.tsx b/lama_cleaner/app/src/components/shared/NumberInput.tsx index ead35aa..2efda06 100644 --- a/lama_cleaner/app/src/components/shared/NumberInput.tsx +++ b/lama_cleaner/app/src/components/shared/NumberInput.tsx @@ -1,4 +1,9 @@ -import React, { FormEvent, InputHTMLAttributes, useState } from 'react' +import React, { + FormEvent, + InputHTMLAttributes, + useEffect, + useState, +} from 'react' import TextInput from './Input' interface NumberInputProps extends InputHTMLAttributes { @@ -12,6 +17,10 @@ const NumberInput = React.forwardRef( const { value, allowFloat, onValue, ...itemProps } = props const [innerValue, setInnerValue] = useState(value) + useEffect(() => { + setInnerValue(value) + }, [value]) + const handleOnInput = (evt: FormEvent) => { const target = evt.target as HTMLInputElement let val = target.value diff --git a/lama_cleaner/app/src/store/Atoms.tsx b/lama_cleaner/app/src/store/Atoms.tsx index ed8890b..4ee63f9 100644 --- a/lama_cleaner/app/src/store/Atoms.tsx +++ b/lama_cleaner/app/src/store/Atoms.tsx @@ -270,6 +270,18 @@ export const settingState = atom({ effects: [localStorageEffect(ROOT_STATE_KEY)], }) +export const seedState = selector({ + key: 'seed', + get: ({ get }) => { + const settings = get(settingState) + return settings.sdSeed + }, + set: ({ get, set }, newValue: any) => { + const settings = get(settingState) + set(settingState, { ...settings, sdSeed: newValue }) + }, +}) + export const hdSettingsState = selector({ key: 'hdSettings', get: ({ get }) => { diff --git a/lama_cleaner/server.py b/lama_cleaner/server.py index 0c7a404..27c2785 100644 --- a/lama_cleaner/server.py +++ b/lama_cleaner/server.py @@ -26,7 +26,7 @@ try: except: pass -from flask import Flask, request, send_file, cli +from flask import Flask, request, send_file, cli, make_response # Disable ability for Flask to display warning about using a development server in a production environment. # https://gist.github.com/jerblack/735b9953ba1ab6234abb43174210d356 @@ -112,14 +112,12 @@ def process(): hd_strategy_crop_margin=form["hdStrategyCropMargin"], hd_strategy_crop_trigger_size=form["hdStrategyCropTrigerSize"], hd_strategy_resize_limit=form["hdStrategyResizeLimit"], - - prompt=form['prompt'], - use_croper=form['useCroper'], - croper_x=form['croperX'], - croper_y=form['croperY'], - croper_height=form['croperHeight'], - croper_width=form['croperWidth'], - + prompt=form["prompt"], + use_croper=form["useCroper"], + croper_x=form["croperX"], + croper_y=form["croperY"], + croper_height=form["croperHeight"], + croper_width=form["croperWidth"], sd_strength=form["sdStrength"], sd_steps=form["sdSteps"], sd_guidance_scale=form["sdGuidanceScale"], @@ -153,10 +151,15 @@ def process(): ) ext = get_image_ext(origin_image_bytes) - return send_file( - io.BytesIO(numpy_to_bytes(res_np_img, ext)), - mimetype=f"image/{ext}", + + response = make_response( + send_file( + io.BytesIO(numpy_to_bytes(res_np_img, ext)), + mimetype=f"image/{ext}", + ) ) + response.headers["X-Seed"] = str(config.sd_seed) + return response @app.route("/model") @@ -210,8 +213,12 @@ def main(args): device = torch.device(args.device) input_image_path = args.input - model = ModelManager(name=args.model, device=device, hf_access_token=args.hf_access_token, - callbacks=[diffuser_callback]) + model = ModelManager( + name=args.model, + device=device, + hf_access_token=args.hf_access_token, + callbacks=[diffuser_callback], + ) if args.gui: app_width, app_height = args.gui_size