Source code for Munin.PriceList.SolutionCube

import xarray as xr
import numpy as np
import pandas as pd
import hashlib
import json
import time
from datetime import datetime, timezone
from multiprocessing import Pool, cpu_count
from typing import Dict, Any, Tuple, Type, Optional
from functools import partial
from tqdm import tqdm

# Import your project's classes
from Munin.Timber.SweTimber import SweTimber
from Munin.PriceList.PriceList import Pricelist, create_pricelist_from_data
from Munin.Taper.Taper import Taper
from Munin.TimberBucking.Nasberg_1985 import Nasberg_1985_BranchBound, BuckingConfig

def _hash_pricelist(price_data: Dict[str, Any]) -> str:
    """Creates a SHA256 hash of a pricelist dictionary for validation."""
    # Using json.dumps with sort_keys ensures a consistent string representation
    dhash = hashlib.sha256()
    encoded = json.dumps(price_data, sort_keys=True).encode()
    dhash.update(encoded)
    return dhash.hexdigest()

def _worker_buck_one_tree(tree_params: Tuple[str, int, int], pricelist_data: Dict, taper_model_class: Type[Taper]):
    """
    A top-level function for a single tree optimization.
    This is what each parallel process will execute.
    """
    species, dbh_cm, height_dm = tree_params
    height_m = height_dm / 10.0

    try:
        timber = SweTimber(species=species, diameter_cm=dbh_cm, height_m=height_m)
        pricelist = create_pricelist_from_data(pricelist_data, species,)
        optimizer = Nasberg_1985_BranchBound(timber, pricelist, taper_model_class)
        
        # We need the full result to get the sections
        result = optimizer.calculate_tree_value(min_diam_dead_wood=99,config=BuckingConfig(save_sections=True))

        if result.sections is None:
            sections_data = []
        else:
            sections_data = [s.__dict__ for s in result.sections]

        sections_json = json.dumps(
            sections_data,
            default=lambda o: o.item() if isinstance(o, np.generic) else str(o)
        )
        
        return {
            "species": species,
            "dbh": dbh_cm,
            "height": height_m,
            "total_value": result.total_value,
            "solution_sections": sections_json
        }
    except Exception as e:
        # Log or handle errors for specific tree combinations
        print(f"Error processing {species} DBH={dbh_cm} H={height_m}: {e}")
        return {
            "species": species,
            "dbh": dbh_cm,
            "height": height_m,
            "total_value": np.nan,
            "solution_sections": '[]'
        }


[docs] class SolutionCube: def __init__(self, dataset: xr.Dataset): """ Initializes the SolutionCube with a loaded xarray Dataset. It's recommended to use the `load` classmethod to create instances. """ self.dataset = dataset self.pricelist_hash = dataset.attrs.get('pricelist_hash') self.taper_model = dataset.attrs.get('taper_model')
[docs] @classmethod def generate( cls, pricelist_data: Dict[str, Any], taper_model: Type[Taper], species_list: list[str], dbh_range: Tuple[float, float], height_range: Tuple[float, float], dbh_step: int = 2, height_step: float = 0.2, workers: int = -1 ): """ Generates the solution cube by running the optimizer in parallel. """ if workers == -1: workers = cpu_count() print(f"Generating Solution Cube using {workers} parallel processes...") pricelist_hash = _hash_pricelist(pricelist_data) print(f"Pricelist hash: {pricelist_hash}") # Create the grid of all tree parameters to compute dbh_coords = np.arange(dbh_range[0], dbh_range[1] + dbh_step, dbh_step) height_coords = np.arange(height_range[0], height_range[1] + height_step, height_step) tasks = [ (sp, int(dbh), int(h*10)) for sp in species_list for dbh in dbh_coords for h in height_coords ] print(f"Total trees to process: {len(tasks)}") # Use a partial function to pass the static pricelist and taper model to the worker worker_func = partial(_worker_buck_one_tree, pricelist_data=pricelist_data, taper_model_class=taper_model) # Run the optimizations in parallel start_time = time.time() with Pool(processes=workers) as pool: # imap_unordered is great for getting results as they complete results = tqdm( list(pool.imap_unordered(worker_func, tasks, chunksize=10)), total=len(tasks), desc='Generating Solution Cube' ) end_time = time.time() print(f"\nFinished parallel computation in {end_time - start_time:.2f} seconds.") # --- Structure the results into an xarray Dataset --- # Convert flat list of dicts to a DataFrame for easier manipulation df = pd.DataFrame(results) df = df.set_index(['species', 'height', 'dbh']) # Convert to an xarray Dataset ds = xr.Dataset.from_dataframe(df) # Add metadata as attributes ds.attrs['pricelist_hash'] = pricelist_hash ds.attrs['taper_model'] = taper_model.__name__ ds.attrs['creation_date_utc'] = datetime.now(timezone.utc).isoformat() ds.attrs['dbh_range'] = f"{dbh_range[0]}-{dbh_range[1]} cm" ds.attrs['height_range'] = f"{height_range[0]}-{height_range[1]} m" print("Successfully created xarray Dataset.") return cls(ds)
[docs] def save(self, path: str): """Saves the dataset to a netCDF file.""" print(f"Saving solution cube to {path}...") self.dataset.to_netcdf(path) print("Save complete.")
[docs] @classmethod def load(cls, path: str, pricelist_to_verify: Optional[Dict] = None): """Loads a solution cube from a netCDF file.""" print(f"Loading solution cube from {path}...") ds = xr.open_dataset(path) if pricelist_to_verify: new_hash = _hash_pricelist(pricelist_to_verify) if ds.attrs.get('pricelist_hash') != new_hash: raise ValueError( "Pricelist hash mismatch! The loaded cube was not generated with the provided pricelist." ) print("Pricelist hash verified.") print("Cube loaded successfully.") return cls(ds)
[docs] def lookup(self, species: str, dbh: float, height: float) -> Tuple[float, list]: """ Performs a fast lookup for a given tree's properties. Uses nearest-neighbor interpolation. """ try: # .sel is xarray's powerful selection method. 'nearest' finds the closest point. solution = self.dataset.sel( species=species, dbh=dbh, height=height, method='nearest' ) total_value = float(solution['total_value'].values) sections_json = str(solution['solution_sections'].values) sections = json.loads(sections_json) return total_value, sections except KeyError: print(f"Warning: Species '{species}' not found in the solution cube.") return 0.0, [] except Exception as e: print(f"An error occurred during lookup: {e}") return 0.0, []