From aa9334d37281faec3d2723a8cbec95ea3c4ff536 Mon Sep 17 00:00:00 2001 From: Oliver Gaskell Date: Sun, 2 Nov 2025 12:40:31 +0000 Subject: [PATCH] engine: refactoring, caching --- engine.py | 114 ++++++++++++++++++++++++++++++++++-------------------- 1 file changed, 72 insertions(+), 42 deletions(-) diff --git a/engine.py b/engine.py index d843e26..4dea71a 100755 --- a/engine.py +++ b/engine.py @@ -25,6 +25,9 @@ FIRST_SEP = ':' OTHER_SEP = ',' LOC_SEP = ';' +cached_dists = {} +cached_series = {} + EncodedLocation = list[tuple[float, list[float]]] @@ -85,6 +88,37 @@ def spherical_dist(pos1, pos2, r=6378137): return r * np.arccos(cos_lat_d - cos_lat1 * cos_lat2 * (1 - cos_lon_d)) +def get_dist_matrix(brand: str): + try: + return cached_dists[brand] + except KeyError: + greggs = np.array(fetch_data(brand)) + + repeat_rows = np.tile(greggs, (len(greggs), 1, 1)) + repeat_cols = np.transpose(repeat_rows, (1, 0, 2)) + dist_matrix = spherical_dist(repeat_rows, repeat_cols) + + cached_dists[brand] = dist_matrix + + return dist_matrix + + +def get_dist_series_list(brand: str, n_locs: int): + if brand in cached_series and cached_series[brand][0] == n_locs: + return cached_series[brand][1] + + dist_matrix = get_dist_matrix("greggs") + + # split the distances matrix into a list of series, which allows us to sort each row + dist_series_list = [] + for i in dist_matrix: + dist_series_list.append(pd.Series(i).sort_values().head(n_locs+1)[1:]) + + cached_series[brand] = (n_locs, dist_series_list) + + return dist_series_list + + # (lat, lon), dist StationT = tuple[tuple[float, float], float] @@ -102,7 +136,39 @@ def trilaterate(stations: list[StationT], disp: bool = False) -> tuple[float, fl Each station is of the format ((lat, lon), distance). """ - return scipy.optimize.fmin(lambda pos: trilat_error(stations, pos), (0., 0.), disp=disp) + res = scipy.optimize.minimize( + lambda pos: trilat_error(stations, pos), + stations[0][0], + method='Nelder-Mead', + ) + + if res.success: + return res.x + else: + raise ValueError("Optimisation failed.") + + +def encode_greggs(loc: int) -> list[float]: + """Given the id of a greggs, encode as a list of distances.""" + + dist_matrix = get_dist_matrix("greggs") + greggs_distances = np.sort(dist_matrix[loc])[1:DISTS_COUNT+1] + + return list(map(float, greggs_distances)) + + +def decode_greggs(distances: list[float]) -> int: + """Get the id of a greggs given a list of distances.""" + + dist_series_list = get_dist_series_list("greggs", len(distances)) + + errors = [sum((pd.Series(j) - distances) ** 2) for j in dist_series_list] + + minerr = min(errors) + if minerr > 1: + print(f"warning: high error value of {minerr}") + + return errors.index(min(errors)) def encode(location: tuple[float, float]) -> EncodedLocation: @@ -110,57 +176,21 @@ def encode(location: tuple[float, float]) -> EncodedLocation: greggs = np.array(fetch_data("greggs")) - repeat_rows = np.tile(greggs, (len(greggs), 1, 1)) - repeat_cols = np.transpose(repeat_rows, (1, 0, 2)) - dist_matrix = spherical_dist(repeat_rows, repeat_cols) - repeated = np.tile(location, (len(greggs), 1)) - distances = spherical_dist(repeated, greggs) - distances = pd.Series(distances) - distances = distances.sort_values() - + distances = pd.Series(spherical_dist(repeated, greggs)).sort_values() closest = distances.head(LOCS_COUNT) - result: EncodedLocation = [] - for v, i in zip(closest.values, closest.index): - greggs_distances = np.sort(dist_matrix[i])[1:DISTS_COUNT+1] - - result.append((v, list(map(float, greggs_distances)))) - - # Stub - return result + return [(v, encode_greggs(i)) for v, i in zip(closest.values, closest.index)] def decode(location: EncodedLocation, disp: bool = False) -> tuple[float, float]: """Decode into a location.""" - - # form the distances matrix - greggs_raw = fetch_data("greggs") - greggs = np.array(greggs_raw) - repeat_rows = np.tile(greggs, (len(greggs), 1, 1)) - repeat_cols = np.transpose(repeat_rows, (1, 0, 2)) - dist_matrix = spherical_dist(repeat_rows, repeat_cols) - - # split the distances matrix into a list of series, which allows us to sort each row - dist_series_list = [] - for i in dist_matrix: - dist_series_list.append(pd.Series(i).sort_values().head(len(location[0][1])+1)[1:]) - # part 1: find the ID of each gregg's - closest_greggs = [] - for loc in location: - dists = loc[1] - - errors = [sum((j - dists) ** 2) for j in dist_series_list] - - minerr = min(errors) - if minerr > 1: - print(f"warning: high error value of {minerr}") - - closest_greggs.append(errors.index(min(errors))) + closest_greggs = [decode_greggs(dists) for _, dists in location] # part 2: trilaterate - stations: list[StationT] = [(greggs_raw[g], location[i][0]) for i, g in enumerate(closest_greggs)] + greggs = fetch_data("greggs") + stations: list[StationT] = [(greggs[g], location[i][0]) for i, g in enumerate(closest_greggs)] return trilaterate(stations, disp=disp)