Add GPU-aware launch and testing docs
This commit is contained in:
@@ -7,6 +7,7 @@ import logging
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
import os
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
@@ -25,6 +26,49 @@ DATE_FORMATS = (
|
||||
"%B %d, %Y",
|
||||
)
|
||||
|
||||
GPU_ACCEL_ENV = "ENABLE_GPU"
|
||||
|
||||
|
||||
def parse_env_flag(value, default=False):
|
||||
if value is None:
|
||||
return default
|
||||
return str(value).strip().lower() in ("1", "true", "yes", "on")
|
||||
|
||||
|
||||
def detect_gpu_available():
|
||||
env_value = os.getenv(GPU_ACCEL_ENV)
|
||||
if env_value is not None:
|
||||
return parse_env_flag(env_value, default=False)
|
||||
|
||||
nvidia_visible = os.getenv("NVIDIA_VISIBLE_DEVICES")
|
||||
if nvidia_visible and nvidia_visible.lower() not in ("none", "void", "off"):
|
||||
return True
|
||||
|
||||
if os.path.exists("/dev/nvidia0"):
|
||||
return True
|
||||
|
||||
if os.path.exists("/dev/dri/renderD128") or os.path.exists("/dev/dri/card0"):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def chromium_launch_args():
|
||||
if not detect_gpu_available():
|
||||
return []
|
||||
|
||||
if os.name == "nt":
|
||||
return ["--enable-gpu"]
|
||||
|
||||
return [
|
||||
"--enable-gpu",
|
||||
"--ignore-gpu-blocklist",
|
||||
"--disable-software-rasterizer",
|
||||
"--use-gl=egl",
|
||||
"--enable-zero-copy",
|
||||
"--enable-gpu-rasterization",
|
||||
]
|
||||
|
||||
|
||||
def parse_date(value):
|
||||
for fmt in DATE_FORMATS:
|
||||
@@ -396,7 +440,12 @@ def scrape_yahoo_options(symbol, expiration=None, strike_limit=25):
|
||||
fallback_to_base = False
|
||||
|
||||
with sync_playwright() as p:
|
||||
browser = p.chromium.launch(headless=True)
|
||||
launch_args = chromium_launch_args()
|
||||
if launch_args:
|
||||
app.logger.info("GPU acceleration enabled")
|
||||
else:
|
||||
app.logger.info("GPU acceleration disabled")
|
||||
browser = p.chromium.launch(headless=True, args=launch_args)
|
||||
page = browser.new_page()
|
||||
page.set_extra_http_headers(
|
||||
{
|
||||
|
||||
199
scripts/test_cycles.py
Normal file
199
scripts/test_cycles.py
Normal file
@@ -0,0 +1,199 @@
|
||||
import argparse
|
||||
import datetime
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
import urllib.parse
|
||||
import urllib.request
|
||||
|
||||
DEFAULT_STOCKS = ["AAPL", "AMZN", "MSFT", "TSLA"]
|
||||
DEFAULT_CYCLES = [None, 5, 10, 25, 50, 75, 100, 150, 200, 500]
|
||||
|
||||
|
||||
def http_get(base_url, params, timeout):
|
||||
query = urllib.parse.urlencode(params)
|
||||
url = f"{base_url}?{query}"
|
||||
with urllib.request.urlopen(url, timeout=timeout) as resp:
|
||||
return json.loads(resp.read().decode("utf-8"))
|
||||
|
||||
|
||||
def expected_code_from_epoch(epoch):
|
||||
return datetime.datetime.utcfromtimestamp(epoch).strftime("%y%m%d")
|
||||
|
||||
|
||||
def all_contracts_match(opts, expected_code):
|
||||
for opt in opts:
|
||||
name = opt.get("Contract Name") or ""
|
||||
if expected_code not in name:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def parse_list(value, default):
|
||||
if not value:
|
||||
return default
|
||||
return [item.strip() for item in value.split(",") if item.strip()]
|
||||
|
||||
|
||||
def parse_cycles(value):
|
||||
if not value:
|
||||
return DEFAULT_CYCLES
|
||||
cycles = []
|
||||
for item in value.split(","):
|
||||
token = item.strip().lower()
|
||||
if not token or token in ("default", "none"):
|
||||
cycles.append(None)
|
||||
continue
|
||||
try:
|
||||
cycles.append(int(token))
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid strikeLimit value: {item}")
|
||||
return cycles
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Yahoo options scraper test cycles")
|
||||
parser.add_argument(
|
||||
"--base-url",
|
||||
default="http://127.0.0.1:9777/scrape_sync",
|
||||
help="Base URL for /scrape_sync",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stocks",
|
||||
default=",".join(DEFAULT_STOCKS),
|
||||
help="Comma-separated stock symbols",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--strike-limits",
|
||||
default="default,5,10,25,50,75,100,150,200,500",
|
||||
help="Comma-separated strike limits (use 'default' for the API default)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--baseline-limit",
|
||||
type=int,
|
||||
default=5000,
|
||||
help="Large strikeLimit used to capture all available strikes",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timeout",
|
||||
type=int,
|
||||
default=180,
|
||||
help="Request timeout in seconds",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sleep",
|
||||
type=float,
|
||||
default=0.2,
|
||||
help="Sleep between requests",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
stocks = parse_list(args.stocks, DEFAULT_STOCKS)
|
||||
cycles = parse_cycles(args.strike_limits)
|
||||
|
||||
print("Fetching expiration lists...")
|
||||
expirations = {}
|
||||
for stock in stocks:
|
||||
data = http_get(args.base_url, {"stock": stock, "expiration": "invalid"}, args.timeout)
|
||||
if "available_expirations" not in data:
|
||||
print(f"ERROR: missing available_expirations for {stock}: {data}")
|
||||
sys.exit(1)
|
||||
values = [opt.get("value") for opt in data["available_expirations"] if opt.get("value")]
|
||||
if len(values) < 4:
|
||||
print(f"ERROR: not enough expirations for {stock}: {values}")
|
||||
sys.exit(1)
|
||||
expirations[stock] = values[:4]
|
||||
print(f" {stock}: {expirations[stock]}")
|
||||
time.sleep(args.sleep)
|
||||
|
||||
print("\nBuilding baseline counts (strikeLimit=%d)..." % args.baseline_limit)
|
||||
baseline_counts = {}
|
||||
for stock, exp_list in expirations.items():
|
||||
for exp in exp_list:
|
||||
data = http_get(
|
||||
args.base_url,
|
||||
{"stock": stock, "expiration": exp, "strikeLimit": args.baseline_limit},
|
||||
args.timeout,
|
||||
)
|
||||
if "error" in data:
|
||||
print(f"ERROR: baseline error for {stock} {exp}: {data}")
|
||||
sys.exit(1)
|
||||
calls_count = data.get("total_calls")
|
||||
puts_count = data.get("total_puts")
|
||||
if calls_count is None or puts_count is None:
|
||||
print(f"ERROR: baseline missing counts for {stock} {exp}: {data}")
|
||||
sys.exit(1)
|
||||
expected_code = expected_code_from_epoch(exp)
|
||||
if not all_contracts_match(data.get("calls", []), expected_code):
|
||||
print(f"ERROR: baseline calls mismatch for {stock} {exp}")
|
||||
sys.exit(1)
|
||||
if not all_contracts_match(data.get("puts", []), expected_code):
|
||||
print(f"ERROR: baseline puts mismatch for {stock} {exp}")
|
||||
sys.exit(1)
|
||||
baseline_counts[(stock, exp)] = (calls_count, puts_count)
|
||||
print(f" {stock} {exp}: calls={calls_count} puts={puts_count}")
|
||||
time.sleep(args.sleep)
|
||||
|
||||
print("\nRunning %d cycles of API tests..." % len(cycles))
|
||||
for idx, strike_limit in enumerate(cycles, start=1):
|
||||
print(f"Cycle {idx}/{len(cycles)} (strikeLimit={strike_limit})")
|
||||
for stock, exp_list in expirations.items():
|
||||
for exp in exp_list:
|
||||
params = {"stock": stock, "expiration": exp}
|
||||
if strike_limit is not None:
|
||||
params["strikeLimit"] = strike_limit
|
||||
data = http_get(args.base_url, params, args.timeout)
|
||||
if "error" in data:
|
||||
print(f"ERROR: {stock} {exp} -> {data}")
|
||||
sys.exit(1)
|
||||
selected_val = data.get("selected_expiration", {}).get("value")
|
||||
if selected_val != exp:
|
||||
print(
|
||||
f"ERROR: selected expiration mismatch for {stock} {exp}: {selected_val}"
|
||||
)
|
||||
sys.exit(1)
|
||||
expected_code = expected_code_from_epoch(exp)
|
||||
if not all_contracts_match(data.get("calls", []), expected_code):
|
||||
print(f"ERROR: calls expiry mismatch for {stock} {exp}")
|
||||
sys.exit(1)
|
||||
if not all_contracts_match(data.get("puts", []), expected_code):
|
||||
print(f"ERROR: puts expiry mismatch for {stock} {exp}")
|
||||
sys.exit(1)
|
||||
available_calls, available_puts = baseline_counts[(stock, exp)]
|
||||
expected_limit = strike_limit if strike_limit is not None else 25
|
||||
expected_calls = min(expected_limit, available_calls)
|
||||
expected_puts = min(expected_limit, available_puts)
|
||||
if data.get("total_calls") != expected_calls:
|
||||
print(
|
||||
f"ERROR: call count mismatch for {stock} {exp}: "
|
||||
f"got {data.get('total_calls')} expected {expected_calls}"
|
||||
)
|
||||
sys.exit(1)
|
||||
if data.get("total_puts") != expected_puts:
|
||||
print(
|
||||
f"ERROR: put count mismatch for {stock} {exp}: "
|
||||
f"got {data.get('total_puts')} expected {expected_puts}"
|
||||
)
|
||||
sys.exit(1)
|
||||
expected_pruned_calls = max(0, available_calls - expected_calls)
|
||||
expected_pruned_puts = max(0, available_puts - expected_puts)
|
||||
if data.get("pruned_calls_count") != expected_pruned_calls:
|
||||
print(
|
||||
f"ERROR: pruned calls mismatch for {stock} {exp}: "
|
||||
f"got {data.get('pruned_calls_count')} expected {expected_pruned_calls}"
|
||||
)
|
||||
sys.exit(1)
|
||||
if data.get("pruned_puts_count") != expected_pruned_puts:
|
||||
print(
|
||||
f"ERROR: pruned puts mismatch for {stock} {exp}: "
|
||||
f"got {data.get('pruned_puts_count')} expected {expected_pruned_puts}"
|
||||
)
|
||||
sys.exit(1)
|
||||
time.sleep(args.sleep)
|
||||
print(f"Cycle {idx} OK")
|
||||
|
||||
print("\nAll cycles completed successfully.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user