#!/usr/bin/env python3
# SPDX-License-Identifier: MPL-2.0
# Copyright ijl (2026)

import asyncio
import datetime
import re
import subprocess
from collections import defaultdict
from pathlib import Path

REPOSITORY = Path(__file__).parent.parent

LICENSE_APACHE = "(Apache-2.0 OR MIT)"
LICENSE_MPL2 = "MPL-2.0"

SPACES = re.compile(r"[ ]{2,}")

TO_INCLUDE = {
    ".github/**/*.yaml",
    "bench/*.py",
    "bench/run",
    "integration/*",
    "pysrc/orjson/*",
    "script/*",
    "src/**/*.rs",
    "test/**/*.py",
}

TO_EXCLUDE = {
    "integration/http",
    "script/cargo",
    "script/debug",
    "script/develop",
    "script/install-fedora",
    "script/lint",
    "script/profile",
    "script/pybench",
    "script/pytest",
    "script/valgrind",
    "src/ffi/atomiculong.rs",
}


def aggregate_files() -> list[Path]:
    files = []

    files.append(REPOSITORY / Path("build.rs"))

    for pattern in TO_INCLUDE:
        files.extend(REPOSITORY.glob(pattern))

    files = {
        each
        for each in files
        if not str(each).endswith(("py.typed", "__pycache__", ".txt"))
    }

    for filename in TO_EXCLUDE:
        files.remove(REPOSITORY / Path(filename))

    return sorted(list(files))


SKIP_FRAGMENTS = (
    "# Copyright",
    "# SPDX",
    "#!/usr",
    "// Copyright",
    "// SPDX",
)


def get_contributor_and_date(line: str) -> tuple[str, datetime.date] | None:
    if not line or "Not Committed Yet" in line:
        return None

    end = line.index(")")
    diff = line[end + 2 :]
    diff = SPACES.sub(r" ", diff)

    # skip blank and ' };' etc
    if len(diff) <= 3:
        return None

    # skip headers and imports
    for fragment in SKIP_FRAGMENTS:
        if diff.startswith(fragment):
            return None

    line = line[line.index("(") + 1 : end]
    line = SPACES.sub(r" ", line)
    segments = line.split(" ")[0:-2]
    contributor = " ".join(segments[:-1])
    date = datetime.date.fromtimestamp(int(segments[-1])).year
    return (contributor, date)


def process_blame(filename: Path, blame: str) -> list[str, list[str]] | None:
    file_license = "(Apache-2.0 OR MIT)"
    contributors = defaultdict(list)
    document = blame.split("\n")
    for line in document:
        ret = get_contributor_and_date(line)
        if ret:
            contributors[ret[0]].append(ret[1])

    overall_earliest = 9999
    overall_latest = 0
    file_credit = []
    for contributor, dates in contributors.items():
        earliest = min(dates)
        latest = max(dates)
        overall_latest = max((latest, overall_latest))
        overall_earliest = min((latest, overall_earliest))
        num_lines = len(dates)
        if earliest == latest:
            file_credit.append((num_lines, f"{contributor} ({earliest})"))
        else:
            file_credit.append((num_lines, f"{contributor} ({earliest}-{latest})"))

    if (len(contributors) == 1 and "ijl" in contributors) or overall_earliest == 2026:
        file_license = LICENSE_MPL2

    file_credit.sort(reverse=True)

    return [file_license, file_credit]


async def handle_file(filename: str):
    blame = await asyncio.create_subprocess_shell(
        f"git blame -C -C -C -M --date=raw {filename}",
        shell=True,
        stdout=asyncio.subprocess.PIPE,
        stderr=asyncio.subprocess.PIPE,
    )

    header = process_blame(filename, (await blame.stdout.read()).decode("utf-8"))
    if header is None:
        print(f"{filename.relative_to(REPOSITORY)} skipping")
        return
    if not header[1] and str(filename).endswith("__init__.py"):
        return

    if str(filename).endswith(".rs"):
        prefix = "//"
    else:
        prefix = "#"

    spdx = f"{prefix} SPDX-License-Identifier: {header[0]}"
    credit = f"{prefix} Copyright {', '.join(each[1] for each in header[1])}"

    contents = filename.read_bytes().decode("utf-8").split("\n")
    if contents[0].startswith("#!"):
        start_idx = 1
    else:
        start_idx = 0

    if contents[start_idx].startswith(f"{prefix} SPDX-License-Identifier"):
        contents[start_idx] = spdx
    else:
        contents.insert(start_idx, spdx)

    if contents[start_idx + 1].startswith(f"{prefix} Copyright"):
        contents[start_idx + 1] = credit
    else:
        contents.insert(start_idx + 1, credit)

    # separate by blank line
    first_line = start_idx + 2
    while contents[first_line].startswith(prefix):
        first_line += 1
    if len(contents[first_line]) > 0:
        contents.insert(first_line, "")

    print(f"{filename.relative_to(REPOSITORY)} {spdx}")

    filename.write_bytes("\n".join(contents).encode("utf-8"))


async def main():
    async with asyncio.TaskGroup() as tg:
        for filename in aggregate_files():
            tg.create_task(handle_file(filename))


if __name__ == "__main__":
    asyncio.run(main())
