engine: complete decoder

This commit is contained in:
Oliver Gaskell 2025-11-02 10:33:19 +00:00
parent fe49a88700
commit b03175f9d9
No known key found for this signature in database
GPG key ID: F971A08925FCC0AD
2 changed files with 32 additions and 28 deletions

View file

@ -5,6 +5,7 @@ import os
import overpy import overpy
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import scipy
from pathlib import Path from pathlib import Path
@ -66,7 +67,11 @@ def fetch_data(brand: str, cache: bool = True) -> list[tuple[float, float]]:
def spherical_dist(pos1, pos2, r=6378137): def spherical_dist(pos1, pos2, r=6378137):
"""Calculate sperical distances between two arrays of coordinates.""" """Calculate sperical distances between two arrays of coordinates.
Return value is the same unit as `r`.
`r` defaults to the radius of the earth, in meters.
"""
pos1 = pos1 * np.pi / 180 pos1 = pos1 * np.pi / 180
pos2 = pos2 * np.pi / 180 pos2 = pos2 * np.pi / 180
@ -77,6 +82,7 @@ def spherical_dist(pos1, pos2, r=6378137):
return r * np.arccos(cos_lat_d - cos_lat1 * cos_lat2 * (1 - cos_lon_d)) return r * np.arccos(cos_lat_d - cos_lat1 * cos_lat2 * (1 - cos_lon_d))
# (lat, lon), dist
StationT = tuple[tuple[float, float], float] StationT = tuple[tuple[float, float], float]
@ -93,9 +99,7 @@ def trilaterate(stations: list[StationT]) -> tuple[float, float]:
Each station is of the format ((lat, lon), distance). Each station is of the format ((lat, lon), distance).
""" """
# TODO scipy optimise using trilat_error return scipy.optimize.fmin(lambda pos: trilat_error(stations, pos), (0., 0.))
return (0., 0.)
def encode(location: tuple[float, float]) -> EncodedLocation: def encode(location: tuple[float, float]) -> EncodedLocation:
@ -113,8 +117,6 @@ def encode(location: tuple[float, float]) -> EncodedLocation:
distances = distances.sort_values() distances = distances.sort_values()
closest = distances.head(LOCS_COUNT) closest = distances.head(LOCS_COUNT)
closest_dist = list(closest.values)
closest_ind = list(closest.index)
result: EncodedLocation = [] result: EncodedLocation = []
for v, i in zip(closest.values, closest.index): for v, i in zip(closest.values, closest.index):
@ -130,7 +132,8 @@ def decode(location: EncodedLocation) -> tuple[float, float]:
"""Decode into a location.""" """Decode into a location."""
# form the distances matrix # form the distances matrix
greggs = np.array(fetch_data("greggs")) greggs_raw = fetch_data("greggs")
greggs = np.array(greggs_raw)
repeat_rows = np.tile(greggs, (len(greggs), 1, 1)) repeat_rows = np.tile(greggs, (len(greggs), 1, 1))
repeat_cols = np.transpose(repeat_rows, (1, 0, 2)) repeat_cols = np.transpose(repeat_rows, (1, 0, 2))
dist_matrix = spherical_dist(repeat_rows, repeat_cols) dist_matrix = spherical_dist(repeat_rows, repeat_cols)
@ -142,25 +145,20 @@ def decode(location: EncodedLocation) -> tuple[float, float]:
# part 1: find the ID of each gregg's # part 1: find the ID of each gregg's
closest_greggs = [] closest_greggs = []
for i in range(len(location)): for loc in location:
dists = location[i][1] dists = loc[1]
errors = [sum((j - dists) ** 2) for j in dist_series_list]
errors = []
for j in dist_series_list:
errors.append(sum((j-dists)**2))
minerr = min(errors) minerr = min(errors)
if minerr > 1: if minerr > 1:
print(f"warning: high error value of {minerr}") print(f"warning: high error value of {minerr}")
closest_greggs = [errors.index(min(errors))] closest_greggs.append(errors.index(min(errors)))
# part 2: trilaterate # part 2: trilaterate
stations: list[StationT] = [(greggs_raw[g], location[i][0]) for i, g in enumerate(closest_greggs)]
# Stub return trilaterate(stations)
return (0.091659, 52.210796)
def format_dist(dist: float) -> str: def format_dist(dist: float) -> str:
@ -188,12 +186,17 @@ def parse_location(location: str) -> EncodedLocation:
def main(): def main():
"""Testing.""" """Testing."""
#print("Running query...") coords = (52.210796, 0.091659)
#greggs = fetch_data("greggs") print("Original:", coords)
#print(f"Query done - got {len(greggs)} Greggs!")
outcome = encode((52.210796, 0.091659)) outcome = encode(coords)
decode(outcome) print("Encoded:", outcome)
decoded = decode(outcome)
print("Decoded:", decoded)
error = spherical_dist(np.array(coords), np.array(decoded))
print(f"Error: {error:.10f}m")
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -2,3 +2,4 @@ overpy
django django
numpy numpy
pandas pandas
scipy