Path: blob/main/transformers_doc/ko/tensorflow/quicktour.ipynb
5906 views
๋๋ฌ๋ณด๊ธฐ [[quick-tour]]
๐ค Transformers๋ฅผ ์์ํด๋ณด์ธ์! ๊ฐ๋ฐํด๋ณธ ์ ์ด ์๋๋ผ๋ ์ฝ๊ฒ ์ฝ์ ์ ์๋๋ก ์ฐ์ธ ์ด ๊ธ์ pipeline
์ ์ฌ์ฉํ์ฌ ์ถ๋ก ํ๊ณ , ์ฌ์ ํ์ต๋ ๋ชจ๋ธ๊ณผ ์ ์ฒ๋ฆฌ๊ธฐ๋ฅผ AutoClass๋ก ๋ก๋ํ๊ณ , PyTorch ๋๋ TensorFlow๋ก ๋ชจ๋ธ์ ๋น ๋ฅด๊ฒ ํ์ต์ํค๋ ๋ฐฉ๋ฒ์ ์๊ฐํด ๋๋ฆด ๊ฒ์
๋๋ค. ๋ณธ ๊ฐ์ด๋์์ ์๊ฐ๋๋ ๊ฐ๋
์ (ํนํ ์ด๋ณด์์ ๊ด์ ์ผ๋ก) ๋ ์น์ ํ๊ฒ ์ ํ๊ณ ์ถ๋ค๋ฉด, ํํ ๋ฆฌ์ผ์ด๋ ์ฝ์ค๋ฅผ ์ฐธ์กฐํ๊ธฐ๋ฅผ ๊ถ์ฅํฉ๋๋ค.
์์ํ๊ธฐ ์ ์ ํ์ํ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๊ฐ ๋ชจ๋ ์ค์น๋์ด ์๋์ง ํ์ธํ์ธ์:
๋ํ ์ ํธํ๋ ๋จธ์ ๋ฌ๋ ํ๋ ์์ํฌ๋ฅผ ์ค์นํด์ผ ํฉ๋๋ค:
ํ์ดํ๋ผ์ธ [[pipeline]]
pipeline
์ ์ฌ์ ํ๋ จ๋ ๋ชจ๋ธ๋ก ์ถ๋ก ํ๊ธฐ์ ๊ฐ์ฅ ์ฝ๊ณ ๋น ๋ฅธ ๋ฐฉ๋ฒ์
๋๋ค. pipeline()
์ ์ฌ๋ฌ ๋ชจ๋ฌ๋ฆฌํฐ์์ ๋ค์ํ ๊ณผ์
์ ์ฝ๊ฒ ์ฒ๋ฆฌํ ์ ์์ผ๋ฉฐ, ์๋ ํ์ ํ์๋ ๋ช ๊ฐ์ง ๊ณผ์
์ ๊ธฐ๋ณธ์ ์ผ๋ก ์ง์ํฉ๋๋ค:
์ฌ์ฉ ๊ฐ๋ฅํ ์์ ์ ์ ์ฒด ๋ชฉ๋ก์ Pipelines API ์ฐธ์กฐ๋ฅผ ํ์ธํ์ธ์.
ํ์คํฌ | ์ค๋ช | ๋ชจ๋ฌ๋ฆฌํฐ | ํ์ดํ๋ผ์ธ ID |
---|---|---|---|
ํ ์คํธ ๋ถ๋ฅ | ํ ์คํธ์ ์๋ง์ ๋ ์ด๋ธ ๋ถ์ด๊ธฐ | ์์ฐ์ด ์ฒ๋ฆฌ(NLP) | pipeline(task="sentiment-analysis") |
ํ ์คํธ ์์ฑ | ์ฃผ์ด์ง ๋ฌธ์์ด ์ ๋ ฅ๊ณผ ์ด์ด์ง๋ ํ ์คํธ ์์ฑํ๊ธฐ | ์์ฐ์ด ์ฒ๋ฆฌ(NLP) | pipeline(task="text-generation") |
๊ฐ์ฒด๋ช ์ธ์ | ๋ฌธ์์ด์ ๊ฐ ํ ํฐ๋ง๋ค ์๋ง์ ๋ ์ด๋ธ ๋ถ์ด๊ธฐ (์ธ๋ฌผ, ์กฐ์ง, ์ฅ์ ๋ฑ๋ฑ) | ์์ฐ์ด ์ฒ๋ฆฌ(NLP) | pipeline(task="ner") |
์ง์์๋ต | ์ฃผ์ด์ง ๋ฌธ๋งฅ๊ณผ ์ง๋ฌธ์ ๋ฐ๋ผ ์ฌ๋ฐ๋ฅธ ๋๋ตํ๊ธฐ | ์์ฐ์ด ์ฒ๋ฆฌ(NLP) | pipeline(task="question-answering") |
๋น์นธ ์ฑ์ฐ๊ธฐ | ๋ฌธ์์ด์ ๋น์นธ์ ์๋ง์ ํ ํฐ ๋ง์ถ๊ธฐ | ์์ฐ์ด ์ฒ๋ฆฌ(NLP) | pipeline(task="fill-mask") |
์์ฝ | ํ ์คํธ๋ ๋ฌธ์๋ฅผ ์์ฝํ๊ธฐ | ์์ฐ์ด ์ฒ๋ฆฌ(NLP) | pipeline(task="summarization") |
๋ฒ์ญ | ํ ์คํธ๋ฅผ ํ ์ธ์ด์์ ๋ค๋ฅธ ์ธ์ด๋ก ๋ฒ์ญํ๊ธฐ | ์์ฐ์ด ์ฒ๋ฆฌ(NLP) | pipeline(task="translation") |
์ด๋ฏธ์ง ๋ถ๋ฅ | ์ด๋ฏธ์ง์ ์๋ง์ ๋ ์ด๋ธ ๋ถ์ด๊ธฐ | ์ปดํจํฐ ๋น์ (CV) | pipeline(task="image-classification") |
์ด๋ฏธ์ง ๋ถํ | ์ด๋ฏธ์ง์ ํฝ์ ๋ง๋ค ๋ ์ด๋ธ ๋ถ์ด๊ธฐ(์๋งจํฑ, ํ๋ํฑ ๋ฐ ์ธ์คํด์ค ๋ถํ ํฌํจ) | ์ปดํจํฐ ๋น์ (CV) | pipeline(task="image-segmentation") |
๊ฐ์ฒด ํ์ง | ์ด๋ฏธ์ง ์ ๊ฐ์ฒด์ ๊ฒฝ๊ณ ์์๋ฅผ ๊ทธ๋ฆฌ๊ณ ํด๋์ค๋ฅผ ์์ธกํ๊ธฐ | ์ปดํจํฐ ๋น์ (CV) | pipeline(task="object-detection") |
์ค๋์ค ๋ถ๋ฅ | ์ค๋์ค ํ์ผ์ ์๋ง์ ๋ ์ด๋ธ ๋ถ์ด๊ธฐ | ์ค๋์ค | pipeline(task="audio-classification") |
์๋ ์์ฑ ์ธ์ | ์ค๋์ค ํ์ผ ์ ์์ฑ์ ํ ์คํธ๋ก ๋ฐ๊พธ๊ธฐ | ์ค๋์ค | pipeline(task="automatic-speech-recognition") |
์๊ฐ ์ง์์๋ต | ์ฃผ์ด์ง ์ด๋ฏธ์ง์ ์ง๋ฌธ์ ๋ํด ์ฌ๋ฐ๋ฅด๊ฒ ๋๋ตํ๊ธฐ | ๋ฉํฐ๋ชจ๋ฌ | pipeline(task="vqa") |
๋ฌธ์ ์ง์์๋ต | ์ฃผ์ด์ง ๋ฌธ์์ ์ง๋ฌธ์ ๋ํด ์ฌ๋ฐ๋ฅด๊ฒ ๋๋ตํ๊ธฐ | ๋ฉํฐ๋ชจ๋ฌ | pipeline(task="document-question-answering") |
์ด๋ฏธ์ง ์บก์ ๋ฌ๊ธฐ | ์ฃผ์ด์ง ์ด๋ฏธ์ง์ ์บก์ ์์ฑํ๊ธฐ | ๋ฉํฐ๋ชจ๋ฌ | pipeline(task="image-to-text") |
๋จผ์ pipeline()
์ ์ธ์คํด์ค๋ฅผ ์์ฑํ๊ณ ์ฌ์ฉํ ์์
์ ์ง์ ํฉ๋๋ค. ์ด ๊ฐ์ด๋์์๋ ๊ฐ์ ๋ถ์์ ์ํด pipeline()
์ ์ฌ์ฉํ๋ ์์ ๋ฅผ ๋ณด์ฌ๋๋ฆฌ๊ฒ ์ต๋๋ค:
pipeline()
์ ๊ฐ์ ๋ถ์์ ์ํ ์ฌ์ ํ๋ จ๋ ๋ชจ๋ธ๊ณผ ํ ํฌ๋์ด์ ๋ฅผ ์๋์ผ๋ก ๋ค์ด๋ก๋ํ๊ณ ์บ์ํฉ๋๋ค. ์ด์ classifier
๋ฅผ ๋์ ํ
์คํธ์ ์ฌ์ฉํ ์ ์์ต๋๋ค:
๋ง์ฝ ์
๋ ฅ์ด ์ฌ๋ฌ ๊ฐ ์๋ ๊ฒฝ์ฐ, ์
๋ ฅ์ ๋ฆฌ์คํธ๋ก pipeline()
์ ์ ๋ฌํ์ฌ, ์ฌ์ ํ๋ จ๋ ๋ชจ๋ธ์ ์ถ๋ ฅ์ ๋์
๋๋ฆฌ๋ก ์ด๋ฃจ์ด์ง ๋ฆฌ์คํธ ํํ๋ก ๋ฐ์ ์ ์์ต๋๋ค:
pipeline()
์ ์ฃผ์ด์ง ๊ณผ์
์ ๊ด๊ณ์์ด ๋ฐ์ดํฐ์
์ ๋ถ๋ฅผ ์ํํ ์๋ ์์ต๋๋ค. ์ด ์์ ์์๋ ์๋ ์์ฑ ์ธ์์ ๊ณผ์
์ผ๋ก ์ ํํด ๋ณด๊ฒ ์ต๋๋ค:
๋ฐ์ดํฐ์ ์ ๋ก๋ํ ์ฐจ๋ก์ ๋๋ค. (์์ธํ ๋ด์ฉ์ ๐ค Datasets ์์ํ๊ธฐ์ ์ฐธ์กฐํ์ธ์) ์ฌ๊ธฐ์์๋ MInDS-14 ๋ฐ์ดํฐ์ ์ ๋ก๋ํ๊ฒ ์ต๋๋ค:
๋ฐ์ดํฐ์
์ ์ํ๋ง ๋ ์ดํธ๊ฐ ๊ธฐ์กด ๋ชจ๋ธ์ธ facebook/wav2vec2-base-960h
์ ํ๋ จ ๋น์ ์ํ๋ง ๋ ์ดํธ์ ์ผ์นํ๋์ง ํ์ธํด์ผ ํฉ๋๋ค:
"audio"
์ด์ ํธ์ถํ๋ฉด ์๋์ผ๋ก ์ค๋์ค ํ์ผ์ ๊ฐ์ ธ์์ ๋ฆฌ์ํ๋งํฉ๋๋ค. ์ฒซ 4๊ฐ ์ํ์์ ์์ ์จ์ด๋ธํผ ๋ฐฐ์ด์ ์ถ์ถํ๊ณ ํ์ดํ๋ผ์ธ์ ๋ฆฌ์คํธ๋ก ์ ๋ฌํ์ธ์:
์์ฑ์ด๋ ๋น์ ๊ณผ ๊ฐ์ด ์ ๋ ฅ์ด ํฐ ๋๊ท๋ชจ ๋ฐ์ดํฐ์ ์ ๊ฒฝ์ฐ, ๋ชจ๋ ์ ๋ ฅ์ ๋ฉ๋ชจ๋ฆฌ์ ๋ก๋ํ๋ ค๋ฉด ๋ฆฌ์คํธ ๋์ ์ ๋๋ ์ดํฐ ํํ๋ก ์ ๋ฌํด์ผ ํฉ๋๋ค. ์์ธํ ๋ด์ฉ์ Pipelines API ์ฐธ์กฐ๋ฅผ ํ์ธํ์ธ์.
ํ์ดํ๋ผ์ธ์์ ๋ค๋ฅธ ๋ชจ๋ธ๊ณผ ํ ํฌ๋์ด์ ์ฌ์ฉํ๊ธฐ [[use-another-model-and-tokenizer-in-the-pipeline]]
pipeline()
์ Hub์ ๋ชจ๋ ๋ชจ๋ธ์ ์ฌ์ฉํ ์ ์๊ธฐ ๋๋ฌธ์, pipeline()
์ ๋ค๋ฅธ ์ฉ๋์ ๋ง๊ฒ ์ฝ๊ฒ ์์ ํ ์ ์์ต๋๋ค. ์๋ฅผ ๋ค์ด, ํ๋์ค์ด ํ
์คํธ๋ฅผ ์ฒ๋ฆฌํ ์ ์๋ ๋ชจ๋ธ์ ์ฌ์ฉํ๊ธฐ ์ํด์ Hub์ ํ๊ทธ๋ฅผ ์ฌ์ฉํ์ฌ ์ ์ ํ ๋ชจ๋ธ์ ํํฐ๋งํ๋ฉด ๋ฉ๋๋ค. ํํฐ๋ง๋ ๊ฒฐ๊ณผ์ ์์ ํญ๋ชฉ์ผ๋ก๋ ํ๋์ค์ด ํ
์คํธ์ ์ฌ์ฉํ ์ ์๋ ๋ค๊ตญ์ด BERT ๋ชจ๋ธ์ด ๋ฐํ๋ฉ๋๋ค:
AutoModelForSequenceClassification๊ณผ AutoTokenizer๋ฅผ ์ฌ์ฉํ์ฌ ์ฌ์ ํ๋ จ๋ ๋ชจ๋ธ๊ณผ ๊ด๋ จ๋ ํ ํฌ๋์ด์ ๋ฅผ ๋ก๋ํ์ธ์ (๋ค์ ์น์
์์ AutoClass
์ ๋ํด ๋ ์์ธํ ์์๋ณด๊ฒ ์ต๋๋ค):
pipeline()
์์ ๋ชจ๋ธ๊ณผ ํ ํฌ๋์ด์ ๋ฅผ ์ง์ ํ๋ฉด, ์ด์ classifier
๋ฅผ ํ๋์ค์ด ํ
์คํธ์ ์ ์ฉํ ์ ์์ต๋๋ค:
๋ง๋ ํ ๋ชจ๋ธ์ ์ฐพ์ ์ ์๋ ๊ฒฝ์ฐ ๋ฐ์ดํฐ๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ์ฌ์ ํ๋ จ๋ ๋ชจ๋ธ์ ๋ฏธ์ธ์กฐ์ ํด์ผ ํฉ๋๋ค. ๋ฏธ์ธ์กฐ์ ๋ฐฉ๋ฒ์ ๋ํ ์์ธํ ๋ด์ฉ์ ๋ฏธ์ธ์กฐ์ ํํ ๋ฆฌ์ผ์ ์ฐธ์กฐํ์ธ์. ์ฌ์ ํ๋ จ๋ ๋ชจ๋ธ์ ๋ฏธ์ธ์กฐ์ ํ ํ์๋ ๋ชจ๋ธ์ Hub์ ์ปค๋ฎค๋ํฐ์ ๊ณต์ ํ์ฌ ๋จธ์ ๋ฌ๋ ๋ฏผ์ฃผํ์ ๊ธฐ์ฌํด์ฃผ์ธ์! ๐ค
AutoClass [[autoclass]]
AutoModelForSequenceClassification๊ณผ AutoTokenizer ํด๋์ค๋ ์์์ ๋ค๋ฃฌ pipeline()
์ ๊ธฐ๋ฅ์ ๊ตฌํํ๋ ๋ฐ ์ฌ์ฉ๋ฉ๋๋ค. AutoClass๋ ์ฌ์ ํ๋ จ๋ ๋ชจ๋ธ์ ์ํคํ
์ฒ๋ฅผ ์ด๋ฆ์ด๋ ๊ฒฝ๋ก์์ ์๋์ผ๋ก ๊ฐ์ ธ์ค๋ '๋ฐ๋ก๊ฐ๊ธฐ'์
๋๋ค. ๊ณผ์
์ ์ ํฉํ AutoClass
๋ฅผ ์ ํํ๊ณ ํด๋น ์ ์ฒ๋ฆฌ ํด๋์ค๋ฅผ ์ ํํ๊ธฐ๋ง ํ๋ฉด ๋ฉ๋๋ค.
์ด์ ์น์
์ ์์ ๋ก ๋์๊ฐ์ pipeline()
์ ๊ฒฐ๊ณผ๋ฅผ AutoClass
๋ฅผ ํ์ฉํด ๋ณต์ ํ๋ ๋ฐฉ๋ฒ์ ์ดํด๋ณด๊ฒ ์ต๋๋ค.
AutoTokenizer [[autotokenizer]]
ํ ํฌ๋์ด์ ๋ ํ ์คํธ๋ฅผ ๋ชจ๋ธ์ ์ ๋ ฅ์ผ๋ก ์ฌ์ฉํ๊ธฐ ์ํด ์ซ์ ๋ฐฐ์ด ํํ๋ก ์ ์ฒ๋ฆฌํ๋ ์ญํ ์ ๋ด๋นํฉ๋๋ค. ํ ํฐํ ๊ณผ์ ์๋ ๋จ์ด๋ฅผ ์ด๋์์ ๋์์ง, ์ด๋ ์์ค๊น์ง ๋๋์ง์ ๊ฐ์ ์ฌ๋ฌ ๊ท์น๋ค์ด ์์ต๋๋ค (ํ ํฐํ์ ๋ํ ์์ธํ ๋ด์ฉ์ ํ ํฌ๋์ด์ ์์ฝ์ ์ฐธ์กฐํ์ธ์). ๊ฐ์ฅ ์ค์ํ ์ ์ ๋ชจ๋ธ์ด ์ฌ์ ํ๋ จ๋ ๋ชจ๋ธ๊ณผ ๋์ผํ ํ ํฐํ ๊ท์น์ ์ฌ์ฉํ๋๋ก ๋์ผํ ๋ชจ๋ธ ์ด๋ฆ์ผ๋ก ํ ํฌ๋์ด์ ๋ฅผ ์ธ์คํด์คํํด์ผ ํ๋ค๋ ๊ฒ์ ๋๋ค.
AutoTokenizer๋ก ํ ํฌ๋์ด์ ๋ฅผ ๋ก๋ํ์ธ์:
ํ ์คํธ๋ฅผ ํ ํฌ๋์ด์ ์ ์ ๋ฌํ์ธ์:
ํ ํฌ๋์ด์ ๋ ๋ค์์ ํฌํจํ ๋์ ๋๋ฆฌ๋ฅผ ๋ฐํํฉ๋๋ค:
input_ids: ํ ํฐ์ ์ซ์ ํํ.
attention_mask: ์ด๋ค ํ ํฐ์ ์ฃผ์๋ฅผ ๊ธฐ์ธ์ฌ์ผ ํ๋์ง๋ฅผ ๋ํ๋ ๋๋ค.
ํ ํฌ๋์ด์ ๋ ์ ๋ ฅ์ ๋ฆฌ์คํธ ํํ๋ก๋ ๋ฐ์ ์ ์์ผ๋ฉฐ, ํ ์คํธ๋ฅผ ํจ๋ฉํ๊ณ ์๋ผ๋ด์ด ์ผ์ ํ ๊ธธ์ด์ ๋ฌถ์์ ๋ฐํํ ์๋ ์์ต๋๋ค:
์ ์ฒ๋ฆฌ ํํ ๋ฆฌ์ผ์ ์ฐธ์กฐํ์๋ฉด ํ ํฐํ์ ๋ํ ์์ธํ ์ค๋ช ๊ณผ ํจ๊ป ์ด๋ฏธ์ง, ์ค๋์ค์ ๋ฉํฐ๋ชจ๋ฌ ์ ๋ ฅ์ ์ ์ฒ๋ฆฌํ๊ธฐ ์ํ AutoImageProcessor์ AutoFeatureExtractor, AutoProcessor์ ์ฌ์ฉ๋ฐฉ๋ฒ๋ ์ ์ ์์ต๋๋ค.
AutoModel [[automodel]]
๐ค Transformers๋ ์ฌ์ ํ๋ จ๋ ์ธ์คํด์ค๋ฅผ ๊ฐ๋จํ๊ณ ํตํฉ๋ ๋ฐฉ๋ฒ์ผ๋ก ๋ก๋ํ ์ ์์ต๋๋ค. ์ฆ, AutoTokenizer์ฒ๋ผ AutoModel์ ๋ก๋ํ ์ ์์ต๋๋ค. ์ ์ผํ ์ฐจ์ด์ ์ ๊ณผ์ ์ ์๋ง์ AutoModel์ ์ ํํด์ผ ํ๋ค๋ ์ ์ ๋๋ค. ํ ์คํธ (๋๋ ์ํ์ค) ๋ถ๋ฅ์ ๊ฒฝ์ฐ AutoModelForSequenceClassification์ ๋ก๋ํด์ผ ํฉ๋๋ค:
AutoModel ํด๋์ค์์ ์ง์ํ๋ ๊ณผ์ ์ ๋ํด์๋ ๊ณผ์ ์์ฝ์ ์ฐธ์กฐํ์ธ์.
์ด์ ์ ์ฒ๋ฆฌ๋ ์
๋ ฅ ๋ฌถ์์ ์ง์ ๋ชจ๋ธ์ ์ ๋ฌํด์ผ ํฉ๋๋ค. ์๋์ฒ๋ผ **
๋ฅผ ์์ ๋ถ์ฌ ๋์
๋๋ฆฌ๋ฅผ ํ์ด์ฃผ๋ฉด ๋ฉ๋๋ค:
๋ชจ๋ธ์ ์ต์ข
ํ์ฑํ ํจ์ ์ถ๋ ฅ์ logits
์์ฑ์ ๋ด๊ฒจ์์ต๋๋ค. logits
์ softmax ํจ์๋ฅผ ์ ์ฉํ์ฌ ํ๋ฅ ์ ์ป์ ์ ์์ต๋๋ค:
๋ชจ๋ ๐ค Transformers ๋ชจ๋ธ(PyTorch ๋๋ TensorFlow)์ (softmax์ ๊ฐ์) ์ต์ข ํ์ฑํ ํจ์ ์ด์ ์ ํ ์๋ฅผ ์ถ๋ ฅํฉ๋๋ค. ์๋ํ๋ฉด ์ต์ข ํ์ฑํ ํจ์์ ์ถ๋ ฅ์ ์ข ์ข ์์ค ํจ์ ์ถ๋ ฅ๊ณผ ๊ฒฐํฉ๋๊ธฐ ๋๋ฌธ์ ๋๋ค. ๋ชจ๋ธ ์ถ๋ ฅ์ ํน์ํ ๋ฐ์ดํฐ ํด๋์ค์ด๋ฏ๋ก IDE์์ ์๋ ์์ฑ๋ฉ๋๋ค. ๋ชจ๋ธ ์ถ๋ ฅ์ ํํ์ด๋ ๋์ ๋๋ฆฌ์ฒ๋ผ ๋์ํ๋ฉฐ (์ ์, ์ฌ๋ผ์ด์ค ๋๋ ๋ฌธ์์ด๋ก ์ธ๋ฑ์ฑ ๊ฐ๋ฅ), None์ธ ์์ฑ์ ๋ฌด์๋ฉ๋๋ค.
๋ชจ๋ธ ์ ์ฅํ๊ธฐ [[save-a-model]]
๋ฏธ์ธ์กฐ์ ๋ ๋ชจ๋ธ์ ํ ํฌ๋์ด์ ์ ํจ๊ป ์ ์ฅํ๋ ค๋ฉด PreTrainedModel.save_pretrained()๋ฅผ ์ฌ์ฉํ์ธ์:
๋ชจ๋ธ์ ๋ค์ ์ฌ์ฉํ๋ ค๋ฉด PreTrainedModel.from_pretrained()๋ก ๋ชจ๋ธ์ ๋ค์ ๋ก๋ํ์ธ์:
๐ค Transformers์ ๋ฉ์ง ๊ธฐ๋ฅ ์ค ํ๋๋ ๋ชจ๋ธ์ PyTorch ๋๋ TensorFlow ๋ชจ๋ธ๋ก ์ ์ฅํด๋๋ค๊ฐ ๋ค๋ฅธ ํ๋ ์์ํฌ๋ก ๋ค์ ๋ก๋ํ ์ ์๋ ์ ์
๋๋ค. from_pt
๋๋ from_tf
๋งค๊ฐ๋ณ์๋ฅผ ์ฌ์ฉํ์ฌ ๋ชจ๋ธ์ ํ ํ๋ ์์ํฌ์์ ๋ค๋ฅธ ํ๋ ์์ํฌ๋ก ๋ณํํ ์ ์์ต๋๋ค:
์ปค์คํ ๋ชจ๋ธ ๊ตฌ์ถํ๊ธฐ [[custom-model-builds]]
๋ชจ๋ธ์ ๊ตฌ์ฑ ํด๋์ค๋ฅผ ์์ ํ์ฌ ๋ชจ๋ธ์ ๊ตฌ์กฐ๋ฅผ ๋ฐ๊ฟ ์ ์์ต๋๋ค. (์๋์ธต์ด๋ ์ดํ ์ ํค๋์ ์์ ๊ฐ์) ๋ชจ๋ธ์ ์์ฑ์ ๊ตฌ์ฑ์์ ์ง์ ๋๊ธฐ ๋๋ฌธ์ ๋๋ค. ์ปค์คํ ๊ตฌ์ฑ ํด๋์ค๋ก ๋ชจ๋ธ์ ๋ง๋ค๋ฉด ์ฒ์๋ถํฐ ์์ํด์ผ ํฉ๋๋ค. ๋ชจ๋ธ ์์ฑ์ ๋ฌด์์๋ก ์ด๊ธฐํ๋๋ฏ๋ก ์๋ฏธ ์๋ ๊ฒฐ๊ณผ๋ฅผ ์ป์ผ๋ ค๋ฉด ๋จผ์ ๋ชจ๋ธ์ ํ๋ จ์์ผ์ผ ํฉ๋๋ค.
๋จผ์ AutoConfig๋ฅผ ๊ฐ์ ธ์ค๊ณ ์์ ํ๊ณ ์ถ์ ์ฌ์ ํ์ต๋ ๋ชจ๋ธ์ ๋ก๋ํ์ธ์. AutoConfig.from_pretrained() ๋ด๋ถ์์ (์ดํ ์ ํค๋ ์์ ๊ฐ์ด) ๋ณ๊ฒฝํ๋ ค๋ ์์ฑ๋ฅผ ์ง์ ํ ์ ์์ต๋๋ค:
AutoModel.from_config()๋ฅผ ์ฌ์ฉํ์ฌ ๋ฐ๊พผ ๊ตฌ์ฑ๋๋ก ๋ชจ๋ธ์ ์์ฑํ์ธ์:
์ปค์คํ ๊ตฌ์ฑ์ ๋ํ ์์ธํ ๋ด์ฉ์ ์ปค์คํ ์ํคํ ์ฒ ๋ง๋ค๊ธฐ ๊ฐ์ด๋๋ฅผ ํ์ธํ์ธ์.
Trainer - PyTorch์ ์ต์ ํ๋ ํ๋ จ ๋ฃจํ [[trainer-a-pytorch-optimized-training-loop]]
๋ชจ๋ ๋ชจ๋ธ์ torch.nn.Module
์ด๋ฏ๋ก ์ผ๋ฐ์ ์ธ ํ๋ จ ๋ฃจํ์์ ์ฌ์ฉํ ์ ์์ต๋๋ค. ์ง์ ํ๋ จ ๋ฃจํ๋ฅผ ์์ฑํ ์๋ ์์ง๋ง, ๐ค Transformers๋ PyTorch๋ฅผ ์ํ Trainer ํด๋์ค๋ฅผ ์ ๊ณตํฉ๋๋ค. ์ด ํด๋์ค์๋ ๊ธฐ๋ณธ ํ๋ จ ๋ฃจํ๊ฐ ํฌํจ๋์ด ์์ผ๋ฉฐ ๋ถ์ฐ ํ๋ จ, ํผํฉ ์ ๋ฐ๋ ๋ฑ๊ณผ ๊ฐ์ ๊ธฐ๋ฅ์ ์ถ๊ฐ๋ก ์ ๊ณตํฉ๋๋ค.
๊ณผ์ ์ ๋ฐ๋ผ ๋ค๋ฅด์ง๋ง ์ผ๋ฐ์ ์ผ๋ก Trainer์ ๋ค์ ๋งค๊ฐ๋ณ์๋ฅผ ์ ๋ฌํฉ๋๋ค:
PreTrainedModel ๋๋
torch.nn.Module
๋ก ์์ํฉ๋๋ค:TrainingArguments๋ ํ์ต๋ฅ , ๋ฐฐ์น ํฌ๊ธฐ, ํ๋ จํ ์ํฌํฌ ์์ ๊ฐ์ ๋ชจ๋ธ ํ์ดํผํ๋ผ๋ฏธํฐ๋ฅผ ํฌํจํฉ๋๋ค. ํ๋ จ ์ธ์๋ฅผ ์ง์ ํ์ง ์์ผ๋ฉด ๊ธฐ๋ณธ๊ฐ์ด ์ฌ์ฉ๋ฉ๋๋ค:
ํ ํฌ๋์ด์ , ์ด๋ฏธ์ง ํ๋ก์ธ์, ํน์ง ์ถ์ถ๊ธฐ(feature extractor) ๋๋ ํ๋ก์ธ์์ ์ ์ฒ๋ฆฌ ํด๋์ค๋ฅผ ๋ก๋ํ์ธ์:
๋ฐ์ดํฐ์ ์ ๋ก๋ํ์ธ์:
๋ฐ์ดํฐ์ ์ ํ ํฐํํ๋ ํจ์๋ฅผ ์์ฑํ์ธ์:
๊ทธ๋ฆฌ๊ณ
map
๋ก ๋ฐ์ดํฐ์ ์ ์ฒด์ ์ ์ฉํ์ธ์:DataCollatorWithPadding์ ์ฌ์ฉํ์ฌ ๋ฐ์ดํฐ์ ์ ํ๋ณธ ๋ฌถ์์ ๋ง๋์ธ์:
์ด์ ์์ ๋ชจ๋ ํด๋์ค๋ฅผ Trainer๋ก ๋ชจ์ผ์ธ์:
์ค๋น๊ฐ ๋์์ผ๋ฉด train()์ ํธ์ถํ์ฌ ํ๋ จ์ ์์ํ์ธ์:
๋ฒ์ญ์ด๋ ์์ฝ๊ณผ ๊ฐ์ด ์ํ์ค-์ํ์ค ๋ชจ๋ธ์ ์ฌ์ฉํ๋ ๊ณผ์ ์๋ Seq2SeqTrainer ๋ฐ Seq2SeqTrainingArguments ํด๋์ค๋ฅผ ์ฌ์ฉํ์ธ์.
Trainer ๋ด์ ๋ฉ์๋๋ฅผ ์๋ธํด๋์คํํ์ฌ ํ๋ จ ๋ฃจํ๋ฅผ ๋ฐ๊ฟ ์๋ ์์ต๋๋ค. ์ด๋ฌ๋ฉด ์์ค ํจ์, ์ตํฐ๋ง์ด์ , ์ค์ผ์ค๋ฌ์ ๊ฐ์ ๊ธฐ๋ฅ ๋ํ ๋ฐ๊ฟ ์ ์๊ฒ ๋ฉ๋๋ค. ๋ณ๊ฒฝ ๊ฐ๋ฅํ ๋ฉ์๋์ ๋ํด์๋ Trainer ๋ฌธ์๋ฅผ ์ฐธ๊ณ ํ์ธ์.
ํ๋ จ ๋ฃจํ๋ฅผ ์์ ํ๋ ๋ค๋ฅธ ๋ฐฉ๋ฒ์ Callbacks๋ฅผ ์ฌ์ฉํ๋ ๊ฒ์ ๋๋ค. Callbacks๋ก ๋ค๋ฅธ ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ ํตํฉํ๊ณ , ํ๋ จ ๋ฃจํ๋ฅผ ์ฒดํฌํ์ฌ ์งํ ์ํฉ์ ๋ณด๊ณ ๋ฐ๊ฑฐ๋, ํ๋ จ์ ์กฐ๊ธฐ์ ์ค๋จํ ์ ์์ต๋๋ค. Callbacks์ ํ๋ จ ๋ฃจํ ์์ฒด๋ฅผ ๋ฐ๊พธ์ง๋ ์์ต๋๋ค. ์์ค ํจ์์ ๊ฐ์ ๊ฒ์ ๋ฐ๊พธ๋ ค๋ฉด Trainer๋ฅผ ์๋ธํด๋์คํํด์ผ ํฉ๋๋ค.
TensorFlow๋ก ํ๋ จ์ํค๊ธฐ [[train-with-tensorflow]]
๋ชจ๋ ๋ชจ๋ธ์ tf.keras.Model
์ด๋ฏ๋ก Keras API๋ฅผ ํตํด TensorFlow์์ ํ๋ จ์ํฌ ์ ์์ต๋๋ค. ๐ค Transformers๋ ๋ฐ์ดํฐ์
์ ์ฝ๊ฒ tf.data.Dataset
ํํ๋ก ์ฝ๊ฒ ๋ก๋ํ ์ ์๋ ~TFPreTrainedModel.prepare_tf_dataset
๋ฉ์๋๋ฅผ ์ ๊ณตํ๊ธฐ ๋๋ฌธ์, Keras์ compile
๋ฐ fit
๋ฉ์๋๋ก ๋ฐ๋ก ํ๋ จ์ ์์ํ ์ ์์ต๋๋ค.
TFPreTrainedModel
๋๋tf.keras.Model
๋ก ์์ํฉ๋๋ค:ํ ํฌ๋์ด์ , ์ด๋ฏธ์ง ํ๋ก์ธ์, ํน์ง ์ถ์ถ๊ธฐ(feature extractor) ๋๋ ํ๋ก์ธ์์ ๊ฐ์ ์ ์ฒ๋ฆฌ ํด๋์ค๋ฅผ ๋ก๋ํ์ธ์:
๋ฐ์ดํฐ์ ์ ํ ํฐํํ๋ ํจ์๋ฅผ ์์ฑํ์ธ์:
map
์ ์ฌ์ฉํ์ฌ ์ ์ฒด ๋ฐ์ดํฐ์ ์ ํ ํฐํ ํจ์๋ฅผ ์ ์ฉํ๊ณ , ๋ฐ์ดํฐ์ ๊ณผ ํ ํฌ๋์ด์ ๋ฅผ~TFPreTrainedModel.prepare_tf_dataset
์ ์ ๋ฌํ์ธ์. ๋ฐฐ์น ํฌ๊ธฐ๋ฅผ ๋ณ๊ฒฝํ๊ฑฐ๋ ๋ฐ์ดํฐ์ ์ ์์ ์๋ ์์ต๋๋ค:์ค๋น๋์์ผ๋ฉด
compile
๋ฐfit
๋ฅผ ํธ์ถํ์ฌ ํ๋ จ์ ์์ํ์ธ์. ๐ค Transformers์ ๋ชจ๋ ๋ชจ๋ธ์ ๊ณผ์ ๊ณผ ๊ด๋ จ๋ ๊ธฐ๋ณธ ์์ค ํจ์๋ฅผ ๊ฐ์ง๊ณ ์์ผ๋ฏ๋ก ๋ช ์์ ์ผ๋ก ์ง์ ํ์ง ์์๋ ๋ฉ๋๋ค:
๋ค์ ๋จ๊ณ๋ ๋ฌด์์ธ๊ฐ์? [[whats-next]]
๐ค Transformers ๋๋ฌ๋ณด๊ธฐ๋ฅผ ๋ชจ๋ ์ฝ์ผ์ จ๋ค๋ฉด, ๊ฐ์ด๋๋ฅผ ์ดํด๋ณด๊ณ ๋ ๊ตฌ์ฒด์ ์ธ ๊ฒ์ ์ํํ๋ ๋ฐฉ๋ฒ์ ์์๋ณด์ธ์. ์ด๋ฅผํ ๋ฉด ์ปค์คํ ๋ชจ๋ธ ๊ตฌ์ถํ๋ ๋ฐฉ๋ฒ, ๊ณผ์ ์ ์๋ง๊ฒ ๋ชจ๋ธ์ ๋ฏธ์ธ์กฐ์ ํ๋ ๋ฐฉ๋ฒ, ์คํฌ๋ฆฝํธ๋ก ๋ชจ๋ธ ํ๋ จํ๋ ๋ฐฉ๋ฒ ๋ฑ์ด ์์ต๋๋ค. ๐ค Transformers ํต์ฌ ๊ฐ๋ ์ ๋ํด ๋ ์์๋ณด๋ ค๋ฉด ์ปคํผ ํ ์ ๋ค๊ณ ๊ฐ๋ ๊ฐ์ด๋๋ฅผ ์ดํด๋ณด์ธ์!