#!/usr/bin/env python3 import argparse import datetime as dt import functools import json import math import os import re import socket import ssl import subprocess import sys import time import urllib.parse import urllib.request DOMAIN_RE = re.compile(r"^(?=.{1,253}$)(?!-)[A-Za-z0-9-]{1,63}(?= len(parts): return [cur] part = parts[idx] if part.endswith("[]"): key = part[:-2] if isinstance(cur, dict): arr = cur.get(key) else: arr = None if not isinstance(arr, list): return [] out = [] for item in arr: out.extend(walk(item, idx + 1)) return out if isinstance(cur, dict) and part in cur: return walk(cur[part], idx + 1) return [] return walk(data, 0) def parse_domains(payload, parser_cfg): domains = [] for p in parser_cfg.get("field_paths", []): values = get_values_by_path(payload, p) domains.extend(flatten_values(values)) for p in parser_cfg.get("json_paths", []): v = get_by_json_path(payload, p) if v is not None: domains.extend(flatten_values(v)) if not domains: regex_s = parser_cfg.get("regex", r"[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}") text = json.dumps(payload, ensure_ascii=True) domains.extend(re.findall(regex_s, text)) clean = [] seen = set() for d in domains: d = str(d).strip().lower().rstrip(".") if (DOMAIN_RE.match(d) or IPV4_RE.match(d)) and d not in seen: seen.add(d) clean.append(d) return clean def parse_timezone(tz_raw): if tz_raw is None: return dt.timezone.utc s = str(tz_raw).strip().upper() if s in {"", "UTC", "Z", "+00:00", "+0000"}: return dt.timezone.utc m = re.match(r"^([+-])(\d{2}):?(\d{2})$", s) if not m: raise ValueError(f"invalid created_time_timezone: {tz_raw}") sign = 1 if m.group(1) == "+" else -1 hh = int(m.group(2)) mm = int(m.group(3)) if hh > 23 or mm > 59: raise ValueError(f"invalid created_time_timezone offset: {tz_raw}") return dt.timezone(sign * dt.timedelta(hours=hh, minutes=mm)) def parse_created_time(value, formats, timezone): if value is None: return None s = str(value).strip() if not s: return None for fmt in formats: try: parsed = dt.datetime.strptime(s, fmt) if parsed.tzinfo is None: parsed = parsed.replace(tzinfo=timezone) return parsed.astimezone(dt.timezone.utc) except Exception: continue try: iso_text = s.replace("Z", "+00:00") parsed = dt.datetime.fromisoformat(iso_text) if parsed.tzinfo is None: parsed = parsed.replace(tzinfo=timezone) return parsed.astimezone(dt.timezone.utc) except Exception: return None def normalize_domain(value): if value is None: return "" return str(value).strip().lower().rstrip(".") def to_float_or_none(value): try: f = float(value) if not math.isfinite(f): return None return f except Exception: return None def resolve_field(record, field_name, field_map): path = field_map.get(field_name) if not path: raise ValueError(f"field '{field_name}' is not registered in record_mapping.field_map") if not isinstance(record, dict): return None return get_by_json_path(record, path) def extract_records(payload, record_mapping): records_path = str(record_mapping.get("records_path", "")).strip() raw = get_values_by_path(payload, records_path) return [x for x in raw if isinstance(x, dict)] def validate_config(cfg): record_mapping = cfg.get("record_mapping") if not isinstance(record_mapping, dict): raise ValueError("record_mapping is required and must be an object") records_path = str(record_mapping.get("records_path", "")).strip() if not records_path: raise ValueError("record_mapping.records_path is required") field_map = record_mapping.get("field_map") if not isinstance(field_map, dict) or not field_map: raise ValueError("record_mapping.field_map is required and must be a non-empty object") for key, path in field_map.items(): if not str(key).strip() or not str(path).strip(): raise ValueError("record_mapping.field_map contains empty field name or path") for required in ["domain", "created_at"]: if required not in field_map: raise ValueError(f"record_mapping.field_map.{required} is required") created_time_formats = record_mapping.get("created_time_formats") if not isinstance(created_time_formats, list) or not created_time_formats: raise ValueError("record_mapping.created_time_formats is required and must be a non-empty array") for fmt in created_time_formats: if not str(fmt).strip(): raise ValueError("record_mapping.created_time_formats contains empty format") parse_timezone(record_mapping.get("created_time_timezone", "UTC")) def ensure_field_registered(field_name, where): if field_name not in field_map: raise ValueError(f"{where}: field '{field_name}' is not in record_mapping.field_map") record_filter = cfg.get("record_filter", {}) if record_filter.get("enabled", False): rules = record_filter.get("exclude_if_any", []) if not isinstance(rules, list): raise ValueError("record_filter.exclude_if_any must be an array") for i, rule in enumerate(rules): if not isinstance(rule, dict): raise ValueError(f"record_filter.exclude_if_any[{i}] must be an object") field_name = str(rule.get("field", "")).strip() if not field_name: raise ValueError(f"record_filter.exclude_if_any[{i}].field is required") ensure_field_registered(field_name, f"record_filter.exclude_if_any[{i}]") has_matcher = any(k in rule for k in ["contains", "equals", "regex"]) if not has_matcher: raise ValueError(f"record_filter.exclude_if_any[{i}] must include one of contains/equals/regex") scoring = cfg.get("scoring", {}) if scoring.get("enabled", False): strategy = str(scoring.get("strategy", "")).strip() if strategy not in {"weighted_average", "lexicographic"}: raise ValueError("scoring.strategy must be 'weighted_average' or 'lexicographic'") within_hours = to_float_or_none(scoring.get("within_hours", 24)) if within_hours is None or within_hours <= 0: raise ValueError("scoring.within_hours must be a positive number") if strategy == "weighted_average": weighted_fields = scoring.get("weighted_fields") if not isinstance(weighted_fields, list) or not weighted_fields: raise ValueError("scoring.weighted_fields is required for weighted_average strategy") for i, item in enumerate(weighted_fields): if not isinstance(item, dict): raise ValueError(f"scoring.weighted_fields[{i}] must be an object") field_name = str(item.get("field", "")).strip() if not field_name: raise ValueError(f"scoring.weighted_fields[{i}].field is required") ensure_field_registered(field_name, f"scoring.weighted_fields[{i}]") weight = to_float_or_none(item.get("weight")) if weight is None or weight <= 0: raise ValueError(f"scoring.weighted_fields[{i}].weight must be > 0") if strategy == "lexicographic": lex_fields = scoring.get("lexicographic_fields") if not isinstance(lex_fields, list) or not lex_fields: raise ValueError("scoring.lexicographic_fields is required for lexicographic strategy") for i, item in enumerate(lex_fields): if isinstance(item, str): field_name = item.strip() order = "" elif isinstance(item, dict): field_name = str(item.get("field", "")).strip() order = str(item.get("order", "")).strip().lower() else: raise ValueError(f"scoring.lexicographic_fields[{i}] must be string or object") if not field_name: raise ValueError(f"scoring.lexicographic_fields[{i}] field is required") ensure_field_registered(field_name, f"scoring.lexicographic_fields[{i}]") if order and order not in {"asc", "desc"}: raise ValueError(f"scoring.lexicographic_fields[{i}].order must be asc or desc") tie_breakers = scoring.get("tie_breakers", []) if tie_breakers is not None: if not isinstance(tie_breakers, list): raise ValueError("scoring.tie_breakers must be an array") for i, item in enumerate(tie_breakers): if not isinstance(item, dict): raise ValueError(f"scoring.tie_breakers[{i}] must be an object") field_name = str(item.get("field", "")).strip() order = str(item.get("order", "")).strip().lower() if not field_name: raise ValueError(f"scoring.tie_breakers[{i}].field is required") if order not in {"asc", "desc"}: raise ValueError(f"scoring.tie_breakers[{i}].order must be asc or desc") ensure_field_registered(field_name, f"scoring.tie_breakers[{i}]") def rule_matches(value, rule): if value is None or not isinstance(rule, dict): return False values = flatten_values(value) if not values: values = [value] case_sensitive = bool(rule.get("case_sensitive", False)) if "contains" in rule: needle = str(rule.get("contains", "")) if not needle: return False for item in values: hay = str(item) if case_sensitive: if needle in hay: return True else: if needle.lower() in hay.lower(): return True return False if "equals" in rule: target = str(rule.get("equals", "")) for item in values: item_s = str(item) if case_sensitive: if item_s == target: return True else: if item_s.lower() == target.lower(): return True return False if "regex" in rule: pattern = str(rule.get("regex", "")) if not pattern: return False flags = 0 if case_sensitive else re.IGNORECASE try: rx = re.compile(pattern, flags) except Exception: return False for item in values: if rx.search(str(item)): return True return False return False def collect_excluded_domains(records, field_map, record_filter_cfg): if not record_filter_cfg.get("enabled", False): return set() rules = record_filter_cfg.get("exclude_if_any", []) if not rules: return set() blocked = set() for record in records: domain = normalize_domain(resolve_field(record, "domain", field_map)) if not domain: continue for rule in rules: field_name = str(rule.get("field", "")).strip() if not field_name: continue value = resolve_field(record, field_name, field_map) if rule_matches(value, rule): blocked.add(domain) break return blocked def build_lexicographic_descriptors(scoring_cfg, prefer_lower): out = [] for item in scoring_cfg.get("lexicographic_fields", []): if isinstance(item, str): field_name = item.strip() order = "asc" if prefer_lower else "desc" else: field_name = str(item.get("field", "")).strip() order = str(item.get("order", "")).strip().lower() if not order: order = "asc" if prefer_lower else "desc" out.append({"field": field_name, "order": order}) return out def parse_scored_records(records, field_map, record_mapping_cfg, scoring_cfg): if not scoring_cfg.get("enabled", False): return [] strategy = str(scoring_cfg.get("strategy", "weighted_average")).strip() prefer_lower = bool(scoring_cfg.get("prefer_lower", False)) timezone = parse_timezone(record_mapping_cfg.get("created_time_timezone", "UTC")) time_formats = [str(x) for x in record_mapping_cfg.get("created_time_formats", [])] weighted_fields = scoring_cfg.get("weighted_fields", []) if strategy == "weighted_average" else [] lex_descriptors = build_lexicographic_descriptors(scoring_cfg, prefer_lower) if strategy == "lexicographic" else [] needed_fields = set() for item in weighted_fields: needed_fields.add(str(item.get("field", "")).strip()) for item in lex_descriptors: needed_fields.add(str(item.get("field", "")).strip()) for item in scoring_cfg.get("tie_breakers", []): needed_fields.add(str(item.get("field", "")).strip()) needed_fields.discard("domain") needed_fields.discard("created_at") out = [] for record in records: domain = normalize_domain(resolve_field(record, "domain", field_map)) if not domain: continue created_raw = resolve_field(record, "created_at", field_map) created_at = parse_created_time(created_raw, time_formats, timezone) field_values = {} for field_name in needed_fields: field_values[field_name] = resolve_field(record, field_name, field_map) score_value = None scores = [] lex_values = [] if strategy == "weighted_average": total = 0.0 total_weight = 0.0 missing = False for item in weighted_fields: field_name = str(item.get("field", "")).strip() weight = float(item.get("weight")) raw_v = resolve_field(record, field_name, field_map) val = to_float_or_none(raw_v) scores.append(val) if val is None: missing = True continue total += val * weight total_weight += weight if not missing and total_weight > 0: score_value = total / total_weight if strategy == "lexicographic": for item in lex_descriptors: field_name = item["field"] order = item["order"] raw_v = resolve_field(record, field_name, field_map) num_v = to_float_or_none(raw_v) v = num_v if num_v is not None else raw_v lex_values.append({"field": field_name, "value": v, "order": order}) scores.append(v) out.append( { "domain": domain, "created_at": created_at, "created_raw": created_raw, "scores": scores, "score_value": score_value, "lex_values": lex_values, "field_values": field_values, } ) return out def cmp_scalar(a, b, order): a_none = a is None b_none = b is None if a_none and b_none: return 0 if a_none: return 1 if b_none: return -1 if isinstance(a, dt.datetime): a = a.timestamp() if isinstance(b, dt.datetime): b = b.timestamp() a_num = to_float_or_none(a) b_num = to_float_or_none(b) if a_num is not None and b_num is not None: if a_num < b_num: base = -1 elif a_num > b_num: base = 1 else: base = 0 else: a_s = str(a).lower() b_s = str(b).lower() if a_s < b_s: base = -1 elif a_s > b_s: base = 1 else: base = 0 return base if order == "asc" else -base def get_sort_field_value(record, field_name): if field_name == "domain": return record.get("domain") if field_name == "created_at": return record.get("created_at") return record.get("field_values", {}).get(field_name) def rank_scored_records(records, scoring_cfg): if not records: return [] within_hours = float(scoring_cfg.get("within_hours", 24)) strategy = str(scoring_cfg.get("strategy", "weighted_average")).strip() prefer_lower = bool(scoring_cfg.get("prefer_lower", False)) tie_breakers = scoring_cfg.get("tie_breakers", []) now = dt.datetime.now(dt.timezone.utc) cutoff = now - dt.timedelta(hours=within_hours) recent = [r for r in records if r.get("created_at") is not None and r["created_at"] >= cutoff] candidates = recent if recent else records default_lex_order = "asc" if prefer_lower else "desc" def compare(a, b): if strategy == "weighted_average": order = "asc" if prefer_lower else "desc" c = cmp_scalar(a.get("score_value"), b.get("score_value"), order) if c != 0: return c elif strategy == "lexicographic": a_lex = a.get("lex_values", []) b_lex = b.get("lex_values", []) n = max(len(a_lex), len(b_lex)) for i in range(n): av = a_lex[i]["value"] if i < len(a_lex) else None bv = b_lex[i]["value"] if i < len(b_lex) else None order = default_lex_order if i < len(a_lex) and a_lex[i].get("order"): order = a_lex[i]["order"] c = cmp_scalar(av, bv, order) if c != 0: return c for item in tie_breakers: field_name = str(item.get("field", "")).strip() order = str(item.get("order", "asc")).strip().lower() av = get_sort_field_value(a, field_name) bv = get_sort_field_value(b, field_name) c = cmp_scalar(av, bv, order) if c != 0: return c return cmp_scalar(a.get("domain"), b.get("domain"), "asc") return sorted(candidates, key=functools.cmp_to_key(compare)) def apply_filter(domains, filter_cfg): include_suffixes = [s.lower() for s in filter_cfg.get("include_suffixes", []) if s] exclude_regex = [re.compile(x) for x in filter_cfg.get("exclude_regex", []) if x] out = [] for d in domains: if include_suffixes and not any(d.endswith(s) for s in include_suffixes): continue if any(rx.search(d) for rx in exclude_regex): continue out.append(d) return out def single_tls_check(domain, timeout_ms, port, tls_verify=True): start = time.perf_counter() timeout_sec = max(0.2, timeout_ms / 1000.0) try: infos = socket.getaddrinfo(domain, port, proto=socket.IPPROTO_TCP) if not infos: return False, None, "dns_empty" af, socktype, proto, _, sockaddr = infos[0] with socket.socket(af, socktype, proto) as sock: sock.settimeout(timeout_sec) sock.connect(sockaddr) if tls_verify: ctx = ssl.create_default_context() else: ctx = ssl._create_unverified_context() with ctx.wrap_socket(sock, server_hostname=domain) as ssock: ssock.do_handshake() elapsed = int((time.perf_counter() - start) * 1000) return True, elapsed, "ok" except Exception as e: return False, None, str(e) def check_domains(domains, hc_cfg): attempts = int(hc_cfg.get("attempts", 2)) timeout_ms = int(hc_cfg.get("timeout_ms", 1800)) port = int(hc_cfg.get("port", 443)) tls_verify = bool(hc_cfg.get("tls_verify", True)) results = [] for d in domains: ok_count = 0 latencies = [] errors = [] for _ in range(attempts): ok, latency, err = single_tls_check(d, timeout_ms, port, tls_verify=tls_verify) if ok: ok_count += 1 latencies.append(latency) else: errors.append(err) success_ratio = ok_count / attempts if attempts else 0.0 avg_latency = int(sum(latencies) / len(latencies)) if latencies else 999999 results.append( { "domain": d, "success_ratio": success_ratio, "avg_latency_ms": avg_latency, "ok_count": ok_count, "attempts": attempts, "errors": errors[:3], } ) results.sort(key=lambda x: (-x["success_ratio"], x["avg_latency_ms"], x["domain"])) return results def render_v2ray(template_file, output_file, token, domain): if not template_file or not output_file: return False if not os.path.exists(template_file): return False with open(template_file, "r", encoding="utf-8") as f: tpl = f.read() rendered = tpl.replace(token, domain) os.makedirs(os.path.dirname(output_file), exist_ok=True) with open(output_file, "w", encoding="utf-8") as f: f.write(rendered) return True def run_notify(cmd, domain, status): if not cmd: return env = os.environ.copy() env["AUTODOMAIN"] = domain env["AUTODOMAIN_STATUS"] = status subprocess.run(cmd, shell=True, check=False, env=env) def choose_domain(filtered_domains, check_results, top_n, ranked_scored): if ranked_scored: domains_by_score = [x["domain"] for x in ranked_scored] if check_results: check_map = {x["domain"]: x for x in check_results} top = [] for d in domains_by_score: if d in check_map and check_map[d]["success_ratio"] > 0: top.append(check_map[d]) if len(top) >= top_n: break if top: return top[0]["domain"], top score_only = [ { "domain": x["domain"], "score_value": x.get("score_value"), "scores": x.get("scores", []), "created_raw": x.get("created_raw"), } for x in ranked_scored[:top_n] ] return score_only[0]["domain"], score_only top_scored = [ { "domain": x["domain"], "score_value": x.get("score_value"), "scores": x.get("scores", []), "created_raw": x.get("created_raw"), } for x in ranked_scored[:top_n] ] if top_scored: return top_scored[0]["domain"], top_scored if check_results: top = [x for x in check_results if x["success_ratio"] > 0][:top_n] if top: return top[0]["domain"], top return None, check_results[:top_n] if filtered_domains: return filtered_domains[0], [{"domain": x} for x in filtered_domains[:top_n]] return None, [] def main(): ap = argparse.ArgumentParser(description="Auto select VMess preferred domain") ap.add_argument("--config", default="config.json", help="Path to config JSON") args = ap.parse_args() config_path_abs = os.path.abspath(args.config) if not os.path.exists(config_path_abs): print(json.dumps({"status": "error", "error": f"config file not found: {config_path_abs}"}, ensure_ascii=True), file=sys.stderr) sys.exit(1) cfg = read_json_file(config_path_abs) try: validate_config(cfg) except Exception as e: print(json.dumps({"status": "error", "error": f"invalid config: {e}"}, ensure_ascii=True), file=sys.stderr) sys.exit(1) output_cfg = cfg.get("output", {}) runtime_dir_cfg = output_cfg.get("runtime_dir", "./runtime") if os.path.isabs(runtime_dir_cfg): runtime_dir = runtime_dir_cfg else: runtime_dir = os.path.normpath(os.path.join(os.path.dirname(config_path_abs), runtime_dir_cfg)) v2_cfg = cfg.get("v2ray", {}) notify_cfg = cfg.get("notify", {}) current_domain_file = os.path.join(runtime_dir, output_cfg.get("current_domain_file", "current_domain.txt")) current_domain_json = os.path.join(runtime_dir, output_cfg.get("current_domain_json", "current_domain.json")) state_file = os.path.join(runtime_dir, output_cfg.get("state_file", "state.json")) substore_vars_file = os.path.join(runtime_dir, output_cfg.get("substore_vars_file", "substore_vars.json")) state = read_json_file(state_file, default={}) last_good = state.get("last_good_domain", "") try: payload = fetch_api_json(cfg) parsed = parse_domains(payload, cfg.get("parser", {})) filtered = apply_filter(parsed, cfg.get("domain_filter", {})) record_mapping_cfg = cfg.get("record_mapping", {}) field_map = record_mapping_cfg.get("field_map", {}) records = extract_records(payload, record_mapping_cfg) record_filter_cfg = cfg.get("record_filter", {}) blocked_domains = collect_excluded_domains(records, field_map, record_filter_cfg) if blocked_domains: filtered = [d for d in filtered if d not in blocked_domains] scoring_cfg = cfg.get("scoring", {}) scored_records = parse_scored_records(records, field_map, record_mapping_cfg, scoring_cfg) filtered_set = set(filtered) scored_records = [r for r in scored_records if r["domain"] in filtered_set] ranked_scored = rank_scored_records(scored_records, scoring_cfg) check_results = [] if cfg.get("healthcheck", {}).get("enabled", True): check_results = check_domains(filtered, cfg.get("healthcheck", {})) top_n = int(cfg.get("selection", {}).get("top_n", 3)) selected, top_candidates = choose_domain(filtered, check_results, top_n, ranked_scored) status = "ok" if not selected and last_good: selected = last_good status = "fallback_last_good" if not selected: raise RuntimeError("No valid domain available from API and no fallback in state") write_text_file(current_domain_file, selected + "\n") current_json = { "domain": selected, "updated_at": utc_now_iso(), "status": status, "source_count": len(parsed), "checked_count": len(check_results), "top_candidates": top_candidates, } write_json_file(current_domain_json, current_json) write_json_file( substore_vars_file, { "AUTO_DOMAIN": selected, "UPDATED_AT": current_json["updated_at"], "STATUS": status, }, ) rendered = render_v2ray( template_file=v2_cfg.get("template_file", ""), output_file=v2_cfg.get("output_file", ""), token=v2_cfg.get("replace_token", "__AUTO_DOMAIN__"), domain=selected, ) new_state = { "updated_at": current_json["updated_at"], "last_good_domain": selected, "status": status, "source_count": len(parsed), "checked_count": len(check_results), "rendered_v2ray": rendered, } write_json_file(state_file, new_state) run_notify(notify_cfg.get("command", ""), selected, status) print(json.dumps(current_json, ensure_ascii=True)) except Exception as e: now = utc_now_iso() err_state = { "updated_at": now, "status": "error", "error": str(e), "last_good_domain": last_good, } write_json_file(state_file, err_state) if last_good: write_text_file(current_domain_file, last_good + "\n") write_json_file( current_domain_json, { "domain": last_good, "updated_at": now, "status": "error_use_last_good", "error": str(e), }, ) run_notify(notify_cfg.get("command", ""), last_good, "error_use_last_good") print(json.dumps({"status": "error_use_last_good", "error": str(e)}, ensure_ascii=True)) return print(json.dumps({"status": "error", "error": str(e)}, ensure_ascii=True), file=sys.stderr) sys.exit(1) if __name__ == "__main__": main()