domain_updater.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581
  1. #!/usr/bin/env python3
  2. import argparse
  3. import datetime as dt
  4. import json
  5. import os
  6. import re
  7. import socket
  8. import ssl
  9. import subprocess
  10. import sys
  11. import time
  12. import urllib.parse
  13. import urllib.request
  14. DOMAIN_RE = re.compile(r"^(?=.{1,253}$)(?!-)[A-Za-z0-9-]{1,63}(?<!-)(\.(?!-)[A-Za-z0-9-]{1,63}(?<!-))+$")
  15. IPV4_RE = re.compile(r"^(?:25[0-5]|2[0-4]\d|1\d\d|[1-9]?\d)(?:\.(?:25[0-5]|2[0-4]\d|1\d\d|[1-9]?\d)){3}$")
  16. def utc_now_iso():
  17. return dt.datetime.now(dt.timezone.utc).replace(microsecond=0).isoformat().replace("+00:00", "Z")
  18. def read_json_file(path, default=None):
  19. if default is None:
  20. default = {}
  21. if not os.path.exists(path):
  22. return default
  23. with open(path, "r", encoding="utf-8") as f:
  24. return json.load(f)
  25. def write_json_file(path, data):
  26. os.makedirs(os.path.dirname(path), exist_ok=True)
  27. with open(path, "w", encoding="utf-8") as f:
  28. json.dump(data, f, ensure_ascii=True, indent=2)
  29. def write_text_file(path, data):
  30. os.makedirs(os.path.dirname(path), exist_ok=True)
  31. with open(path, "w", encoding="utf-8") as f:
  32. f.write(data)
  33. def build_url(base_url, params):
  34. if not params:
  35. return base_url
  36. parsed = urllib.parse.urlparse(base_url)
  37. current = urllib.parse.parse_qs(parsed.query)
  38. for k, v in params.items():
  39. current[k] = [str(v)]
  40. query = urllib.parse.urlencode(current, doseq=True)
  41. return urllib.parse.urlunparse(parsed._replace(query=query))
  42. def fetch_api_json(cfg):
  43. api = cfg["api"]
  44. url = build_url(api["url"], api.get("params", {}))
  45. method = api.get("method", "GET").upper()
  46. headers = api.get("headers", {})
  47. timeout = int(api.get("timeout_sec", 10))
  48. body_obj = api.get("body")
  49. body = None
  50. if body_obj is not None:
  51. body = json.dumps(body_obj).encode("utf-8")
  52. headers = {**headers, "Content-Type": "application/json"}
  53. req = urllib.request.Request(url=url, data=body, headers=headers, method=method)
  54. with urllib.request.urlopen(req, timeout=timeout) as resp:
  55. raw = resp.read().decode("utf-8", errors="replace")
  56. return json.loads(raw)
  57. def flatten_values(value):
  58. out = []
  59. if isinstance(value, str):
  60. out.append(value)
  61. elif isinstance(value, list):
  62. for item in value:
  63. out.extend(flatten_values(item))
  64. elif isinstance(value, dict):
  65. for item in value.values():
  66. out.extend(flatten_values(item))
  67. return out
  68. def get_by_json_path(data, path):
  69. cur = data
  70. for part in path.split("."):
  71. if isinstance(cur, dict) and part in cur:
  72. cur = cur[part]
  73. else:
  74. return None
  75. return cur
  76. def get_values_by_path(data, path):
  77. parts = path.split(".")
  78. def walk(cur, idx):
  79. if idx >= len(parts):
  80. return [cur]
  81. part = parts[idx]
  82. if part.endswith("[]"):
  83. key = part[:-2]
  84. if isinstance(cur, dict):
  85. arr = cur.get(key)
  86. else:
  87. arr = None
  88. if not isinstance(arr, list):
  89. return []
  90. out = []
  91. for item in arr:
  92. out.extend(walk(item, idx + 1))
  93. return out
  94. if isinstance(cur, dict) and part in cur:
  95. return walk(cur[part], idx + 1)
  96. return []
  97. return walk(data, 0)
  98. def parse_domains(payload, parser_cfg):
  99. domains = []
  100. for p in parser_cfg.get("field_paths", []):
  101. values = get_values_by_path(payload, p)
  102. domains.extend(flatten_values(values))
  103. for p in parser_cfg.get("json_paths", []):
  104. v = get_by_json_path(payload, p)
  105. if v is not None:
  106. domains.extend(flatten_values(v))
  107. if not domains:
  108. regex_s = parser_cfg.get("regex", r"[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}")
  109. text = json.dumps(payload, ensure_ascii=True)
  110. domains.extend(re.findall(regex_s, text))
  111. clean = []
  112. seen = set()
  113. for d in domains:
  114. d = d.strip().lower().rstrip(".")
  115. if (DOMAIN_RE.match(d) or IPV4_RE.match(d)) and d not in seen:
  116. seen.add(d)
  117. clean.append(d)
  118. return clean
  119. def parse_created_time(s):
  120. if not s:
  121. return None
  122. try:
  123. return dt.datetime.strptime(str(s).strip(), "%Y-%m-%d %H:%M:%S").replace(tzinfo=dt.timezone.utc)
  124. except Exception:
  125. return None
  126. def record_field_value(record, field_path):
  127. if not isinstance(record, dict) or not field_path:
  128. return None
  129. return get_by_json_path(record, field_path)
  130. def rule_matches(value, rule):
  131. if value is None or not isinstance(rule, dict):
  132. return False
  133. values = flatten_values(value)
  134. if not values:
  135. values = [value]
  136. case_sensitive = bool(rule.get("case_sensitive", False))
  137. if "contains" in rule:
  138. needle = str(rule.get("contains", ""))
  139. if not needle:
  140. return False
  141. for item in values:
  142. hay = str(item)
  143. if case_sensitive:
  144. if needle in hay:
  145. return True
  146. else:
  147. if needle.lower() in hay.lower():
  148. return True
  149. return False
  150. if "equals" in rule:
  151. target = str(rule.get("equals", ""))
  152. for item in values:
  153. item_s = str(item)
  154. if case_sensitive:
  155. if item_s == target:
  156. return True
  157. else:
  158. if item_s.lower() == target.lower():
  159. return True
  160. return False
  161. if "regex" in rule:
  162. pattern = str(rule.get("regex", ""))
  163. if not pattern:
  164. return False
  165. flags = 0 if case_sensitive else re.IGNORECASE
  166. try:
  167. rx = re.compile(pattern, flags)
  168. except Exception:
  169. return False
  170. for item in values:
  171. if rx.search(str(item)):
  172. return True
  173. return False
  174. return False
  175. def collect_excluded_domains(payload, record_filter_cfg, scoring_cfg):
  176. if not record_filter_cfg.get("enabled", False):
  177. return set()
  178. rules = record_filter_cfg.get("exclude_if_any", [])
  179. if not rules:
  180. return set()
  181. records_path = record_filter_cfg.get("records_path", scoring_cfg.get("records_path", "data.good[]"))
  182. domain_field = record_filter_cfg.get("domain_field", scoring_cfg.get("ip_field", "ip"))
  183. blocked = set()
  184. for record in get_values_by_path(payload, records_path):
  185. if not isinstance(record, dict):
  186. continue
  187. domain_raw = record_field_value(record, domain_field)
  188. domain = str(domain_raw or "").strip().lower().rstrip(".")
  189. if not domain:
  190. continue
  191. for rule in rules:
  192. field_path = str(rule.get("field_path", "")).strip()
  193. if not field_path:
  194. continue
  195. value = record_field_value(record, field_path)
  196. if rule_matches(value, rule):
  197. blocked.add(domain)
  198. break
  199. return blocked
  200. def parse_scored_records(payload, scoring_cfg):
  201. if not scoring_cfg.get("enabled", False):
  202. return []
  203. records_path = scoring_cfg.get("records_path", "data.good[]")
  204. ip_field = scoring_cfg.get("ip_field", "ip")
  205. created_time_field = scoring_cfg.get("created_time_field", "createdTime")
  206. score_fields = scoring_cfg.get("score_fields", ["avgScore", "ydScore", "dxScore", "ltScore"])
  207. raw_records = get_values_by_path(payload, records_path)
  208. out = []
  209. for r in raw_records:
  210. if not isinstance(r, dict):
  211. continue
  212. domain = str(r.get(ip_field, "")).strip().lower().rstrip(".")
  213. if not domain:
  214. continue
  215. created = parse_created_time(r.get(created_time_field))
  216. scores = []
  217. for f in score_fields:
  218. v = r.get(f)
  219. try:
  220. scores.append(float(v))
  221. except Exception:
  222. scores.append(float("inf"))
  223. out.append(
  224. {
  225. "domain": domain,
  226. "created_at": created,
  227. "created_raw": r.get(created_time_field),
  228. "scores": scores,
  229. }
  230. )
  231. return out
  232. def rank_scored_records(records, scoring_cfg):
  233. if not records:
  234. return []
  235. within_hours = float(scoring_cfg.get("within_hours", 24))
  236. prefer_lower = bool(scoring_cfg.get("prefer_lower", True))
  237. use_api_order = bool(scoring_cfg.get("use_api_order", False))
  238. now = dt.datetime.now(dt.timezone.utc)
  239. cutoff = now - dt.timedelta(hours=within_hours)
  240. recent = [r for r in records if r.get("created_at") is not None and r["created_at"] >= cutoff]
  241. candidates = recent if recent else records
  242. if use_api_order:
  243. seen = set()
  244. ordered = []
  245. for r in candidates:
  246. d = r["domain"]
  247. if d in seen:
  248. continue
  249. seen.add(d)
  250. ordered.append(r)
  251. return ordered
  252. def key_lower(r):
  253. return tuple(r["scores"] + [r["domain"]])
  254. def key_higher(r):
  255. return tuple([-x if x != float("inf") else float("inf") for x in r["scores"]] + [r["domain"]])
  256. ranked = sorted(candidates, key=key_lower if prefer_lower else key_higher)
  257. return ranked
  258. def apply_filter(domains, filter_cfg):
  259. include_suffixes = [s.lower() for s in filter_cfg.get("include_suffixes", []) if s]
  260. exclude_regex = [re.compile(x) for x in filter_cfg.get("exclude_regex", []) if x]
  261. out = []
  262. for d in domains:
  263. if include_suffixes and not any(d.endswith(s) for s in include_suffixes):
  264. continue
  265. if any(rx.search(d) for rx in exclude_regex):
  266. continue
  267. out.append(d)
  268. return out
  269. def single_tls_check(domain, timeout_ms, port, tls_verify=True):
  270. start = time.perf_counter()
  271. timeout_sec = max(0.2, timeout_ms / 1000.0)
  272. try:
  273. infos = socket.getaddrinfo(domain, port, proto=socket.IPPROTO_TCP)
  274. if not infos:
  275. return False, None, "dns_empty"
  276. af, socktype, proto, _, sockaddr = infos[0]
  277. with socket.socket(af, socktype, proto) as sock:
  278. sock.settimeout(timeout_sec)
  279. sock.connect(sockaddr)
  280. if tls_verify:
  281. ctx = ssl.create_default_context()
  282. else:
  283. ctx = ssl._create_unverified_context()
  284. with ctx.wrap_socket(sock, server_hostname=domain) as ssock:
  285. ssock.do_handshake()
  286. elapsed = int((time.perf_counter() - start) * 1000)
  287. return True, elapsed, "ok"
  288. except Exception as e:
  289. return False, None, str(e)
  290. def check_domains(domains, hc_cfg):
  291. attempts = int(hc_cfg.get("attempts", 2))
  292. timeout_ms = int(hc_cfg.get("timeout_ms", 1800))
  293. port = int(hc_cfg.get("port", 443))
  294. tls_verify = bool(hc_cfg.get("tls_verify", True))
  295. results = []
  296. for d in domains:
  297. ok_count = 0
  298. latencies = []
  299. errors = []
  300. for _ in range(attempts):
  301. ok, latency, err = single_tls_check(d, timeout_ms, port, tls_verify=tls_verify)
  302. if ok:
  303. ok_count += 1
  304. latencies.append(latency)
  305. else:
  306. errors.append(err)
  307. success_ratio = ok_count / attempts if attempts else 0.0
  308. avg_latency = int(sum(latencies) / len(latencies)) if latencies else 999999
  309. results.append(
  310. {
  311. "domain": d,
  312. "success_ratio": success_ratio,
  313. "avg_latency_ms": avg_latency,
  314. "ok_count": ok_count,
  315. "attempts": attempts,
  316. "errors": errors[:3],
  317. }
  318. )
  319. results.sort(key=lambda x: (-x["success_ratio"], x["avg_latency_ms"], x["domain"]))
  320. return results
  321. def render_v2ray(template_file, output_file, token, domain):
  322. if not template_file or not output_file:
  323. return False
  324. if not os.path.exists(template_file):
  325. return False
  326. with open(template_file, "r", encoding="utf-8") as f:
  327. tpl = f.read()
  328. rendered = tpl.replace(token, domain)
  329. os.makedirs(os.path.dirname(output_file), exist_ok=True)
  330. with open(output_file, "w", encoding="utf-8") as f:
  331. f.write(rendered)
  332. return True
  333. def run_notify(cmd, domain, status):
  334. if not cmd:
  335. return
  336. env = os.environ.copy()
  337. env["AUTODOMAIN"] = domain
  338. env["AUTODOMAIN_STATUS"] = status
  339. subprocess.run(cmd, shell=True, check=False, env=env)
  340. def choose_domain(filtered_domains, check_results, top_n, ranked_scored):
  341. if ranked_scored:
  342. domains_by_score = [x["domain"] for x in ranked_scored]
  343. if check_results:
  344. check_map = {x["domain"]: x for x in check_results}
  345. top = []
  346. for d in domains_by_score:
  347. if d in check_map and check_map[d]["success_ratio"] > 0:
  348. top.append(check_map[d])
  349. if len(top) >= top_n:
  350. break
  351. if top:
  352. return top[0]["domain"], top
  353. score_only = [{"domain": x["domain"], "scores": x["scores"], "created_raw": x["created_raw"]} for x in ranked_scored[:top_n]]
  354. return score_only[0]["domain"], score_only
  355. top_scored = [{"domain": x["domain"], "scores": x["scores"], "created_raw": x["created_raw"]} for x in ranked_scored[:top_n]]
  356. if top_scored:
  357. return top_scored[0]["domain"], top_scored
  358. if check_results:
  359. top = [x for x in check_results if x["success_ratio"] > 0][:top_n]
  360. if top:
  361. return top[0]["domain"], top
  362. return None, check_results[:top_n]
  363. if filtered_domains:
  364. return filtered_domains[0], [{"domain": x} for x in filtered_domains[:top_n]]
  365. return None, []
  366. def main():
  367. ap = argparse.ArgumentParser(description="Auto select VMess preferred domain")
  368. ap.add_argument("--config", default="config.json", help="Path to config JSON")
  369. args = ap.parse_args()
  370. config_path_abs = os.path.abspath(args.config)
  371. if not os.path.exists(config_path_abs):
  372. print(json.dumps({"status": "error", "error": f"config file not found: {config_path_abs}"}, ensure_ascii=True), file=sys.stderr)
  373. sys.exit(1)
  374. cfg = read_json_file(config_path_abs)
  375. output_cfg = cfg.get("output", {})
  376. runtime_dir_cfg = output_cfg.get("runtime_dir", "./runtime")
  377. if os.path.isabs(runtime_dir_cfg):
  378. runtime_dir = runtime_dir_cfg
  379. else:
  380. runtime_dir = os.path.normpath(os.path.join(os.path.dirname(config_path_abs), runtime_dir_cfg))
  381. v2_cfg = cfg.get("v2ray", {})
  382. notify_cfg = cfg.get("notify", {})
  383. current_domain_file = os.path.join(runtime_dir, output_cfg.get("current_domain_file", "current_domain.txt"))
  384. current_domain_json = os.path.join(runtime_dir, output_cfg.get("current_domain_json", "current_domain.json"))
  385. state_file = os.path.join(runtime_dir, output_cfg.get("state_file", "state.json"))
  386. substore_vars_file = os.path.join(runtime_dir, output_cfg.get("substore_vars_file", "substore_vars.json"))
  387. state = read_json_file(state_file, default={})
  388. last_good = state.get("last_good_domain", "")
  389. try:
  390. payload = fetch_api_json(cfg)
  391. parsed = parse_domains(payload, cfg.get("parser", {}))
  392. filtered = apply_filter(parsed, cfg.get("domain_filter", {}))
  393. record_filter_cfg = cfg.get("record_filter", {})
  394. blocked_domains = collect_excluded_domains(payload, record_filter_cfg, cfg.get("scoring", {}))
  395. if blocked_domains:
  396. filtered = [d for d in filtered if d not in blocked_domains]
  397. scored_records = parse_scored_records(payload, cfg.get("scoring", {}))
  398. scored_records = [r for r in scored_records if r["domain"] in set(filtered)]
  399. ranked_scored = rank_scored_records(scored_records, cfg.get("scoring", {}))
  400. check_results = []
  401. if cfg.get("healthcheck", {}).get("enabled", True):
  402. check_results = check_domains(filtered, cfg.get("healthcheck", {}))
  403. top_n = int(cfg.get("selection", {}).get("top_n", 3))
  404. selected, top_candidates = choose_domain(filtered, check_results, top_n, ranked_scored)
  405. status = "ok"
  406. if not selected and last_good:
  407. selected = last_good
  408. status = "fallback_last_good"
  409. if not selected:
  410. raise RuntimeError("No valid domain available from API and no fallback in state")
  411. write_text_file(current_domain_file, selected + "\n")
  412. current_json = {
  413. "domain": selected,
  414. "updated_at": utc_now_iso(),
  415. "status": status,
  416. "source_count": len(parsed),
  417. "checked_count": len(check_results),
  418. "top_candidates": top_candidates,
  419. }
  420. write_json_file(current_domain_json, current_json)
  421. write_json_file(
  422. substore_vars_file,
  423. {
  424. "AUTO_DOMAIN": selected,
  425. "UPDATED_AT": current_json["updated_at"],
  426. "STATUS": status,
  427. },
  428. )
  429. rendered = render_v2ray(
  430. template_file=v2_cfg.get("template_file", ""),
  431. output_file=v2_cfg.get("output_file", ""),
  432. token=v2_cfg.get("replace_token", "__AUTO_DOMAIN__"),
  433. domain=selected,
  434. )
  435. new_state = {
  436. "updated_at": current_json["updated_at"],
  437. "last_good_domain": selected,
  438. "status": status,
  439. "source_count": len(parsed),
  440. "checked_count": len(check_results),
  441. "rendered_v2ray": rendered,
  442. }
  443. write_json_file(state_file, new_state)
  444. run_notify(notify_cfg.get("command", ""), selected, status)
  445. print(json.dumps(current_json, ensure_ascii=True))
  446. except Exception as e:
  447. now = utc_now_iso()
  448. err_state = {
  449. "updated_at": now,
  450. "status": "error",
  451. "error": str(e),
  452. "last_good_domain": last_good,
  453. }
  454. write_json_file(state_file, err_state)
  455. if last_good:
  456. write_text_file(current_domain_file, last_good + "\n")
  457. write_json_file(
  458. current_domain_json,
  459. {
  460. "domain": last_good,
  461. "updated_at": now,
  462. "status": "error_use_last_good",
  463. "error": str(e),
  464. },
  465. )
  466. run_notify(notify_cfg.get("command", ""), last_good, "error_use_last_good")
  467. print(json.dumps({"status": "error_use_last_good", "error": str(e)}, ensure_ascii=True))
  468. return
  469. print(json.dumps({"status": "error", "error": str(e)}, ensure_ascii=True), file=sys.stderr)
  470. sys.exit(1)
  471. if __name__ == "__main__":
  472. main()