domain_updater.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908
  1. #!/usr/bin/env python3
  2. import argparse
  3. import datetime as dt
  4. import functools
  5. import json
  6. import math
  7. import os
  8. import re
  9. import socket
  10. import ssl
  11. import subprocess
  12. import sys
  13. import time
  14. import urllib.parse
  15. import urllib.request
  16. DOMAIN_RE = re.compile(r"^(?=.{1,253}$)(?!-)[A-Za-z0-9-]{1,63}(?<!-)(\.(?!-)[A-Za-z0-9-]{1,63}(?<!-))+$")
  17. IPV4_RE = re.compile(r"^(?:25[0-5]|2[0-4]\d|1\d\d|[1-9]?\d)(?:\.(?:25[0-5]|2[0-4]\d|1\d\d|[1-9]?\d)){3}$")
  18. def utc_now_iso():
  19. return dt.datetime.now(dt.timezone.utc).replace(microsecond=0).isoformat().replace("+00:00", "Z")
  20. def read_json_file(path, default=None):
  21. if default is None:
  22. default = {}
  23. if not os.path.exists(path):
  24. return default
  25. with open(path, "r", encoding="utf-8") as f:
  26. return json.load(f)
  27. def write_json_file(path, data):
  28. os.makedirs(os.path.dirname(path), exist_ok=True)
  29. with open(path, "w", encoding="utf-8") as f:
  30. json.dump(data, f, ensure_ascii=True, indent=2)
  31. def write_text_file(path, data):
  32. os.makedirs(os.path.dirname(path), exist_ok=True)
  33. with open(path, "w", encoding="utf-8") as f:
  34. f.write(data)
  35. def build_url(base_url, params):
  36. if not params:
  37. return base_url
  38. parsed = urllib.parse.urlparse(base_url)
  39. current = urllib.parse.parse_qs(parsed.query)
  40. for k, v in params.items():
  41. current[k] = [str(v)]
  42. query = urllib.parse.urlencode(current, doseq=True)
  43. return urllib.parse.urlunparse(parsed._replace(query=query))
  44. def fetch_api_json(cfg):
  45. api = cfg["api"]
  46. url = build_url(api["url"], api.get("params", {}))
  47. method = api.get("method", "GET").upper()
  48. headers = api.get("headers", {})
  49. timeout = int(api.get("timeout_sec", 10))
  50. body_obj = api.get("body")
  51. body = None
  52. if body_obj is not None:
  53. body = json.dumps(body_obj).encode("utf-8")
  54. headers = {**headers, "Content-Type": "application/json"}
  55. req = urllib.request.Request(url=url, data=body, headers=headers, method=method)
  56. with urllib.request.urlopen(req, timeout=timeout) as resp:
  57. raw = resp.read().decode("utf-8", errors="replace")
  58. return json.loads(raw)
  59. def flatten_values(value):
  60. out = []
  61. if isinstance(value, str):
  62. out.append(value)
  63. elif isinstance(value, list):
  64. for item in value:
  65. out.extend(flatten_values(item))
  66. elif isinstance(value, dict):
  67. for item in value.values():
  68. out.extend(flatten_values(item))
  69. return out
  70. def get_by_json_path(data, path):
  71. cur = data
  72. for part in path.split("."):
  73. if isinstance(cur, dict) and part in cur:
  74. cur = cur[part]
  75. else:
  76. return None
  77. return cur
  78. def get_values_by_path(data, path):
  79. parts = path.split(".")
  80. def walk(cur, idx):
  81. if idx >= len(parts):
  82. return [cur]
  83. part = parts[idx]
  84. if part.endswith("[]"):
  85. key = part[:-2]
  86. if isinstance(cur, dict):
  87. arr = cur.get(key)
  88. else:
  89. arr = None
  90. if not isinstance(arr, list):
  91. return []
  92. out = []
  93. for item in arr:
  94. out.extend(walk(item, idx + 1))
  95. return out
  96. if isinstance(cur, dict) and part in cur:
  97. return walk(cur[part], idx + 1)
  98. return []
  99. return walk(data, 0)
  100. def parse_domains(payload, parser_cfg):
  101. domains = []
  102. for p in parser_cfg.get("field_paths", []):
  103. values = get_values_by_path(payload, p)
  104. domains.extend(flatten_values(values))
  105. for p in parser_cfg.get("json_paths", []):
  106. v = get_by_json_path(payload, p)
  107. if v is not None:
  108. domains.extend(flatten_values(v))
  109. if not domains:
  110. regex_s = parser_cfg.get("regex", r"[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}")
  111. text = json.dumps(payload, ensure_ascii=True)
  112. domains.extend(re.findall(regex_s, text))
  113. clean = []
  114. seen = set()
  115. for d in domains:
  116. d = str(d).strip().lower().rstrip(".")
  117. if (DOMAIN_RE.match(d) or IPV4_RE.match(d)) and d not in seen:
  118. seen.add(d)
  119. clean.append(d)
  120. return clean
  121. def parse_timezone(tz_raw):
  122. if tz_raw is None:
  123. return dt.timezone.utc
  124. s = str(tz_raw).strip().upper()
  125. if s in {"", "UTC", "Z", "+00:00", "+0000"}:
  126. return dt.timezone.utc
  127. m = re.match(r"^([+-])(\d{2}):?(\d{2})$", s)
  128. if not m:
  129. raise ValueError(f"invalid created_time_timezone: {tz_raw}")
  130. sign = 1 if m.group(1) == "+" else -1
  131. hh = int(m.group(2))
  132. mm = int(m.group(3))
  133. if hh > 23 or mm > 59:
  134. raise ValueError(f"invalid created_time_timezone offset: {tz_raw}")
  135. return dt.timezone(sign * dt.timedelta(hours=hh, minutes=mm))
  136. def parse_created_time(value, formats, timezone):
  137. if value is None:
  138. return None
  139. s = str(value).strip()
  140. if not s:
  141. return None
  142. for fmt in formats:
  143. try:
  144. parsed = dt.datetime.strptime(s, fmt)
  145. if parsed.tzinfo is None:
  146. parsed = parsed.replace(tzinfo=timezone)
  147. return parsed.astimezone(dt.timezone.utc)
  148. except Exception:
  149. continue
  150. try:
  151. iso_text = s.replace("Z", "+00:00")
  152. parsed = dt.datetime.fromisoformat(iso_text)
  153. if parsed.tzinfo is None:
  154. parsed = parsed.replace(tzinfo=timezone)
  155. return parsed.astimezone(dt.timezone.utc)
  156. except Exception:
  157. return None
  158. def normalize_domain(value):
  159. if value is None:
  160. return ""
  161. return str(value).strip().lower().rstrip(".")
  162. def to_float_or_none(value):
  163. try:
  164. f = float(value)
  165. if not math.isfinite(f):
  166. return None
  167. return f
  168. except Exception:
  169. return None
  170. def resolve_field(record, field_name, field_map):
  171. path = field_map.get(field_name)
  172. if not path:
  173. raise ValueError(f"field '{field_name}' is not registered in record_mapping.field_map")
  174. if not isinstance(record, dict):
  175. return None
  176. return get_by_json_path(record, path)
  177. def extract_records(payload, record_mapping):
  178. records_path = str(record_mapping.get("records_path", "")).strip()
  179. raw = get_values_by_path(payload, records_path)
  180. return [x for x in raw if isinstance(x, dict)]
  181. def validate_config(cfg):
  182. record_mapping = cfg.get("record_mapping")
  183. if not isinstance(record_mapping, dict):
  184. raise ValueError("record_mapping is required and must be an object")
  185. records_path = str(record_mapping.get("records_path", "")).strip()
  186. if not records_path:
  187. raise ValueError("record_mapping.records_path is required")
  188. field_map = record_mapping.get("field_map")
  189. if not isinstance(field_map, dict) or not field_map:
  190. raise ValueError("record_mapping.field_map is required and must be a non-empty object")
  191. for key, path in field_map.items():
  192. if not str(key).strip() or not str(path).strip():
  193. raise ValueError("record_mapping.field_map contains empty field name or path")
  194. for required in ["domain", "created_at"]:
  195. if required not in field_map:
  196. raise ValueError(f"record_mapping.field_map.{required} is required")
  197. created_time_formats = record_mapping.get("created_time_formats")
  198. if not isinstance(created_time_formats, list) or not created_time_formats:
  199. raise ValueError("record_mapping.created_time_formats is required and must be a non-empty array")
  200. for fmt in created_time_formats:
  201. if not str(fmt).strip():
  202. raise ValueError("record_mapping.created_time_formats contains empty format")
  203. parse_timezone(record_mapping.get("created_time_timezone", "UTC"))
  204. def ensure_field_registered(field_name, where):
  205. if field_name not in field_map:
  206. raise ValueError(f"{where}: field '{field_name}' is not in record_mapping.field_map")
  207. record_filter = cfg.get("record_filter", {})
  208. if record_filter.get("enabled", False):
  209. rules = record_filter.get("exclude_if_any", [])
  210. if not isinstance(rules, list):
  211. raise ValueError("record_filter.exclude_if_any must be an array")
  212. for i, rule in enumerate(rules):
  213. if not isinstance(rule, dict):
  214. raise ValueError(f"record_filter.exclude_if_any[{i}] must be an object")
  215. field_name = str(rule.get("field", "")).strip()
  216. if not field_name:
  217. raise ValueError(f"record_filter.exclude_if_any[{i}].field is required")
  218. ensure_field_registered(field_name, f"record_filter.exclude_if_any[{i}]")
  219. has_matcher = any(k in rule for k in ["contains", "equals", "regex"])
  220. if not has_matcher:
  221. raise ValueError(f"record_filter.exclude_if_any[{i}] must include one of contains/equals/regex")
  222. scoring = cfg.get("scoring", {})
  223. if scoring.get("enabled", False):
  224. strategy = str(scoring.get("strategy", "")).strip()
  225. if strategy not in {"weighted_average", "lexicographic"}:
  226. raise ValueError("scoring.strategy must be 'weighted_average' or 'lexicographic'")
  227. within_hours = to_float_or_none(scoring.get("within_hours", 24))
  228. if within_hours is None or within_hours <= 0:
  229. raise ValueError("scoring.within_hours must be a positive number")
  230. if strategy == "weighted_average":
  231. weighted_fields = scoring.get("weighted_fields")
  232. if not isinstance(weighted_fields, list) or not weighted_fields:
  233. raise ValueError("scoring.weighted_fields is required for weighted_average strategy")
  234. for i, item in enumerate(weighted_fields):
  235. if not isinstance(item, dict):
  236. raise ValueError(f"scoring.weighted_fields[{i}] must be an object")
  237. field_name = str(item.get("field", "")).strip()
  238. if not field_name:
  239. raise ValueError(f"scoring.weighted_fields[{i}].field is required")
  240. ensure_field_registered(field_name, f"scoring.weighted_fields[{i}]")
  241. weight = to_float_or_none(item.get("weight"))
  242. if weight is None or weight <= 0:
  243. raise ValueError(f"scoring.weighted_fields[{i}].weight must be > 0")
  244. if strategy == "lexicographic":
  245. lex_fields = scoring.get("lexicographic_fields")
  246. if not isinstance(lex_fields, list) or not lex_fields:
  247. raise ValueError("scoring.lexicographic_fields is required for lexicographic strategy")
  248. for i, item in enumerate(lex_fields):
  249. if isinstance(item, str):
  250. field_name = item.strip()
  251. order = ""
  252. elif isinstance(item, dict):
  253. field_name = str(item.get("field", "")).strip()
  254. order = str(item.get("order", "")).strip().lower()
  255. else:
  256. raise ValueError(f"scoring.lexicographic_fields[{i}] must be string or object")
  257. if not field_name:
  258. raise ValueError(f"scoring.lexicographic_fields[{i}] field is required")
  259. ensure_field_registered(field_name, f"scoring.lexicographic_fields[{i}]")
  260. if order and order not in {"asc", "desc"}:
  261. raise ValueError(f"scoring.lexicographic_fields[{i}].order must be asc or desc")
  262. tie_breakers = scoring.get("tie_breakers", [])
  263. if tie_breakers is not None:
  264. if not isinstance(tie_breakers, list):
  265. raise ValueError("scoring.tie_breakers must be an array")
  266. for i, item in enumerate(tie_breakers):
  267. if not isinstance(item, dict):
  268. raise ValueError(f"scoring.tie_breakers[{i}] must be an object")
  269. field_name = str(item.get("field", "")).strip()
  270. order = str(item.get("order", "")).strip().lower()
  271. if not field_name:
  272. raise ValueError(f"scoring.tie_breakers[{i}].field is required")
  273. if order not in {"asc", "desc"}:
  274. raise ValueError(f"scoring.tie_breakers[{i}].order must be asc or desc")
  275. ensure_field_registered(field_name, f"scoring.tie_breakers[{i}]")
  276. def rule_matches(value, rule):
  277. if value is None or not isinstance(rule, dict):
  278. return False
  279. values = flatten_values(value)
  280. if not values:
  281. values = [value]
  282. case_sensitive = bool(rule.get("case_sensitive", False))
  283. if "contains" in rule:
  284. needle = str(rule.get("contains", ""))
  285. if not needle:
  286. return False
  287. for item in values:
  288. hay = str(item)
  289. if case_sensitive:
  290. if needle in hay:
  291. return True
  292. else:
  293. if needle.lower() in hay.lower():
  294. return True
  295. return False
  296. if "equals" in rule:
  297. target = str(rule.get("equals", ""))
  298. for item in values:
  299. item_s = str(item)
  300. if case_sensitive:
  301. if item_s == target:
  302. return True
  303. else:
  304. if item_s.lower() == target.lower():
  305. return True
  306. return False
  307. if "regex" in rule:
  308. pattern = str(rule.get("regex", ""))
  309. if not pattern:
  310. return False
  311. flags = 0 if case_sensitive else re.IGNORECASE
  312. try:
  313. rx = re.compile(pattern, flags)
  314. except Exception:
  315. return False
  316. for item in values:
  317. if rx.search(str(item)):
  318. return True
  319. return False
  320. return False
  321. def collect_excluded_domains(records, field_map, record_filter_cfg):
  322. if not record_filter_cfg.get("enabled", False):
  323. return set()
  324. rules = record_filter_cfg.get("exclude_if_any", [])
  325. if not rules:
  326. return set()
  327. blocked = set()
  328. for record in records:
  329. domain = normalize_domain(resolve_field(record, "domain", field_map))
  330. if not domain:
  331. continue
  332. for rule in rules:
  333. field_name = str(rule.get("field", "")).strip()
  334. if not field_name:
  335. continue
  336. value = resolve_field(record, field_name, field_map)
  337. if rule_matches(value, rule):
  338. blocked.add(domain)
  339. break
  340. return blocked
  341. def build_lexicographic_descriptors(scoring_cfg, prefer_lower):
  342. out = []
  343. for item in scoring_cfg.get("lexicographic_fields", []):
  344. if isinstance(item, str):
  345. field_name = item.strip()
  346. order = "asc" if prefer_lower else "desc"
  347. else:
  348. field_name = str(item.get("field", "")).strip()
  349. order = str(item.get("order", "")).strip().lower()
  350. if not order:
  351. order = "asc" if prefer_lower else "desc"
  352. out.append({"field": field_name, "order": order})
  353. return out
  354. def parse_scored_records(records, field_map, record_mapping_cfg, scoring_cfg):
  355. if not scoring_cfg.get("enabled", False):
  356. return []
  357. strategy = str(scoring_cfg.get("strategy", "weighted_average")).strip()
  358. prefer_lower = bool(scoring_cfg.get("prefer_lower", False))
  359. timezone = parse_timezone(record_mapping_cfg.get("created_time_timezone", "UTC"))
  360. time_formats = [str(x) for x in record_mapping_cfg.get("created_time_formats", [])]
  361. weighted_fields = scoring_cfg.get("weighted_fields", []) if strategy == "weighted_average" else []
  362. lex_descriptors = build_lexicographic_descriptors(scoring_cfg, prefer_lower) if strategy == "lexicographic" else []
  363. needed_fields = set()
  364. for item in weighted_fields:
  365. needed_fields.add(str(item.get("field", "")).strip())
  366. for item in lex_descriptors:
  367. needed_fields.add(str(item.get("field", "")).strip())
  368. for item in scoring_cfg.get("tie_breakers", []):
  369. needed_fields.add(str(item.get("field", "")).strip())
  370. needed_fields.discard("domain")
  371. needed_fields.discard("created_at")
  372. out = []
  373. for record in records:
  374. domain = normalize_domain(resolve_field(record, "domain", field_map))
  375. if not domain:
  376. continue
  377. created_raw = resolve_field(record, "created_at", field_map)
  378. created_at = parse_created_time(created_raw, time_formats, timezone)
  379. field_values = {}
  380. for field_name in needed_fields:
  381. field_values[field_name] = resolve_field(record, field_name, field_map)
  382. score_value = None
  383. scores = []
  384. lex_values = []
  385. if strategy == "weighted_average":
  386. total = 0.0
  387. total_weight = 0.0
  388. missing = False
  389. for item in weighted_fields:
  390. field_name = str(item.get("field", "")).strip()
  391. weight = float(item.get("weight"))
  392. raw_v = resolve_field(record, field_name, field_map)
  393. val = to_float_or_none(raw_v)
  394. scores.append(val)
  395. if val is None:
  396. missing = True
  397. continue
  398. total += val * weight
  399. total_weight += weight
  400. if not missing and total_weight > 0:
  401. score_value = total / total_weight
  402. if strategy == "lexicographic":
  403. for item in lex_descriptors:
  404. field_name = item["field"]
  405. order = item["order"]
  406. raw_v = resolve_field(record, field_name, field_map)
  407. num_v = to_float_or_none(raw_v)
  408. v = num_v if num_v is not None else raw_v
  409. lex_values.append({"field": field_name, "value": v, "order": order})
  410. scores.append(v)
  411. out.append(
  412. {
  413. "domain": domain,
  414. "created_at": created_at,
  415. "created_raw": created_raw,
  416. "scores": scores,
  417. "score_value": score_value,
  418. "lex_values": lex_values,
  419. "field_values": field_values,
  420. }
  421. )
  422. return out
  423. def cmp_scalar(a, b, order):
  424. a_none = a is None
  425. b_none = b is None
  426. if a_none and b_none:
  427. return 0
  428. if a_none:
  429. return 1
  430. if b_none:
  431. return -1
  432. if isinstance(a, dt.datetime):
  433. a = a.timestamp()
  434. if isinstance(b, dt.datetime):
  435. b = b.timestamp()
  436. a_num = to_float_or_none(a)
  437. b_num = to_float_or_none(b)
  438. if a_num is not None and b_num is not None:
  439. if a_num < b_num:
  440. base = -1
  441. elif a_num > b_num:
  442. base = 1
  443. else:
  444. base = 0
  445. else:
  446. a_s = str(a).lower()
  447. b_s = str(b).lower()
  448. if a_s < b_s:
  449. base = -1
  450. elif a_s > b_s:
  451. base = 1
  452. else:
  453. base = 0
  454. return base if order == "asc" else -base
  455. def get_sort_field_value(record, field_name):
  456. if field_name == "domain":
  457. return record.get("domain")
  458. if field_name == "created_at":
  459. return record.get("created_at")
  460. return record.get("field_values", {}).get(field_name)
  461. def rank_scored_records(records, scoring_cfg):
  462. if not records:
  463. return []
  464. within_hours = float(scoring_cfg.get("within_hours", 24))
  465. strategy = str(scoring_cfg.get("strategy", "weighted_average")).strip()
  466. prefer_lower = bool(scoring_cfg.get("prefer_lower", False))
  467. tie_breakers = scoring_cfg.get("tie_breakers", [])
  468. now = dt.datetime.now(dt.timezone.utc)
  469. cutoff = now - dt.timedelta(hours=within_hours)
  470. recent = [r for r in records if r.get("created_at") is not None and r["created_at"] >= cutoff]
  471. candidates = recent if recent else records
  472. default_lex_order = "asc" if prefer_lower else "desc"
  473. def compare(a, b):
  474. if strategy == "weighted_average":
  475. order = "asc" if prefer_lower else "desc"
  476. c = cmp_scalar(a.get("score_value"), b.get("score_value"), order)
  477. if c != 0:
  478. return c
  479. elif strategy == "lexicographic":
  480. a_lex = a.get("lex_values", [])
  481. b_lex = b.get("lex_values", [])
  482. n = max(len(a_lex), len(b_lex))
  483. for i in range(n):
  484. av = a_lex[i]["value"] if i < len(a_lex) else None
  485. bv = b_lex[i]["value"] if i < len(b_lex) else None
  486. order = default_lex_order
  487. if i < len(a_lex) and a_lex[i].get("order"):
  488. order = a_lex[i]["order"]
  489. c = cmp_scalar(av, bv, order)
  490. if c != 0:
  491. return c
  492. for item in tie_breakers:
  493. field_name = str(item.get("field", "")).strip()
  494. order = str(item.get("order", "asc")).strip().lower()
  495. av = get_sort_field_value(a, field_name)
  496. bv = get_sort_field_value(b, field_name)
  497. c = cmp_scalar(av, bv, order)
  498. if c != 0:
  499. return c
  500. return cmp_scalar(a.get("domain"), b.get("domain"), "asc")
  501. return sorted(candidates, key=functools.cmp_to_key(compare))
  502. def apply_filter(domains, filter_cfg):
  503. include_suffixes = [s.lower() for s in filter_cfg.get("include_suffixes", []) if s]
  504. exclude_regex = [re.compile(x) for x in filter_cfg.get("exclude_regex", []) if x]
  505. out = []
  506. for d in domains:
  507. if include_suffixes and not any(d.endswith(s) for s in include_suffixes):
  508. continue
  509. if any(rx.search(d) for rx in exclude_regex):
  510. continue
  511. out.append(d)
  512. return out
  513. def single_tls_check(domain, timeout_ms, port, tls_verify=True):
  514. start = time.perf_counter()
  515. timeout_sec = max(0.2, timeout_ms / 1000.0)
  516. try:
  517. infos = socket.getaddrinfo(domain, port, proto=socket.IPPROTO_TCP)
  518. if not infos:
  519. return False, None, "dns_empty"
  520. af, socktype, proto, _, sockaddr = infos[0]
  521. with socket.socket(af, socktype, proto) as sock:
  522. sock.settimeout(timeout_sec)
  523. sock.connect(sockaddr)
  524. if tls_verify:
  525. ctx = ssl.create_default_context()
  526. else:
  527. ctx = ssl._create_unverified_context()
  528. with ctx.wrap_socket(sock, server_hostname=domain) as ssock:
  529. ssock.do_handshake()
  530. elapsed = int((time.perf_counter() - start) * 1000)
  531. return True, elapsed, "ok"
  532. except Exception as e:
  533. return False, None, str(e)
  534. def check_domains(domains, hc_cfg):
  535. attempts = int(hc_cfg.get("attempts", 2))
  536. timeout_ms = int(hc_cfg.get("timeout_ms", 1800))
  537. port = int(hc_cfg.get("port", 443))
  538. tls_verify = bool(hc_cfg.get("tls_verify", True))
  539. results = []
  540. for d in domains:
  541. ok_count = 0
  542. latencies = []
  543. errors = []
  544. for _ in range(attempts):
  545. ok, latency, err = single_tls_check(d, timeout_ms, port, tls_verify=tls_verify)
  546. if ok:
  547. ok_count += 1
  548. latencies.append(latency)
  549. else:
  550. errors.append(err)
  551. success_ratio = ok_count / attempts if attempts else 0.0
  552. avg_latency = int(sum(latencies) / len(latencies)) if latencies else 999999
  553. results.append(
  554. {
  555. "domain": d,
  556. "success_ratio": success_ratio,
  557. "avg_latency_ms": avg_latency,
  558. "ok_count": ok_count,
  559. "attempts": attempts,
  560. "errors": errors[:3],
  561. }
  562. )
  563. results.sort(key=lambda x: (-x["success_ratio"], x["avg_latency_ms"], x["domain"]))
  564. return results
  565. def render_v2ray(template_file, output_file, token, domain):
  566. if not template_file or not output_file:
  567. return False
  568. if not os.path.exists(template_file):
  569. return False
  570. with open(template_file, "r", encoding="utf-8") as f:
  571. tpl = f.read()
  572. rendered = tpl.replace(token, domain)
  573. os.makedirs(os.path.dirname(output_file), exist_ok=True)
  574. with open(output_file, "w", encoding="utf-8") as f:
  575. f.write(rendered)
  576. return True
  577. def run_notify(cmd, domain, status):
  578. if not cmd:
  579. return
  580. env = os.environ.copy()
  581. env["AUTODOMAIN"] = domain
  582. env["AUTODOMAIN_STATUS"] = status
  583. subprocess.run(cmd, shell=True, check=False, env=env)
  584. def choose_domain(filtered_domains, check_results, top_n, ranked_scored):
  585. if ranked_scored:
  586. domains_by_score = [x["domain"] for x in ranked_scored]
  587. if check_results:
  588. check_map = {x["domain"]: x for x in check_results}
  589. top = []
  590. for d in domains_by_score:
  591. if d in check_map and check_map[d]["success_ratio"] > 0:
  592. top.append(check_map[d])
  593. if len(top) >= top_n:
  594. break
  595. if top:
  596. return top[0]["domain"], top
  597. score_only = [
  598. {
  599. "domain": x["domain"],
  600. "score_value": x.get("score_value"),
  601. "scores": x.get("scores", []),
  602. "created_raw": x.get("created_raw"),
  603. }
  604. for x in ranked_scored[:top_n]
  605. ]
  606. return score_only[0]["domain"], score_only
  607. top_scored = [
  608. {
  609. "domain": x["domain"],
  610. "score_value": x.get("score_value"),
  611. "scores": x.get("scores", []),
  612. "created_raw": x.get("created_raw"),
  613. }
  614. for x in ranked_scored[:top_n]
  615. ]
  616. if top_scored:
  617. return top_scored[0]["domain"], top_scored
  618. if check_results:
  619. top = [x for x in check_results if x["success_ratio"] > 0][:top_n]
  620. if top:
  621. return top[0]["domain"], top
  622. return None, check_results[:top_n]
  623. if filtered_domains:
  624. return filtered_domains[0], [{"domain": x} for x in filtered_domains[:top_n]]
  625. return None, []
  626. def main():
  627. ap = argparse.ArgumentParser(description="Auto select VMess preferred domain")
  628. ap.add_argument("--config", default="config.json", help="Path to config JSON")
  629. args = ap.parse_args()
  630. config_path_abs = os.path.abspath(args.config)
  631. if not os.path.exists(config_path_abs):
  632. print(json.dumps({"status": "error", "error": f"config file not found: {config_path_abs}"}, ensure_ascii=True), file=sys.stderr)
  633. sys.exit(1)
  634. cfg = read_json_file(config_path_abs)
  635. try:
  636. validate_config(cfg)
  637. except Exception as e:
  638. print(json.dumps({"status": "error", "error": f"invalid config: {e}"}, ensure_ascii=True), file=sys.stderr)
  639. sys.exit(1)
  640. output_cfg = cfg.get("output", {})
  641. runtime_dir_cfg = output_cfg.get("runtime_dir", "./runtime")
  642. if os.path.isabs(runtime_dir_cfg):
  643. runtime_dir = runtime_dir_cfg
  644. else:
  645. runtime_dir = os.path.normpath(os.path.join(os.path.dirname(config_path_abs), runtime_dir_cfg))
  646. v2_cfg = cfg.get("v2ray", {})
  647. notify_cfg = cfg.get("notify", {})
  648. current_domain_file = os.path.join(runtime_dir, output_cfg.get("current_domain_file", "current_domain.txt"))
  649. current_domain_json = os.path.join(runtime_dir, output_cfg.get("current_domain_json", "current_domain.json"))
  650. state_file = os.path.join(runtime_dir, output_cfg.get("state_file", "state.json"))
  651. substore_vars_file = os.path.join(runtime_dir, output_cfg.get("substore_vars_file", "substore_vars.json"))
  652. state = read_json_file(state_file, default={})
  653. last_good = state.get("last_good_domain", "")
  654. try:
  655. payload = fetch_api_json(cfg)
  656. parsed = parse_domains(payload, cfg.get("parser", {}))
  657. filtered = apply_filter(parsed, cfg.get("domain_filter", {}))
  658. record_mapping_cfg = cfg.get("record_mapping", {})
  659. field_map = record_mapping_cfg.get("field_map", {})
  660. records = extract_records(payload, record_mapping_cfg)
  661. record_filter_cfg = cfg.get("record_filter", {})
  662. blocked_domains = collect_excluded_domains(records, field_map, record_filter_cfg)
  663. if blocked_domains:
  664. filtered = [d for d in filtered if d not in blocked_domains]
  665. scoring_cfg = cfg.get("scoring", {})
  666. scored_records = parse_scored_records(records, field_map, record_mapping_cfg, scoring_cfg)
  667. filtered_set = set(filtered)
  668. scored_records = [r for r in scored_records if r["domain"] in filtered_set]
  669. ranked_scored = rank_scored_records(scored_records, scoring_cfg)
  670. check_results = []
  671. if cfg.get("healthcheck", {}).get("enabled", True):
  672. check_results = check_domains(filtered, cfg.get("healthcheck", {}))
  673. top_n = int(cfg.get("selection", {}).get("top_n", 3))
  674. selected, top_candidates = choose_domain(filtered, check_results, top_n, ranked_scored)
  675. status = "ok"
  676. if not selected and last_good:
  677. selected = last_good
  678. status = "fallback_last_good"
  679. if not selected:
  680. raise RuntimeError("No valid domain available from API and no fallback in state")
  681. write_text_file(current_domain_file, selected + "\n")
  682. current_json = {
  683. "domain": selected,
  684. "updated_at": utc_now_iso(),
  685. "status": status,
  686. "source_count": len(parsed),
  687. "checked_count": len(check_results),
  688. "top_candidates": top_candidates,
  689. }
  690. write_json_file(current_domain_json, current_json)
  691. write_json_file(
  692. substore_vars_file,
  693. {
  694. "AUTO_DOMAIN": selected,
  695. "UPDATED_AT": current_json["updated_at"],
  696. "STATUS": status,
  697. },
  698. )
  699. rendered = render_v2ray(
  700. template_file=v2_cfg.get("template_file", ""),
  701. output_file=v2_cfg.get("output_file", ""),
  702. token=v2_cfg.get("replace_token", "__AUTO_DOMAIN__"),
  703. domain=selected,
  704. )
  705. new_state = {
  706. "updated_at": current_json["updated_at"],
  707. "last_good_domain": selected,
  708. "status": status,
  709. "source_count": len(parsed),
  710. "checked_count": len(check_results),
  711. "rendered_v2ray": rendered,
  712. }
  713. write_json_file(state_file, new_state)
  714. run_notify(notify_cfg.get("command", ""), selected, status)
  715. print(json.dumps(current_json, ensure_ascii=True))
  716. except Exception as e:
  717. now = utc_now_iso()
  718. err_state = {
  719. "updated_at": now,
  720. "status": "error",
  721. "error": str(e),
  722. "last_good_domain": last_good,
  723. }
  724. write_json_file(state_file, err_state)
  725. if last_good:
  726. write_text_file(current_domain_file, last_good + "\n")
  727. write_json_file(
  728. current_domain_json,
  729. {
  730. "domain": last_good,
  731. "updated_at": now,
  732. "status": "error_use_last_good",
  733. "error": str(e),
  734. },
  735. )
  736. run_notify(notify_cfg.get("command", ""), last_good, "error_use_last_good")
  737. print(json.dumps({"status": "error_use_last_good", "error": str(e)}, ensure_ascii=True))
  738. return
  739. print(json.dumps({"status": "error", "error": str(e)}, ensure_ascii=True), file=sys.stderr)
  740. sys.exit(1)
  741. if __name__ == "__main__":
  742. main()