From cc02dd122bb5503edebdb2957d22ba70f670923b Mon Sep 17 00:00:00 2001 From: Oliver Gaskell Date: Sun, 2 Nov 2025 11:11:29 +0000 Subject: [PATCH] engine_test: create testing script --- engine_test.py | 42 ++++++++++++++++++++++++++++++++++++++++++ requirements.txt | 1 + 2 files changed, 43 insertions(+) create mode 100755 engine_test.py diff --git a/engine_test.py b/engine_test.py new file mode 100755 index 0000000..544d75b --- /dev/null +++ b/engine_test.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python + +from engine import encode, decode, format_location, parse_location, spherical_dist, fetch_data +from tqdm import tqdm +import numpy as np + +DIST_THRESH = 0.1 + + +def gen_locations() -> list[tuple[float, float]]: + """Generate a list of locations for testing.""" + # return [(52.210796, 0.091659)] + return fetch_data("tesco") + + +def test_encode_decode(): + locations = gen_locations() + passed = 0 + fails = 0 + dists = [] + + try: + for loc in tqdm(locations): + encoded = format_location(encode(loc)) + decoded = decode(parse_location(encoded)) + dist = spherical_dist(np.array(loc), np.array(decoded)) + dists.append(dist) + + if dist > DIST_THRESH: + tqdm.write(f"FAIL:\n\tloc: {loc}\n\tdecoded: {decoded}\n\tdist: {dist}") + fails += 1 + else: + passed += 1 + except KeyboardInterrupt: + print("KeyboardInterrupt - halting...") + + avg_dist = sum(dists) / len(dists) + print(f"\nDONE. Passed: {passed}. Failed: {fails}. Average error: {avg_dist:.3f}m") + + +if __name__ == "__main__": + test_encode_decode() diff --git a/requirements.txt b/requirements.txt index b6e53de..11fd1e3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,3 +3,4 @@ django numpy pandas scipy +tqdm