diff --git a/pages/Image_prediction.py b/pages/Image_prediction.py index c6ae6e8..00e375a 100644 --- a/pages/Image_prediction.py +++ b/pages/Image_prediction.py @@ -13,92 +13,92 @@ import gdown import tempfile -def show_image_prediction(): - # Initialize session state attributes - if 'saved_predictions' not in st.session_state: - st.session_state.saved_predictions = [] - if 'predictions' not in st.session_state: - st.session_state.predictions = [] - if 'uploaded_images' not in st.session_state: - st.session_state.uploaded_images = [] - if 'model_temp_file' not in st.session_state: - st.session_state.model_temp_file = None - - # Define the model links - model_links = { - 'CNN': { - 'url': 'https://drive.google.com/uc?id=1mtDtPtM-E7y20LlFlEPn1UI20fykFL2P', - 'target_size': (128, 128) - }, - 'ResNet50': { - 'url': 'https://drive.google.com/uc?id=1F47j2nr5JSa09mBWtUsWOkQYl5pLixqG', - 'target_size': (260, 260) - }, - 'EfficientNet': { - 'url': 'https://drive.google.com/uc?id=10mXZQWQ1RyGx6BqEsaJdcv8NcSKv-v8A', - 'target_size': (260, 260) - }, - 'DenseNet': { - 'url': 'https://drive.google.com/uc?id=14-7XGitYTJTYAksSI-LPS-b_lW7Dn8aC', - 'target_size': (224, 224) - }, - 'VGG19': { - 'url': 'https://drive.google.com/uc?id=19o-JaeGBDXpITObAVkII2sNFYp3qmGoH', - 'target_size': (224, 224) - }, - } - - def load_existing_predictions(): - if os.path.exists('prediction_history.json'): - with open('prediction_history.json', 'r') as f: - return json.load(f) - return [] - - # Load existing predictions into session state - st.session_state.saved_predictions = load_existing_predictions() - - def save_predictions_to_history(uploaded_files, predictions, model_name): - prediction_data = [] - for i, uploaded_file in enumerate(uploaded_files): - actual = 'Cancer' if predictions[i][0] == 0 else 'Non Cancer' - prediction_data.append({ - 'file_name': uploaded_file.name, - 'model_used': model_name, - 'prediction': actual - }) - - st.session_state.saved_predictions.extend(prediction_data) - - with open('prediction_history.json', 'w') as f: - json.dump(st.session_state.saved_predictions, f, indent=4) - st.success("Predictions saved to history successfully.") - - cancer_warning_messages = [ - "Please consult a doctor immediately.", - "We recommend scheduling a medical check-up soon.", - "It's crucial to seek medical advice right away.", - "Contact your healthcare provider for further examination.", - "This result may be concerning. Please consult a specialist." - ] - - def download_and_load_model(model_url): - """Downloads and loads the model from the provided Google Drive URL.""" - if st.session_state.model_temp_file is None: - with tempfile.NamedTemporaryFile(suffix='.keras', delete=False) as tmp: - st.session_state.model_temp_file = tmp.name - # Show toast message for downloading model - st.toast("📥 Downloading model... Please wait.") - # Spinner for downloading the model - with st.spinner("Downloading the model..."): - gdown.download(model_url, st.session_state.model_temp_file, quiet=False) - - # Show toast message for download completion - st.toast("✅ Model download completed!") - - # Load the model from the temp file - model = load_model(st.session_state.model_temp_file) - return model +# Define the model links +model_links = { + 'CNN': { + 'url': 'https://drive.google.com/uc?id=1mtDtPtM-E7y20LlFlEPn1UI20fykFL2P', + 'target_size': (128, 128) + }, + 'ResNet50': { + 'url': 'https://drive.google.com/uc?id=1F47j2nr5JSa09mBWtUsWOkQYl5pLixqG', + 'target_size': (260, 260) + }, + 'EfficientNet': { + 'url': 'https://drive.google.com/uc?id=10mXZQWQ1RyGx6BqEsaJdcv8NcSKv-v8A', + 'target_size': (260, 260) + }, + 'DenseNet': { + 'url': 'https://drive.google.com/uc?id=14-7XGitYTJTYAksSI-LPS-b_lW7Dn8aC', + 'target_size': (224, 224) + }, + 'VGG19': { + 'url': 'https://drive.google.com/uc?id=19o-JaeGBDXpITObAVkII2sNFYp3qmGoH', + 'target_size': (224, 224) + }, +} + +# Initialize session state +if 'saved_predictions' not in st.session_state: + st.session_state.saved_predictions = [] +if 'predictions' not in st.session_state: + st.session_state.predictions = [] +if 'uploaded_images' not in st.session_state: + st.session_state.uploaded_images = [] +if 'model_temp_file' not in st.session_state: + st.session_state.model_temp_file = None + +def load_existing_predictions(): + if os.path.exists('prediction_history.json'): + with open('prediction_history.json', 'r') as f: + return json.load(f) + return [] + +# Load existing predictions into session state +st.session_state.saved_predictions = load_existing_predictions() + +def save_predictions_to_history(uploaded_files, predictions, model_name): + prediction_data = [] + for i, uploaded_file in enumerate(uploaded_files): + actual = 'Cancer' if predictions[i][0] == 0 else 'Non Cancer' + prediction_data.append({ + 'file_name': uploaded_file.name, + 'model_used': model_name, + 'prediction': actual + }) + + st.session_state.saved_predictions.extend(prediction_data) + + with open('prediction_history.json', 'w') as f: + json.dump(st.session_state.saved_predictions, f, indent=4) + st.success("Predictions saved to history successfully.") + +cancer_warning_messages = [ + "Please consult a doctor immediately.", + "We recommend scheduling a medical check-up soon.", + "It's crucial to seek medical advice right away.", + "Contact your healthcare provider for further examination.", + "This result may be concerning. Please consult a specialist." +] + +def download_and_load_model(model_url): + """Downloads and loads the model from the provided Google Drive URL.""" + if st.session_state.model_temp_file is None: + with tempfile.NamedTemporaryFile(suffix='.keras', delete=False) as tmp: + st.session_state.model_temp_file = tmp.name + # Show toast message for downloading model + st.toast("📥 Downloading model... Please wait.") + # Spinner for downloading the model + with st.spinner("Downloading the model..."): + gdown.download(model_url, st.session_state.model_temp_file, quiet=False) + + # Show toast message for download completion + st.toast("✅ Model download completed!") + + # Load the model from the temp file + model = load_model(st.session_state.model_temp_file) + return model +def show_image_prediction(): # Streamlit UI st.title('Oral Cancer Detection Model Evaluation')