from flask import Flask, request, jsonify
from openai import OpenAI
import base64
import io
import hashlib
import textwrap
import ipaddress
from pathlib import Path
from PIL import Image, ImageDraw, ImageFont
from japanese_to_kana import *
import traceback

# parameters...

ocr_base_url = "http://127.0.0.1:8888/v1"
ocr_api_key = ""
model = "nanonets/Nanonets-OCR2-3B"

OCR_PROMPT = "Extract all Japanese text from this image. Preserve line breaks and spacing."
DEFAULT_PORT = 4404

MAX_CHARS_PER_LINE = 60

# ...end.

app = Flask(__name__)

client = OpenAI(api_key=ocr_api_key, base_url=ocr_base_url)

CACHE_DIR = Path("cache_kana")

LOCAL_NETWORK_ONLY = True

CACHE_DIR.mkdir(exist_ok=True)

@app.before_request
def check_local_network():
    if LOCAL_NETWORK_ONLY:
        client_ip = request.remote_addr
        
        if not is_local_network(client_ip):
            print(f"\nBLOCKED: Non-local IP address attempted connection: {client_ip}")
            return jsonify({
                'error': 'Access denied: This server only accepts connections from local network'
            }), 403

def read_from_cache(image_hash):
    cache_file = CACHE_DIR / f"{image_hash}.txt"
    
    if not cache_file.exists():
        return None, None
    
    try:
        with open(cache_file, 'r', encoding='utf-8') as f:
            content = f.read()

        japanese_text = None
        kana_text = None
        
        is_japanese = False
        is_kana = False
        
        for line in content.split('\n'):
            if line.startswith('japanese:'):
                is_japanese = True
                is_kana = False
            elif line.startswith('kana:'):
                is_japanese = False
                is_kana = True
            
            if is_japanese:
                if japanese_text == None:
                    japanese_text = ""
                elif japanese_text != "":
                    japanese_text = japanese_text + '\n'
                
                if line.startswith('japanese:'):
                    japanese_text = japanese_text + line[len('japanese:'):].strip()
                else:
                    japanese_text = japanese_text + line.strip()
            elif is_kana:
                if kana_text == None:
                    kana_text = ""
                elif kana_text != "":
                    kana_text = kana_text + '\n'
                
                if line.startswith('kana:'):
                    kana_text = kana_text + line[len('kana:'):].strip()
                else:
                    kana_text = kana_text + line.strip()

        if japanese_text is None or kana_text is None:
            return None, None
        
        if japanese_text == "" or kana_text == "":
            return None, None
        
        return japanese_text, kana_text
    
    except Exception as e:
        print(f"Warning: Error reading cache file {cache_file}: {e}")
        return None, None

def write_to_cache(image_hash, japanese_text, kana_text):
    cache_file = CACHE_DIR / f"{image_hash}.txt"
    
    try:
        with open(cache_file, 'w', encoding='utf-8') as f:
            f.write(f"japanese: {japanese_text}\n")
            f.write(f"kana: {kana_text}\n")
        
        print(f"  Cached to: {cache_file.name}")
    
    except Exception as e:
        print(f"Warning: Error writing cache file {cache_file}: {e}")

def process_japanese_to_kana(ocr_text):
    lines = ocr_text.split('\n')
    kana_lines = []
    
    for line in lines:
        line = line.strip()
        if line:
            try:
                kana = japanese_to_kana_mixed_spaced(line)
                kana = kana.replace("  ", " ").strip()
                if kana:
                    kana_lines.append(kana)
            except Exception as e:
                print(f"Warning: Error converting line to kana: {e}")
                if line:
                    kana_lines.append(line)
    
    return '\n'.join(kana_lines)

