engine: implement encode, fetch_data caching, format improvements

This commit is contained in:
Oliver Gaskell 2025-11-01 18:36:42 +00:00
parent 78228cb747
commit 3255e2e30f
No known key found for this signature in database
GPG key ID: F971A08925FCC0AD

View file

@ -1,23 +1,40 @@
#!/usr/bin/env python #!/usr/bin/env python
import json
import os
import overpy import overpy
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from tqdm import tqdm from pathlib import Path
# brandname : overpass query filters # brandname : overpass query filters
BRANDS: dict[str, str] = { BRANDS: dict[str, str] = {
"greggs": "[\"brand:wikidata\"=\"Q3403981\"]", "greggs": "[\"brand:wikidata\"=\"Q3403981\"]",
"tesco": "[\"brand:wikidata\"~\"^(Q487494|Q98456772|Q25172225|Q65954217)$\"]", # Includes Tesco Express, Tesco Extra, and One Stop
} }
DATA_FOLDER = "" CACHE_FOLDER = Path(".cache")
LOCS_COUNT = 3
DISTS_COUNT = 100
FORMAT_FACTOR = 1e6 # μm
EncodedLocation = list[tuple[float, list[float]]] EncodedLocation = list[tuple[float, list[float]]]
def fetch_data(brand: str) -> list[tuple[float, float]]: def fetch_data(brand: str, cache: bool = True) -> list[tuple[float, float]]:
"""Fetch a list of locations from OSM.""" """Fetch a list of locations from OSM."""
cache_loc = (CACHE_FOLDER / f"{brand}.json")
# Try load from cache
if cache and cache_loc.exists():
with open(cache_loc, "r") as f:
data = json.load(f)
return data
api = overpy.Overpass() api = overpy.Overpass()
filters = BRANDS[brand] filters = BRANDS[brand]
@ -35,6 +52,16 @@ def fetch_data(brand: str) -> list[tuple[float, float]]:
if (lat is None) or (lon is None): if (lat is None) or (lon is None):
raise ValueError("Item missing coords!") raise ValueError("Item missing coords!")
# Save to cache
if cache:
if not CACHE_FOLDER.exists():
os.makedirs(CACHE_FOLDER)
with open(cache_loc, "w") as f:
json.dump(result, f)
print(f"Got {len(result)} {brand}s")
return result return result
@ -59,28 +86,23 @@ def encode(location: tuple[float, float]) -> EncodedLocation:
repeat_cols = np.transpose(repeat_rows, (1, 0, 2)) repeat_cols = np.transpose(repeat_rows, (1, 0, 2))
dist_matrix = spherical_dist(repeat_rows, repeat_cols) dist_matrix = spherical_dist(repeat_rows, repeat_cols)
#find closest greggs
# distances = pd.Series(np.zeros(len(greggs)))
# for i in range(len(greggs)):
# current = greggs[i]
# #distances[i] = np.sqrt((current[0]-location[0])**2 + (current[1]-location[1])**2
# distances[i] = distance.distance(distance.lonlat(*current), distance.lonlat(*location)).km*1000
repeated = np.tile(location, (len(greggs), 1)) repeated = np.tile(location, (len(greggs), 1))
distances = spherical_dist(repeated, greggs) distances = spherical_dist(repeated, greggs)
distances = pd.Series(distances) distances = pd.Series(distances)
print(distances)
distances = distances.sort_values() distances = distances.sort_values()
top3 = distances.head(3)
print(top3)
closest = distances.head(LOCS_COUNT)
closest_dist = list(closest.values)
closest_ind = list(closest.index)
result: EncodedLocation = []
for v, i in zip(closest.values, closest.index):
greggs_distances = np.sort(dist_matrix[i])[1:DISTS_COUNT+1]
result.append((v, list(map(float, greggs_distances))))
# Stub # Stub
return [ return result
(5., [1., 2., 3.]),
(6., [4., 5., 6.]),
]
def decode(location: EncodedLocation) -> tuple[float, float]: def decode(location: EncodedLocation) -> tuple[float, float]:
@ -90,9 +112,17 @@ def decode(location: EncodedLocation) -> tuple[float, float]:
return (0.091659, 52.210796) return (0.091659, 52.210796)
def format_dist(dist: float) -> str:
return f"{int(round(dist * FORMAT_FACTOR))}"
def parse_dist(dist: str) -> float:
return float(dist) / FORMAT_FACTOR
def format_location(location: EncodedLocation) -> str: def format_location(location: EncodedLocation) -> str:
"""Format an encoded location as a string.""" """Format an encoded location as a string."""
return ";".join([f"{a}:{','.join(map(str, b))}" for (a, b) in location]) return ";\n".join([f"{format_dist(a)}:{','.join(map(format_dist, b))}" for (a, b) in location])
def parse_location(location: str) -> EncodedLocation: def parse_location(location: str) -> EncodedLocation:
@ -107,11 +137,11 @@ def parse_location(location: str) -> EncodedLocation:
def main(): def main():
"""Testing.""" """Testing."""
print("Running query...") #print("Running query...")
#greggs = fetch_data("greggs") #greggs = fetch_data("greggs")
#print(f"Query done - got {len(greggs)} Greggs!") #print(f"Query done - got {len(greggs)} Greggs!")
print(format_location(encode((52.210796, 0.091659)))) # print(format_location(encode((52.210796, 0.091659))))
if __name__ == "__main__": if __name__ == "__main__":