import { Stability, StoreKey, ACCESS_CODE_PREFIX, ApiPath, } from "@/app/constant"; import { getBearerToken } from "@/app/client/api"; import { createPersistStore } from "@/app/utils/store"; import { nanoid } from "nanoid"; import { uploadImage, base64Image2Blob } from "@/app/utils/chat"; import { models, getModelParamBasicData } from "@/app/components/sd/sd-panel"; import { useAccessStore } from "./access"; const defaultModel = { name: models[0].name, value: models[0].value, }; const defaultParams = getModelParamBasicData(models[0].params({}), {}); const DEFAULT_SD_STATE = { currentId: 0, draw: [], currentModel: defaultModel, currentParams: defaultParams, }; export const useSdStore = createPersistStore< { currentId: number; draw: any[]; currentModel: typeof defaultModel; currentParams: any; }, { getNextId: () => number; sendTask: (data: any, okCall?: Function) => void; updateDraw: (draw: any) => void; setCurrentModel: (model: any) => void; setCurrentParams: (data: any) => void; } >( DEFAULT_SD_STATE, (set, _get) => { function get() { return { ..._get(), ...methods, }; } const methods = { getNextId() { const id = ++_get().currentId; set({ currentId: id }); return id; }, sendTask(data: any, okCall?: Function) { data = { ...data, id: nanoid(), status: "running" }; set({ draw: [data, ..._get().draw] }); this.getNextId(); this.stabilityRequestCall(data); okCall?.(); }, stabilityRequestCall(data: any) { const accessStore = useAccessStore.getState(); let prefix: string = ApiPath.Stability as string; let bearerToken = ""; if (accessStore.useCustomConfig) { prefix = accessStore.stabilityUrl || (ApiPath.Stability as string); bearerToken = getBearerToken(accessStore.stabilityApiKey); } if (!bearerToken && accessStore.enabledAccessControl()) { bearerToken = getBearerToken( ACCESS_CODE_PREFIX + accessStore.accessCode, ); } const headers = { Accept: "application/json", Authorization: bearerToken, }; const path = `${prefix}/${Stability.GeneratePath}/${data.model}`; const formData = new FormData(); for (let paramsKey in data.params) { formData.append(paramsKey, data.params[paramsKey]); } fetch(path, { method: "POST", headers, body: formData, }) .then((response) => response.json()) .then((resData) => { if (resData.errors && resData.errors.length > 0) { this.updateDraw({ ...data, status: "error", error: resData.errors[0], }); this.getNextId(); return; } const self = this; if (resData.finish_reason === "SUCCESS") { uploadImage(base64Image2Blob(resData.image, "image/png")) .then((img_data) => { console.debug("uploadImage success", img_data, self); self.updateDraw({ ...data, status: "success", img_data, }); }) .catch((e) => { console.error("uploadImage error", e); self.updateDraw({ ...data, status: "error", error: JSON.stringify(e), }); }); } else { self.updateDraw({ ...data, status: "error", error: JSON.stringify(resData), }); } this.getNextId(); }) .catch((error) => { this.updateDraw({ ...data, status: "error", error: error.message }); console.error("Error:", error); this.getNextId(); }); }, updateDraw(_draw: any) { const draw = _get().draw || []; draw.some((item, index) => { if (item.id === _draw.id) { draw[index] = _draw; set(() => ({ draw })); return true; } }); }, setCurrentModel(model: any) { set({ currentModel: model }); }, setCurrentParams(data: any) { set({ currentParams: data, }); }, }; return methods; }, { name: StoreKey.SdList, version: 1.0, }, );