def generate_overlay_image(text, width=1920, height=1080, position=1, max_chars_per_line=60):
    img = Image.new('RGBA', (width, height), (0, 0, 0, 0))
    draw = ImageDraw.Draw(img)
    
    try:
        font_size = max(24, int(height / 30))

        font_paths = [
            "NotoSansCJKBold.otf"
        ]
        
        font = None
        for font_path in font_paths:
            try:
                font = ImageFont.truetype(font_path, font_size)
                break
            except:
                continue
        
        if font is None:
            font = ImageFont.load_default()
    except:
        font = ImageFont.load_default()

    def contains_japanese(text):
        for char in text:
            code = ord(char)

            if (0x3040 <= code <= 0x309F or
                0x30A0 <= code <= 0x30FF or
                0x4E00 <= code <= 0x9FFF or
                0x3400 <= code <= 0x4DBF):
                return True
        return False
    
    def wrap_text_by_width(text, max_width, font, draw):
        if not text.strip():
            return ['']

        if contains_japanese(text):
            words = []
            current_word = ''
            current_width = 0
            
            for char in text:
                char_bbox = draw.textbbox((0, 0), current_word + char, font=font)
                char_width = char_bbox[2] - char_bbox[0]
                
                if char_width > max_width:
                    if current_word:
                        words.append(current_word)
                    current_word = char
                else:
                    current_word += char
                
            if current_word:
                words.append(current_word)
            
            return words
        else:
            return textwrap.wrap(text, width=max_chars_per_line, break_long_words=False, break_on_hyphens=False)

    original_lines = text.split('\n')
    wrapped_lines = []

    max_text_width = width - (40 * 2)
    
    for line in original_lines:
        if line.strip():
            bbox = draw.textbbox((0, 0), line, font=font)
            line_width = bbox[2] - bbox[0]
            
            if line_width > max_text_width:
                wrapped = wrap_text_by_width(line, max_text_width, font, draw)
                wrapped_lines.extend(wrapped)
            else:
                wrapped_lines.append(line)
        else:
            wrapped_lines.append(line)

    padding = 20
    line_spacing = 5
    line_heights = []
    line_widths = []
    
    for line in wrapped_lines:
        if line.strip():
            bbox = draw.textbbox((0, 0), line, font=font)
            line_widths.append(bbox[2] - bbox[0])
            line_heights.append(bbox[3] - bbox[1])
        else:
            line_widths.append(0)
            line_heights.append(font_size // 2)
    
    if not line_heights:
        return None

    total_text_height = sum(line_heights) + (len(line_heights) - 1) * line_spacing
    max_text_width = max(line_widths) if line_widths else 0
    box_width = min(max_text_width + (padding * 2), width - (padding * 2))
    box_height = total_text_height + (padding * 2)

    if position == 2:
        box_y = padding
    else:
        box_y = height - box_height - padding
    
    box_x = (width - box_width) // 2

    draw.rectangle(
        [box_x, box_y, box_x + box_width, box_y + box_height],
        fill=(0, 0, 0, 180)
    )

    current_y = box_y + padding
    for i, line in enumerate(wrapped_lines):
        if line.strip():
            line_width = line_widths[i] if i < len(line_widths) else 0
            text_x = box_x + (box_width - line_width) // 2

            outline_range = 2
            for ox in range(-outline_range, outline_range + 1):
                for oy in range(-outline_range, outline_range + 1):
                    if ox != 0 or oy != 0:
                        draw.text((text_x + ox, current_y + oy), line, font=font, fill=(0, 0, 0, 255))

            draw.text((text_x, current_y), line, font=font, fill=(255, 255, 255, 255))
            
            current_y += line_heights[i] + line_spacing
        else:
            current_y += line_heights[i]

    buffer = io.BytesIO()
    img.save(buffer, format='PNG')
    img_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
    
    return img_base64

@app.route('/', methods=['GET', 'POST'])
@app.route('/translate', methods=['POST'])
def handle_request():
    if request.method == 'GET':
        return ""
    
    return translate()

def translate():
    try:
        data = request.get_json(force=True)
        
        if not data:
            return jsonify({
                'error': 'No JSON data provided in request body'
            }), 400
            
        if 'image' not in data:
            return jsonify({
                'error': 'No image data provided in JSON payload'
            }), 400
        
        image_base64 = data['image']
        image_format = data.get('format', 'png')
        
        source_lang = request.args.get('source_lang', 'ja')
        target_lang = request.args.get('target_lang', 'en')
        output_format = request.args.get('output', 'text')
        
        output_formats = [fmt.strip() for fmt in output_format.split(',')]
        wants_image = 'image' in output_formats or any(fmt in ['png', 'png-a', 'bmp'] for fmt in output_formats)
        wants_text = 'text' in output_formats or 'subs' in output_formats
        
        print(f"\n{'='*60}")
        print(f"Processing RetroArch request:")
        print(f"  Source Language: {source_lang}")
        print(f"  Target Language: {target_lang}")
        print(f"  Output Formats: {output_formats}")
        print(f"  Wants Image: {wants_image}")
        print(f"  Wants Text: {wants_text}")
        print(f"  Image Format: {image_format}")
        print(f"  Image Size: {len(image_base64)} bytes (base64)")
        
        viewport = data.get('viewport', [1920, 1080])
        coords = data.get('coords', [0, 0, viewport[0], viewport[1]])
        viewport_width = viewport[0] if len(viewport) > 0 else 1920
        viewport_height = viewport[1] if len(viewport) > 1 else 1080
        print(f"  Viewport: {viewport_width}x{viewport_height}")
        
        state = data.get('state', {})
        is_paused = state.get('paused', 0)
        print(f"  Game Paused: {bool(is_paused)}")
        
        print(f"\n  Calculating image hash...")
        image_hash = calculate_image_hash(image_base64)
        print(f"  Image SHA256: {image_hash}")
        
        print(f"  Checking cache...")
        cached_japanese, cached_kana = read_from_cache(image_hash)
        
        if cached_japanese is not None and cached_kana is not None:
            print(f"  Cache HIT! Using cached results")
            print(f"  Cached Japanese: {cached_japanese[:100]}{'...' if len(cached_japanese) > 100 else ''}")
            print(f"  Cached kana: {cached_kana[:100]}{'...' if len(cached_kana) > 100 else ''}")
            
            ocr_result = cached_japanese
            kana_text = cached_kana
        else:
            print(f"  Cache MISS - Performing OCR...")

            ocr_result = ocr_image(image_base64, OCR_PROMPT)

            ocr_preview = ocr_result[:200].replace('\n', ' ')
            print(f"  OCR Result: {ocr_preview}{'...' if len(ocr_result) > 200 else ''}")

            print(f"  Converting to kana...")
            kana_text = process_japanese_to_kana(ocr_result)

            kana_preview = kana_text[:200].replace('\n', ' ')
            print(f"  kana Result: {kana_preview}{'...' if len(kana_text) > 200 else ''}")

            write_to_cache(image_hash, ocr_result, kana_text)
            save_image_to_cache(image_hash, image_base64)
        
        print(f"{'='*60}\n")

        response = {}

        if not kana_text.strip():
            if wants_image:
                response['auto'] = 'continue'
            elif wants_text:
                response['text'] = '[No Japanese text detected]'
                response['text_position'] = 1
                response['auto'] = 'continue'
            else:
                response['auto'] = 'continue'
            
            return jsonify(response)

        if wants_image:
            print(f"  Generating overlay image...")
            overlay_img = generate_overlay_image(
                kana_text,
                width=viewport_width,
                height=viewport_height,
                position=1,
                max_chars_per_line=MAX_CHARS_PER_LINE
            )
            if overlay_img:
                response['image'] = overlay_img
                print(f"  Generated overlay image: {len(overlay_img)} bytes (base64)")

        if wants_text or not wants_image:
            response['text'] = kana_text
            response['text_position'] = 1

        response['auto'] = 'auto'
        
        return jsonify(response)
    
    except Exception as e:
        error_msg = f'Error processing image: {str(e)}'
        print(f"\nERROR: {error_msg}")
        print(traceback.format_exc())
        
        return jsonify({
            'error': error_msg
        }), 500

@app.route('/health', methods=['GET'])
def health():
    return jsonify({
        'status': 'ok',
        'service': 'RetroArch AI Translation Service (Japanese to Kana)',
        'ocr_server': client.base_url,
        'ocr_model': model
    })

def is_local_network(ip_str):
    try:
        ip = ipaddress.ip_address(ip_str)
        
        return (
            ip.is_private or
            ip.is_loopback or
            ip.is_link_local
        )
    except ValueError:
        return False

def ocr_image(img_base64, prompt_text):
    response = client.chat.completions.create(
        model=model,
        messages=[
            {
                "role": "user",
                "content": [
                    {
                        "type": "image_url",
                        "image_url": {"url": f"data:image/png;base64,{img_base64}"},
                    },
                    {
                        "type": "text",
                        "text": prompt_text,
                    },
                ],
            }
        ],
        temperature=0.0,
        max_tokens=15000
    )
    return response.choices[0].message.content

def calculate_image_hash(img_base64):
    img_binary = base64.b64decode(img_base64)

    sha256_hash = hashlib.sha256(img_binary).hexdigest()
    
    return sha256_hash

def save_image_to_cache(image_hash, img_base64):
    image_file = CACHE_DIR / f"{image_hash}.png"
    
    try:
        img_binary = base64.b64decode(img_base64)
        
        img = Image.open(io.BytesIO(img_binary))
        img.save(image_file, format='PNG')
        
        print(f"  Saved image: {image_file.name}")
    
    except Exception as e:
        print(f"Warning: Error saving image file {image_file}: {e}")

def main():
    print("="*70)
    print("RetroArch AI Translation Service - Japanese to Kana")
    print("By baka-neko / neko.works")
    print("="*70)
    print(f"\n Server Configuration:")
    print(f"   Endpoint: http://localhost:{DEFAULT_PORT}/ (or /translate)")
    print(f"   OCR Server: {client.base_url}")
    print(f"   OCR Model: {model}")
    
    print(f"\n Security Configuration:")
    if LOCAL_NETWORK_ONLY:
        print(f"   Local Network Only: ENABLED")
        print(f"   Allowed networks: 192.168.x.x, 10.x.x.x, 172.16-31.x.x, 127.x.x.x")
        print(f"   External access: BLOCKED")
    else:
        print(f"   Local Network Only: DISABLED")
        print(f"   WARNING: Server accessible from any network!")
        print(f"   Set LOCAL_NETWORK_ONLY = True for security")
    
    print(f"\n   RetroArch Setup Instructions:")
    print(f"   1. Go to Settings -> AI Service")
    print(f"   2. Set 'AI Service Enabled' to ON")
    print(f"   3. Set 'AI Service URL' to: http://localhost:{DEFAULT_PORT}/ (or /translate)")
    print(f"   4. Set 'AI Service Output' to: Image Mode")
    print(f"   5. Set 'Source Language' to: Japanese")
    print(f"   6. Assign a hotkey in Settings -> Input -> Hotkey Binds")
    
    print(f"\n   Usage:")
    print(f"   - Press the AI Service hotkey once to START translation")
    
    print(f"\n   Requirements:")
    print(f"   - Nanonets OCR server must be running at {client.base_url}")
    print(f"   - pykakasi must be installed for kana conversion")
    print(f"   - flask must be installed for web server")
    
    print(f"\n   Features:")
    print(f"   - SHA256-based caching (10-50x faster for repeated frames)")
    print(f"   - Screenshot archiving (.png files saved to cache_kana/)")
    print(f"   - Intelligent word wrapping (max {MAX_CHARS_PER_LINE} chars/line)")
    
    print(f"\n{'='*70}")
    print(f"Starting server on 0.0.0.0:{DEFAULT_PORT}...")
    if LOCAL_NETWORK_ONLY:
        print(f"  Security: Local network access only")
    else:
        print(f"️  Security: Open to all networks (NOT RECOMMENDED)")
    print(f"{'='*70}\n")
    
    try:
        app.run(host='0.0.0.0', port=DEFAULT_PORT, debug=False, threaded=True)
    except KeyboardInterrupt:
        print("\n\nShutting down server...")
    except Exception as e:
        print(f"\n\nError starting server: {e}")
        print("Make sure port 4404 is not already in use.")

if __name__ == '__main__':
    main()