-
Notifications
You must be signed in to change notification settings - Fork 0
Ml pipeline #11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Ml pipeline #11
Conversation
…ons and a readme along with requirement.txt.
…ons and a readme along with requirement.txt.
…ons and a readme along with requirement.txt.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a nice PR - your scripts look easy to use to an outsider. I would just suggest that you look back at the readme again - did chatgpt write it? 😉
ml_pipeline/requirememts.txt
Outdated
@@ -0,0 +1,7 @@ | |||
python==3.10.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The python version doesn't actually live in this file - this is just for python libraries. The reason is that you can't pip install python==3.10.0
but you can pip install everything else 😄
ml_pipeline/readme.md
Outdated
|
||
Usage: | ||
```bash | ||
python predict_with_model.py --config_file /path/to/config.json |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here. Your prediction.py
file (not the different name) hard-codes a config.py
path.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for making the changes Satyam. With a couple of exceptions that I pointed out inline, the readme is a lot better now.
Looking through it though, I am a bit confused with where a couple of functions are defined. Is this actually the version of the code that you used to get your results? If you've refactored it since running it, could you rerun it please, to make sure it all works? Then we can merge it!
|
||
### 1. Generating Training Points | ||
|
||
The script `generate_training_points.py` takes a raster dataset, randomly samples specific classes, and creates a GeoDataFrame containing the sampled points. The sampled points serve as training data for classification. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be generate_training_sample.py
?
|
||
Usage: | ||
```bash | ||
python generate_training_points.py --raster_path /path/to/raster/file.tif --num_samples 100 --target_classes 1 2 3 --export_path /path/to/export.shp |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here
|
||
### 3. Model Training using Training Points | ||
|
||
The script `train_model.py` loads the GeoDataFrame generated in the first step, processes the data, sets up a PyCaret experiment, creates a classification model, and saves the trained model along with evaluation plots and reports. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be training.py
?
|
||
Usage: | ||
```bash | ||
python train_model.py /path/to/config.json |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here
|
||
### 4. Prediction using Pretrained Model | ||
|
||
The script `predict_with_model.py` uses a pretrained classification model to make predictions on input raster tiles. It saves the binary and probability prediction outputs. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be prediction.py
?
|
||
Usage: | ||
```bash | ||
python predict_with_model.py /path/to/config.json |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here
gdf = gpd.read_file(config['Paths']['shapefile_path']) | ||
new_df1 = gdf.drop(columns=config['ColumnsToDrop']) | ||
|
||
new_df1 = new_df1.rename(columns=lambda col: extract_number(col)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is extract_number
defined in pycaret.classification
? Or elsewhere? I can't see it defined in your code, or in pycaret 🤔 :
https://github.com/pycaret/pycaret/blob/master/pycaret/classification/__init__.py
https://pycaret.readthedocs.io/en/stable/api/classification.html
new_df1 = new_df1.rename(columns=lambda col: extract_number(col)) | ||
|
||
# Load global statistics from cache | ||
global_stats_dict = {band_name: cache_global_stats(band_name, config['Paths']['pickle_dir']) for band_name in band_names} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here, I think that cache_global_stats
is defined in prediction.py
but is used here (without being imported). Since it's only used in training.py
, I think that this would be the best file to define it in.
global_stats_dict = {band_name: cache_global_stats(band_name, config['Paths']['pickle_dir']) for band_name in band_names} | ||
|
||
# Process bands in the dataframe | ||
new_df2 = process_bands_in_dataframe(new_df1, band_names, global_stats_dict) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similarly, where is process_bands_in_dataframe
defined?
I have added necessary scripts for Machine Learning pipeline. Please review and let me know if any updates or changes required .