added civitai api connection
This commit is contained in:
79
src/civitai.py
Normal file
79
src/civitai.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import httpx
|
||||
from rich.pretty import pprint
|
||||
|
||||
|
||||
class TagStore:
|
||||
def __init__(self) -> None:
|
||||
self.tag_set = set()
|
||||
self.tag_count = 0
|
||||
|
||||
def add_tag(self, tag: int) -> None:
|
||||
if tag not in self.tag_set:
|
||||
self.tag_set.add(tag)
|
||||
self.tag_count += 1
|
||||
|
||||
def __str__(self) -> str:
|
||||
return str(self.tag_set)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return str(self.tag_set)
|
||||
|
||||
|
||||
tag_db: dict[str, set[str]] = {}
|
||||
model_db: dict[str, set[str]] = {}
|
||||
|
||||
|
||||
def update_tag_db(model_path: str, tags: list[str]) -> None:
|
||||
for tag in tags:
|
||||
if tag not in tag_db.keys():
|
||||
tag_db[tag] = set()
|
||||
tag_db[tag].add(model_path)
|
||||
if model_path not in model_db.keys():
|
||||
model_db[model_path] = set()
|
||||
model_db[model_path].add(tag)
|
||||
pprint(tag_db)
|
||||
|
||||
|
||||
def api_get_json(url: str) -> dict | None:
|
||||
"""
|
||||
Fetches JSON data from the specified URL using an HTTP GET request.
|
||||
|
||||
Args:
|
||||
url (str): The URL from which to retrieve the JSON data.
|
||||
|
||||
Returns:
|
||||
dict | None: A dictionary representing the JSON data if the HTTP response status code is 200 (OK),
|
||||
otherwise returns None.
|
||||
|
||||
Example:
|
||||
data = api_get_json("https://api.example.com/data")
|
||||
if data is not None:
|
||||
print("JSON data:", data)
|
||||
else:
|
||||
print("Failed to fetch JSON data.")
|
||||
"""
|
||||
response = httpx.get(url)
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
return None
|
||||
|
||||
|
||||
def civitai_api_model(model_id: int) -> dict | None:
|
||||
"""
|
||||
Fetches JSON data from the specified Civitai API model.
|
||||
|
||||
Args:
|
||||
model_id (int): The ID of the Civitai API model.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary representing the JSON data if the HTTP response status code is 200 (OK),
|
||||
otherwise returns None.
|
||||
|
||||
Example:
|
||||
data = civitai_api_model(1)
|
||||
if data is not None:
|
||||
print("JSON data:", data)
|
||||
else:
|
||||
print("Failed to fetch JSON data.")
|
||||
"""
|
||||
return api_get_json(f"https://civitai.com/api/v1/models/{model_id}")
|
||||
@@ -1,40 +1,93 @@
|
||||
import glob
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from os import path
|
||||
from datetime import datetime
|
||||
from os import path
|
||||
|
||||
import PySimpleGUI as sg
|
||||
from PIL import Image
|
||||
from rich.console import Console
|
||||
from rich.logging import RichHandler
|
||||
from rich.traceback import install
|
||||
|
||||
from civitai import civitai_api_model, update_tag_db
|
||||
|
||||
install(show_locals=True)
|
||||
console = Console()
|
||||
logging.basicConfig(
|
||||
level="NOTSET",
|
||||
format="%(message)s",
|
||||
datefmt="[%X]",
|
||||
handlers=[RichHandler(rich_tracebacks=True)],
|
||||
)
|
||||
|
||||
log = logging.getLogger("rich")
|
||||
|
||||
|
||||
def get_civitai_text(civitai_path: str) -> str:
|
||||
with open(civitai_path, "r") as f:
|
||||
civitai_json = json.load(f)
|
||||
def get_civitai_text(civitai_path: str, model_path: str) -> str:
|
||||
with open(civitai_path, "rb") as f:
|
||||
# civitai_json = json.load(f)
|
||||
civitai_json = json.loads(f.read().decode("utf-8", errors="ignore"))
|
||||
created_at: datetime = datetime.fromisoformat(civitai_json["createdAt"])
|
||||
return f"""Model Name: {civitai_json["model"]["name"]}
|
||||
model_id = civitai_json["modelId"]
|
||||
model_data = civitai_api_model(model_id)
|
||||
if "tags" in model_data:
|
||||
tags = model_data["tags"]
|
||||
update_tag_db(model_path, tags)
|
||||
try:
|
||||
model_name = model_data["model"]["name"]
|
||||
except KeyError:
|
||||
model_name = model_data["name"]
|
||||
return f"""Model Name: {model_name}
|
||||
Checkpoint Name: {civitai_json["name"]}
|
||||
Created at: {created_at}"""
|
||||
Created at: {created_at}
|
||||
Tags: {tags if "tags" in model_data else "None"}"""
|
||||
|
||||
|
||||
def get_extra_filenames(model_path: str):
|
||||
def get_extra_filenames(model_path: str) -> tuple[str | None, str | None]:
|
||||
"""
|
||||
Retrieves additional filenames associated with a given model file path.
|
||||
|
||||
Args:
|
||||
model_path (str): The path of the model file.
|
||||
|
||||
Returns:
|
||||
tuple[str | None, str | None]: A tuple containing two strings representing additional filenames,
|
||||
or None if the corresponding file does not exist. The first element of the tuple represents
|
||||
the preview image filename, and the second element represents the Civitai information filename.
|
||||
|
||||
Example:
|
||||
preview, info = get_extra_filenames("path/to/model_file.pth")
|
||||
if preview is not None:
|
||||
print("Preview image file exists:", preview)
|
||||
if info is not None:
|
||||
print("Civitai information file exists:", info)
|
||||
"""
|
||||
base, ext = os.path.splitext(model_path)
|
||||
preview_image: str = base + ".preview.png"
|
||||
civitai_info: str = base + ".civitai.info"
|
||||
return preview_image if path.exists(preview_image) else None, civitai_info if path.exists(civitai_info) else None
|
||||
return (
|
||||
preview_image if path.exists(preview_image) else None,
|
||||
civitai_info if path.exists(civitai_info) else None,
|
||||
)
|
||||
|
||||
|
||||
def get_models_from_folder(model_folder: str):
|
||||
return glob.glob(path.join(model_folder, "**/*.safetensors"), recursive=True) + glob.glob(
|
||||
path.join(model_folder, "**/*.ckpt"), recursive=True)
|
||||
return glob.glob(
|
||||
path.join(model_folder, "**/*.safetensors"), recursive=True
|
||||
) + glob.glob(path.join(model_folder, "**/*.ckpt"), recursive=True)
|
||||
|
||||
|
||||
layout = [
|
||||
[sg.Text("Enter a folder:")],
|
||||
[sg.InputText(key="model_folder", enable_events=True), sg.FolderBrowse()],
|
||||
[sg.Combo(key="main_folder", enable_events=True, values=[], size=60),
|
||||
sg.Combo(key="models", enable_events=True, values=[], size=60)],
|
||||
[
|
||||
sg.Combo(key="main_folder", enable_events=True, values=[], size=60),
|
||||
sg.Combo(key="models", enable_events=True, values=[], size=60),
|
||||
],
|
||||
[sg.Button("Index Models", key="index_models", disabled=True)],
|
||||
[sg.Image(key="preview_image", enable_events=True)],
|
||||
[sg.Text(key="civitai_info", enable_events=True)],
|
||||
]
|
||||
@@ -46,7 +99,9 @@ while True:
|
||||
if event == "model_folder":
|
||||
model_folder = values["model_folder"]
|
||||
if model_folder:
|
||||
list_subfolders_with_paths = [f.path for f in os.scandir(model_folder) if f.is_dir()]
|
||||
list_subfolders_with_paths = [
|
||||
f.path for f in os.scandir(model_folder) if f.is_dir()
|
||||
]
|
||||
if list_subfolders_with_paths:
|
||||
window["main_folder"].update(values=list_subfolders_with_paths)
|
||||
if event == "main_folder":
|
||||
@@ -54,19 +109,23 @@ while True:
|
||||
models = get_models_from_folder(model_subfolder)
|
||||
if model_subfolder and models:
|
||||
window["models"].update(values=models)
|
||||
window["index_models"].update(disabled=False)
|
||||
else:
|
||||
window["models"].update(values=[])
|
||||
window["index_models"].update(disabled=True)
|
||||
if event == "models":
|
||||
model_path = values["models"]
|
||||
preview_image, civitai_info = get_extra_filenames(model_path)
|
||||
if preview_image:
|
||||
img = Image.open(preview_image)
|
||||
img.thumbnail((512, 512))
|
||||
bio = io.BytesIO()
|
||||
img.save(bio, format="PNG")
|
||||
window["preview_image"].update(data=bio.getvalue())
|
||||
with Image.open(preview_image) as img:
|
||||
img.thumbnail((512, 512))
|
||||
bio = io.BytesIO()
|
||||
img.save(bio, format="PNG")
|
||||
window["preview_image"].update(data=bio.getvalue())
|
||||
else:
|
||||
window["preview_image"].update(data=None)
|
||||
if civitai_info:
|
||||
window["civitai_info"].update(get_civitai_text(civitai_info))
|
||||
window["civitai_info"].update(get_civitai_text(civitai_info, model_path))
|
||||
else:
|
||||
window["civitai_info"].update(data=None)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user