forked from isaaccorley/torchrs
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
6271e9b
commit 46a649d
Showing
9 changed files
with
152 additions
and
37 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,29 +1,34 @@ | ||
import os | ||
|
||
import tifffile | ||
import torchvision.transforms as T | ||
from torchvision.datasets import ImageFolder | ||
|
||
from torchrs.transforms import ToTensor | ||
|
||
|
||
class EuroSATRGB(ImageFolder): | ||
|
||
def __init__( | ||
self, | ||
root: str, | ||
transforms: T.Compose | ||
root: str = ".data/eurosat-rgb", | ||
transform: T.Compose = T.Compose([T.ToTensor()]) | ||
): | ||
super().__init__( | ||
root=os.path.join(root, "2750"), | ||
transform=transforms | ||
transform=transform | ||
) | ||
|
||
|
||
class EuroSATMS(ImageFolder): | ||
|
||
def __init__( | ||
self, | ||
root: str, | ||
transforms: T.Compose | ||
root: str = ".data/eurosat-ms", | ||
transform: T.Compose = T.Compose([ToTensor()]) | ||
): | ||
super().__init__( | ||
root=os.path.join(root, "ds/images/remote_sensing/otherDatasets/sentinel_2/tif"), | ||
transform=transforms | ||
transform=transform, | ||
loader=tifffile.imread | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,44 +1,38 @@ | ||
import os | ||
import json | ||
from typing import List, Dict, Optional | ||
from typing import List, Dict | ||
|
||
import torch | ||
import torchvision.transforms as T | ||
from PIL import Image | ||
|
||
from torchrs.transforms import ToTensor | ||
|
||
|
||
class RSICD(torch.utils.data.Dataset): | ||
|
||
def __init__( | ||
self, | ||
root: str = ".data/rscid", | ||
annotations_path: Optional[str] = None, | ||
root: str = ".data/rsicd", | ||
split: str = "train", | ||
transforms: T.Compose = T.Compose([ToTensor()]) | ||
transform: T.Compose = T.Compose([T.ToTensor()]) | ||
): | ||
assert split in ["train", "val", "test"] | ||
self.root = root | ||
self.transforms = transforms | ||
|
||
|
||
self.annotations = self.load_annotations(annotations_path, split) | ||
print(f"RSICD {split} dataset loaded with {len(self.annotations)} annotations") | ||
self.transform = transform | ||
self.captions = self.load_captions(os.path.join(root, "dataset_rsicd.json"), split) | ||
|
||
def load_annotations(self, path: str, split: str) -> List[Dict]: | ||
@staticmethod | ||
def load_captions(path: str, split: str) -> List[Dict]: | ||
with open(path) as f: | ||
annotations = json.load(f)["images"] | ||
|
||
return [a for a in annotations if a["split"] == split] | ||
captions = json.load(f)["images"] | ||
return [c for c in captions if c["split"] == split] | ||
|
||
def __len__(self) -> int: | ||
return len(self.annotations) | ||
return len(self.captions) | ||
|
||
def __getitem__(self, idx: int) -> Dict: | ||
annotation = self.annotations[idx] | ||
path = os.path.join(self.root, annotation["filename"]) | ||
captions = self.captions[idx] | ||
path = os.path.join(self.root, "RSICD_images", captions["filename"]) | ||
x = Image.open(path).convert("RGB") | ||
x = self.transforms(x) | ||
captions = [sentence["raw"] for sentence in annotation["sentences"]] | ||
return dict(x=x, captions=captions) | ||
x = self.transform(x) | ||
sentences = [sentence["raw"] for sentence in captions["sentences"]] | ||
return dict(x=x, captions=sentences) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,3 @@ | ||
from .rams import RAMS | ||
|
||
__all__ = ["RAMS"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters