#!/usr/bin/env python3
import sys, argparse
from PIL import Image, ImageEnhance
import numpy as np, cv2

def to_1bit(pil_img):
    return pil_img.convert("1")

def mode_edge(arr):
    blur = cv2.GaussianBlur(arr, (3,3), 0)
    edges = cv2.Canny(blur, 60, 180)            # viền
    edges = cv2.bitwise_not(edges)              # nền trắng, nét đen
    return Image.fromarray(edges)

def mode_threshold(arr, thr=None):
    # Otsu tự động nếu không truyền thr
    if thr is None:
        _, mask = cv2.threshold(arr, 150, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)
    else:
        _, mask = cv2.threshold(arr, int(thr), 255, cv2.THRESH_BINARY)
    return Image.fromarray(mask)

def mode_adaptive(arr):
    mask = cv2.adaptiveThreshold(arr, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
                                 cv2.THRESH_BINARY, 21, 5)
    return Image.fromarray(mask)

def mode_dither(pil_gray, contrast=1.0):
    if contrast != 1.0:
        pil_gray = ImageEnhance.Contrast(pil_gray).enhance(float(contrast))
    # Floyd-Steinberg dither chuẩn của Pillow khi convert("1")
    return pil_gray.convert("1")

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("inp")
    ap.add_argument("out")
    ap.add_argument("--mode", choices=["edge","threshold","adaptive","dither"], default="threshold")
    ap.add_argument("--thr", type=int, help="ngưỡng cố định cho mode=threshold (0..255); bỏ để Otsu")
    ap.add_argument("--contrast", type=float, default=1.0, help="độ tương phản cho dither")
    args = ap.parse_args()

    pil = Image.open(args.inp).convert("L")       # giữ nguyên kích thước
    arr = np.array(pil)

    if args.mode == "edge":
        out = mode_edge(arr)
        out = to_1bit(out)                        # chốt về 1-bit
    elif args.mode == "threshold":
        out = mode_threshold(arr, args.thr)
        out = to_1bit(out)
    elif args.mode == "adaptive":
        out = mode_adaptive(arr)
        out = to_1bit(out)
    else:  # dither
        out = mode_dither(pil, args.contrast)

    out.save(args.out, optimize=True)

if __name__ == "__main__":
    main()
