added civitai api connection

This commit is contained in:
2023-09-21 11:15:42 -04:00
parent a40cc83882
commit 22690dbb21
4 changed files with 349 additions and 99 deletions

79
src/civitai.py Normal file
View File

@@ -0,0 +1,79 @@
import httpx
from rich.pretty import pprint
class TagStore:
def __init__(self) -> None:
self.tag_set = set()
self.tag_count = 0
def add_tag(self, tag: int) -> None:
if tag not in self.tag_set:
self.tag_set.add(tag)
self.tag_count += 1
def __str__(self) -> str:
return str(self.tag_set)
def __repr__(self) -> str:
return str(self.tag_set)
tag_db: dict[str, set[str]] = {}
model_db: dict[str, set[str]] = {}
def update_tag_db(model_path: str, tags: list[str]) -> None:
for tag in tags:
if tag not in tag_db.keys():
tag_db[tag] = set()
tag_db[tag].add(model_path)
if model_path not in model_db.keys():
model_db[model_path] = set()
model_db[model_path].add(tag)
pprint(tag_db)
def api_get_json(url: str) -> dict | None:
"""
Fetches JSON data from the specified URL using an HTTP GET request.
Args:
url (str): The URL from which to retrieve the JSON data.
Returns:
dict | None: A dictionary representing the JSON data if the HTTP response status code is 200 (OK),
otherwise returns None.
Example:
data = api_get_json("https://api.example.com/data")
if data is not None:
print("JSON data:", data)
else:
print("Failed to fetch JSON data.")
"""
response = httpx.get(url)
if response.status_code == 200:
return response.json()
return None
def civitai_api_model(model_id: int) -> dict | None:
"""
Fetches JSON data from the specified Civitai API model.
Args:
model_id (int): The ID of the Civitai API model.
Returns:
dict: A dictionary representing the JSON data if the HTTP response status code is 200 (OK),
otherwise returns None.
Example:
data = civitai_api_model(1)
if data is not None:
print("JSON data:", data)
else:
print("Failed to fetch JSON data.")
"""
return api_get_json(f"https://civitai.com/api/v1/models/{model_id}")

View File

@@ -1,40 +1,93 @@
import glob
import io
import json
import logging
import os
from os import path
from datetime import datetime
from os import path
import PySimpleGUI as sg
from PIL import Image
from rich.console import Console
from rich.logging import RichHandler
from rich.traceback import install
from civitai import civitai_api_model, update_tag_db
install(show_locals=True)
console = Console()
logging.basicConfig(
level="NOTSET",
format="%(message)s",
datefmt="[%X]",
handlers=[RichHandler(rich_tracebacks=True)],
)
log = logging.getLogger("rich")
def get_civitai_text(civitai_path: str) -> str:
with open(civitai_path, "r") as f:
civitai_json = json.load(f)
def get_civitai_text(civitai_path: str, model_path: str) -> str:
with open(civitai_path, "rb") as f:
# civitai_json = json.load(f)
civitai_json = json.loads(f.read().decode("utf-8", errors="ignore"))
created_at: datetime = datetime.fromisoformat(civitai_json["createdAt"])
return f"""Model Name: {civitai_json["model"]["name"]}
model_id = civitai_json["modelId"]
model_data = civitai_api_model(model_id)
if "tags" in model_data:
tags = model_data["tags"]
update_tag_db(model_path, tags)
try:
model_name = model_data["model"]["name"]
except KeyError:
model_name = model_data["name"]
return f"""Model Name: {model_name}
Checkpoint Name: {civitai_json["name"]}
Created at: {created_at}"""
Created at: {created_at}
Tags: {tags if "tags" in model_data else "None"}"""
def get_extra_filenames(model_path: str):
def get_extra_filenames(model_path: str) -> tuple[str | None, str | None]:
"""
Retrieves additional filenames associated with a given model file path.
Args:
model_path (str): The path of the model file.
Returns:
tuple[str | None, str | None]: A tuple containing two strings representing additional filenames,
or None if the corresponding file does not exist. The first element of the tuple represents
the preview image filename, and the second element represents the Civitai information filename.
Example:
preview, info = get_extra_filenames("path/to/model_file.pth")
if preview is not None:
print("Preview image file exists:", preview)
if info is not None:
print("Civitai information file exists:", info)
"""
base, ext = os.path.splitext(model_path)
preview_image: str = base + ".preview.png"
civitai_info: str = base + ".civitai.info"
return preview_image if path.exists(preview_image) else None, civitai_info if path.exists(civitai_info) else None
return (
preview_image if path.exists(preview_image) else None,
civitai_info if path.exists(civitai_info) else None,
)
def get_models_from_folder(model_folder: str):
return glob.glob(path.join(model_folder, "**/*.safetensors"), recursive=True) + glob.glob(
path.join(model_folder, "**/*.ckpt"), recursive=True)
return glob.glob(
path.join(model_folder, "**/*.safetensors"), recursive=True
) + glob.glob(path.join(model_folder, "**/*.ckpt"), recursive=True)
layout = [
[sg.Text("Enter a folder:")],
[sg.InputText(key="model_folder", enable_events=True), sg.FolderBrowse()],
[sg.Combo(key="main_folder", enable_events=True, values=[], size=60),
sg.Combo(key="models", enable_events=True, values=[], size=60)],
[
sg.Combo(key="main_folder", enable_events=True, values=[], size=60),
sg.Combo(key="models", enable_events=True, values=[], size=60),
],
[sg.Button("Index Models", key="index_models", disabled=True)],
[sg.Image(key="preview_image", enable_events=True)],
[sg.Text(key="civitai_info", enable_events=True)],
]
@@ -46,7 +99,9 @@ while True:
if event == "model_folder":
model_folder = values["model_folder"]
if model_folder:
list_subfolders_with_paths = [f.path for f in os.scandir(model_folder) if f.is_dir()]
list_subfolders_with_paths = [
f.path for f in os.scandir(model_folder) if f.is_dir()
]
if list_subfolders_with_paths:
window["main_folder"].update(values=list_subfolders_with_paths)
if event == "main_folder":
@@ -54,19 +109,23 @@ while True:
models = get_models_from_folder(model_subfolder)
if model_subfolder and models:
window["models"].update(values=models)
window["index_models"].update(disabled=False)
else:
window["models"].update(values=[])
window["index_models"].update(disabled=True)
if event == "models":
model_path = values["models"]
preview_image, civitai_info = get_extra_filenames(model_path)
if preview_image:
img = Image.open(preview_image)
img.thumbnail((512, 512))
bio = io.BytesIO()
img.save(bio, format="PNG")
window["preview_image"].update(data=bio.getvalue())
with Image.open(preview_image) as img:
img.thumbnail((512, 512))
bio = io.BytesIO()
img.save(bio, format="PNG")
window["preview_image"].update(data=bio.getvalue())
else:
window["preview_image"].update(data=None)
if civitai_info:
window["civitai_info"].update(get_civitai_text(civitai_info))
window["civitai_info"].update(get_civitai_text(civitai_info, model_path))
else:
window["civitai_info"].update(data=None)