704 lines
21 KiB
Python
704 lines
21 KiB
Python
import time
|
|
import os
|
|
import json
|
|
import aiohttp
|
|
import pprint
|
|
import decimal
|
|
import aiofiles
|
|
import sys
|
|
import asyncio
|
|
import aiosqlite
|
|
import base64
|
|
import logging
|
|
import pandas as pd
|
|
import plotly.express as px
|
|
import plotly.graph_objs as go
|
|
import plotly.io as pio
|
|
from collections import defaultdict, Counter
|
|
from pathlib import Path
|
|
from urllib.parse import urlparse
|
|
from typing import Any, Optional, List, Set, Tuple, Dict
|
|
from aiolimiter import AsyncLimiter
|
|
from dataclasses import dataclass, field
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
# from https://github.com/picobyte/stable-diffusion-webui-wd14-tagger/blob/49cfdcd6be77086e6cc0d26b5ec26605b702be6e/tagger/utils.py#L15C1-L58C1
|
|
DEFAULTS = [
|
|
"wd14-vit.v1",
|
|
"wd14-vit.v2",
|
|
"wd14-convnext.v1",
|
|
"wd14-convnext.v2",
|
|
"wd14-convnextv2.v1",
|
|
# "wd14-swinv2-v1",
|
|
"wd-v1-4-moat-tagger.v2",
|
|
"mld-caformer.dec-5-97527",
|
|
# broken model: "mld-tresnetd.6-30000",
|
|
]
|
|
|
|
RATING = ["general", "explicit", "questionable", "sensitive", "safe"]
|
|
|
|
|
|
@dataclass
|
|
class InterrogatorPost:
|
|
tags: List[str]
|
|
time_taken: float
|
|
|
|
|
|
@dataclass
|
|
class Interrogator:
|
|
model_id: str
|
|
address: str
|
|
threshold: float = 0.55
|
|
_fucked_rating: bool = False
|
|
|
|
def _process(self, lst):
|
|
return lst
|
|
|
|
async def fetch(self, ctx, md5_hash):
|
|
rows = await ctx.db.execute_fetchall(
|
|
"select output_tag_string, time_taken from interrogated_posts where md5 = ? and model_name = ?",
|
|
(md5_hash, self.model_id),
|
|
)
|
|
if not rows:
|
|
raise AssertionError("run fight mode first, there are missing posts..")
|
|
tag_string, time_taken = rows[0][0], rows[0][1]
|
|
return InterrogatorPost(self._process(tag_string.split()), time_taken)
|
|
|
|
|
|
class DDInterrogator(Interrogator):
|
|
def _process(self, lst):
|
|
new_lst = []
|
|
for tag in lst:
|
|
if tag.startswith("rating:"):
|
|
continue
|
|
else:
|
|
original_danbooru_tag = tag
|
|
|
|
if original_danbooru_tag == "safe":
|
|
continue
|
|
|
|
if original_danbooru_tag in RATING:
|
|
continue
|
|
|
|
new_lst.append(original_danbooru_tag)
|
|
return new_lst
|
|
|
|
async def interrogate(self, ctx, path):
|
|
async with ctx.session.post(
|
|
f"{self.address}/",
|
|
params={
|
|
"threshold": "0.55",
|
|
},
|
|
headers={"Authorization": "Bearer sex"},
|
|
data={"file": path.open("rb")},
|
|
) as resp:
|
|
assert resp.status == 200
|
|
tags = await resp.json()
|
|
upstream_tags = [tag.replace(" ", "_") for tag in tags]
|
|
return " ".join(upstream_tags)
|
|
|
|
|
|
class SDInterrogator(Interrogator):
|
|
def _process(self, lst):
|
|
new_lst = []
|
|
for tag in lst:
|
|
if tag.startswith("rating_"):
|
|
continue
|
|
elif tag in RATING:
|
|
continue
|
|
else:
|
|
original_danbooru_tag = tag
|
|
new_lst.append(original_danbooru_tag)
|
|
return new_lst
|
|
|
|
async def interrogate(self, ctx, path):
|
|
async with aiofiles.open(path, "rb") as fd:
|
|
as_base64 = base64.b64encode(await fd.read()).decode("utf-8")
|
|
|
|
url = f"{self.address}/tagger/v1/interrogate"
|
|
async with ctx.session.post(
|
|
url,
|
|
json={
|
|
"model": self.model_id,
|
|
"threshold": self.threshold,
|
|
"image": as_base64,
|
|
},
|
|
) as resp:
|
|
log.info("%s got %d from %s", path, resp.status, url)
|
|
assert resp.status == 200
|
|
data = await resp.json()
|
|
tags = []
|
|
|
|
for maybe_tag, maybe_weight in data["caption"].items():
|
|
if isinstance(maybe_weight, float):
|
|
tags.append(maybe_tag)
|
|
elif isinstance(maybe_weight, dict):
|
|
for tag, weight in maybe_weight.items():
|
|
assert isinstance(weight, float)
|
|
tags.append(tag)
|
|
else:
|
|
raise AssertionError(f"invalid weight: {maybe_weight!s}")
|
|
|
|
upstream_tags = [tag.replace(" ", "_") for tag in tags]
|
|
return " ".join(upstream_tags)
|
|
|
|
|
|
def tag_string_for(post: dict) -> str:
|
|
return (
|
|
post["tag_string_general"]
|
|
+ " "
|
|
+ post["tag_string_copyright"]
|
|
+ " "
|
|
+ post["tag_string_character"]
|
|
)
|
|
|
|
|
|
class ControlInterrogator(Interrogator):
|
|
async def fetch(self, ctx, path):
|
|
md5_hash = Path(path).stem
|
|
post = await fetch_post(ctx, md5_hash)
|
|
tag_string = tag_string_for(post)
|
|
return InterrogatorPost(tag_string.split(), 0)
|
|
|
|
async def interrogate(self, ctx, path):
|
|
md5_hash = Path(path).stem
|
|
post = await fetch_post(ctx, md5_hash)
|
|
return tag_string_for(post)
|
|
|
|
|
|
@dataclass
|
|
class Config:
|
|
sd_webui_address: str
|
|
dd_address: str
|
|
dd_model_name: str
|
|
sd_webui_extras: Dict[str, str]
|
|
camie_address: str
|
|
joytag_address: str
|
|
sd_webui_models: List[str] = field(default_factory=lambda: list(DEFAULTS))
|
|
|
|
@property
|
|
def all_available_models(self) -> List[Any]:
|
|
return (
|
|
[
|
|
SDInterrogator(sd_interrogator, self.sd_webui_address)
|
|
for sd_interrogator in self.sd_webui_models
|
|
]
|
|
+ [
|
|
SDInterrogator(sd_interrogator, url)
|
|
for sd_interrogator, url in self.sd_webui_extras.items()
|
|
]
|
|
+ [
|
|
DDInterrogator(self.dd_model_name, self.dd_address),
|
|
SDInterrogator("camie-tagger-v1", self.camie_address, 0.5, True),
|
|
SDInterrogator("joytag-v1", self.joytag_address, 0.5),
|
|
]
|
|
+ [ControlInterrogator("control", None)]
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class Context:
|
|
db: Any
|
|
session: aiohttp.ClientSession
|
|
config: Config
|
|
|
|
|
|
@dataclass
|
|
class Booru:
|
|
session: aiohttp.ClientSession
|
|
limiter: AsyncLimiter
|
|
tag_limiter: AsyncLimiter
|
|
file_locks: dict = field(default_factory=lambda: defaultdict(asyncio.Lock))
|
|
tag_locks: dict = field(default_factory=lambda: defaultdict(asyncio.Lock))
|
|
|
|
fetch_type = "hash"
|
|
|
|
@property
|
|
def hash_style(self):
|
|
return HashStyle.md5
|
|
|
|
|
|
class Danbooru(Booru):
|
|
title = "Danbooru"
|
|
base_url = "https://danbooru.donmai.us"
|
|
|
|
async def posts(self, tag_query: str, limit, page: int):
|
|
log.info("%s: submit %r", self.title, tag_query)
|
|
async with self.limiter:
|
|
log.info(
|
|
"%s: submit upstream query=%r limit=%r page=%r",
|
|
self.title,
|
|
tag_query,
|
|
limit,
|
|
page,
|
|
)
|
|
async with self.session.get(
|
|
f"{self.base_url}/posts.json",
|
|
params={"tags": tag_query, "limit": limit, "page": page},
|
|
) as resp:
|
|
assert resp.status == 200
|
|
rjson = await resp.json()
|
|
return rjson
|
|
|
|
|
|
DOWNLOADS = Path.cwd() / "posts"
|
|
DOWNLOADS.mkdir(exist_ok=True)
|
|
|
|
|
|
ALLOWED_EXTENSIONS = (
|
|
"jpg",
|
|
"jpeg",
|
|
"png",
|
|
)
|
|
|
|
|
|
async def download_images(ctx):
|
|
try:
|
|
tagquery = sys.argv[2]
|
|
except IndexError:
|
|
tagquery = ""
|
|
|
|
try:
|
|
limit = int(sys.argv[3])
|
|
except IndexError:
|
|
limit = 30
|
|
|
|
try:
|
|
pageskip = int(sys.argv[4])
|
|
except IndexError:
|
|
pageskip = 150
|
|
|
|
danbooru = Danbooru(
|
|
ctx.session,
|
|
AsyncLimiter(1, 10),
|
|
AsyncLimiter(1, 3),
|
|
)
|
|
|
|
posts = await danbooru.posts(tagquery, limit, pageskip)
|
|
for post in posts:
|
|
if "md5" not in post:
|
|
continue
|
|
md5 = post["md5"]
|
|
log.info("processing post %r", md5)
|
|
existing_post = await fetch_post(ctx, md5)
|
|
if existing_post:
|
|
log.info("already exists %r", md5)
|
|
continue
|
|
|
|
# download the post
|
|
post_file_url = post["file_url"]
|
|
post_url_path = Path(urlparse(post_file_url).path)
|
|
good_extension = False
|
|
for extension in ALLOWED_EXTENSIONS:
|
|
if extension in post_url_path.suffix:
|
|
good_extension = True
|
|
if not good_extension:
|
|
log.info("ignoring %r, invalid extension (%r)", md5, post_url_path.suffix)
|
|
continue
|
|
post_filename = post_url_path.name
|
|
post_filepath = DOWNLOADS / post_filename
|
|
if not post_filepath.exists():
|
|
log.info("downloading %r to %r", post_file_url, post_filepath)
|
|
async with ctx.session.get(post_file_url) as resp:
|
|
assert resp.status == 200
|
|
with post_filepath.open("wb") as fd:
|
|
async for chunk in resp.content.iter_chunked(1024):
|
|
fd.write(chunk)
|
|
|
|
# when it's done successfully, insert it
|
|
await insert_post(ctx, post)
|
|
|
|
|
|
async def fetch_post(ctx, md5) -> Optional[dict]:
|
|
rows = await ctx.db.execute_fetchall("select data from posts where md5 = ?", (md5,))
|
|
if not rows:
|
|
return None
|
|
assert len(rows) == 1
|
|
post = json.loads(rows[0][0])
|
|
post["tag_string"] = tag_string_for(post)
|
|
return post
|
|
|
|
|
|
async def insert_post(ctx, post):
|
|
await ctx.db.execute_insert(
|
|
"insert into posts (md5, filepath ,data) values (?, ?,?)",
|
|
(
|
|
post["md5"],
|
|
Path(urlparse(post["file_url"]).path).name,
|
|
json.dumps(post),
|
|
),
|
|
)
|
|
await ctx.db.commit()
|
|
|
|
|
|
async def insert_interrogated_result(
|
|
ctx, interrogator: Interrogator, md5: str, tag_string: str, time_taken: float
|
|
):
|
|
await ctx.db.execute_insert(
|
|
"insert into interrogated_posts (md5, model_name, output_tag_string, time_taken) values (?,?,?,?)",
|
|
(md5, interrogator.model_id, tag_string, time_taken),
|
|
)
|
|
await ctx.db.commit()
|
|
|
|
|
|
async def process_hash(ctx, interrogator, missing_hash, semaphore, index, total):
|
|
async with semaphore:
|
|
log.info("interrogating %r (%d/%d)", missing_hash, index, total)
|
|
post_filepath = next(DOWNLOADS.glob(f"{missing_hash}*"))
|
|
|
|
start_ts = time.monotonic()
|
|
tag_string = await interrogator.interrogate(ctx, post_filepath)
|
|
end_ts = time.monotonic()
|
|
time_taken = round(end_ts - start_ts, 10)
|
|
|
|
log.info("took %.5fsec, got %r", time_taken, tag_string)
|
|
await insert_interrogated_result(
|
|
ctx, interrogator, missing_hash, tag_string, time_taken
|
|
)
|
|
|
|
|
|
async def fight(ctx):
|
|
interrogators = ctx.config.all_available_models
|
|
|
|
all_rows = await ctx.db.execute_fetchall("select md5 from posts")
|
|
all_hashes = set(r[0] for r in all_rows)
|
|
|
|
for interrogator in interrogators:
|
|
log.info("processing fight for %r", interrogator)
|
|
# calculate set of images we didn't interrogate yet
|
|
interrogated_rows = await ctx.db.execute_fetchall(
|
|
"select md5 from interrogated_posts where model_name = ?",
|
|
(interrogator.model_id,),
|
|
)
|
|
interrogated_hashes = set(row[0] for row in interrogated_rows)
|
|
missing_hashes = all_hashes - interrogated_hashes
|
|
|
|
log.info("missing %d hashes", len(missing_hashes))
|
|
|
|
semaphore = asyncio.Semaphore(3)
|
|
|
|
tasks = []
|
|
for index, missing_hash in enumerate(missing_hashes):
|
|
task = process_hash(
|
|
ctx,
|
|
interrogator,
|
|
missing_hash,
|
|
semaphore,
|
|
index + 1,
|
|
len(missing_hashes),
|
|
)
|
|
tasks.append(task)
|
|
|
|
# Run all tasks concurrently with semaphore limiting to 3 at a time
|
|
await asyncio.gather(*tasks)
|
|
|
|
|
|
def score(
|
|
danbooru_tags: Set[str], interrogator_tags: Set[str]
|
|
) -> Tuple[decimal.Decimal, Set[str]]:
|
|
|
|
f1 = None
|
|
# Handle edge cases
|
|
if len(danbooru_tags) == 0 and len(interrogator_tags) == 0:
|
|
f1 = decimal.Decimal("1.0") # Both empty means perfect match
|
|
|
|
if len(danbooru_tags) == 0 or len(interrogator_tags) == 0:
|
|
f1 = decimal.Decimal("0.0") # One empty means no match
|
|
|
|
# Calculate true positives (tags that appear in both sets)
|
|
true_positives = decimal.Decimal(len(danbooru_tags.intersection(interrogator_tags)))
|
|
|
|
# Calculate precision: TP / (TP + FP)
|
|
precision = (
|
|
true_positives / len(interrogator_tags)
|
|
if len(interrogator_tags) > 0
|
|
else decimal.Decimal("0.0")
|
|
)
|
|
|
|
# Calculate recall: TP / (TP + FN)
|
|
recall = (
|
|
true_positives / len(danbooru_tags)
|
|
if len(danbooru_tags) > 0
|
|
else decimal.Decimal("0.0")
|
|
)
|
|
print("recall", recall)
|
|
|
|
# Handle the case where both precision and recall are 0
|
|
if f1 is None and precision == 0 and recall == 0:
|
|
f1 = decimal.Decimal("0.0")
|
|
else:
|
|
f1 = decimal.Decimal("2.0") * (precision * recall) / (precision + recall)
|
|
|
|
tags_not_in_danbooru = interrogator_tags - danbooru_tags
|
|
return (
|
|
round(f1, 10),
|
|
tags_not_in_danbooru,
|
|
)
|
|
|
|
|
|
async def scores(ctx):
|
|
interrogators = ctx.config.all_available_models
|
|
|
|
all_rows = await ctx.db.execute_fetchall("select md5 from posts")
|
|
all_hashes = set(r[0] for r in all_rows)
|
|
|
|
# absolute_scores = defaultdict(decimal.Decimal)
|
|
model_scores = defaultdict(dict)
|
|
runtimes = defaultdict(list)
|
|
incorrect_tags_counters = defaultdict(Counter)
|
|
predicted_tags_counter = defaultdict(int)
|
|
|
|
for md5_hash in all_hashes:
|
|
log.info("processing score for %r", md5_hash)
|
|
post = await fetch_post(ctx, md5_hash)
|
|
danbooru_tags = set(post["tag_string"].split())
|
|
for interrogator in interrogators:
|
|
post_data = await interrogator.fetch(ctx, md5_hash)
|
|
runtimes[interrogator.model_id].append(post_data.time_taken)
|
|
interrogator_tags = set(post_data.tags)
|
|
tagging_score, incorrect_tags = score(danbooru_tags, interrogator_tags)
|
|
for tag in incorrect_tags:
|
|
incorrect_tags_counters[interrogator.model_id][tag] += 1
|
|
|
|
log.info(f"{interrogator.model_id} {tagging_score}")
|
|
predicted_tags_counter[interrogator.model_id] += len(interrogator_tags)
|
|
correct_tags = interrogator_tags.intersection(danbooru_tags)
|
|
model_scores[interrogator.model_id][md5_hash] = {
|
|
"score": tagging_score,
|
|
"predicted_tags": interrogator_tags,
|
|
"incorrect_tags": incorrect_tags,
|
|
"correct_tags": correct_tags,
|
|
}
|
|
|
|
summed_scores = {
|
|
model_id: sum(d["score"] for d in post_scores.values())
|
|
for model_id, post_scores in model_scores.items()
|
|
}
|
|
|
|
normalized_scores = {
|
|
model: round(summed_scores[model] / len(all_hashes), 10)
|
|
for model in summed_scores
|
|
}
|
|
|
|
print("scores are [worst...best]")
|
|
|
|
for model in sorted(
|
|
normalized_scores.keys(),
|
|
key=lambda model: normalized_scores[model],
|
|
reverse=True,
|
|
):
|
|
average_runtime = sum(runtimes[model]) / len(runtimes[model])
|
|
print(model, normalized_scores[model], "runtime", average_runtime, "sec")
|
|
if os.getenv("SHOWOFF", "0") == "1":
|
|
print("[", end="")
|
|
|
|
for bad_md5_hash in sorted(
|
|
model_scores[model].keys(),
|
|
key=lambda md5_hash: model_scores[model][md5_hash]["score"],
|
|
)[:4]:
|
|
data = model_scores[model][bad_md5_hash]
|
|
if os.getenv("DEBUG", "0") == "1":
|
|
print(md5_hash, data["score"], " ".join(data["incorrect_tags"]))
|
|
else:
|
|
print(data["score"], end=",")
|
|
print("...", end="")
|
|
|
|
for good_md5_hash in sorted(
|
|
model_scores[model].keys(),
|
|
key=lambda md5_hash: model_scores[model][md5_hash]["score"],
|
|
reverse=True,
|
|
)[:4]:
|
|
data = model_scores[model][good_md5_hash]
|
|
print(data["score"], end=",")
|
|
|
|
print("]")
|
|
total_incorrect = 0
|
|
for _, c in incorrect_tags_counters[model].most_common(10000000):
|
|
total_incorrect += c
|
|
print(
|
|
"most incorrect tags from",
|
|
total_incorrect,
|
|
"incorrect tags",
|
|
"predicted",
|
|
predicted_tags_counter[model],
|
|
"tags",
|
|
)
|
|
for t, c in incorrect_tags_counters[model].most_common(7):
|
|
print("\t", t, c)
|
|
PLOTS = Path.cwd() / "plots"
|
|
PLOTS.mkdir(exist_ok=True)
|
|
|
|
log.info("plotting score histogram...")
|
|
|
|
data_for_df = {}
|
|
data_for_df["scores"] = []
|
|
data_for_df["model"] = []
|
|
|
|
for model in sorted(
|
|
normalized_scores.keys(),
|
|
key=lambda model: normalized_scores[model],
|
|
reverse=True,
|
|
):
|
|
for post_score in (d["score"] for d in model_scores[model].values()):
|
|
data_for_df["scores"].append(post_score)
|
|
data_for_df["model"].append(model)
|
|
|
|
df = pd.DataFrame(data_for_df)
|
|
fig = px.histogram(
|
|
df,
|
|
x="scores",
|
|
color="model",
|
|
histfunc="count",
|
|
marginal="rug",
|
|
histnorm="probability",
|
|
)
|
|
pio.write_image(fig, PLOTS / "score_histogram.png", width=1024, height=800)
|
|
|
|
log.info("plotting positive histogram...")
|
|
plot2(PLOTS / "positive_score_histogram.png", normalized_scores, model_scores)
|
|
log.info("plotting error rates...")
|
|
plot3(
|
|
PLOTS / "error_rate.png",
|
|
PLOTS / "score_avg.png",
|
|
normalized_scores,
|
|
model_scores,
|
|
)
|
|
|
|
print("md table")
|
|
print("| model | score | avg. runtime (ms) |")
|
|
print("| ---- | ---- | ---- |")
|
|
for model in sorted(
|
|
normalized_scores.keys(),
|
|
key=lambda model: normalized_scores[model],
|
|
reverse=True,
|
|
):
|
|
model_score = normalized_scores[model]
|
|
average_runtime = sum(runtimes[model]) / len(runtimes[model])
|
|
print(
|
|
"|", model, "|", round(model_score, 3), "|", round(average_runtime, 2), "|"
|
|
)
|
|
|
|
|
|
def plot2(output_path, normalized_scores, model_scores):
|
|
data_for_df = {}
|
|
data_for_df["scores"] = []
|
|
data_for_df["model"] = []
|
|
|
|
for model in sorted(
|
|
normalized_scores.keys(),
|
|
key=lambda model: normalized_scores[model],
|
|
reverse=True,
|
|
):
|
|
for post_score in (d["score"] for d in model_scores[model].values()):
|
|
if post_score < 0:
|
|
continue
|
|
data_for_df["scores"].append(post_score)
|
|
data_for_df["model"].append(model)
|
|
|
|
df = pd.DataFrame(data_for_df)
|
|
fig = px.histogram(
|
|
df,
|
|
x="scores",
|
|
color="model",
|
|
histfunc="count",
|
|
marginal="rug",
|
|
histnorm="probability",
|
|
)
|
|
pio.write_image(fig, output_path, width=1024, height=800)
|
|
|
|
|
|
def plot3(output_path, output_score_avg_path, normalized_scores, model_scores):
|
|
data_for_df = {
|
|
"model": [],
|
|
"score_avg": [],
|
|
"predicted": [],
|
|
"correct": [],
|
|
"incorrect": [],
|
|
}
|
|
|
|
for model in sorted(
|
|
normalized_scores.keys(),
|
|
key=lambda model: normalized_scores[model],
|
|
reverse=True,
|
|
):
|
|
total_predicted_tags, total_incorrect_tags, total_correct_tags = 0, 0, 0
|
|
for score_data in model_scores[model].values():
|
|
total_predicted_tags += len(score_data["predicted_tags"])
|
|
total_incorrect_tags += len(score_data["incorrect_tags"])
|
|
total_correct_tags += len(score_data["correct_tags"])
|
|
|
|
data_for_df["score_avg"].append(normalized_scores[model])
|
|
data_for_df["predicted"].append(total_predicted_tags)
|
|
data_for_df["incorrect"].append(total_incorrect_tags)
|
|
data_for_df["correct"].append(total_correct_tags)
|
|
data_for_df["model"].append(model)
|
|
|
|
df = pd.DataFrame(data_for_df)
|
|
|
|
fig = go.Figure(
|
|
data=[
|
|
go.Bar(name="predicted tags", x=df.model, y=df.predicted),
|
|
go.Bar(name="incorrect tags", x=df.model, y=df.incorrect),
|
|
go.Bar(name="correct tags", x=df.model, y=df.correct),
|
|
]
|
|
)
|
|
pio.write_image(fig, output_path, width=1024, height=800)
|
|
fig2 = go.Figure(
|
|
data=[
|
|
go.Bar(name="score avg", x=df.model, y=df.score_avg),
|
|
]
|
|
)
|
|
pio.write_image(fig2, output_score_avg_path, width=1024, height=800)
|
|
|
|
|
|
async def realmain(ctx):
|
|
await ctx.db.executescript(
|
|
"""
|
|
create table if not exists posts (
|
|
md5 text primary key,
|
|
filepath text,
|
|
data text
|
|
);
|
|
create table if not exists interrogated_posts (
|
|
md5 text,
|
|
model_name text not null,
|
|
output_tag_string text not null,
|
|
time_taken real not null,
|
|
primary key (md5, model_name)
|
|
);
|
|
"""
|
|
)
|
|
|
|
try:
|
|
mode = sys.argv[1]
|
|
except IndexError:
|
|
raise Exception("must have mode")
|
|
|
|
if mode == "download_images":
|
|
await download_images(ctx)
|
|
elif mode == "fight":
|
|
await fight(ctx)
|
|
elif mode == "scores":
|
|
await scores(ctx)
|
|
else:
|
|
raise AssertionError(f"invalid mode {mode}")
|
|
|
|
|
|
async def main():
|
|
CONFIG_PATH = Path.cwd() / "config.json"
|
|
log.info("hewo")
|
|
async with aiosqlite.connect("./data.db") as db:
|
|
async with aiohttp.ClientSession() as session:
|
|
with CONFIG_PATH.open() as config_fd:
|
|
config_json = json.load(config_fd)
|
|
|
|
config = Config(**config_json)
|
|
ctx = Context(db, session, config)
|
|
await realmain(ctx)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
logging.basicConfig(level=logging.INFO)
|
|
asyncio.run(main())
|